From 752a4562eb078a8e32ee14ecf4cf520ce53e8c8d Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 1 Sep 2023 12:58:44 +0530 Subject: [PATCH 01/70] Update data_loader.py --- research/SpreadGNN/data/data_loader.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/research/SpreadGNN/data/data_loader.py b/research/SpreadGNN/data/data_loader.py index efbc5ba995..b85a812fec 100644 --- a/research/SpreadGNN/data/data_loader.py +++ b/research/SpreadGNN/data/data_loader.py @@ -102,6 +102,22 @@ def create_non_uniform_split(args, idxs, client_number, is_train=True): def partition_data_by_sample_size( args, path, client_number, uniform=True, compact=True ): + """ + Partition dataset into multiple clients based on sample size. + + Args: + args (list): Arguments. + path (str): Path to the dataset. + client_number (int): Number of clients to partition the dataset into. + uniform (bool, optional): If True, create uniform partitions. If False, create non-uniform partitions. + compact (bool, optional): Whether to use compact representation. + + Returns: + tuple: A tuple containing global_data_dict and partition_dicts. + + global_data_dict (dict): A dictionary containing global datasets (train, val, test). + partition_dicts (list): A list of dictionaries containing partitioned datasets for each client. + """ ( train_adj_matrices, train_feature_matrices, From f7a560e204e29f0a1a5ea4cac8c76df7352ff4bc Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 12:15:05 +0530 Subject: [PATCH 02/70] code and docstring update --- research/SpreadGNN/data/data_loader.py | 163 ++++++++++++++++++ research/SpreadGNN/data/datasets.py | 58 ++++++- research/SpreadGNN/data/utils.py | 50 ++++++ research/SpreadGNN/model/gat_readout.py | 111 +++++++++++- research/SpreadGNN/model/sage_readout.py | 62 ++++++- .../SpreadGNN/trainer/gat_readout_trainer.py | 55 ++++++ .../trainer/gat_readout_trainer_regression.py | 65 +++++++ .../SpreadGNN/trainer/sage_readout_trainer.py | 54 ++++++ .../sage_readout_trainer_regression.py | 53 ++++++ 9 files changed, 662 insertions(+), 9 deletions(-) diff --git a/research/SpreadGNN/data/data_loader.py b/research/SpreadGNN/data/data_loader.py index b85a812fec..889da69c37 100644 --- a/research/SpreadGNN/data/data_loader.py +++ b/research/SpreadGNN/data/data_loader.py @@ -10,6 +10,21 @@ def get_data(path): + """ + Load data from the specified path. + + Args: + path (str): The path to the directory containing data files. + + Returns: + tuple: A tuple containing the following elements: + - adj_matrices (list): A list of adjacency matrices. + - feature_matrices (list): A list of feature matrices. + - labels (numpy.ndarray): An array of labels. + + Raises: + FileNotFoundError: If any of the required data files are not found. + """ with open(path + "/adjacency_matrices.pkl", "rb") as f: adj_matrices = pickle.load(f) @@ -21,6 +36,27 @@ def get_data(path): return adj_matrices, feature_matrices, labels def create_random_split(path): + """ + Create a random 80/10/10 split of data from the specified path. + + Args: + path (str): The path to the directory containing data files. + + Returns: + tuple: A tuple containing the following elements for training, validation, and testing sets: + - train_adj_matrices (list): A list of adjacency matrices for training. + - train_feature_matrices (list): A list of feature matrices for training. + - train_labels (list): A list of labels for training. + - val_adj_matrices (list): A list of adjacency matrices for validation. + - val_feature_matrices (list): A list of feature matrices for validation. + - val_labels (list): A list of labels for validation. + - test_adj_matrices (list): A list of adjacency matrices for testing. + - test_feature_matrices (list): A list of feature matrices for testing. + - test_labels (list): A list of labels for testing. + + Raises: + FileNotFoundError: If any of the required data files are not found. + """ adj_matrices, feature_matrices, labels = get_data(path) # Random 80/10/10 split as suggested in the MoleculeNet whitepaper @@ -74,6 +110,27 @@ def create_random_split(path): ) def create_non_uniform_split(args, idxs, client_number, is_train=True): + """ + Create a non-uniform split of data indices among clients based on the Dirichlet distribution. + + Args: + args: An object containing relevant parameters. + idxs (list): A list of data indices to be split. + client_number (int): The number of clients. + is_train (bool): A flag indicating whether the split is for training data. + + Returns: + list: A list of lists where each sublist contains data indices assigned to a client. + + Logging: + This function logs information about the data split process. + + Note: + This function relies on the `partition_class_samples_with_dirichlet_distribution` function. + + Raises: + None + """ logging.info("create_non_uniform_split------------------------------------------") N = len(idxs) alpha = args.partition_alpha @@ -249,6 +306,25 @@ def partition_data_by_sample_size( # For centralized training def get_dataloader(path, compact=True, normalize_features=False, normalize_adj=False): + """ + Get data loaders for training, validation, and testing sets. + + Args: + path (str): The path to the directory containing data files. + compact (bool, optional): Whether to use compact data format. Defaults to True. + normalize_features (bool, optional): Whether to normalize features. Defaults to False. + normalize_adj (bool, optional): Whether to normalize adjacency matrices. Defaults to False. + + Returns: + tuple: A tuple containing data loaders for training, validation, and testing sets. + + Note: + This function utilizes the `MoleculesDataset` class and data collators to create data loaders. + Each batch size is set to 1 to ensure that each batch represents an entire molecule. + + Raises: + None + """ ( train_adj_matrices, train_feature_matrices, @@ -318,6 +394,39 @@ def load_partition_data( normalize_features=False, normalize_adj=False, ): + """ + Load and partition data for federated learning among multiple clients. + + Args: + args: An object containing relevant parameters. + path (str): The path to the directory containing data files. + client_number (int): The number of clients. + uniform (bool, optional): Whether to use uniform data partitioning. Defaults to True. + global_test (bool, optional): Whether to use a global test dataset. Defaults to True. + compact (bool, optional): Whether to use compact data format. Defaults to True. + normalize_features (bool, optional): Whether to normalize features. Defaults to False. + normalize_adj (bool, optional): Whether to normalize adjacency matrices. Defaults to False. + + Returns: + tuple: A tuple containing information about the loaded data for federated learning. The tuple includes: + - train_data_num (int): Total number of training samples in the global dataset. + - val_data_num (int): Total number of validation samples in the global dataset. + - test_data_num (int): Total number of testing samples in the global dataset. + - train_data_global (data.DataLoader): DataLoader for the global training dataset. + - val_data_global (data.DataLoader): DataLoader for the global validation dataset. + - test_data_global (data.DataLoader): DataLoader for the global testing dataset. + - data_local_num_dict (dict): A dictionary mapping client IDs to the number of local samples. + - train_data_local_dict (dict): A dictionary mapping client IDs to their DataLoader for training. + - val_data_local_dict (dict): A dictionary mapping client IDs to their DataLoader for validation. + - test_data_local_dict (dict): A dictionary mapping client IDs to their DataLoader for testing. + + Note: + This function relies on data partitioning using the `partition_data_by_sample_size` function. + Each batch size is set to 1 to represent each molecule as a batch. + + Raises: + None + """ global_data_dict, partition_dicts = partition_data_by_sample_size( args, path, client_number, uniform, compact=compact ) @@ -415,6 +524,33 @@ def load_partition_data( def load_partition_data_distributed(process_id, path, client_number, uniform=True): + """ + Load and partition data for distributed federated learning. + + Args: + process_id (int): The ID of the current process. + path (str): The path to the directory containing data files. + client_number (int): The number of clients. + uniform (bool, optional): Whether to use uniform data partitioning. Defaults to True. + + Returns: + tuple: A tuple containing information about the loaded data for distributed federated learning. The tuple includes: + - train_data_num (int): Total number of training samples in the global dataset. + - train_data_global (data.DataLoader): DataLoader for the global training dataset (for process_id 0). + - val_data_global (data.DataLoader): DataLoader for the global validation dataset (for process_id 0). + - test_data_global (data.DataLoader): DataLoader for the global testing dataset (for process_id 0). + - local_data_num (int): Total number of local samples for the current process. + - train_data_local (data.DataLoader): DataLoader for the local training dataset (for process_id > 0). + - val_data_local (data.DataLoader): DataLoader for the local validation dataset (for process_id > 0). + - test_data_local (data.DataLoader): DataLoader for the local testing dataset (for process_id > 0). + + Note: + This function relies on data partitioning using the `partition_data_by_sample_size` function. + Each batch size is set to 1 to represent each molecule as a batch. + + Raises: + None + """ global_data_dict, partition_dicts = partition_data_by_sample_size( path, client_number, uniform ) @@ -490,6 +626,33 @@ def load_partition_data_distributed(process_id, path, client_number, uniform=Tru def load_moleculenet(args, dataset_name): + """ + Load and partition data for distributed federated learning. + + Args: + process_id (int): The ID of the current process. + path (str): The path to the directory containing data files. + client_number (int): The number of clients. + uniform (bool, optional): Whether to use uniform data partitioning. Defaults to True. + + Returns: + tuple: A tuple containing information about the loaded data for distributed federated learning. The tuple includes: + - train_data_num (int): Total number of training samples in the global dataset. + - train_data_global (data.DataLoader): DataLoader for the global training dataset (for process_id 0). + - val_data_global (data.DataLoader): DataLoader for the global validation dataset (for process_id 0). + - test_data_global (data.DataLoader): DataLoader for the global testing dataset (for process_id 0). + - local_data_num (int): Total number of local samples for the current process. + - train_data_local (data.DataLoader): DataLoader for the local training dataset (for process_id > 0). + - val_data_local (data.DataLoader): DataLoader for the local validation dataset (for process_id > 0). + - test_data_local (data.DataLoader): DataLoader for the local testing dataset (for process_id > 0). + + Note: + This function relies on data partitioning using the `partition_data_by_sample_size` function. + Each batch size is set to 1 to represent each molecule as a batch. + + Raises: + None + """ num_cats, feat_dim = 0, 0 if dataset_name not in ["sider", "tox21", "muv","qm8" ]: raise Exception("no such dataset!") diff --git a/research/SpreadGNN/data/datasets.py b/research/SpreadGNN/data/datasets.py index a76390403d..443cd76e09 100644 --- a/research/SpreadGNN/data/datasets.py +++ b/research/SpreadGNN/data/datasets.py @@ -13,12 +13,21 @@ # From GTTF, need to cite once paper is officially accepted to ICLR 2021 class CompactAdjacency: def __init__(self, adj, precomputed=None, subset=None): - """Constructs CompactAdjacency. + """ + Constructs a CompactAdjacency object. Args: - adj: scipy sparse matrix containing full adjacency. - precomputed: If given, must be a tuple (compact_adj, degrees). - In this case, adj must be None. If supplied, subset will be ignored. + adj: scipy sparse matrix containing the full adjacency. + precomputed: If given, must be a tuple (compact_adj, degrees). + In this case, adj must be None. If supplied, subset will be ignored. + subset: Optional set of node indices to consider in the adjacency matrix. + + Note: + This constructor initializes a CompactAdjacency object based on the provided arguments. + If 'precomputed' is provided, 'adj' and 'subset' will be ignored. + + Raises: + ValueError: If both 'adj' and 'precomputed' are set. """ if adj is None: return @@ -114,6 +123,25 @@ def __init__( fanouts=[2, 2], split="train", ): + """ + Constructs a dataset for molecules with adjacency matrices, feature matrices, and labels. + + Args: + adj_matrices (list): A list of adjacency matrices. + feature_matrices (list): A list of feature matrices. + labels (list): A list of labels. + path (str): The path to the directory containing data files. + compact (bool, optional): Whether to use compact adjacency matrices. Defaults to True. + fanouts (list, optional): A list of fanout values for each adjacency matrix. Defaults to [2, 2]. + split (str, optional): The dataset split ('train', 'val', or 'test'). Defaults to 'train'. + + Note: + This constructor initializes a MoleculesDataset object based on the provided arguments. + If 'compact' is set to True, it uses compact adjacency matrices. + + Raises: + None + """ if compact: # filename = path + '/train_comp_adjs.pkl' # if split == 'val': @@ -143,6 +171,19 @@ def __init__( self.fanouts = [fanouts] * len(adj_matrices) def __getitem__(self, index): + """ + Retrieves an item from the dataset. + + Args: + index (int): The index of the item to retrieve. + + Returns: + tuple: A tuple containing the following elements: + - adj_matrix: The adjacency matrix. + - feature_matrix: The feature matrix. + - label: The label. + - fanouts: The list of fanout values. + """ return ( self.adj_matrices[index], self.feature_matrices[index], @@ -151,4 +192,13 @@ def __getitem__(self, index): ) def __len__(self): + """ + Returns the total number of items in the dataset. + + Args: + None + + Returns: + int: The number of items in the dataset. + """ return len(self.adj_matrices) diff --git a/research/SpreadGNN/data/utils.py b/research/SpreadGNN/data/utils.py index 47a41b414a..2fd21df36c 100644 --- a/research/SpreadGNN/data/utils.py +++ b/research/SpreadGNN/data/utils.py @@ -5,6 +5,17 @@ def np_uniform_sample_next(compact_adj, tree, fanout): + """ + Uniformly sample next neighbors for a given compact adjacency matrix and traversal tree. + + Args: + compact_adj (CompactAdjacency): The compact adjacency matrix. + tree (list): The traversal tree. + fanout (int): The number of neighbors to sample for each node. + + Returns: + np.ndarray: An array containing the sampled neighbor indices. + """ last_level = tree[-1] # [batch, f^depth] batch_lengths = compact_adj.degrees[last_level] nodes = np.repeat(last_level, fanout, axis=1) @@ -27,6 +38,21 @@ def np_uniform_sample_next(compact_adj, tree, fanout): def np_traverse( compact_adj, seed_nodes, fanouts=(1,), sample_fn=np_uniform_sample_next ): + """ + Traverse a compact adjacency matrix. + + Args: + compact_adj (CompactAdjacency): The compact adjacency matrix. + seed_nodes (np.ndarray): An array of seed node indices. + fanouts (tuple): A tuple of fanout values. + sample_fn (function): A function for sampling neighbors. + + Returns: + list: A list containing the traversal tree. + + Raises: + ValueError: If the input seed_nodes format is incorrect. + """ if not isinstance(seed_nodes, np.ndarray): raise ValueError("Seed must a numpy array") @@ -53,6 +79,18 @@ def np_traverse( class WalkForestCollator(object): def __init__(self, normalize_features=False): + """ + Collate function for walking forest-based data. + + Args: + molecule (tuple): A tuple containing the molecular data. + + Returns: + tuple: A tuple containing collated data. + + Raises: + None + """ self.normalize_features = normalize_features def __call__(self, molecule): @@ -88,6 +126,18 @@ def __call__(self, molecule): class DefaultCollator(object): + """ + Default collate function for data. + + Args: + molecule (tuple): A tuple containing the molecular data + + Args: + molecule (tuple): A tuple containing the molecular data. + + Returns: + tuple: A tuple containing collated data. + """ def __init__(self, normalize_features=True, normalize_adj=True): self.normalize_features = normalize_features self.normalize_adj = normalize_adj diff --git a/research/SpreadGNN/model/gat_readout.py b/research/SpreadGNN/model/gat_readout.py index d1cd3a3d0e..51fadf9fab 100644 --- a/research/SpreadGNN/model/gat_readout.py +++ b/research/SpreadGNN/model/gat_readout.py @@ -7,6 +7,22 @@ class GraphAttentionLayer(nn.Module): """ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 """ + """ + A single Graph Attention Layer (GAT) module. + + Args: + in_features (int): The number of input features. + out_features (int): The number of output features. + dropout (float): Dropout probability for attention coefficients. + alpha (float): LeakyReLU slope parameter. + concat (bool): Whether to concatenate the multi-head results or not. + + Attributes: + W (nn.Parameter): Learnable weight matrix. + a (nn.Parameter): Learnable attention parameter matrix. + leakyrelu (nn.LeakyReLU): LeakyReLU activation with slope alpha. + + """ def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(GraphAttentionLayer, self).__init__() @@ -24,6 +40,17 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True): self.leakyrelu = nn.LeakyReLU(self.alpha) def forward(self, h, adj): + """ + Forward pass for the GAT layer. + + Args: + h (torch.Tensor): Input feature tensor. + adj (torch.Tensor): Adjacency matrix. + + Returns: + torch.Tensor: Output feature tensor. + + """ Wh = torch.mm( h, self.W ) # h.shape: (N, in_features), Wh.shape: (N, out_features) @@ -52,6 +79,13 @@ def _prepare_attentional_mechanism_input(self, Wh): return all_combinations_matrix.view(N, N, 2 * self.out_features) def __repr__(self): + """ + String representation of the GAT layer. + + Returns: + str: A string representing the layer. + + """ return ( self.__class__.__name__ + " (" @@ -63,6 +97,23 @@ def __repr__(self): class GAT(nn.Module): + """ + Graph Attention Network (GAT) model for node classification. + + Args: + feat_dim (int): Number of input features. + hidden_dim1 (int): Number of hidden units in the first GAT layer. + hidden_dim2 (int): Number of hidden units in the second GAT layer. + dropout (float): Dropout probability for attention coefficients. + alpha (float): LeakyReLU slope parameter. + nheads (int): Number of attention heads. + + Attributes: + dropout (float): Dropout probability. + attentions (nn.ModuleList): List of GAT layers with multiple heads. + out_att (GraphAttentionLayer): Final GAT layer. + + """ def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout, alpha, nheads): """Dense version of GAT.""" super(GAT, self).__init__() @@ -85,6 +136,17 @@ def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout, alpha, nheads): ) def forward(self, x, adj): + """ + Forward pass for the GAT model. + + Args: + x (torch.Tensor): Input feature tensor. + adj (torch.Tensor): Adjacency matrix. + + Returns: + torch.Tensor: Node embeddings. + + """ x = F.dropout(x, self.dropout, training=self.training) x = torch.cat([att(x, adj) for att in self.attentions], dim=1) x = F.dropout(x, self.dropout, training=self.training) @@ -95,7 +157,25 @@ def forward(self, x, adj): class Readout(nn.Module): """ - This module learns a single graph level representation for a molecule given GNN generated node embeddings + This module learns a single graph-level representation for a molecule given GNN-generated node embeddings. + + Args: + attr_dim (int): Dimension of node attributes. + embedding_dim (int): Dimension of node embeddings. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + + Attributes: + attr_dim (int): Dimension of node attributes. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + layer1 (nn.Linear): First linear layer. + layer2 (nn.Linear): Second linear layer. + output (nn.Linear): Output layer. + act (nn.ReLU): ReLU activation function. + """ def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): @@ -111,6 +191,17 @@ def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): self.act = nn.ReLU() def forward(self, node_features, node_embeddings): + """ + Forward pass for the Readout module. + + Args: + node_features (torch.Tensor): Node attributes. + node_embeddings (torch.Tensor): Node embeddings. + + Returns: + torch.Tensor: Logits for multilabel classification. + + """ combined_rep = torch.cat( (node_features, node_embeddings), dim=1 ) # Concat initial node attributed with embeddings from sage @@ -128,7 +219,23 @@ def forward(self, node_features, node_embeddings): class GatMoleculeNet(nn.Module): """ - Network that consolidates GAT + Readout into a single nn.Module + Neural network that combines GAT (Graph Attention Network) and Readout into a single module for molecular data. + + Args: + feat_dim (int): Dimension of input node features. + gat_hidden_dim1 (int): Dimension of the hidden layer in the GAT model. + node_embedding_dim (int): Dimension of node embeddings. + gat_dropout (float): Dropout probability for GAT layers. + gat_alpha (float): LeakyReLU slope parameter for GAT. + gat_nheads (int): Number of attention heads in GAT. + readout_hidden_dim (int): Dimension of the hidden layer in the Readout module. + graph_embedding_dim (int): Dimension of the graph-level embedding. + num_categories (int): Number of categories for classification. + + Attributes: + gat (GAT): GAT (Graph Attention Network) module. + readout (Readout): Readout module for graph-level representation. + """ def __init__( diff --git a/research/SpreadGNN/model/sage_readout.py b/research/SpreadGNN/model/sage_readout.py index e8df8f7235..659cfcdd41 100644 --- a/research/SpreadGNN/model/sage_readout.py +++ b/research/SpreadGNN/model/sage_readout.py @@ -7,7 +7,27 @@ class GraphSage(nn.Module): GraphSAGE model (https://arxiv.org/abs/1706.02216) to learn the role of atoms in the molecules inductively. Transforms input features into a fixed length embedding in a vector space. The embedding captures the role. """ - + """ + GraphSAGE model to learn the role of atoms in molecules inductively. + + GraphSAGE (Graph Sample and Aggregated) transforms input features into a fixed-length embedding in a vector space. + The resulting embedding captures the role of atoms in molecules. + + Args: + feat_dim (int): Dimension of input node features. + hidden_dim1 (int): Dimension of the first hidden layer. + hidden_dim2 (int): Dimension of the second hidden layer. + dropout (float): Dropout probability. + + Attributes: + feat_dim (int): Dimension of input node features. + hidden_dim1 (int): Dimension of the first hidden layer. + hidden_dim2 (int): Dimension of the second hidden layer. + layer1 (nn.Linear): First linear layer for feature transformation. + layer2 (nn.Linear): Second linear layer for feature transformation. + relu (nn.ReLU): ReLU activation function. + dropout (nn.Dropout): Dropout layer. + """ def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout): super(GraphSage, self).__init__() @@ -61,7 +81,27 @@ def forward(self, forest, feature_matrix): class Readout(nn.Module): """ - This module learns a single graph level representation for a molecule given GraphSAGE generated embeddings + This module learns a single graph-level representation for a molecule using GraphSAGE-generated embeddings. + + The Readout module combines node-level features and GraphSAGE-generated embeddings to produce a graph-level representation of a molecule. This representation can be used for various downstream tasks, such as multi-label classification. + + Args: + attr_dim (int): Dimension of initial node attributes. + embedding_dim (int): Dimension of GraphSAGE-generated node embeddings. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + + Attributes: + attr_dim (int): Dimension of initial node attributes. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + layer1 (nn.Linear): First linear layer for feature transformation. + layer2 (nn.Linear): Second linear layer for feature transformation. + output (nn.Linear): Output linear layer. + act (nn.ReLU): ReLU activation function. + """ def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): @@ -94,7 +134,23 @@ def forward(self, node_features, node_embeddings): class SageMoleculeNet(nn.Module): """ - Network that consolidates Sage + Readout into a single nn.Module + Network that combines Sage (GraphSAGE) and Readout into a single neural network module. + + The SageMoleculeNet module integrates GraphSAGE for node-level embedding generation and a Readout module for graph-level representation of molecules. It is designed for tasks such as multi-label classification on molecular graphs. + + Args: + feat_dim (int): Dimension of node features. + sage_hidden_dim1 (int): Dimension of the first hidden layer in GraphSAGE. + node_embedding_dim (int): Dimension of node embeddings generated by GraphSAGE. + sage_dropout (float): Dropout rate for GraphSAGE. + readout_hidden_dim (int): Dimension of the hidden layer in the Readout module. + graph_embedding_dim (int): Dimension of the graph-level embeddings. + num_categories (int): Number of categories for classification. + + Attributes: + sage (GraphSage): GraphSAGE module for node-level embedding generation. + readout (Readout): Readout module for generating graph-level representations. + """ def __init__( diff --git a/research/SpreadGNN/trainer/gat_readout_trainer.py b/research/SpreadGNN/trainer/gat_readout_trainer.py index 989a102bff..74f8895e37 100755 --- a/research/SpreadGNN/trainer/gat_readout_trainer.py +++ b/research/SpreadGNN/trainer/gat_readout_trainer.py @@ -12,14 +12,46 @@ class GatMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the GatMoleculeNet model. + + This trainer is responsible for training and testing the GatMoleculeNet model on client devices. It implements methods for setting and retrieving model parameters, training the model, and evaluating its performance. + + Args: + model (GatMoleculeNet): The GatMoleculeNet model to be trained. + test_data (list of torch.Tensor): The test data for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list of torch.Tensor): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the maximum test score and the best model parameters. + """ model = self.model model.to(device) @@ -85,6 +117,17 @@ def train(self, train_data, device, args): return max_test_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list of torch.Tensor): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -138,6 +181,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server. + + Args: + train_data_local_dict (dict): A dictionary of training data for each client. + test_data_local_dict (dict): A dictionary of test data for each client. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if testing on the server is successful. + """ logging.info("----------test_on_the_server--------") model_list, score_list = [], [] diff --git a/research/SpreadGNN/trainer/gat_readout_trainer_regression.py b/research/SpreadGNN/trainer/gat_readout_trainer_regression.py index 88b66c5418..703c9bf22e 100755 --- a/research/SpreadGNN/trainer/gat_readout_trainer_regression.py +++ b/research/SpreadGNN/trainer/gat_readout_trainer_regression.py @@ -11,14 +11,46 @@ class GatMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the GatMoleculeNet model. + + This trainer is responsible for training and testing the GatMoleculeNet model on client devices. It implements methods for setting and retrieving model parameters, training the model, and evaluating its performance. + + Args: + model (GatMoleculeNet): The GatMoleculeNet model to be trained. + test_data (list of torch.Tensor): The test data for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list of torch.Tensor): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the minimum test score and the best model parameters. + """ model = self.model model.to(device) @@ -93,6 +125,17 @@ def train(self, train_data, device, args): return min_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list of torch.Tensor): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -129,6 +172,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server. + + Args: + train_data_local_dict (dict): A dictionary of training data for each client. + test_data_local_dict (dict): A dictionary of test data for each client. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if testing on the server is successful. + """ logging.info("----------test_on_the_server--------") # for client_idx in train_data_local_dict.keys(): # train_data = train_data_local_dict[client_idx] @@ -158,6 +213,16 @@ def test_on_the_server( return True def _compare_models(self, model_1, model_2): + """ + Compare two models to check if they match. + + Args: + model_1 (torch.nn.Module): The first model to compare. + model_2 (torch.nn.Module): The second model to compare. + + Raises: + Exception: If a mismatch is found between the two models. + """ models_differ = 0 for key_item_1, key_item_2 in zip( model_1.state_dict().items(), model_2.state_dict().items() diff --git a/research/SpreadGNN/trainer/sage_readout_trainer.py b/research/SpreadGNN/trainer/sage_readout_trainer.py index c5ec0a3f57..1e3bea9174 100755 --- a/research/SpreadGNN/trainer/sage_readout_trainer.py +++ b/research/SpreadGNN/trainer/sage_readout_trainer.py @@ -12,14 +12,45 @@ class SageMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the MoleculeNet model. This trainer handles training and testing the model on client devices. + + Args: + model (nn.Module): The MoleculeNet model to be trained. + test_data (list): The test data used for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ + return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the minimum test score and the best model parameters. + """ model = self.model model.to(device) @@ -87,6 +118,17 @@ def train(self, train_data, device, args): return max_test_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -138,6 +180,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server using data from client devices. + + Args: + train_data_local_dict (dict): A dictionary of training data from client devices. + test_data_local_dict (dict): A dictionary of test data from client devices. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if the testing is successful. + """ logging.info("----------test_on_the_server--------") model_list, score_list = [], [] diff --git a/research/SpreadGNN/trainer/sage_readout_trainer_regression.py b/research/SpreadGNN/trainer/sage_readout_trainer_regression.py index b0977b1086..4248427349 100755 --- a/research/SpreadGNN/trainer/sage_readout_trainer_regression.py +++ b/research/SpreadGNN/trainer/sage_readout_trainer_regression.py @@ -12,14 +12,44 @@ class SageMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the MoleculeNet model. This trainer is responsible for training and testing the MoleculeNet model on client devices. + + Args: + model (nn.Module): The MoleculeNet model to be trained. + test_data (list): The test data for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the minimum test score and the best model parameters. + """ model = self.model model.to(device) @@ -94,6 +124,17 @@ def train(self, train_data, device, args): return min_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -131,6 +172,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server. + + Args: + train_data_local_dict (dict): A dictionary of training data for each client. + test_data_local_dict (dict): A dictionary of test data for each client. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if testing on the server is successful. + """ logging.info("----------test_on_the_server--------") # for client_idx in train_data_local_dict.keys(): # train_data = train_data_local_dict[client_idx] From f5c94b8f5732c47b0bb7157e7fd6645a206be98a Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 13:42:55 +0530 Subject: [PATCH 03/70] additon --- python/fedml/arguments.py | 49 +++++++++++++++ python/fedml/launch_simulation.py | 6 +- python/fedml/runner.py | 92 ++++++++++++++++++++++++++++ python/fedml/simulation/simulator.py | 37 +++++++++++ 4 files changed, 183 insertions(+), 1 deletion(-) diff --git a/python/fedml/arguments.py b/python/fedml/arguments.py index 6459b6c2c4..16d4ff25da 100755 --- a/python/fedml/arguments.py +++ b/python/fedml/arguments.py @@ -34,6 +34,12 @@ def add_args(): + """ + Create and parse command line arguments for FedML. + + Returns: + argparse.Namespace: A namespace containing the parsed arguments. + """ parser = argparse.ArgumentParser(description="FedML") parser.add_argument( "--yaml_config_file", @@ -76,6 +82,15 @@ class Arguments: """Argument class which contains all arguments from yaml config and constructs additional arguments""" def __init__(self, cmd_args, training_type=None, comm_backend=None, override_cmd_args=True): + """ + Initialize the Arguments class. + + Args: + cmd_args (argparse.Namespace): Command line arguments. + training_type (str, optional): The training platform type. Defaults to None. + comm_backend (str, optional): The communication backend type. Defaults to None. + override_cmd_args (bool, optional): Whether to override command line arguments. Defaults to True. + """ # set the command line arguments cmd_args_dict = cmd_args.__dict__ for arg_key, arg_val in cmd_args_dict.items(): @@ -87,6 +102,16 @@ def __init__(self, cmd_args, training_type=None, comm_backend=None, override_cmd for arg_key, arg_val in cmd_args_dict.items(): setattr(self, arg_key, arg_val) def load_yaml_config(self, yaml_path): + """ + Load a YAML configuration file. + + Args: + yaml_path (str): Path to the YAML configuration file. + + Returns: + dict: Loaded configuration as a dictionary. + """ + try: with open(yaml_path, "r") as stream: try: @@ -97,6 +122,14 @@ def load_yaml_config(self, yaml_path): return None def get_default_yaml_config(self, cmd_args, training_type=None, comm_backend=None): + """ + Set default YAML configuration based on training type and communication backend. + + Args: + cmd_args (argparse.Namespace): Command line arguments. + training_type (str, optional): The training platform type. Defaults to None. + comm_backend (str, optional): The communication backend type. Defaults to None. + """ if cmd_args.yaml_config_file == "": path_current_file = path.abspath(path.dirname(__file__)) if ( @@ -191,12 +224,28 @@ def get_default_yaml_config(self, cmd_args, training_type=None, comm_backend=Non return configuration def set_attr_from_config(self, configuration): + """ + Set class attributes from a configuration dictionary. + + Args: + configuration (dict): Configuration dictionary. + """ for _, param_family in configuration.items(): for key, val in param_family.items(): setattr(self, key, val) def load_arguments(training_type=None, comm_backend=None): + """ + Load arguments from command line and YAML config file. + + Args: + training_type (str, optional): The training platform type. Defaults to None. + comm_backend (str, optional): The communication backend type. Defaults to None. + + Returns: + argparse.Namespace: Parsed arguments. + """ cmd_args = add_args() # Load all arguments from YAML config file args = Arguments(cmd_args, training_type, comm_backend) diff --git a/python/fedml/launch_simulation.py b/python/fedml/launch_simulation.py index 37335a2753..b8ca2cdfdf 100644 --- a/python/fedml/launch_simulation.py +++ b/python/fedml/launch_simulation.py @@ -7,8 +7,12 @@ def run_simulation(backend=FEDML_SIMULATION_TYPE_SP): + """ + Run a simulation of the FedML Parrot. - """FedML Parrot""" + Args: + backend (str): The communication backend to use for the simulation. Defaults to FEDML_SIMULATION_TYPE_SP. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_SIMULATION fedml._global_comm_backend = backend diff --git a/python/fedml/runner.py b/python/fedml/runner.py index d536bab3cc..214bd8d659 100644 --- a/python/fedml/runner.py +++ b/python/fedml/runner.py @@ -17,6 +17,18 @@ class FedMLRunner: + """ + The main runner for different Federated Learning scenarios. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + algorithm_flow (FedMLAlgorithmFlow, optional): The pre-defined algorithm flow. Defaults to None. + """ def __init__( self, args, @@ -55,6 +67,20 @@ def __init__( def _init_simulation_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the simulation-based Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized simulation-based runner. + """ if hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_SP: from .simulation.simulator import SimulatorSingleProcess @@ -81,6 +107,20 @@ def _init_simulation_runner( def _init_cross_silo_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the cross-silo Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized cross-silo runner. + """ if args.scenario == "horizontal": if args.role == "client": from .cross_silo import Client @@ -118,6 +158,20 @@ def _init_cross_silo_runner( def _init_cheetah_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the Cheetah Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized Cheetah runner. + """ if args.role == "client": from .cheetah import Client @@ -137,6 +191,20 @@ def _init_cheetah_runner( def _init_model_serving_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the model serving Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized model serving runner. + """ if args.role == "client": from .serving import Client @@ -156,6 +224,20 @@ def _init_model_serving_runner( def _init_cross_device_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the cross-device Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized cross-device runner. + """ if args.role == "server": from .cross_device import ServerMNN @@ -170,6 +252,11 @@ def _init_cross_device_runner( @staticmethod def log_runner_result(): + """ + Log the result of the runner to a file. + + This method creates a log file containing the process ID and saves it to the "fedml_trace" directory. + """ log_runner_result_dir = os.path.join(expanduser("~"), "fedml_trace") if not os.path.exists(log_runner_result_dir): os.makedirs(log_runner_result_dir, exist_ok=True) @@ -179,6 +266,11 @@ def log_runner_result(): log_file_obj.close() def run(self): + """ + Run the initialized Federated Learning runner. + + This method executes the Federated Learning process using the selected runner. + """ self.runner.run() FedMLRunner.log_runner_result() diff --git a/python/fedml/simulation/simulator.py b/python/fedml/simulation/simulator.py index abf0394869..3740e60eba 100644 --- a/python/fedml/simulation/simulator.py +++ b/python/fedml/simulation/simulator.py @@ -26,6 +26,17 @@ class SimulatorSingleProcess: def __init__(self, args, device, dataset, model, client_trainer=None, server_aggregator=None): + """ + Initialize the SimulatorSingleProcess. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run simulations on. + dataset (object): The dataset used for training. + model (nn.Module): The machine learning model. + client_trainer (ClientTrainer, optional): The client trainer to use. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator to use. Defaults to None. + """ from .sp.classical_vertical_fl.vfl_api import VflFedAvgAPI from .sp.fedavg import FedAvgAPI from .sp.fedprox.fedprox_trainer import FedProxTrainer @@ -64,6 +75,9 @@ def __init__(self, args, device, dataset, model, client_trainer=None, server_agg raise Exception("Exception") def run(self): + """ + Run the federated training simulation. + """ self.fl_trainer.train() @@ -77,6 +91,18 @@ def __init__( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Initialize the SimulatorMPI. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run simulations on. + dataset (object): The dataset used for training. + model (nn.Module): The machine learning model. + client_trainer (ClientTrainer, optional): The client trainer to use. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator to use. Defaults to None. + """ + # Import various trainer classes based on the selected federated optimizer from .mpi.base_framework.algorithm_api import FedML_Base_distributed from .mpi.decentralized_framework.algorithm_api import FedML_Decentralized_Demo_distributed from .mpi.fedavg.FedAvgAPI import FedML_FedAvg_distributed @@ -217,6 +243,17 @@ def run(self): class SimulatorNCCL: def __init__(self, args, device, dataset, model, client_trainer=None, server_aggregator=None): + """ + Initialize the SimulatorNCCL. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run simulations on. + dataset (object): The dataset used for training. + model (nn.Module): The machine learning model. + client_trainer (ClientTrainer, optional): The client trainer to use. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator to use. Defaults to None. + """ from .nccl.fedavg.FedAvgAPI import FedML_FedAvg_NCCL if args.federated_optimizer == "FedAvg": From 012fd9cdf9d555b7bf1b6b8286f208daaa9b9d7e Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 15:22:09 +0530 Subject: [PATCH 04/70] `fedml\simulation\sp\3 folder` update --- .../sp/classical_vertical_fl/client.py | 45 +++++++++ .../sp/classical_vertical_fl/party_models.py | 49 ++++++++++ .../sp/classical_vertical_fl/vfl.py | 69 +++++++++++++ .../sp/classical_vertical_fl/vfl_api.py | 66 ++++++++++++- .../sp/classical_vertical_fl/vfl_fixture.py | 33 +++++++ .../sp/decentralized/client_dsgd.py | 98 +++++++++++++++++++ .../sp/decentralized/client_pushsum.py | 53 ++++++++++ .../sp/decentralized/decentralized_fl_api.py | 25 +++++ .../sp/decentralized/topology_manager.py | 53 ++++++++++ python/fedml/simulation/sp/fedavg/client.py | 65 ++++++++++++ .../fedml/simulation/sp/fedavg/fedavg_api.py | 55 ++++++++++- 11 files changed, 607 insertions(+), 4 deletions(-) diff --git a/python/fedml/simulation/sp/classical_vertical_fl/client.py b/python/fedml/simulation/sp/classical_vertical_fl/client.py index 1f141efe98..9d096b08fb 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/client.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/client.py @@ -2,6 +2,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a federated learning client. + + Args: + client_idx (int): Index of the client. + local_training_data (dataset): Local training dataset for the client. + local_test_data (dataset): Local test dataset for the client. + local_sample_number (int): Number of samples in the local dataset. + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run training and inference on. + model_trainer (ModelTrainer): Trainer for the client's machine learning model. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -12,21 +24,54 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset and client index. + + Args: + client_idx (int): New index of the client. + local_training_data (dataset): New local training dataset for the client. + local_test_data (dataset): New local test dataset for the client. + local_sample_number (int): New number of samples in the local dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number def get_sample_number(self): + """ + Get the number of samples in the local dataset. + + Returns: + int: Number of samples in the local dataset. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's machine learning model using global model parameters. + + Args: + w_global (list): Global model parameters. + + Returns: + list: Updated model parameters after training. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's machine learning model. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset for testing. + + Returns: + dict: Metrics obtained from local testing. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/classical_vertical_fl/party_models.py b/python/fedml/simulation/sp/classical_vertical_fl/party_models.py index 5fa8237cfa..8846022829 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/party_models.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/party_models.py @@ -11,6 +11,12 @@ def sigmoid(x): class VFLGuestModel(object): def __init__(self, local_model): + """ + Initialize a VFL guest model. + + Args: + local_model (torch.nn.Module): Local machine learning model. + """ super(VFLGuestModel, self).__init__() self.localModel = local_model self.feature_dim = local_model.get_output_dim() @@ -24,14 +30,29 @@ def __init__(self, local_model): self.y = None def set_dense_model(self, dense_model): + """ + Set the dense model for the guest model. + + Args: + dense_model: New dense model to set. + """ self.dense_model = dense_model def set_batch(self, X, y, global_step): + """ + Set the batch data and global step for training. + + Args: + X: Input data for training. + y: Target labels for training. + global_step: Current global step in training. + """ self.X = X self.y = y self.current_global_step = global_step def _fit(self, X, y): + self.temp_K_Z = self.localModel.forward(X) self.K_U = self.dense_model.forward(self.temp_K_Z) @@ -39,6 +60,16 @@ def _fit(self, X, y): self._update_models(X, y) def predict(self, X, component_list): + """ + Predict using the guest model. + + Args: + X: Input data for prediction. + component_list: List of components to consider in the prediction. + + Returns: + Predicted values. + """ temp_K_Z = self.localModel.forward(X) U = self.dense_model.forward(temp_K_Z) for comp in component_list: @@ -46,6 +77,12 @@ def predict(self, X, component_list): return sigmoid(np.sum(U, axis=1)) def receive_components(self, component_list): + """ + Receive and store components from other parties. + + Args: + component_list: List of components to receive and store. + """ for party_component in component_list: self.parties_grad_component_list.append(party_component) @@ -67,6 +104,12 @@ def _compute_common_gradient_and_loss(self, y): self.loss = class_loss.item() def send_gradients(self): + """ + Send gradients to other parties. + + Returns: + Gradients to send. + """ return self.top_grads def _update_models(self, X, y): @@ -74,6 +117,12 @@ def _update_models(self, X, y): self.localModel.backward(X, back_grad) def get_loss(self): + """ + Get the loss value of the guest model. + + Returns: + Loss value. + """ return self.loss diff --git a/python/fedml/simulation/sp/classical_vertical_fl/vfl.py b/python/fedml/simulation/sp/classical_vertical_fl/vfl.py index dd421d32db..023377fdc2 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/vfl.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/vfl.py @@ -1,5 +1,33 @@ class VerticalMultiplePartyLogisticRegressionFederatedLearning(object): + """ + Federated Learning class for logistic regression with multiple parties. + + Args: + party_A (VFLGuestModel): The party with labels (party A). + main_party_id (str, optional): The ID of the main party. Defaults to "_main". + + Methods: + set_debug(is_debug): + Set the debug mode for the federated learning. + get_main_party_id(): + Get the ID of the main party. + add_party(id, party_model): + Add a party to the federated learning. + + Attributes: + main_party_id (str): The ID of the main party. + party_a (VFLGuestModel): The party with labels (party A). + party_dict (dict): A dictionary to store other parties without labels. + is_debug (bool): Flag to enable or disable debug mode. + """ def __init__(self, party_A, main_party_id="_main"): + """ + Initialize the VerticalMultiplePartyLogisticRegressionFederatedLearning. + + Args: + party_A (VFLGuestModel): The party with labels (party A). + main_party_id (str, optional): The ID of the main party. Defaults to "_main". + """ super(VerticalMultiplePartyLogisticRegressionFederatedLearning, self).__init__() self.main_party_id = main_party_id # party A is the parity with labels @@ -9,15 +37,46 @@ def __init__(self, party_A, main_party_id="_main"): self.is_debug = False def set_debug(self, is_debug): + """ + Set the debug mode for the federated learning. + + Args: + is_debug (bool): True to enable debug mode, False to disable. + """ self.is_debug = is_debug def get_main_party_id(self): + """ + Get the ID of the main party. + + Returns: + str: The ID of the main party. + """ return self.main_party_id def add_party(self, *, id, party_model): + """ + Add a party to the federated learning. + + Args: + id (str): The ID of the party. + party_model: The model associated with the party. + """ self.party_dict[id] = party_model def fit(self, X_A, y, party_X_dict, global_step): + """ + Perform the federated learning training. + + Args: + X_A: The batch data for party A (with labels). + y: The labels for party A. + party_X_dict (dict): A dictionary of batch data for other parties. + global_step: The global training step. + + Returns: + float: The loss after training. + """ if self.is_debug: print("==> start fit") @@ -54,6 +113,16 @@ def fit(self, X_A, y, party_X_dict, global_step): return loss def predict(self, X_A, party_X_dict): + """ + Perform predictions using the federated learning model. + + Args: + X_A: The input data for party A (with labels). + party_X_dict (dict): A dictionary of input data for other parties. + + Returns: + array: Predicted labels. + """ comp_list = [] for id, party_X in party_X_dict.items(): comp_list.append(self.party_dict[id].predict(party_X)) diff --git a/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py b/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py index 518612f5f8..9a910d46e1 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py @@ -11,6 +11,15 @@ class VflFedAvgAPI(object): + """ + Federated Learning using the FedAvg algorithm. + + Args: + args (Namespace): Command-line arguments and settings. + device (str): The device (e.g., 'cpu', 'cuda') for model training. + dataset (tuple): A tuple containing dataset information. + model (torch.nn.Module): The machine learning model used for federated learning. + """ def __init__(self, args, device, dataset, model): self.device = device self.args = args @@ -46,6 +55,15 @@ def __init__(self, args, device, dataset, model): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for federated learning. + + Args: + train_data_local_num_dict (dict): A dictionary mapping client indexes to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indexes to local training data. + test_data_local_dict (dict): A dictionary mapping client indexes to local test data. + model_trainer (ModelTrainer): The model trainer used for local client training. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -61,6 +79,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Perform federated learning using the FedAvg algorithm. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() for round_idx in range(self.args.comm_round): @@ -109,6 +130,17 @@ def train(self): self._local_test_on_all_clients(round_idx) def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for the federated learning round. + + Args: + round_idx (int): The current round index. + client_num_in_total (int): The total number of clients in the dataset. + client_num_per_round (int): The number of clients to sample per round. + + Returns: + list: List of client indexes for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -119,6 +151,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset subset for testing. + + Args: + num_samples (int): The number of samples to include in the validation set. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -126,6 +164,15 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate model weights from all clients using Federated Averaging (FedAvg). + + Args: + w_locals (list): List of local model weights and sample numbers for each client. + + Returns: + dict: Averaged global model weights. + """ training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -144,10 +191,13 @@ def _aggregate(self, w_locals): def _aggregate_noniid_avg(self, w_locals): """ - The old aggregate method will impact the model performance when it comes to Non-IID setting + Aggregate model weights from all clients using non-IID averaging. + Args: - w_locals: + w_locals (list): List of local model weights for each client. + Returns: + dict: Averaged global model weights. """ (_, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -158,6 +208,12 @@ def _aggregate_noniid_avg(self, w_locals): return averaged_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx (int): The current round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -213,6 +269,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on a validation set and log the results. + + Args: + round_idx (int): The current round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py b/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py index 080a76ca52..b88c8c8b53 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py @@ -6,6 +6,20 @@ def compute_correct_prediction(*, y_targets, y_prob_preds, threshold=0.5): + """ + Compute correct predictions and counts based on probability predictions and threshold. + + Args: + y_targets (array-like): True labels. + y_prob_preds (array-like): Predicted probabilities. + threshold (float, optional): Threshold for binary classification. Defaults to 0.5. + + Returns: + Tuple: + - y_hat_lbls (numpy.ndarray): Predicted labels (0 or 1). + - [pred_pos_count, pred_neg_count, correct_count] (list): Counts of predicted positive, + predicted negative, and correct predictions. + """ y_hat_lbls = [] pred_pos_count = 0 pred_neg_count = 0 @@ -25,12 +39,31 @@ def compute_correct_prediction(*, y_targets, y_prob_preds, threshold=0.5): class FederatedLearningFixture(object): + """ + Fixture for performing federated learning with a specified model. + """ def __init__( self, federated_learning: VerticalMultiplePartyLogisticRegressionFederatedLearning, ): + """ + Initialize a Federated Learning Fixture. + + Args: + federated_learning (VerticalMultiplePartyLogisticRegressionFederatedLearning): + The federated learning instance to be used. + """ self.federated_learning = federated_learning def fit(self, train_data, test_data, epochs=50, batch_size=-1): + """ + Fit the federated learning model on the provided data. + + Args: + train_data (dict): Training data containing X and Y for each party. + test_data (dict): Testing data containing X and Y for each party. + epochs (int, optional): Number of training epochs. Defaults to 50. + batch_size (int, optional): Batch size for training. Defaults to -1 (no batching). + """ main_party_id = self.federated_learning.get_main_party_id() Xa_train = train_data[main_party_id]["X"] diff --git a/python/fedml/simulation/sp/decentralized/client_dsgd.py b/python/fedml/simulation/sp/decentralized/client_dsgd.py index f9891a9273..fd9cd28a1b 100644 --- a/python/fedml/simulation/sp/decentralized/client_dsgd.py +++ b/python/fedml/simulation/sp/decentralized/client_dsgd.py @@ -4,6 +4,53 @@ class ClientDSGD(object): + """ + Client for Distributed Stochastic Gradient Descent (DSGD). + + Args: + model: The machine learning model used by the client. + model_cache: The model cache used for temporary values. + client_id (int): The unique identifier of the client. + streaming_data (list): Streaming data for training. + topology_manager: The manager for defining communication topology. + iteration_number (int): The total number of iterations. + learning_rate (float): The learning rate for gradient descent. + batch_size (int): The batch size for training. + weight_decay (float): The weight decay for regularization. + latency (float): The communication latency. + b_symmetric (bool): Flag for symmetric or asymmetric communication topology. + + Methods: + train_local(iteration_id): + Train the client's model on local data for a specified iteration. + train(iteration_id): + Train the client's model on streaming data for a specified iteration. + get_regret(): + Get the regret (loss) for each iteration. + send_local_gradient_to_neighbor(client_list): + Send local gradients to neighboring clients. + receive_neighbor_gradients(client_id, model_x, topo_weight): + Receive gradients from a neighboring client. + update_local_parameters(): + Update local model parameters based on received gradients. + + Attributes: + model: The machine learning model used by the client. + b_symmetric (bool): Flag for symmetric or asymmetric communication topology. + topology_manager: The manager for defining communication topology. + id (int): The unique identifier of the client. + streaming_data (list): Streaming data for training. + optimizer: The optimizer for training the model. + criterion: The loss criterion used for training. + learning_rate (float): The learning rate for gradient descent. + iteration_number (int): The total number of iterations. + latency (float): The communication latency. + batch_size (int): The batch size for training. + loss_in_each_iteration (list): List to store loss for each iteration. + model_x: The model cache for temporary values. + neighbors_weight_dict (dict): Dictionary to store neighboring client weights. + neighbors_topo_weight_dict (dict): Dictionary to store neighboring client topology weights. + """ def __init__( self, model, @@ -18,6 +65,22 @@ def __init__( latency, b_symmetric, ): + """ + Initialize the ClientDSGD object. + + Args: + model: The machine learning model used by the client. + model_cache: The model cache used for temporary values. + client_id (int): The unique identifier of the client. + streaming_data (list): Streaming data for training. + topology_manager: The manager for defining communication topology. + iteration_number (int): The total number of iterations. + learning_rate (float): The learning rate for gradient descent. + batch_size (int): The batch size for training. + weight_decay (float): The weight decay for regularization. + latency (float): The communication latency. + b_symmetric (bool): Flag for symmetric or asymmetric communication topology. + """ # logging.info("streaming_data = %s" % streaming_data) # Since we use logistic regression, the model size is small. @@ -56,6 +119,12 @@ def __init__( self.neighbors_topo_weight_dict = dict() def train_local(self, iteration_id): + """ + Train the client's model on local data for a specified iteration. + + Args: + iteration_id (int): The current iteration. + """ self.optimizer.zero_grad() train_x = torch.from_numpy(self.streaming_data[iteration_id]["x"]) train_y = torch.FloatTensor([self.streaming_data[iteration_id]["y"]]) @@ -66,6 +135,12 @@ def train_local(self, iteration_id): self.loss_in_each_iteration.append(loss) def train(self, iteration_id): + """ + Train the client's model on streaming data for a specified iteration. + + Args: + iteration_id (int): The current iteration. + """ self.optimizer.zero_grad() if iteration_id >= self.iteration_number: @@ -86,10 +161,22 @@ def train(self, iteration_id): self.loss_in_each_iteration.append(loss) def get_regret(self): + """ + Get the regret (loss) for each iteration. + + Returns: + list: A list containing the loss for each iteration. + """ return self.loss_in_each_iteration # simulation def send_local_gradient_to_neighbor(self, client_list): + """ + Send local gradients to neighboring clients for simulation. + + Args: + client_list (list): List of client objects representing neighbors. + """ for index in range(len(self.topology)): if self.topology[index] != 0 and index != self.id: client = client_list[index] @@ -98,10 +185,21 @@ def send_local_gradient_to_neighbor(self, client_list): ) def receive_neighbor_gradients(self, client_id, model_x, topo_weight): + """ + Receive gradients from a neighboring client for simulation. + + Args: + client_id (int): The identifier of the neighboring client. + model_x: Model parameters from the neighboring client. + topo_weight (float): Topology weight associated with the neighboring client. + """ self.neighbors_weight_dict[client_id] = model_x self.neighbors_topo_weight_dict[client_id] = topo_weight def update_local_parameters(self): + """ + Update local model parameters based on received gradients. + """ # update x_{t+1/2} for x_paras in self.model_x.parameters(): x_paras.data.mul_(self.topology[self.id]) diff --git a/python/fedml/simulation/sp/decentralized/client_pushsum.py b/python/fedml/simulation/sp/decentralized/client_pushsum.py index 08da9bccf8..05e9dbc808 100644 --- a/python/fedml/simulation/sp/decentralized/client_pushsum.py +++ b/python/fedml/simulation/sp/decentralized/client_pushsum.py @@ -20,6 +20,23 @@ def __init__( b_symmetric, time_varying, ): + """ + Initialize a ClientPushsum instance. + + Args: + model: The client's model. + model_cache: Cache for the model parameters. + client_id (int): Identifier for the client. + streaming_data: Streaming data for training. + topology_manager: Topology manager for network topology. + iteration_number (int): Number of iterations. + learning_rate (float): Learning rate for optimization. + batch_size (int): Batch size for training. + weight_decay (float): Weight decay for optimization. + latency (float): Latency in communication. + b_symmetric (bool): Whether the topology is symmetric. + time_varying (bool): Whether the topology is time-varying. + """ # logging.info("streaming_data = %s" % streaming_data) # Since we use logistic regression, the model size is small. @@ -60,6 +77,12 @@ def __init__( self.neighbors_topo_weight_dict = dict() def train_local(self, iteration_id): + """ + Train the client's model using local data for a specific iteration. + + Args: + iteration_id (int): The iteration index. + """ self.optimizer.zero_grad() train_x = torch.from_numpy(self.streaming_data[iteration_id]["x"]) train_y = torch.FloatTensor([self.streaming_data[iteration_id]["y"]]) @@ -70,6 +93,12 @@ def train_local(self, iteration_id): self.loss_in_each_iteration.append(loss) def train(self, iteration_id): + """ + Train the client's model using data for a specific iteration. + + Args: + iteration_id (int): The iteration index. + """ self.optimizer.zero_grad() if iteration_id >= self.iteration_number: @@ -105,10 +134,22 @@ def train(self, iteration_id): self.loss_in_each_iteration.append(loss) def get_regret(self): + """ + Get the regret (loss) for each iteration. + + Returns: + list: A list containing the loss for each iteration. + """ return self.loss_in_each_iteration # simulation def send_local_gradient_to_neighbor(self, client_list): + """ + Send local gradients to neighboring clients for simulation. + + Args: + client_list (list): List of client objects representing neighbors. + """ for index in range(len(self.topology)): if self.topology[index] != 0 and index != self.id: client = client_list[index] @@ -120,11 +161,23 @@ def send_local_gradient_to_neighbor(self, client_list): ) def receive_neighbor_gradients(self, client_id, model_x, topo_weight, omega): + """ + Receive gradients from a neighboring client for simulation. + + Args: + client_id (int): The identifier of the neighboring client. + model_x: Model parameters from the neighboring client. + topo_weight (float): Topology weight associated with the neighboring client. + omega (float): Omega value for push-sum. + """ self.neighbors_weight_dict[client_id] = model_x self.neighbors_topo_weight_dict[client_id] = topo_weight self.neighbors_omega_dict[client_id] = omega def update_local_parameters(self): + """ + Update local model parameters and omega based on received gradients. + """ # update x_{t+1/2} for x_paras in self.model_x.parameters(): x_paras.data.mul_(self.topology[self.id]) diff --git a/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py b/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py index 9ced125465..90cdadb53c 100644 --- a/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py +++ b/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py @@ -9,6 +9,17 @@ def cal_regret(client_list, client_number, t): + """ + Calculate the average regret across all clients. + + Args: + client_list (list): List of client objects. + client_number (int): Total number of clients. + t (int): Current iteration. + + Returns: + float: Average regret across all clients. + """ regret = 0 for client in client_list: regret += np.sum(client.get_regret()) @@ -20,6 +31,20 @@ def cal_regret(client_list, client_number, t): def FedML_decentralized_fl( client_number, client_id_list, streaming_data, model, model_cache, args ): + """ + Run decentralized federated learning with the specified configuration. + + Args: + client_number (int): Total number of clients. + client_id_list (list): List of client IDs. + streaming_data (list): List of streaming data for each client. + model: The federated learning model. + model_cache: Model cache for each client. + args: Additional arguments for configuration. + + Returns: + None + """ iteration_number_T = args.iteration_number lr_rate = args.learning_rate batch_size = args.batch_size diff --git a/python/fedml/simulation/sp/decentralized/topology_manager.py b/python/fedml/simulation/sp/decentralized/topology_manager.py index 906f6d77e7..7f4f3565ff 100644 --- a/python/fedml/simulation/sp/decentralized/topology_manager.py +++ b/python/fedml/simulation/sp/decentralized/topology_manager.py @@ -3,6 +3,29 @@ class TopologyManager: + """ + Manages the network topology for decentralized federated learning. + + Args: + n (int): Total number of clients. + b_symmetric (bool): Flag indicating symmetric or asymmetric topology. + undirected_neighbor_num (int): Number of undirected neighbors for symmetric topology. + out_directed_neighbor (int): Number of outgoing directed neighbors for asymmetric topology. + + Attributes: + n (int): Total number of clients. + b_symmetric (bool): Flag indicating symmetric or asymmetric topology. + undirected_neighbor_num (int): Number of undirected neighbors for symmetric topology. + out_directed_neighbor (int): Number of outgoing directed neighbors for asymmetric topology. + topology_symmetric (list): Symmetric topology information. + topology_asymmetric (list): Asymmetric topology information. + b_fully_connected (bool): Flag indicating if the topology is fully connected. + + Methods: + generate_topology(): Generates the network topology. + get_symmetric_neighbor_list(client_idx): Gets symmetric neighbors for a client. + get_asymmetric_neighbor_list(client_idx): Gets asymmetric neighbors for a client. + """ def __init__( self, n, b_symmetric, undirected_neighbor_num=5, out_directed_neighbor=5 ): @@ -17,6 +40,9 @@ def __init__( self.b_fully_connected = True def generate_topology(self): + """ + Generates the network topology based on configuration. + """ if self.b_fully_connected: self.__fully_connected() return @@ -27,16 +53,37 @@ def generate_topology(self): self.__randomly_pick_neighbors_asymmetric() def get_symmetric_neighbor_list(self, client_idx): + """ + Gets the symmetric neighbor list for a client. + + Args: + client_idx (int): Index of the client. + + Returns: + list: List of symmetric neighbors for the specified client. + """ if client_idx >= self.n: return [] return self.topology_symmetric[client_idx] def get_asymmetric_neighbor_list(self, client_idx): + """ + Gets the asymmetric neighbor list for a client. + + Args: + client_idx (int): Index of the client. + + Returns: + list: List of asymmetric neighbors for the specified client. + """ if client_idx >= self.n: return [] return self.topology_asymmetric[client_idx] def __randomly_pick_neighbors_symmetric(self): + """ + Generates symmetric topology with randomly added links for each node. + """ # first generate a ring topology topology_ring = np.array( nx.to_numpy_matrix(nx.watts_strogatz_graph(self.n, 2, 0)), dtype=np.float32 @@ -74,6 +121,9 @@ def __randomly_pick_neighbors_symmetric(self): self.topology_symmetric = topology_symmetric def __randomly_pick_neighbors_asymmetric(self): + """ + Generates asymmetric topology with randomly added links for each node. + """ # randomly add some links for each node (symmetric) k = self.undirected_neighbor_num # print("neighbors = " + str(k)) @@ -134,6 +184,9 @@ def __randomly_pick_neighbors_asymmetric(self): self.topology_asymmetric = topology_ring def __fully_connected(self): + """ + Generates fully connected symmetric topology. + """ topology_fully_connected = np.array( nx.to_numpy_matrix(nx.watts_strogatz_graph(self.n, self.n - 1, 0)), dtype=np.float32, diff --git a/python/fedml/simulation/sp/fedavg/client.py b/python/fedml/simulation/sp/fedavg/client.py index cc74a9d932..12df31b681 100644 --- a/python/fedml/simulation/sp/fedavg/client.py +++ b/python/fedml/simulation/sp/fedavg/client.py @@ -1,4 +1,36 @@ class Client: + """ + Represents a client in a federated learning system. + + Args: + client_idx (int): The index of the client. + local_training_data (list): Local training data. + local_test_data (list): Local test data. + local_sample_number (int): Number of local samples. + args (object): Arguments for configuration. + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + model_trainer (object): The model trainer object for training and testing. + + Attributes: + client_idx (int): The index of the client. + local_training_data (list): Local training data. + local_test_data (list): Local test data. + local_sample_number (int): Number of local samples. + args (object): Arguments for configuration. + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + model_trainer (object): The model trainer object for training and testing. + + Methods: + update_local_dataset(client_idx, local_training_data, local_test_data, local_sample_number): + Updates the local dataset for the client. + + get_sample_number(): Gets the number of local samples. + + train(w_global): Trains the client's model using the global model weights. + + local_test(b_use_test_dataset): Tests the client's model using local or test data. + + """ def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): @@ -12,6 +44,15 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Updates the local dataset for the client. + + Args: + client_idx (int): The index of the client. + local_training_data (list): Updated local training data. + local_test_data (list): Updated local test data. + local_sample_number (int): Updated number of local samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -19,15 +60,39 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Gets the number of local samples. + + Returns: + int: Number of local samples. + """ return self.local_sample_number def train(self, w_global): + """ + Trains the client's model using the global model weights. + + Args: + w_global (object): Global model weights. + + Returns: + object: Updated client model weights. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Tests the client's model using local or test data. + + Args: + b_use_test_dataset (bool): Flag to use test dataset for testing. + + Returns: + object: Model evaluation metrics. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/fedavg/fedavg_api.py b/python/fedml/simulation/sp/fedavg/fedavg_api.py index a748186fb3..ded0d92efa 100644 --- a/python/fedml/simulation/sp/fedavg/fedavg_api.py +++ b/python/fedml/simulation/sp/fedavg/fedavg_api.py @@ -12,6 +12,33 @@ class FedAvgAPI(object): + """ + Federated Averaging API for federated learning. + + Args: + args (object): Arguments for configuration. + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + dataset (tuple): A tuple containing dataset information. + + Attributes: + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + args (object): Arguments for configuration. + train_global: Global training dataset. + test_global: Global test dataset. + val_global: Global validation dataset. + train_data_num_in_total (int): Total number of training samples. + test_data_num_in_total (int): Total number of test samples. + client_list (list): List of client instances. + train_data_local_num_dict (dict): Dictionary mapping client index to the number of local training samples. + train_data_local_dict (dict): Dictionary mapping client index to local training data. + test_data_local_dict (dict): Dictionary mapping client index to local test data. + model_trainer: Model trainer for federated learning. + model: The federated model. + + Methods: + train(): Train the federated model using federated averaging. + + """ def __init__(self, args, device, dataset, model): self.device = device self.args = args @@ -49,6 +76,7 @@ def __init__(self, args, device, dataset, model): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """Setup client instances.""" logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -125,6 +153,7 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """Sample clients for communication round.""" if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -135,6 +164,7 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """Generate a validation dataset.""" test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -142,6 +172,7 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """Aggregate local model weights.""" training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -160,10 +191,14 @@ def _aggregate(self, w_locals): def _aggregate_noniid_avg(self, w_locals): """ - The old aggregate method will impact the model performance when it comes to Non-IID setting + Aggregate local model weights using non-IID average method. + Args: - w_locals: + w_locals (list): List of tuples containing (sample_num, local_weights). + Returns: + dict: Averaged model parameters. + """ (_, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -174,6 +209,16 @@ def _aggregate_noniid_avg(self, w_locals): return averaged_params def _local_test_on_all_clients(self, round_idx): + """ + Aggregate local model weights using non-IID average method. + + Args: + w_locals (list): List of tuples containing (sample_num, local_weights). + + Returns: + dict: Averaged model parameters. + + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -235,7 +280,13 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on all clients and log the results. + Args: + round_idx (int): The current communication round index. + + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) if self.val_global is None: From 8ab163c86ad629719dff5a345115e3a752a35b36 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 17:17:28 +0530 Subject: [PATCH 05/70] same as previous --- .../fedml/simulation/sp/feddyn/client copy.py | 54 +++++ .../simulation/sp/hierarchical_fl/client.py | 24 +++ .../simulation/sp/hierarchical_fl/group.py | 34 +++ .../simulation/sp/hierarchical_fl/trainer.py | 34 +++ python/fedml/simulation/sp/mime/client.py | 47 +++++ .../fedml/simulation/sp/mime/mime_trainer.py | 70 ++++++- python/fedml/simulation/sp/mime/opt_utils.py | 58 +++++- python/fedml/simulation/sp/scaffold/client.py | 46 +++++ .../sp/scaffold/scaffold_trainer.py | 59 ++++++ .../simulation/sp/turboaggregate/TA_client.py | 18 ++ .../sp/turboaggregate/TA_trainer.py | 65 ++++++ .../sp/turboaggregate/mpc_function.py | 194 +++++++++++++++++- 12 files changed, 697 insertions(+), 6 deletions(-) diff --git a/python/fedml/simulation/sp/feddyn/client copy.py b/python/fedml/simulation/sp/feddyn/client copy.py index 02d1b30333..1a7bc2bb50 100644 --- a/python/fedml/simulation/sp/feddyn/client copy.py +++ b/python/fedml/simulation/sp/feddyn/client copy.py @@ -4,6 +4,15 @@ def model_parameter_vector(model): + """ + Flatten the model's parameters into a single vector. + + Args: + model (torch.nn.Module): The neural network model. + + Returns: + torch.Tensor: A flattened vector containing all model parameters. + """ param = [p.view(-1) for p in model.parameters()] return torch.concat(param, dim=0) @@ -12,6 +21,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client for federated learning. + + Args: + client_idx (int): Index of the client. + local_training_data (torch.utils.data.DataLoader): Local training data. + local_test_data (torch.utils.data.DataLoader): Local test data. + local_sample_number (int): Number of samples in the local dataset. + args: Command-line arguments. + device (torch.device): Device for training. + model_trainer: Model trainer for training and testing. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -30,6 +51,15 @@ def __init__( def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): Index of the client. + local_training_data (torch.utils.data.DataLoader): Local training data. + local_test_data (torch.utils.data.DataLoader): Local test data. + local_sample_number (int): Number of samples in the local dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -37,15 +67,39 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of samples in the local dataset. + + Returns: + int: Number of samples. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's model using the global model parameters. + + Args: + w_global: Global model parameters. + + Returns: + tuple: A tuple containing the updated weights and gradients. + """ self.model_trainer.set_model_params(w_global) self.old_grad = self.model_trainer.train(self.local_training_data, self.device, self.args, self.old_grad) weights = self.model_trainer.get_model_params() return weights, self.old_grad def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's dataset. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset or training dataset. + + Returns: + dict: Metrics from the local test. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/hierarchical_fl/client.py b/python/fedml/simulation/sp/hierarchical_fl/client.py index 48efe5ba6c..5eed5e402e 100644 --- a/python/fedml/simulation/sp/hierarchical_fl/client.py +++ b/python/fedml/simulation/sp/hierarchical_fl/client.py @@ -7,6 +7,19 @@ class HFLClient(Client): + """ + Represents a High-Frequency Learning (HFL) client in a federated learning setting. + + Args: + client_idx (int): Index of the client. + local_training_data: Local training data for the client. + local_test_data: Local test data for the client. + local_sample_number: Number of local samples. + args: Arguments for client configuration. + device: Device (e.g., 'cuda' or 'cpu') to perform computations. + model: The client's model. + model_trainer: Trainer for the client's model. + """ def __init__(self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model, model_trainer): @@ -24,6 +37,17 @@ def __init__(self, client_idx, local_training_data, local_test_data, local_sampl self.criterion = nn.CrossEntropyLoss().to(device) def train(self, global_round_idx, group_round_idx, w): + """ + Train the client's model using High-Frequency Learning (HFL) approach. + + Args: + global_round_idx (int): Global round index. + group_round_idx (int): Group round index. + w: Model weights to initialize training. + + Returns: + list: A list of tuples containing global epoch and model state dictionaries. + """ self.model.load_state_dict(w) self.model.to(self.device) diff --git a/python/fedml/simulation/sp/hierarchical_fl/group.py b/python/fedml/simulation/sp/hierarchical_fl/group.py index adfa27d0bd..70c568fbc1 100644 --- a/python/fedml/simulation/sp/hierarchical_fl/group.py +++ b/python/fedml/simulation/sp/hierarchical_fl/group.py @@ -5,6 +5,20 @@ class Group(FedAvgAPI): + """ + Represents a group of clients in a federated learning setting. + + Args: + idx (int): Index of the group. + total_client_indexes (list): List of client indexes in the group. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + train_data_local_num_dict: Dictionary containing the number of local training samples for each client. + args: Arguments for group configuration. + device: Device (e.g., 'cuda' or 'cpu') to perform computations. + model: The shared model used by clients in the group. + model_trainer: Trainer for the shared model. + """ def __init__( self, idx, @@ -35,12 +49,32 @@ def __init__( ) def get_sample_number(self, sampled_client_indexes): + """ + Calculate the total number of training samples in the group. + + Args: + sampled_client_indexes (list): List of sampled client indexes. + + Returns: + int: Total number of training samples in the group. + """ self.group_sample_number = 0 for client_idx in sampled_client_indexes: self.group_sample_number += self.train_data_local_num_dict[client_idx] return self.group_sample_number def train(self, global_round_idx, w, sampled_client_indexes): + """ + Train the group of clients using federated learning. + + Args: + global_round_idx (int): Global round index. + w: Model weights to initialize training. + sampled_client_indexes (list): List of sampled client indexes. + + Returns: + list: A list of tuples containing global epoch and aggregated model weights. + """ sampled_client_list = [self.client_dict[client_idx] for client_idx in sampled_client_indexes] w_group = w w_group_list = [] diff --git a/python/fedml/simulation/sp/hierarchical_fl/trainer.py b/python/fedml/simulation/sp/hierarchical_fl/trainer.py index c0d1c05003..63085dd67e 100644 --- a/python/fedml/simulation/sp/hierarchical_fl/trainer.py +++ b/python/fedml/simulation/sp/hierarchical_fl/trainer.py @@ -8,6 +8,15 @@ class HierarchicalTrainer(FedAvgAPI): + """ + Represents a hierarchical federated learning trainer. + + Args: + train_data_local_num_dict: Dictionary containing the number of local training samples for each client. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + model_trainer: Trainer for the shared model. + """ def _setup_clients( self, train_data_local_num_dict, @@ -15,6 +24,15 @@ def _setup_clients( test_data_local_dict, model_trainer, ): + """ + Set up client groups and maintain a dummy client for testing. + + Args: + train_data_local_num_dict: Dictionary containing the number of local training samples for each client. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + model_trainer: Trainer for the shared model. + """ logging.info("############setup_clients (START)#############") if self.args.group_method == "random": self.group_indexes = np.random.randint( @@ -61,6 +79,17 @@ def _setup_clients( def _client_sampling( self, global_round_idx, client_num_in_total, client_num_per_round ): + """ + Sample clients for training in a hierarchical manner. + + Args: + global_round_idx (int): Global round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + dict: Dictionary mapping group indexes to sampled client indexes. + """ sampled_client_indexes = super()._client_sampling( global_round_idx, client_num_in_total, client_num_per_round ) @@ -76,6 +105,11 @@ def _client_sampling( return group_to_client_indexes def train(self): + """ + Train the hierarchical federated learning model. + + This method manages global communication rounds and client sampling. + """ w_global = self.model.state_dict() for global_round_idx in range(self.args.comm_round): logging.info( diff --git a/python/fedml/simulation/sp/mime/client.py b/python/fedml/simulation/sp/mime/client.py index 00df1b004e..e0b0553175 100644 --- a/python/fedml/simulation/sp/mime/client.py +++ b/python/fedml/simulation/sp/mime/client.py @@ -1,4 +1,16 @@ class Client: + """ + Represents a client in a federated learning setting. + + Args: + client_idx (int): Index of the client. + local_training_data: Local training data for the client. + local_test_data: Local test data for the client. + local_sample_number: Number of local samples. + args: Arguments for client configuration. + device: Device (e.g., 'cuda' or 'cpu') to perform computations. + model_trainer: Trainer for the client's model. + """ def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): @@ -12,6 +24,15 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): Index of the client. + local_training_data: New local training data for the client. + local_test_data: New local test data for the client. + local_sample_number: New number of local samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -19,15 +40,41 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local samples. + + Returns: + int: Number of local samples. + """ return self.local_sample_number def train(self, w_global, grad_global, global_named_states): + """ + Train the client's model. + + Args: + w_global: Global model parameters. + grad_global: Global gradient. + global_named_states: Named states of the global optimizer. + + Returns: + tuple: A tuple containing local model weights and local gradients. + """ self.model_trainer.set_model_params(w_global) local_grad = self.model_trainer.train(self.local_training_data, self.device, self.args, grad_global, global_named_states) weights = self.model_trainer.get_model_params() return weights, local_grad def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's dataset. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset. + + Returns: + dict: Metrics from the local test. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/mime/mime_trainer.py b/python/fedml/simulation/sp/mime/mime_trainer.py index ff7831a523..74eff35dfd 100644 --- a/python/fedml/simulation/sp/mime/mime_trainer.py +++ b/python/fedml/simulation/sp/mime/mime_trainer.py @@ -18,7 +18,20 @@ class MimeTrainer(object): + """ + Trainer for the Mime model on federated learning. + """ def __init__(self, dataset, model, device, args): + """ + Initialize the MimeTrainer. + + Args: + dataset: A list containing dataset information. + model: The Mime model. + device: The target device for training. + args: Training arguments. + """ + self.device = device self.args = args [ @@ -58,6 +71,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for federated learning. + + Args: + train_data_local_num_dict: Dictionary containing local training data numbers for each client. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + model_trainer: Model trainer for client instances. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -73,11 +95,10 @@ def _setup_clients( logging.info("############setup_clients (END)#############") - - - - def train(self): + """ + Perform federated training using the Mime model. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -142,6 +163,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Perform client sampling for each communication round. + + Args: + round_idx: Index of the communication round. + client_num_in_total: Total number of clients. + client_num_per_round: Number of clients per round. + + Returns: + List: List of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -152,6 +184,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set by sampling a subset of the test data. + + Args: + num_samples (int): The number of samples to include in the validation set. Default is 10,000. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -160,6 +198,9 @@ def _generate_validation_set(self, num_samples=10000): def _instanciate_opt(self): + """ + Initialize the optimizer for the MimeTrainer. + """ self.opt = OptRepo.name2cls(self.args.server_optimizer)( # self.model_global.parameters(), lr=self.args.server_lr self.model_trainer.model.parameters(), @@ -173,11 +214,26 @@ def _instanciate_opt(self): def _aggregate(self, w_locals): + """ + Aggregate the local model weights to obtain global model weights. + + Args: + w_locals: List of local model weights. + + Returns: + avg_params: Aggregated global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients. + + Args: + round_idx: Index of the communication round. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -253,6 +309,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set. + + Args: + round_idx: Index of the communication round. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/mime/opt_utils.py b/python/fedml/simulation/sp/mime/opt_utils.py index 303b386db2..28ac32a944 100644 --- a/python/fedml/simulation/sp/mime/opt_utils.py +++ b/python/fedml/simulation/sp/mime/opt_utils.py @@ -4,6 +4,12 @@ def show_opt_state(optimizer): + """ + Display selected optimizer's state information. + + Args: + optimizer: The optimizer to display state information for. + """ i = 0 for p in optimizer.state.keys(): # print(list(optimizer.state[p].keys())) @@ -18,6 +24,12 @@ def show_opt_state(optimizer): print(key, torch.norm((optimizer.state[p][key]))) def show_named_state(named_states): + """ + Display state information for a dictionary of named states. + + Args: + named_states (dict): A dictionary containing named states to display. + """ i = 0 for name in named_states.keys(): # print(list(optimizer.state[p].keys())) @@ -34,8 +46,14 @@ def show_named_state(named_states): class OptimizerLoader(): - def __init__(self, model, optimizer): + """ + Initialize the OptimizerLoader. + + Args: + model: The model being optimized. + optimizer: The optimizer used for training. + """ self.optimizer = optimizer self.model = model self.named_states = {} @@ -50,9 +68,22 @@ def __init__(self, model, optimizer): # print(key, type(optimizer.state[p][key])) def get_opt_state(self): + """ + Get the optimizer's named states. + + Returns: + dict: A dictionary containing the optimizer's named states. + """ return self.named_states def set_opt_state(self, named_states, device="cpu"): + """ + Set the optimizer's named states. + + Args: + named_states (dict): A dictionary containing the named states to set. + device (str): The target device for the named states (default is "cpu"). + """ for p in self.optimizer.state.keys(): new_state = named_states[self.parameter_names[p]] # for key in self.optimizer.state[p].keys(): @@ -61,12 +92,25 @@ def set_opt_state(self, named_states, device="cpu"): # print(key, type(self.optimizer.state[p][key])) def get_grad(self): + """ + Get the gradients of the model's parameters. + + Returns: + dict: A dictionary containing the gradients of the model's parameters. + """ grad = {} for name, parameter in self.model.named_parameters(): grad[name] = parameter.grad return grad def set_grad(self, grad, device="cpu"): + """ + Set the gradients of the model's parameters. + + Args: + grad (dict): A dictionary containing the gradients to set. + device (str): The target device for the gradients (default is "cpu"). + """ for name, parameter in self.model.named_parameters(): # logging.info(f"parameter.grad: {type(parameter.grad)}, grad[name]: {type(grad[name])} ") # logging.info(f"parameter.grad.shape: {parameter.grad.shape}, grad[name].shape: {grad[name].shape} ") @@ -76,9 +120,21 @@ def set_grad(self, grad, device="cpu"): return grad def zero_grad(self): + """ + Zero out the gradients of the model's parameters. + """ self.optimizer.zero_grad() def update_opt_state(self, update_model=False): + """ + Update the optimizer's state after a step. + + Args: + update_model (bool): Whether to update the model's parameters as well (default is False). + + Returns: + dict: A dictionary containing the updated optimizer's named states. + """ if not update_model: origin_model_params = self.model.state_dict() self.optimizer.step() diff --git a/python/fedml/simulation/sp/scaffold/client.py b/python/fedml/simulation/sp/scaffold/client.py index eed328cc24..6b610122f6 100644 --- a/python/fedml/simulation/sp/scaffold/client.py +++ b/python/fedml/simulation/sp/scaffold/client.py @@ -6,6 +6,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client for federated learning. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The DataLoader for local training data. + local_test_data (torch.utils.data.DataLoader): The DataLoader for local test data. + local_sample_number (int): The number of local training samples. + args: The arguments for the client. + device (torch.device): The device to perform computations on. + model_trainer: The model trainer used for training. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -20,6 +32,15 @@ def __init__( def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The DataLoader for local training data. + local_test_data (torch.utils.data.DataLoader): The DataLoader for local test data. + local_sample_number (int): The number of local training samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -27,9 +48,25 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local training samples. + + Returns: + int: The number of local training samples. + """ return self.local_sample_number def train(self, w_global, c_model_global_param): + """ + Perform local training for the client. + + Args: + w_global: The global model parameters. + c_model_global_param: The global model parameters of the central model. + + Returns: + tuple: A tuple containing weights_delta and c_delta_para. + """ c_model_global_param = deepcopy(c_model_global_param) c_model_local_param = self.c_model_local.state_dict() @@ -56,6 +93,15 @@ def train(self, w_global, c_model_global_param): return weights_delta, c_delta_para def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's dataset. + + Args: + b_use_test_dataset (bool): If True, use the test dataset; if False, use the training dataset. + + Returns: + dict: A dictionary containing test metrics. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/scaffold/scaffold_trainer.py b/python/fedml/simulation/sp/scaffold/scaffold_trainer.py index 9a812b8a32..4caea49898 100644 --- a/python/fedml/simulation/sp/scaffold/scaffold_trainer.py +++ b/python/fedml/simulation/sp/scaffold/scaffold_trainer.py @@ -16,6 +16,15 @@ class ScaffoldTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the ScaffoldTrainer. + + Args: + dataset: A list of dataset components. + model: The model to be trained. + device: The computing device (e.g., 'cuda' or 'cpu'). + args: Training arguments. + """ self.device = device self.args = args [ @@ -58,6 +67,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for training. + + Args: + train_data_local_num_dict: A dictionary of local training dataset sizes. + train_data_local_dict: A dictionary of local training datasets. + test_data_local_dict: A dictionary of local test datasets. + model_trainer: The model trainer instance. + """ logging.info("############setup_clients (START)#############") if self.args.initialize_all_clients: num_initialized_clients = self.args.client_num_in_total @@ -77,6 +95,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Perform the training process. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -155,6 +176,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample clients for federated learning communication round. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select per round. + + Returns: + list: List of client indexes selected for communication. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -165,6 +197,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset subset for local validation. + + Args: + num_samples (int): Number of samples in the validation dataset subset. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -172,6 +210,15 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model updates using FedMLAggOperator. + + Args: + w_locals: List of local model updates. + + Returns: + tuple: Total aggregated model update and total client delta parameters. + """ # training_num = 0 # for idx in range(len(w_locals)): # (sample_num, averaged_params) = w_locals[idx] @@ -185,6 +232,12 @@ def _aggregate(self, w_locals): def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients for both training and test datasets. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -246,6 +299,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on a validation set for the specified round. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/turboaggregate/TA_client.py b/python/fedml/simulation/sp/turboaggregate/TA_client.py index 0d617c3b24..4f6356a941 100644 --- a/python/fedml/simulation/sp/turboaggregate/TA_client.py +++ b/python/fedml/simulation/sp/turboaggregate/TA_client.py @@ -4,6 +4,18 @@ class TA_Client(Client): + """ + A subclass of the Client class for a specific type of client. + + Args: + client_idx (int): The index of the client. + local_training_data: The local training data for the client. + local_test_data: The local test data for the client. + local_sample_number: The number of local samples. + args: Additional arguments. + device: The computing device (e.g., 'cuda' or 'cpu'). + model_trainer: The model trainer for this client. + """ def __init__( self, client_idx, @@ -30,4 +42,10 @@ def __init__( # self.buffer_out = np.zeros(dtype='int') def set_dropout(self, isdrop): + """ + Set the dropout flag for this client. + + Args: + isdrop (bool): Whether to enable dropout for this client. + """ self.isdrop = isdrop diff --git a/python/fedml/simulation/sp/turboaggregate/TA_trainer.py b/python/fedml/simulation/sp/turboaggregate/TA_trainer.py index 59283423be..90041e16df 100644 --- a/python/fedml/simulation/sp/turboaggregate/TA_trainer.py +++ b/python/fedml/simulation/sp/turboaggregate/TA_trainer.py @@ -10,6 +10,15 @@ class TurboAggregateTrainer(object): + """ + TurboAggregateTrainer for federated learning with Turbo-Aggregate protocol. + + Args: + dataset: A list containing dataset-related information. + model: The global model for training. + device: The computing device (e.g., 'cuda' or 'cpu'). + args: Additional training arguments. + """ def __init__(self, dataset, model, device, args): self.device = device self.args = args @@ -40,6 +49,14 @@ def __init__(self, dataset, model, device, args): self.setup_clients(data_local_num_dict, train_data_local_dict, test_data_local_dict) def setup_clients(self, data_local_num_dict, train_data_local_dict, test_data_local_dict): + """ + Set up the list of clients for federated learning. + + Args: + data_local_num_dict: A dictionary containing the number of local samples for each client. + train_data_local_dict: A dictionary containing local training data for each client. + test_data_local_dict: A dictionary containing local test data for each client. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_in_total): c = TA_Client( @@ -55,6 +72,9 @@ def setup_clients(self, data_local_num_dict, train_data_local_dict, test_data_lo logging.info("############setup_clients (END)#############") def train(self): + """ + Train the global model using the Turbo-Aggregate protocol. + """ for round_idx in range(self.args.comm_round): logging.info("Communication round : {}".format(round_idx)) w_global = self.model_trainer.get_model_params() @@ -94,6 +114,15 @@ def train(self): self.local_test(self.model_global, round_idx) def aggregate(self, w_locals): + """ + Aggregate the local model weights from clients using Turbo-Aggregate. + + Args: + w_locals: List of local model weights from clients. + + Returns: + Averaged global model weights. + """ logging.info("################aggregate: %d" % len(w_locals)) (sample_num, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -107,6 +136,7 @@ def aggregate(self, w_locals): return averaged_params def TA_topology_vanilla(self): + # logging.info("################aggregate: %d" % len(w_locals)) # N = self.args.client_number @@ -119,10 +149,24 @@ def TA_topology_vanilla(self): pass def local_test(self, model_global, round_idx): + """ + Perform local testing on clients. + + Args: + model_global: The global model to evaluate. + round_idx: The communication round index. + """ self.local_test_on_training_data(model_global, round_idx) self.local_test_on_test_data(model_global, round_idx) def local_test_on_training_data(self, model_global, round_idx): + """ + Perform local testing on training data for clients. + + Args: + model_global: The global model to evaluate. + round_idx: The communication round index. + """ num_samples = [] tot_corrects = [] losses = [] @@ -148,6 +192,13 @@ def local_test_on_training_data(self, model_global, round_idx): logging.info(stats) def local_test_on_test_data(self, model_global, round_idx): + """ + Perform local testing on test data for clients. + + Args: + model_global: The global model to evaluate. + round_idx: The communication round index. + """ num_samples = [] tot_corrects = [] losses = [] @@ -172,6 +223,9 @@ def local_test_on_test_data(self, model_global, round_idx): logging.info(stats) def global_test(self): + """ + Perform global testing using the global dataset and log the results. + """ logging.info("################global_test") acc_train, num_sample, loss_train = self.test_using_global_dataset( self.model_global, self.train_global, self.device @@ -190,6 +244,17 @@ def global_test(self): wandb.log({"Global Testing Accuracy": acc_test}) def test_using_global_dataset(self, model_global, global_test_data, device): + """ + Test the global model using the global test dataset. + + Args: + model_global: The global model to evaluate. + global_test_data: The global test dataset. + device: The computing device (e.g., 'cuda' or 'cpu'). + + Returns: + Tuple of testing accuracy, total samples, and testing loss. + """ model_global.eval() model_global.to(device) test_loss = test_acc = test_total = 0.0 diff --git a/python/fedml/simulation/sp/turboaggregate/mpc_function.py b/python/fedml/simulation/sp/turboaggregate/mpc_function.py index e2ab80b983..3c4e976761 100644 --- a/python/fedml/simulation/sp/turboaggregate/mpc_function.py +++ b/python/fedml/simulation/sp/turboaggregate/mpc_function.py @@ -2,6 +2,16 @@ def modular_inv(a, p): + """ + Compute the modular inverse of 'a' modulo 'p'. + + Args: + a (int): The number for which the modular inverse is calculated. + p (int): The prime modulo. + + Returns: + int: The modular inverse of 'a' modulo 'p'. + """ x, y, m = 1, 0, p while a > 1: q = a // m @@ -19,6 +29,17 @@ def modular_inv(a, p): def divmod(_num, _den, _p): + """ + Compute the result of `_num` divided by `_den` modulo prime `_p`. + + Args: + _num (int): The numerator. + _den (int): The denominator. + _p (int): The prime modulo. + + Returns: + int: The result of `_num / _den` modulo `_p`. + """ # compute num / den modulo prime p _num = np.mod(_num, _p) _den = np.mod(_den, _p) @@ -28,6 +49,16 @@ def divmod(_num, _den, _p): def PI(vals, p): # upper-case PI -- product of inputs + """ + Compute the product of a list of values modulo prime 'p'. + + Args: + vals (list): List of integers to be multiplied. + p (int): The prime modulo. + + Returns: + int: The product of the values in 'vals' modulo 'p'. + """ accum = 1 for v in vals: @@ -37,6 +68,18 @@ def PI(vals, p): # upper-case PI -- product of inputs def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): + """ + Generate Lagrange coefficients for polynomial interpolation. + + Args: + alpha_s (list): List of alpha values. + beta_s (list): List of beta values. + p (int): The prime modulo. + is_K1 (int): Flag indicating if it's K1. + + Returns: + numpy.ndarray: A matrix of Lagrange coefficients. + """ if is_K1 == 1: num_alpha = 1 else: @@ -60,6 +103,18 @@ def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): def BGW_encoding(X, N, T, p): + """ + Perform BGW encoding of input data X. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + N (int): The number of evaluation points. + T (int): The number of terms for encoding. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The BGW encoded data of shape (N, m, d). + """ m = len(X) d = len(X[0]) @@ -76,6 +131,16 @@ def BGW_encoding(X, N, T, p): def gen_BGW_lambda_s(alpha_s, p): + """ + Generate BGW lambda values for polynomial interpolation. + + Args: + alpha_s (numpy.ndarray): The alpha values. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The lambda values. + """ lambda_s = np.zeros((1, len(alpha_s)), dtype="int64") for i in range(len(alpha_s)): @@ -87,7 +152,19 @@ def gen_BGW_lambda_s(alpha_s, p): return lambda_s.astype("int64") -def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluation points +def BGW_decoding(f_eval, worker_idx, p): + """ + Decode the output from T+1 evaluation points using BGW decoding. + + Args: + f_eval (numpy.ndarray): The evaluation points of shape (RT, d). + worker_idx (numpy.ndarray): The worker indices of shape (1, RT). + p (int): The prime modulo. + + Returns: + numpy.ndarray: The decoded output of shape (1, d). + """ + # decode the output from T+1 evaluation points # f_eval : [RT X d ] # worker_idx : [ 1 X RT] # output : [ 1 X d ] @@ -109,6 +186,19 @@ def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluatio def LCC_encoding(X, N, K, T, p): + """ + Perform LCC encoding of input data X. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The LCC encoded data of shape (N, m//K, d). + """ m = len(X) d = len(X[0]) # print(m,d,m//K) @@ -135,6 +225,20 @@ def LCC_encoding(X, N, K, T, p): def LCC_encoding_w_Random(X, R_, N, K, T, p): + """ + Perform LCC encoding of input data X with random data R_. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + R_ (numpy.ndarray): Random data of shape (T, m // K, d). + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The LCC encoded data of shape (N, m//K, d). + """ m = len(X) d = len(X[0]) # print(m,d,m//K) @@ -165,6 +269,21 @@ def LCC_encoding_w_Random(X, R_, N, K, T, p): def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): + """ + Perform partial LCC encoding of input data X with random data R_ for specific workers. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + R_ (numpy.ndarray): Random data of shape (T, m // K, d). + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + p (int): The prime modulo. + worker_idx (numpy.ndarray): Worker indices for partial encoding. + + Returns: + numpy.ndarray: The partial LCC encoded data of shape (N_out, m//K, d). + """ m = len(X) d = len(X[0]) # print(m,d,m//K) @@ -193,6 +312,21 @@ def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): + """ + Perform LCC decoding of the given evaluation points and worker indices. + + Args: + f_eval (numpy.ndarray): The evaluation points of shape (RT, d). + f_deg (int): The degree of the polynomial. + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + worker_idx (numpy.ndarray): Worker indices for decoding. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The decoded output of shape (1, d). + """ # RT_LCC = f_deg * (K + T - 1) + 1 n_beta = K # +T @@ -212,6 +346,17 @@ def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): def Gen_Additive_SS(d, n_out, p): + """ + Generate an additive secret sharing matrix. + + Args: + d (int): The dimension of the secret. + n_out (int): The number of output shares. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The additive secret sharing matrix. + """ # x_model should be one dimension temp = np.random.randint(0, p, size=(n_out - 1, d)) @@ -225,6 +370,18 @@ def Gen_Additive_SS(d, n_out, p): def LCC_encoding_with_points(X, alpha_s, beta_s, p): + """ + Perform LCC encoding of input data X using specific alpha and beta points. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + alpha_s (numpy.ndarray): The alpha points. + beta_s (numpy.ndarray): The beta points. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The LCC encoded data of shape (N, d). + """ m, d = np.shape(X) # print alpha_s @@ -247,6 +404,18 @@ def LCC_encoding_with_points(X, alpha_s, beta_s, p): def LCC_decoding_with_points(f_eval, eval_points, target_points, p): + """ + Perform LCC decoding of the given evaluation points and target points. + + Args: + f_eval (numpy.ndarray): The evaluation points of shape (RT, d). + eval_points (numpy.ndarray): The evaluation points for decoding. + target_points (numpy.ndarray): The target points for decoding. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The decoded output of shape (1, d). + """ alpha_s_eval = eval_points beta_s = target_points @@ -261,6 +430,17 @@ def LCC_decoding_with_points(f_eval, eval_points, target_points, p): def my_pk_gen(my_sk, p, g): + """ + Generate a public key using the private key and prime modulo. + + Args: + my_sk (int): The private key. + p (int): The prime modulo. + g (int): An optional generator. + + Returns: + int: The public key. + """ # print 'my_pk_gen option: g=',g if g == 0: return my_sk @@ -269,6 +449,18 @@ def my_pk_gen(my_sk, p, g): def my_key_agreement(my_sk, u_pk, p, g): + """ + Perform key agreement using private key, public key, prime modulo, and an optional generator. + + Args: + my_sk (int): The private key. + u_pk (int): The other party's public key. + p (int): The prime modulo. + g (int): An optional generator. + + Returns: + int: The shared secret key. + """ if g == 0: return np.mod(my_sk * u_pk, p) else: From e5545926ec97fdcef89553f1900483c77e15c4b9 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 22:42:44 +0530 Subject: [PATCH 06/70] same --- python/fedml/simulation/sp/feddyn/client.py | 54 +++++++++++++ .../sp/feddyn/feddyn_trainer copy.py | 80 +++++++++++++++++++ .../simulation/sp/feddyn/feddyn_trainer.py | 80 +++++++++++++++++++ python/fedml/simulation/sp/fednova/client.py | 73 +++++++++++++++++ .../simulation/sp/fednova/comm_helpers.py | 52 +++++++----- python/fedml/simulation/sp/fednova/fednova.py | 19 ++++- .../simulation/sp/fednova/fednova_api.py | 80 ++++++++++++++++++- .../simulation/sp/fednova/fednova_trainer.py | 59 ++++++++++++++ python/fedml/simulation/sp/fedopt/client.py | 45 +++++++++++ .../fedml/simulation/sp/fedopt/fedopt_api.py | 77 ++++++++++++++++++ python/fedml/simulation/sp/fedopt/optrepo.py | 6 +- python/fedml/simulation/sp/fedprox/client.py | 49 +++++++++++- .../simulation/sp/fedprox/fedprox_trainer.py | 65 +++++++++++++++ 13 files changed, 713 insertions(+), 26 deletions(-) diff --git a/python/fedml/simulation/sp/feddyn/client.py b/python/fedml/simulation/sp/feddyn/client.py index 3b37dd73b1..6fa77ef55a 100644 --- a/python/fedml/simulation/sp/feddyn/client.py +++ b/python/fedml/simulation/sp/feddyn/client.py @@ -4,6 +4,15 @@ def model_parameter_vector(model): + """ + Flatten and concatenate the model parameters into a single vector. + + Args: + model (nn.Module): The PyTorch model. + + Returns: + torch.Tensor: The concatenated parameter vector. + """ param = [p.view(-1) for p in model.parameters()] return torch.concat(param, dim=0) @@ -12,6 +21,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client for federated learning. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The local training dataset. + local_test_data (torch.utils.data.DataLoader): The local test dataset. + local_sample_number (int): The number of local samples. + args: The command-line arguments. + device (torch.device): The device (e.g., "cuda" or "cpu") for computation. + model_trainer: The model trainer responsible for training and testing. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -30,6 +51,15 @@ def __init__( def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the client's local dataset. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The new local training dataset. + local_test_data (torch.utils.data.DataLoader): The new local test dataset. + local_sample_number (int): The number of local samples in the new dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -37,9 +67,24 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local samples. + + Returns: + int: The number of local samples. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's model using global weights. + + Args: + w_global: The global model weights. + + Returns: + dict: The updated client's model weights. + """ self.model_trainer.set_model_params(w_global) self.old_grad = self.model_trainer.train(self.local_training_data, self.device, self.args, self.old_grad) weights = self.model_trainer.get_model_params() @@ -47,6 +92,15 @@ def train(self, w_global): return weights def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's model. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset for testing. + + Returns: + dict: Test metrics including correctness, loss, and more. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py b/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py index 60f50d2f3e..d04f6e07bb 100644 --- a/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py +++ b/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py @@ -17,6 +17,15 @@ class FedDynTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the FedDynTrainer. + + Args: + dataset: A tuple containing dataset information. + model: The model to be trained. + device: The device to run the training on (e.g., 'cpu' or 'cuda'). + args: Additional training configuration and hyperparameters. + """ self.device = device self.args = args [ @@ -59,6 +68,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for training. + + Args: + train_data_local_num_dict: A dictionary containing the number of samples for each local training dataset. + train_data_local_dict: A dictionary containing local training datasets. + test_data_local_dict: A dictionary containing local test datasets. + model_trainer: The model trainer instance. + """ logging.info("############setup_clients (START)#############") if self.args.initialize_all_clients: num_initialized_clients = self.args.client_num_in_total @@ -78,6 +96,11 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Train the federated dynamic model using FedDyn. + + This method performs the federated training loop, including client selection, training, aggregation, and testing. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -175,6 +198,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Select a subset of clients for communication in each round. + + Args: + round_idx: The current communication round index. + client_num_in_total: The total number of clients. + client_num_per_round: The number of clients to select in each round. + + Returns: + A list of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -185,6 +219,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset from the test dataset. + + Args: + num_samples: The number of samples to include in the validation dataset. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -192,11 +232,31 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model weights from all clients. + + Args: + w_locals: A list of tuples containing the number of samples and local model weights for each client. + + Returns: + The aggregated global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _test(self, test_data, device, args): + """ + Perform testing on the test dataset. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + + Returns: + A dictionary containing testing metrics (e.g., test accuracy, test loss). + """ model = self.model_trainer.model model.to(device) @@ -222,6 +282,14 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Perform testing on the test dataset and log testing metrics. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -253,6 +321,12 @@ def test(self, test_data, device, args): def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -314,6 +388,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set for all clients. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/feddyn/feddyn_trainer.py b/python/fedml/simulation/sp/feddyn/feddyn_trainer.py index 42b3eaa79c..c25251586f 100644 --- a/python/fedml/simulation/sp/feddyn/feddyn_trainer.py +++ b/python/fedml/simulation/sp/feddyn/feddyn_trainer.py @@ -17,6 +17,15 @@ class FedDynTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the FedDynTrainer. + + Args: + dataset: A tuple containing dataset information. + model: The model to be trained. + device: The device to run the training on (e.g., 'cpu' or 'cuda'). + args: Additional training configuration and hyperparameters. + """ self.device = device self.args = args [ @@ -59,6 +68,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for training. + + Args: + train_data_local_num_dict: A dictionary containing the number of samples for each local training dataset. + train_data_local_dict: A dictionary containing local training datasets. + test_data_local_dict: A dictionary containing local test datasets. + model_trainer: The model trainer instance. + """ logging.info("############setup_clients (START)#############") if self.args.initialize_all_clients: num_initialized_clients = self.args.client_num_in_total @@ -78,6 +96,11 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Train the federated dynamic model using FedDyn. + + This method performs the federated training loop, including client selection, training, aggregation, and testing. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -170,6 +193,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Select a subset of clients for communication in each round. + + Args: + round_idx: The current communication round index. + client_num_in_total: The total number of clients. + client_num_per_round: The number of clients to select in each round. + + Returns: + A list of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -180,6 +214,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset from the test dataset. + + Args: + num_samples: The number of samples to include in the validation dataset. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -187,11 +227,31 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model weights from all clients. + + Args: + w_locals: A list of tuples containing the number of samples and local model weights for each client. + + Returns: + The aggregated global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _test(self, test_data, device, args): + """ + Perform testing on the test dataset. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + + Returns: + A dictionary containing testing metrics (e.g., test accuracy, test loss). + """ model = self.model_trainer.model model.to(device) @@ -217,6 +277,14 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Perform testing on the test dataset and log testing metrics. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -248,6 +316,12 @@ def test(self, test_data, device, args): def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -309,6 +383,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set for all clients. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/fednova/client.py b/python/fedml/simulation/sp/fednova/client.py index 32703acc87..0cbd2b2762 100644 --- a/python/fedml/simulation/sp/fednova/client.py +++ b/python/fedml/simulation/sp/fednova/client.py @@ -16,6 +16,17 @@ def __init__( args, device, ): + """ + Initialize a client instance. + + Args: + client_idx (int): The index of the client. + local_training_data: The local training data for this client. + local_test_data: The local test data for this client. + local_sample_number: The number of samples in the local training data. + args: Command-line arguments. + device: The device (e.g., "cpu" or "cuda") on which to perform computations. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -39,15 +50,42 @@ def __init__( def update_local_dataset( self, client_idx, local_training_data, local_test_data, local_sample_number ): + """ + Update the local datasets for the client. + + Args: + client_idx (int): The index of the client. + local_training_data: The new local training data. + local_test_data: The new local test data. + local_sample_number: The number of samples in the new local training data. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number def get_sample_number(self): + """ + Get the number of samples in the local training data. + + Returns: + int: The number of samples in the local training data. + """ return self.local_sample_number def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): + """ + Calculate the local normalized gradient. + + Args: + opt: The FedNova optimizer. + cur_params: The current parameters of the model. + init_params: The initial parameters of the model. + weight (float): Weight factor for the gradient calculation. + + Returns: + dict: A dictionary containing the local normalized gradients. + """ if weight == 0: weight = opt.ratio grad_dict = {} @@ -59,12 +97,27 @@ def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): return grad_dict def get_local_tau_eff(self, opt): + """ + Calculate the local effective tau. + + Args: + opt: The FedNova optimizer. + + Returns: + float: The local effective tau. + """ if opt.mu != 0: return opt.local_steps * opt.ratio else: return opt.local_normalizing_vec * opt.ratio def reset_fednova_optimizer(self, opt): + """ + Reset the FedNova optimizer state for the client. + + Args: + opt: The FedNova optimizer. + """ opt.local_counter = 0 opt.local_normalizing_vec = 0 opt.local_steps = 0 @@ -77,6 +130,16 @@ def reset_fednova_optimizer(self, opt): param_state["momentum_buffer"].zero_() def train(self, net, ratio): + """ + Train the model on the local training data. + + Args: + net: The neural network model. + ratio: The ratio used in training. + + Returns: + tuple: A tuple containing the loss, gradients, and effective tau. + """ net.train() # train and update init_params = copy.deepcopy(net.state_dict()) @@ -120,6 +183,16 @@ def train(self, net, ratio): return sum(epoch_loss) / len(epoch_loss), norm_grad, tau_eff def local_test(self, model_global, b_use_test_dataset=False): + """ + Evaluate the performance of the global model on the local test or training dataset. + + Args: + model_global: The global model to evaluate. + b_use_test_dataset (bool): Whether to use the local test dataset. If False, uses the local training dataset. + + Returns: + dict: A dictionary containing evaluation metrics, including accuracy, loss, precision, recall, and total samples. + """ model_global.eval() model_global.to(self.device) metrics = { diff --git a/python/fedml/simulation/sp/fednova/comm_helpers.py b/python/fedml/simulation/sp/fednova/comm_helpers.py index 94b2dfa1da..2ff37fcfd2 100644 --- a/python/fedml/simulation/sp/fednova/comm_helpers.py +++ b/python/fedml/simulation/sp/fednova/comm_helpers.py @@ -7,16 +7,21 @@ def flatten_tensors(tensors): """ + Flatten a list of dense tensors into a contiguous 1D buffer. + Reference: https://github.com/facebookresearch/stochastic_gradient_push - Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of - same dense type. + This function takes a list of dense tensors and flattens them into a single + contiguous 1D buffer. It assumes that all input tensors are of the same dense type. + Since inputs are dense, the resulting tensor will be a concatenated 1D buffer. Element-wise operation on this buffer will be equivalent to operating individually. - Arguments: - tensors (Iterable[Tensor]): dense tensors to flatten. + + Args: + tensors (Iterable[Tensor]): The list of dense tensors to flatten. + Returns: - A 1D buffer containing input tensors. + Tensor: A 1D buffer containing the flattened input tensors. """ if len(tensors) == 1: return tensors[0].view(-1).clone() @@ -27,15 +32,19 @@ def flatten_tensors(tensors): def unflatten_tensors(flat, tensors): """ Reference: https://github.com/facebookresearch/stochastic_gradient_push - View a flat buffer using the sizes of tensors. Assume that tensors are of - same dense type, and that flat is given by flatten_dense_tensors. - Arguments: - flat (Tensor): flattened dense tensors to unflatten. - tensors (Iterable[Tensor]): dense tensors whose sizes will be used to - unflatten flat. + Unflatten a flat buffer into a list of tensors using their original sizes. + + This function takes a flat buffer and unflattens it into a list of tensors using + the sizes of the original tensors. It assumes that all input tensors are of the + same dense type and that the flat buffer was generated using `flatten_tensors`. + + Args: + flat (Tensor): The flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): The dense tensors whose sizes will be used to + unflatten the flat buffer. + Returns: - Unflattened dense tensors with sizes same as tensors and values from - flat. + tuple: Unflattened dense tensors with sizes same as `tensors` and values from `flat`. """ outputs = [] offset = 0 @@ -48,13 +57,18 @@ def unflatten_tensors(flat, tensors): def communicate(tensors, communication_op): """ + Communicate a list of tensors using a specified communication operation. + Reference: https://github.com/facebookresearch/stochastic_gradient_push - Communicate a list of tensors. - Arguments: - tensors (Iterable[Tensor]): list of tensors. - communication_op: a method or partial object which takes a tensor as - input and communicates it. It can be a partial object around - something like torch.distributed.all_reduce. + This function takes a list of tensors and communicates them using a specified + communication operation. It assumes that the communication_op can handle the + provided tensors appropriately, such as performing an all-reduce operation. + + Args: + tensors (Iterable[Tensor]): List of tensors to be communicated. + communication_op: A method or partial object which takes a tensor as input + and communicates it. It can be a partial object around something like + `torch.distributed.all_reduce`. """ flat_tensor = flatten_tensors(tensors) communication_op(tensor=flat_tensor) diff --git a/python/fedml/simulation/sp/fednova/fednova.py b/python/fedml/simulation/sp/fednova/fednova.py index 63e4662e60..a1590972c7 100644 --- a/python/fedml/simulation/sp/fednova/fednova.py +++ b/python/fedml/simulation/sp/fednova/fednova.py @@ -94,10 +94,15 @@ def __setstate__(self, state): group.setdefault("nesterov", False) def step(self, closure=None): - """Performs a single optimization step. - Arguments: + """ + Performs a single optimization step. + + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + loss: The loss after the optimization step. """ loss = None @@ -169,6 +174,16 @@ def step(self, closure=None): return loss def average(self, weight=0, tau_eff=0): + """ + Averages accumulated local gradients across clients. + + Args: + weight (float, optional): Weight factor for averaging (default: 0). + tau_eff (float, optional): Effective tau value (default: 0). + + Returns: + None + """ if weight == 0: weight = self.ratio if tau_eff == 0: diff --git a/python/fedml/simulation/sp/fednova/fednova_api.py b/python/fedml/simulation/sp/fednova/fednova_api.py index 543370c9d6..214c057be0 100644 --- a/python/fedml/simulation/sp/fednova/fednova_api.py +++ b/python/fedml/simulation/sp/fednova/fednova_api.py @@ -12,6 +12,15 @@ class FedAvgAPI(object): def __init__(self, args, device, dataset, model): + """ + Initialize the FedAvgAPI. + + Args: + args (object): Arguments object containing configuration settings. + device (object): Device on which to perform computations. + dataset (list): List containing dataset information. + model (object): Machine learning model. + """ self.device = device self.args = args [ @@ -46,6 +55,15 @@ def __init__(self, args, device, dataset, model): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up the clients for federated training. + + Args: + train_data_local_num_dict (dict): Dictionary containing the number of local training samples for each client. + train_data_local_dict (dict): Dictionary containing local training data for each client. + test_data_local_dict (dict): Dictionary containing local test data for each client. + model_trainer (object): Model trainer object. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -60,6 +78,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Perform federated training using the FedAvg algorithm. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() for round_idx in range(self.args.comm_round): @@ -108,6 +129,17 @@ def train(self): self._local_test_on_all_clients(round_idx) def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for federated training. + + Args: + round_idx (int): Index of the communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select for the current round. + + Returns: + list: List of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -118,6 +150,16 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set by randomly sampling from the global test dataset. + + Args: + num_samples (int, optional): Number of samples to include in the validation set. Default is 10,000. + + Note: + This function samples `num_samples` from the global test dataset and stores it as the validation set (`self.val_global`). + + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -125,6 +167,16 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model parameters weighted by the number of samples. + + Args: + w_locals (list): List of tuples, where each tuple contains the number of local samples and local model parameters. + + Returns: + dict: Averaged global model parameters. + + """ training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -143,10 +195,17 @@ def _aggregate(self, w_locals): def _aggregate_noniid_avg(self, w_locals): """ - The old aggregate method will impact the model performance when it comes to Non-IID setting + Aggregate local model parameters using a simple average, suitable for Non-IID settings. + Args: - w_locals: + w_locals (list): List of tuples, where each tuple contains the number of local samples and local model parameters. + Returns: + dict: Averaged global model parameters. + + Note: + In Non-IID settings, where the data distribution among clients is not identical, a simple average of local model parameters may be used for aggregation. This method averages the model parameters across clients for each parameter independently. + """ (_, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -157,6 +216,16 @@ def _aggregate_noniid_avg(self, w_locals): return averaged_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx (int): The current communication round index. + + Note: + This function iterates over all clients and performs testing on both training and test datasets. It then logs the training and test accuracy along with losses. + + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -212,6 +281,13 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set and log the results. + + Args: + round_idx (int): The current communication round index. + + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/fednova/fednova_trainer.py b/python/fedml/simulation/sp/fednova/fednova_trainer.py index bbf72182ef..539a1e47d1 100644 --- a/python/fedml/simulation/sp/fednova/fednova_trainer.py +++ b/python/fedml/simulation/sp/fednova/fednova_trainer.py @@ -10,6 +10,15 @@ class FedNovaTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the FedNovaTrainer. + + Args: + dataset (tuple): A tuple containing dataset information. + model (torch.nn.Module): The global model to be trained. + device (torch.device): The target device for model training. + args (argparse.Namespace): Command-line arguments. + """ self.device = device self.args = args [ @@ -41,6 +50,17 @@ def __init__(self, dataset, model, device, args): def setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict ): + """ + Set up client instances for federated training. + + Args: + train_data_local_num_dict (dict): Dictionary containing local training data sizes. + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + + Returns: + None + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -55,6 +75,17 @@ def setup_clients( logging.info("############setup_clients (END)#############") def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Perform client sampling for federated training. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample in each round. + + Returns: + list: List of client indexes selected for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -71,6 +102,12 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def train(self): + """ + Perform federated training using FedNova optimizer. + + Returns: + None + """ for round_idx in range(self.args.comm_round): logging.info("################Communication round : {}".format(round_idx)) @@ -134,6 +171,18 @@ def train(self): self.local_test_on_all_clients(self.model_global, round_idx) def aggregate(self, params, norm_grads, tau_effs, tau_eff=0): + """ + Aggregate local gradients and update global model parameters. + + Args: + params (dict): Dictionary containing global model parameters. + norm_grads (list of dict): List of dictionaries containing normalized local gradients. + tau_effs (list): List of effective tau values for each client. + tau_eff (float): Effective tau value (optional). + + Returns: + dict: Updated global model parameters. + """ # get tau_eff if tau_eff == 0: tau_eff = sum(tau_effs) @@ -164,6 +213,16 @@ def aggregate(self, params, norm_grads, tau_effs, tau_eff=0): return params def local_test_on_all_clients(self, model_global, round_idx): + """ + Perform local testing on all clients and log results. + + Args: + model (torch.nn.Module): The global model for testing. + round_idx (int): The current communication round index. + + Returns: + None + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) train_metrics = { "num_samples": [], diff --git a/python/fedml/simulation/sp/fedopt/client.py b/python/fedml/simulation/sp/fedopt/client.py index 993634a74f..856749a9b2 100644 --- a/python/fedml/simulation/sp/fedopt/client.py +++ b/python/fedml/simulation/sp/fedopt/client.py @@ -5,6 +5,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client in the federated learning system. + + Args: + client_idx (int): The unique identifier for this client. + local_training_data (torch.Dataset): The local training dataset for this client. + local_test_data (torch.Dataset): The local test dataset for this client. + local_sample_number (int): The number of samples in the local training dataset. + args: Additional arguments and settings. + device: The device (e.g., CPU or GPU) on which to perform computations. + model_trainer: The model trainer responsible for training and testing. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -16,21 +28,54 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for this client. + + Args: + client_idx (int): The unique identifier for this client. + local_training_data (torch.Dataset): The new local training dataset. + local_test_data (torch.Dataset): The new local test dataset. + local_sample_number (int): The number of samples in the new local training dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number def get_sample_number(self): + """ + Get the number of samples in the local training dataset. + + Returns: + int: The number of samples in the local training dataset. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's local model. + + Args: + w_global: The global model weights. + + Returns: + weights: The updated local model weights. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Perform local testing using either the local test dataset or local training dataset. + + Args: + b_use_test_dataset (bool): If True, use the local test dataset for testing. Otherwise, use the local training dataset. + + Returns: + metrics: The evaluation metrics obtained during testing. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/fedopt/fedopt_api.py b/python/fedml/simulation/sp/fedopt/fedopt_api.py index 8b0dd9e457..29cb6de340 100644 --- a/python/fedml/simulation/sp/fedopt/fedopt_api.py +++ b/python/fedml/simulation/sp/fedopt/fedopt_api.py @@ -12,6 +12,18 @@ class FedOptAPI(object): + """ + Base class for Federated Optimization. + + This class provides the foundation for federated optimization techniques. It sets up clients, + handles client sampling, and manages the global model and optimizer. + + Args: + args (object): Arguments containing configuration options. + device (str): Device (e.g., 'cpu' or 'cuda') to run computations on. + dataset (tuple): A tuple containing dataset information. + model (torch.nn.Module): The global model used for federated optimization. + """ def __init__(self, args, device, dataset, model): self.device = device self.args = args @@ -44,6 +56,14 @@ def __init__(self, args, device, dataset, model): self._setup_clients(train_data_local_num_dict, train_data_local_dict, test_data_local_dict) def _setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict): + """ + Set up client instances for federated optimization. + + Args: + train_data_local_num_dict (dict): A dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indices to their local training datasets. + test_data_local_dict (dict): A dictionary mapping client indices to their local test datasets. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -59,6 +79,17 @@ def _setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_ logging.info("############setup_clients (END)#############") def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a set of clients for a communication round. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): Total number of clients in the system. + client_num_per_round (int): Number of clients to sample for the current round. + + Returns: + List[int]: A list of sampled client indices. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -69,6 +100,15 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set from the global test dataset. + + Args: + num_samples (int): Number of samples to include in the validation set. + + Notes: + This method updates the `val_global` attribute with the generated validation set. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -76,6 +116,11 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample/home/chaoyanghe/zhtang_FedML/python/fedml/simulation/sp/fedopt/__pycache___testset def _instanciate_opt(self): + """ + Initialize the server optimizer. + + This method initializes the server optimizer based on the specified server optimizer type and learning rate. + """ self.opt = OptRepo.name2cls(self.args.server_optimizer)( # self.model_global.parameters(), lr=self.args.server_lr self.model_trainer.model.parameters(), @@ -85,6 +130,11 @@ def _instanciate_opt(self): ) def train(self): + """ + Train the global model using federated optimization. + + This method trains the global model using federated optimization over multiple communication rounds. + """ for round_idx in range(self.args.comm_round): w_global = self.model_trainer.get_model_params() logging.info("################ Communication round : {}".format(round_idx)) @@ -141,6 +191,15 @@ def train(self): self._local_test_on_all_clients(round_idx) def _aggregate(self, w_locals): + """ + Aggregate local model weights to compute global model weights. + + Args: + w_locals (list): A list of tuples containing local sample numbers and local model weights. + + Returns: + dict: A dictionary containing aggregated global model weights. + """ training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -158,6 +217,12 @@ def _aggregate(self, w_locals): return averaged_params def _set_model_global_grads(self, new_state): + """ + Set the gradients of the global model based on the difference between new and current model states. + + Args: + new_state (dict): The new state of the global model. + """ new_model = copy.deepcopy(self.model_trainer.model) new_model.load_state_dict(new_state) with torch.no_grad(): @@ -171,6 +236,12 @@ def _set_model_global_grads(self, new_state): self.model_trainer.set_model_params(new_model_state_dict) def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) train_metrics = {"num_samples": [], "num_correct": [], "losses": []} @@ -231,6 +302,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) if self.val_global is None: diff --git a/python/fedml/simulation/sp/fedopt/optrepo.py b/python/fedml/simulation/sp/fedopt/optrepo.py index 50615227d7..2a82c60e1e 100644 --- a/python/fedml/simulation/sp/fedopt/optrepo.py +++ b/python/fedml/simulation/sp/fedopt/optrepo.py @@ -5,8 +5,12 @@ class OptRepo: - """Collects and provides information about the subclasses of torch.optim.Optimizer.""" + """ + Collects and provides information about the subclasses of torch.optim.Optimizer. + This class allows you to retrieve optimizer classes by name and obtain information about supported optimizers. + """ + repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} @classmethod diff --git a/python/fedml/simulation/sp/fedprox/client.py b/python/fedml/simulation/sp/fedprox/client.py index cc74a9d932..ff669658bd 100644 --- a/python/fedml/simulation/sp/fedprox/client.py +++ b/python/fedml/simulation/sp/fedprox/client.py @@ -1,17 +1,38 @@ class Client: + """ + Represents a federated learning client. + + Args: + client_idx (int): Index of the client. + local_training_data (Dataset): Local training dataset for the client. + local_test_data (Dataset): Local test dataset for the client. + local_sample_number (int): Number of local training samples. + args (argparse.Namespace): Command-line arguments. + device (torch.device): Device for training (e.g., "cpu" or "cuda"). + model_trainer (ModelTrainer): Trainer for the client's model. + """ + def __init__( - self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, + self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer ): self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number - self.args = args self.device = device self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): Index of the client. + local_training_data (Dataset): New local training dataset. + local_test_data (Dataset): New local test dataset. + local_sample_number (int): Number of local training samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -19,15 +40,39 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local training samples. + + Returns: + int: Number of local training samples. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's model using the global model weights. + + Args: + w_global (dict): Global model weights. + + Returns: + dict: Updated client model weights. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Test the client's model on the local test dataset. + + Args: + b_use_test_dataset (bool): Flag to indicate whether to use the test dataset. + + Returns: + dict: Evaluation metrics. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/fedprox/fedprox_trainer.py b/python/fedml/simulation/sp/fedprox/fedprox_trainer.py index df333b1f7a..0473e70ec7 100644 --- a/python/fedml/simulation/sp/fedprox/fedprox_trainer.py +++ b/python/fedml/simulation/sp/fedprox/fedprox_trainer.py @@ -15,6 +15,16 @@ class FedProxTrainer(object): + """ + Federated Proximal Trainer for a federated learning model. + + Args: + dataset (list): A list containing various dataset components. + model (nn.Module): The federated learning model. + device (torch.device): Device for training (e.g., "cpu" or "cuda"). + args (argparse.Namespace): Command-line arguments. + """ + def __init__(self, dataset, model, device, args): self.device = device self.args = args @@ -51,6 +61,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up federated clients. + + Args: + train_data_local_num_dict (dict): Number of local training samples for each client. + train_data_local_dict (dict): Local training datasets for clients. + test_data_local_dict (dict): Local test datasets for clients. + model_trainer (ModelTrainer): Trainer for the client's model. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -66,6 +85,14 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Train the federated model using federated learning. + + This method performs federated learning by aggregating client updates. + + Returns: + None + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -126,6 +153,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for communication in a round. + + Args: + round_idx (int): Index of the communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of sampled client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -136,6 +174,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset from the test dataset. + + Args: + num_samples (int): Number of samples to include in the validation set (default is 10,000). + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -143,11 +187,26 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model weights from multiple clients. + + Args: + w_locals (list): List of local model weights. + + Returns: + dict: Averaged global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients in the federation. + + Args: + round_idx (int): Index of the communication round. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -209,6 +268,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on all clients on validation set. + + Args: + round_idx (int): Index of the communication round. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) From 01381ee8b16dfdf48adc44ce8fabd34621754a46 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 11:29:08 +0530 Subject: [PATCH 07/70] `python\fedml\utils\ `update --- python/fedml/__init__.py | 124 ++++++- python/fedml/launch_cheeath.py | 12 +- python/fedml/launch_cross_device.py | 5 +- python/fedml/launch_cross_silo_hi.py | 12 +- python/fedml/launch_cross_silo_horizontal.py | 12 +- python/fedml/launch_serving.py | 12 +- python/fedml/utils/compression.py | 326 ++++++++++++++++++- python/fedml/utils/context.py | 50 +++ python/fedml/utils/logging.py | 13 +- python/fedml/utils/model_utils.py | 133 +++++++- 10 files changed, 654 insertions(+), 45 deletions(-) diff --git a/python/fedml/__init__.py b/python/fedml/__init__.py index 18e2ec0504..99382964bb 100644 --- a/python/fedml/__init__.py +++ b/python/fedml/__init__.py @@ -32,7 +32,17 @@ def init(args=None, check_env=True, should_init_logs=True): if args is None: args = load_arguments(fedml._global_training_type, fedml._global_comm_backend) - """Initialize FedML Engine.""" + """ + Initialize the FedML Engine. + + Args: + args (argparse.Namespace, optional): Command-line arguments. Defaults to None. + check_env (bool, optional): Whether to check the environment. Defaults to True. + should_init_logs (bool, optional): Whether to initialize logs. Defaults to True. + + Returns: + argparse.Namespace: Updated command-line arguments. + """ if check_env: collect_env(args) @@ -120,6 +130,12 @@ def init(args=None, check_env=True, should_init_logs=True): def print_args(args): + """ + Print the arguments to the log, excluding sensitive paths. + + Args: + args (argparse.Namespace): Command-line arguments. + """ mqtt_config_path = None s3_config_path = None args_copy = args @@ -138,7 +154,9 @@ def print_args(args): def update_client_specific_args(args): """ - data_silo_config is used for reading specific configuration for each client + Update client-specific arguments based on data_silo_config. + + data_silo_config is used for reading specific configuration for each client Example: In fedml_config.yaml, we have the following configuration client_specific_args: data_silo_config: @@ -149,6 +167,9 @@ def update_client_specific_args(args): fedml_config/data_silo_4_config.yaml, ] data_silo_1_config.yaml contains some client client speicifc arguments. + + Args: + args (argparse.Namespace): Command-line arguments. """ if ( hasattr(args, "data_silo_config") @@ -166,7 +187,17 @@ def update_client_specific_args(args): def init_simulation_mpi(args): + from mpi4py import MPI + """ + Initialize MPI-based simulation. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + argparse.Namespace: Updated command-line arguments. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() @@ -183,14 +214,35 @@ def init_simulation_mpi(args): def init_simulation_sp(args): + """ + Initialize single-process simulation. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + argparse.Namespace: Updated command-line arguments. + """ return args def init_simulation_nccl(args): + """ + Initialize NCCL-based simulation. + + Args: + args (argparse.Namespace): Command-line arguments. + """ return def manage_profiling_args(args): + """ + Manage profiling-related arguments and configurations. + + Args: + args (argparse.Namespace): Command-line arguments. + """ if not hasattr(args, "sys_perf_profiling"): args.sys_perf_profiling = True if not hasattr(args, "sys_perf_profiling"): @@ -236,6 +288,12 @@ def manage_profiling_args(args): def manage_cuda_rpc_args(args): + """ + Manage CUDA RPC-related arguments and configurations. + + Args: + args (argparse.Namespace): Command-line arguments. + """ if (not hasattr(args, "enable_cuda_rpc")) or (not args.using_gpu): args.enable_cuda_rpc = False @@ -264,6 +322,12 @@ def manage_cuda_rpc_args(args): def manage_mpi_args(args): + """ + Manage MPI-related arguments and configurations. + + Args: + args (argparse.Namespace): Command-line arguments. + """ if hasattr(args, "backend") and args.backend == "MPI": from mpi4py import MPI @@ -282,6 +346,15 @@ def manage_mpi_args(args): args.comm = None def init_cross_silo_horizontal(args): + """ + Initialize the cross-silo training for the horizontal scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.n_proc_in_silo = 1 args.proc_rank_in_silo = 0 manage_mpi_args(args) @@ -291,6 +364,15 @@ def init_cross_silo_horizontal(args): def init_cross_silo_hierarchical(args): + """ + Initialize the cross-silo training for the hierarchical scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ manage_mpi_args(args) manage_cuda_rpc_args(args) @@ -344,6 +426,15 @@ def init_cross_silo_hierarchical(args): def init_cheetah(args): + """ + Initialize the CheetaH training scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.n_proc_in_silo = 1 args.proc_rank_in_silo = 0 manage_mpi_args(args) @@ -353,6 +444,15 @@ def init_cheetah(args): def init_model_serving(args): + """ + Initialize the model serving scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.n_proc_in_silo = 1 args.proc_rank_in_silo = 0 manage_cuda_rpc_args(args) @@ -361,10 +461,12 @@ def init_model_serving(args): def update_client_id_list(args): - """ - generate args.client_id_list for CLI mode where args.client_id_list is set to None - In MLOps mode, args.client_id_list will be set to real-time client id list selected by UI (not starting from 1) + Generate args.client_id_list for the CLI mode where args.client_id_list is set to None. + In MLOps mode, args.client_id_list will be set to a real-time client id list selected by the UI (not starting from 1). + + Args: + args (argparse.Namespace): Command-line arguments. """ if not hasattr(args, "using_mlops") or (hasattr(args, "using_mlops") and not args.using_mlops): if not hasattr(args, "client_id_list") or args.client_id_list is None or args.client_id_list == "None" or args.client_id_list == "[]": @@ -396,12 +498,24 @@ def update_client_id_list(args): def init_cross_device(args): + """ + Initialize the cross-device training scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.rank = 0 # only server runs on Python package args.role = "server" return args def run_distributed(): + """ + Placeholder function for running distributed training. + """ pass diff --git a/python/fedml/launch_cheeath.py b/python/fedml/launch_cheeath.py index d0c40f8a14..e323bf2d26 100644 --- a/python/fedml/launch_cheeath.py +++ b/python/fedml/launch_cheeath.py @@ -5,7 +5,11 @@ def run_cheetah_server(): - """FedML Cheetah""" + """ + Run the server for the FedML Cheetah platform. + + This function initializes the server, loads data, and starts training using the Cheetah server. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CHEETAH args = fedml.init() @@ -26,7 +30,11 @@ def run_cheetah_server(): def run_cheetah_client(): - """FedML Cheetah""" + """ + Run a client for the FedML Cheetah platform. + + This function initializes a client, loads data, and starts training using the Cheetah client. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CHEETAH args = fedml.init() diff --git a/python/fedml/launch_cross_device.py b/python/fedml/launch_cross_device.py index 23934bcabb..1b613f88a6 100644 --- a/python/fedml/launch_cross_device.py +++ b/python/fedml/launch_cross_device.py @@ -5,8 +5,11 @@ def run_mnn_server(): from .cross_device import ServerMNN + """ + Run the server for the FedML BeeHive platform. - """FedML BeeHive""" + This function initializes the server, loads data, and starts training using the MNN (Multi-device Neural Network) server. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_DEVICE args = fedml.init() diff --git a/python/fedml/launch_cross_silo_hi.py b/python/fedml/launch_cross_silo_hi.py index 140cb1718e..c3ca6499bf 100644 --- a/python/fedml/launch_cross_silo_hi.py +++ b/python/fedml/launch_cross_silo_hi.py @@ -5,7 +5,11 @@ def run_hierarchical_cross_silo_server(): - """FedML Octopus""" + """ + Run the server for the FedML Octopus platform. + + This function initializes the server, loads data, and starts training using the Cross-Silo Octopus server. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO args = fedml.init() @@ -26,7 +30,11 @@ def run_hierarchical_cross_silo_server(): def run_hierarchical_cross_silo_client(): - """FedML Octopus""" + """ + Run a client for the FedML Octopus platform. + + This function initializes a client, loads data, and starts training using the Cross-Silo Octopus client. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO args = fedml.init() diff --git a/python/fedml/launch_cross_silo_horizontal.py b/python/fedml/launch_cross_silo_horizontal.py index aebe72c06c..85484e18ef 100644 --- a/python/fedml/launch_cross_silo_horizontal.py +++ b/python/fedml/launch_cross_silo_horizontal.py @@ -5,7 +5,11 @@ def run_cross_silo_server(): - """FedML Octopus""" + """ + Run the server for the FedML Octopus platform using Cross-Silo training. + + This function initializes the server, loads data, and starts training for the Cross-Silo Octopus platform. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO args = fedml.init() @@ -26,7 +30,11 @@ def run_cross_silo_server(): def run_cross_silo_client(): - """FedML Octopus""" + """ + Run a client for the FedML Octopus platform using Cross-Silo training. + + This function initializes a client, loads data, and starts training for the Cross-Silo Octopus platform. + """ global _global_training_type _global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO diff --git a/python/fedml/launch_serving.py b/python/fedml/launch_serving.py index 2d9c8bf5c4..719ce7f8f9 100644 --- a/python/fedml/launch_serving.py +++ b/python/fedml/launch_serving.py @@ -5,7 +5,11 @@ def run_model_serving_server(): - """FedML Model Serving""" + """ + Run the server for the FedML Model Serving platform. + + This function initializes the server, loads data, and starts serving the model for the Model Serving platform. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_SERVING args = fedml.init() @@ -26,7 +30,11 @@ def run_model_serving_server(): def run_model_serving_client(): - """FedML Model Serving""" + """ + Run a client for the FedML Model Serving platform. + + This function initializes a client, loads data, and starts serving the model for the Model Serving platform. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_SERVING args = fedml.init() diff --git a/python/fedml/utils/compression.py b/python/fedml/utils/compression.py index 8038abfc36..064fabd29e 100644 --- a/python/fedml/utils/compression.py +++ b/python/fedml/utils/compression.py @@ -7,13 +7,38 @@ class NoneCompressor(): + """ + A compressor that does not perform any compression. + + This compressor simply returns the input tensor as-is when compressing and decompressing. + """ def __init__(self): self.name = 'none' def compress(self, tensor): + """ + Compresses the input tensor. + + Args: + tensor: The input tensor to be compressed. + + Returns: + compressed_tensor: The same input tensor. + dtype: The data type of the tensor. + """ return tensor, tensor.dtype def decompress(self, tensor, ctc): + """ + Decompresses the input tensor. + + Args: + tensor: The compressed tensor. + ctc: The data type of the tensor (ignored). + + Returns: + z: The decompressed tensor, which is the same as the input tensor. + """ z = tensor return z @@ -23,6 +48,9 @@ class TopKCompressor(): Sparse Communication for Distributed Gradient Descent, Alham Fikri Aji et al., 2017 """ def __init__(self): + """ + Initialize the TopKCompressor. + """ self.residuals = {} self.sparsities = [] self.zero_conditions = {} @@ -38,9 +66,23 @@ def __init__(self): def _process_data_before_selecting(self, name, data): + """ + Perform data processing before selecting the top-k values. + + Args: + name (str): The name of the data. + data (Tensor): The input data tensor. + """ pass def _process_data_after_residual(self, name, data): + """ + Perform data processing after applying residuals. + + Args: + name (str): The name of the data. + data (Tensor): The input data tensor. + """ if name not in self.zero_conditions: self.zero_conditions[name] = torch.ones(data.numel(), dtype=torch.float32, device=data.device) zero_condition = self.zero_conditions[name] @@ -49,6 +91,9 @@ def _process_data_after_residual(self, name, data): self.zc = zero_condition def clear(self): + """ + Clear the compressor's internal state. + """ self.residuals = {} self.sparsities = [] self.zero_conditions = {} @@ -57,6 +102,20 @@ def clear(self): def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): + """ + Compress the input tensor using top-k selection. + + Args: + tensor (Tensor): The input tensor to be compressed. + name (str): The name of the tensor (optional). + sigma_scale (float): Scaling factor for selecting top-k values (default: 2.5). + ratio (float): Ratio of values to be retained (default: 0.05). + + Returns: + tensor (Tensor): The compressed tensor. + indexes (Tensor): The indexes of the top-k values. + values (Tensor): The top-k values. + """ start = time.time() with torch.no_grad(): # top-k solution @@ -73,14 +132,32 @@ def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): return tensor, indexes, values def decompress(self, tensor, original_tensor_size): + """ + Decompress the input tensor. + + Args: + tensor (Tensor): The compressed tensor. + original_tensor_size: The size of the original tensor (ignored). + + Returns: + tensor (Tensor): The decompressed tensor, which is the same as the input tensor. + """ return tensor def decompress_new(self, tensor, indexes, name=None, shape=None): - ''' - Just decompress, without unflatten. - Remember to do unflatter after decompress - ''' + """ + Decompress the input tensor without unflattening. Remember to do unflatter after decompress + + Args: + tensor (Tensor): The compressed tensor. + indexes (Tensor): The indexes of the top-k values. + name (str): The name of the tensor (optional). + shape (tuple): The shape of the tensor (optional). + + Returns: + decompress_tensor (Tensor): The decompressed tensor, which may need to be unflattened. + """ if shape is None: decompress_tensor = torch.zeros( self.shapes[name], dtype=tensor.dtype, device=tensor.device).view(-1) @@ -97,30 +174,69 @@ def decompress_new(self, tensor, indexes, name=None, shape=None): return decompress_tensor def flatten(self, tensor, name=None): - ''' - flatten a tensor - ''' + """ + Flatten the input tensor. + + Args: + tensor (Tensor): The input tensor to be flattened. + name (str): The name of the tensor (optional). + + Returns: + flattened_tensor (Tensor): The flattened tensor. + """ self.shapes[name] = tensor.shape return tensor.view(-1) def unflatten(self, tensor, name=None, shape=None): - ''' - unflatten a tensor - ''' + """ + Unflatten the input tensor. + + Args: + tensor (Tensor): The input tensor to be unflattened. + name (str): The name of the tensor (optional). + shape (tuple): The desired shape for unflattening (optional). + + Returns: + unflattened_tensor (Tensor): The unflattened tensor. + """ if shape is None: return tensor.view(self.shapes[name]) else: return tensor.view(shape) def update_shapes_dict(self, tensor, name): + """ + Update the shapes dictionary with the shape of the tensor. + + Args: + tensor (Tensor): The input tensor. + name (str): The name of the tensor. + """ self.shapes[name] = tensor.shape def get_residuals(self, name, like_tensor): + """ + Get the residuals for a given tensor name. + + Args: + name (str): The name of the tensor. + like_tensor (Tensor): A tensor with the same shape and device as the residuals. + + Returns: + residuals (Tensor): The residuals tensor. + """ if name not in self.residuals: self.residuals[name] = torch.zeros_like(like_tensor.data) return self.residuals[name] def add_residuals(self, included_indexes, name): + """ + Add residuals to the tensor for specified indexes. + + Args: + included_indexes (Tensor or ndarray): The indexes to include in the residuals. + name (str): The name of the tensor. + """ with torch.no_grad(): residuals = self.residuals[name] if type(included_indexes) is np.ndarray: @@ -138,12 +254,39 @@ def add_residuals(self, included_indexes, name): class EFTopKCompressor(TopKCompressor): """ + EFTopKCompressor extends the TopKCompressor class to provide error-feedback top-k compression. + + Args: + None + + Attributes: + name (str): The name of the compressor. + + Methods: + __init__(): Initializes the EFTopKCompressor instance. + compress(tensor, name=None, sigma_scale=2.5, ratio=0.05): Compresses the input tensor using error-feedback top-k compression. + _process_data_before_selecting(name, data): Helper method to process data before selecting top-k values. """ def __init__(self): + """ + Initializes a new instance of EFTopKCompressor. + """ super().__init__() self.name = 'eftopk' def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): + """ + Compresses the input tensor using error-feedback top-k compression. + + Args: + tensor (torch.Tensor): The input tensor to be compressed. + name (str): The name associated with the compression operation (optional). + sigma_scale (float): The scale factor for sigma used in compression (default: 2.5). + ratio (float): The compression ratio (default: 0.05). + + Returns: + tuple: A tuple containing the compressed tensor, indexes of top-k values, and the top-k values themselves. + """ start = time.time() with torch.no_grad(): if name not in self.residuals: @@ -168,11 +311,38 @@ def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): return tensor, indexes, values def _process_data_before_selecting(self, name, data): + """ + Helper method to process data before selecting top-k values. + + Args: + name (str): The name associated with the compression operation. + data (torch.Tensor): The data tensor to be processed. + """ data.add_(self.residuals[name].data) class QuantizationCompressor(object): + """ + Quantization Compressor. + + This class represents a compressor that performs quantization on tensors. + + Attributes: + name (str): The name of the compressor. + residuals (dict): A dictionary to store residuals. + values (dict): A dictionary to store quantized values. + zc: Not specified in the code. + current_ratio (float): The current quantization ratio. + shapes (dict): A dictionary to store tensor shapes. + + Methods: + get_naive_quantize(x, s, is_biased=False): Calculate quantized values for the input tensor. + compress(tensor, name=None, quantize_level=32, is_biased=True): Compress a tensor. + decompress_new(tensor): Decompress a tensor. + update_shapes_dict(tensor, name): Update the shapes dictionary. + + """ def __init__(self): self.name = 'quant' self.residuals = {} @@ -183,6 +353,17 @@ def __init__(self): self.shapes = {} def get_naive_quantize(self, x, s, is_biased=False): + """ + Calculate quantized values for the input tensor. + + Args: + x: Input tensor. + s: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Quantized tensor. + """ norm = x.norm(p=2) # calculate the quantization value of tensor `x` at level `log_2 s`. level_float = s * x.abs() / norm @@ -191,6 +372,18 @@ def get_naive_quantize(self, x, s, is_biased=False): return torch.sign(x) * norm * previous_level / s def compress(self, tensor, name=None, quantize_level=32, is_biased=True): + """ + Compress a tensor. + + Args: + tensor: Input tensor. + name: Name for the tensor. + quantize_level: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Compressed tensor. + """ if quantize_level != 32: s = 2 ** quantize_level - 1 values = self.get_naive_quantize(tensor, s, is_biased) @@ -199,15 +392,53 @@ def compress(self, tensor, name=None, quantize_level=32, is_biased=True): return values def decompress_new(self, tensor): + """ + Decompress a tensor. + + Args: + tensor: Compressed tensor. + + Returns: + Tensor: Decompressed tensor. + """ return tensor def update_shapes_dict(self, tensor, name): + """ + Update the shapes dictionary with the shape of the given tensor. + + Args: + tensor: Input tensor. + name (str): Name for the tensor. + """ self.shapes[name] = tensor.shape +class QSGDCompressor(object): + """ + QSGD (Quantized Stochastic Gradient Descent) Compressor. + QSGD is a compression technique for gradient updates in distributed training. -class QSGDCompressor(object): + Args: + None + + Attributes: + name (str): The name of the compressor. + residuals (dict): Dictionary to store residuals. + values (dict): Dictionary to store quantized values. + zc: Not specified in the code. + current_ratio (float): Current quantization ratio. + shapes (dict): Dictionary to store tensor shapes. + + Methods: + get_qsgd(x, s, is_biased=False): Calculate quantized values for the input tensor. + qsgd_quantize_numpy(x, s, is_biased=False): Quantize a numpy array. + compress(tensor, name=None, quantize_level=32, is_biased=True): Compress a tensor. + decompress_new(tensor): Decompress a tensor. + update_shapes_dict(tensor, name): Update the shapes dictionary. + + """ def __init__(self): self.name = 'qsgd' self.residuals = {} @@ -218,6 +449,17 @@ def __init__(self): self.shapes = {} def get_qsgd(self, x, s, is_biased=False): + """ + Calculate quantized values for the input tensor. + + Args: + x: Input tensor. + s: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Quantized tensor. + """ norm = x.norm(p=2) # calculate the quantization value of tensor `x` at level `log_2 s`. level_float = s * x.abs() / norm @@ -235,7 +477,17 @@ def get_qsgd(self, x, s, is_biased=False): return scale * torch.sign(x) * norm * new_level / s def qsgd_quantize_numpy(self, x, s, is_biased=False): - """quantize the tensor x in d level on the absolute value coef wise""" + """ + Quantize a numpy array the tensor x in d level on the absolute value coef wise. + + Args: + x: Input numpy array. + s: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + ndarray: Quantized numpy array. + """ norm = np.sqrt(np.sum(np.square(x))) # calculate the quantization value of tensor `x` at level `log_2 s`. level_float = s * np.abs(x) / norm @@ -253,6 +505,18 @@ def qsgd_quantize_numpy(self, x, s, is_biased=False): return scale * np.sign(x) * norm * new_level / s def compress(self, tensor, name=None, quantize_level=32, is_biased=True): + """ + Compress a tensor. + + Args: + tensor: Input tensor. + name: Name for the tensor. + quantize_level: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Compressed tensor. + """ if quantize_level != 32: s = 2 ** quantize_level - 1 values = self.get_qsgd(tensor, s, is_biased) @@ -261,13 +525,26 @@ def compress(self, tensor, name=None, quantize_level=32, is_biased=True): return values def decompress_new(self, tensor): - return tensor + """ + Decompress a tensor. - def update_shapes_dict(self, tensor, name): - self.shapes[name] = tensor.shape + Args: + tensor: Compressed tensor. + Returns: + Tensor: Decompressed tensor. + """ + return tensor + def update_shapes_dict(self, tensor, name): + """ + Update the shapes dictionary. + Args: + tensor: Input tensor. + name: Name for the tensor. + """ + self.shapes[name] = tensor.shape compressors = { @@ -282,11 +559,30 @@ def update_shapes_dict(self, tensor, name): def gen_threshold_from_normal_distribution(p_value, mu, sigma): r"""PPF.""" + """ + Generate threshold from a normal distribution. + + Args: + p_value (float): The p-value. + mu (float): The mean of the distribution. + sigma (float): The standard deviation of the distribution. + + Returns: + left_thres (float): The left threshold value. + right_thres (float): The right threshold value. + """ zvalue = stats.norm.ppf((1-p_value)/2) return mu+zvalue*sigma, mu-zvalue*sigma def test_gaussion_thres(): + """ + Test threshold calculation for a Gaussian distribution. + + This function generates random data from a Gaussian distribution and computes various statistics + including p-value, mean, and standard deviation. It then calculates a threshold and compares it + with the threshold generated from the Gaussian distribution. + """ set_mean = 0.0; set_std = 0.5 d = np.random.normal(set_mean, set_std, 10000) k2, p = stats.normaltest(d) diff --git a/python/fedml/utils/context.py b/python/fedml/utils/context.py index 1303a3fbc1..40a6c25f79 100644 --- a/python/fedml/utils/context.py +++ b/python/fedml/utils/context.py @@ -7,6 +7,21 @@ @contextmanager def raise_MPI_error(): + """ + Context manager to catch and handle MPI-related errors. + + This context manager is used to catch exceptions and errors that may occur + during MPI (Message Passing Interface) operations and handle them gracefully. + + Usage: + ```python + with raise_MPI_error(): + # Code that may raise MPI-related errors + ``` + + Returns: + None + """ import logging logging.debug("Debugging, Enter the MPI catch error") @@ -20,6 +35,21 @@ def raise_MPI_error(): @contextmanager def raise_error_without_process(): + """ + Context manager to catch and handle errors without aborting the MPI process. + + This context manager is used to catch exceptions and errors without aborting + the MPI (Message Passing Interface) process, allowing it to continue running. + + Usage: + ```python + with raise_error_without_process(): + # Code that may raise errors + ``` + + Returns: + None + """ import logging logging.debug("Debugging, Enter the MPI catch error") @@ -32,6 +62,26 @@ def raise_error_without_process(): @contextmanager def get_lock(lock: threading.Lock()): + """ + Context manager to acquire and release a threading lock. + + This context manager is used to acquire and release a threading lock in a controlled + manner. It ensures that the lock is always released, even in the presence of exceptions. + + Args: + lock (threading.Lock): The threading lock to acquire and release. + + Usage: + ```python + my_lock = threading.Lock() + with get_lock(my_lock): + # Code that requires the lock + # The lock is automatically released after the code block + ``` + + Returns: + None + """ lock.acquire() yield if lock.locked(): diff --git a/python/fedml/utils/logging.py b/python/fedml/utils/logging.py index 8aa089d5f9..5fa886b027 100644 --- a/python/fedml/utils/logging.py +++ b/python/fedml/utils/logging.py @@ -1,5 +1,6 @@ import logging +#define log levels log_levels = { "debug": logging.DEBUG, "info": logging.INFO, @@ -12,16 +13,16 @@ class LoggerCreator: @staticmethod def create_logger(name=None, level=logging.INFO, args=None): - """create a logger + """ + Create and configure a logger. Args: - name (str): name of the logger - level: level of logger + name (str): The name of the logger. + level: The logging level for the logger. - Raises: - ValueError is name is None + Returns: + logger: An instance of the logger. """ - if name is None: raise ValueError("name for logger cannot be None") diff --git a/python/fedml/utils/model_utils.py b/python/fedml/utils/model_utils.py index 0c4b58421c..6961473e49 100644 --- a/python/fedml/utils/model_utils.py +++ b/python/fedml/utils/model_utils.py @@ -9,7 +9,13 @@ def get_weights(state): """ - Returns list of weights from state_dict + Returns a list of weights from a state_dict. + + Args: + state (dict or None): A PyTorch state_dict or None. + + Returns: + list or None: A list of tensor weights or None if the state is None. """ if state is not None: return list(state.values()) @@ -18,6 +24,12 @@ def get_weights(state): def clear_optim_buffer(optimizer): + """ + Clears the optimizer's momentum buffers for each parameter. + + Args: + optimizer: A PyTorch optimizer. + """ for group in optimizer.param_groups: for p in group["params"]: param_state = optimizer.state[p] @@ -30,6 +42,13 @@ def clear_optim_buffer(optimizer): def optimizer_to(optim, device): + """ + Moves the optimizer's state and associated tensors to the specified device. + + Args: + optim (torch.optim.Optimizer): A PyTorch optimizer. + device (torch.device): The target device (e.g., 'cuda' or 'cpu'). + """ for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): @@ -45,6 +64,16 @@ def optimizer_to(optim, device): def move_to_cpu(model, optimizer): + """ + Moves a PyTorch model and its associated optimizer to the CPU device. + + Args: + model (torch.nn.Module): The PyTorch model. + optimizer (torch.optim.Optimizer): The optimizer associated with the model. + + Returns: + torch.nn.Module: The model after moving it to the CPU. + """ if str(next(model.parameters()).device) == "cpu": pass else: @@ -56,6 +85,17 @@ def move_to_cpu(model, optimizer): def move_to_gpu(model, optimizer, device): + """ + Moves a PyTorch model and its associated optimizer to the specified GPU device. + + Args: + model (torch.nn.Module): The PyTorch model. + optimizer (torch.optim.Optimizer): The optimizer associated with the model. + device (str or torch.device): The target GPU device, e.g., 'cuda:0'. + + Returns: + torch.nn.Module: The model after moving it to the GPU. + """ if str(next(model.parameters()).device) == "cpu": model = model.to(device) else: @@ -72,9 +112,15 @@ def move_to_gpu(model, optimizer, device): def get_named_data(model, mode="MODEL", use_cuda=True): """ - getting the whole model and getting the gradients can be conducted - by using different methods for reducing the communication. - `model` choices: ['MODEL', 'GRAD', 'MODEL+GRAD'] + Get various components of a PyTorch model based on the specified mode. + + Args: + model (torch.nn.Module): The PyTorch model. + mode (str): Mode for extracting components ('MODEL', 'GRAD', or 'MODEL+GRAD'). + use_cuda (bool): Whether to use CUDA (GPU) for extraction. + + Returns: + dict: A dictionary containing the requested components. """ if mode == "MODEL": own_state = model.cpu().state_dict() @@ -113,6 +159,17 @@ def get_named_data(model, mode="MODEL", use_cuda=True): def get_bn_params(prefix, module, use_cuda=True): + """ + Get batch normalization parameters with the specified prefix. + + Args: + prefix (str): Prefix for parameter names. + module (nn.BatchNorm2d): Batch normalization module. + use_cuda (bool): Whether to use CUDA (GPU) for extraction. + + Returns: + dict: A dictionary containing batch normalization parameters. + """ bn_params = {} if use_cuda: bn_params[f"{prefix}.weight"] = module.weight @@ -130,6 +187,16 @@ def get_bn_params(prefix, module, use_cuda=True): def get_all_bn_params(model, use_cuda=True): + """ + Get all batch normalization parameters from a PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model. + use_cuda (bool): Whether to use CUDA (GPU) for extraction. + + Returns: + dict: A dictionary containing all batch normalization parameters. + """ all_bn_params = {} for module_name, module in model.named_modules(): # print(f"key:{key}, module, {module}") @@ -146,6 +213,12 @@ def get_all_bn_params(model, use_cuda=True): def check_bn_status(bn_module): + """ + Print and log batch normalization parameters and status. + + Args: + bn_module (nn.BatchNorm2d): Batch normalization module. + """ logging.info(f"weight: {bn_module.weight[:10].mean()}") logging.info(f"bias: {bn_module.bias[:10].mean()}") logging.info(f"running_mean: {bn_module.running_mean[:10].mean()}") @@ -159,10 +232,15 @@ def check_bn_status(bn_module): def average_named_params(named_params_list, average_weights_dict_list, inplace=True): """ - This is a weighted average operation. - average_weights_dict_list: includes weights with respect to clients. Same for each param. - inplace: Whether change the first client's model inplace. - Note: This function also can be used to average gradients. + Average named parameters based on a list of parameters and their associated weights. + + Args: + named_params_list (list): List of named parameters to be averaged. + average_weights_dict_list (list): List of weights for each set of named parameters. + inplace (bool): Whether to modify the first set of parameters in-place. + + Returns: + dict: Averaged named parameters. """ # logging.info("################aggregate: %d" % len(named_params_list)) @@ -219,6 +297,15 @@ def average_named_params(named_params_list, average_weights_dict_list, inplace=T def get_average_weight(sample_num_list): + """ + Calculate average weights based on a list of sample numbers. + + Args: + sample_num_list (list): List of sample numbers. + + Returns: + list: List of average weights. + """ # balance_sample_number_list = [] average_weights_dict_list = [] sum = 0 @@ -239,6 +326,16 @@ def get_average_weight(sample_num_list): def check_device(data_src, device=None): + """ + Ensure data is on the specified device. + + Args: + data_src: Data to be moved to the device. + device (str): Device to move the data to (e.g., 'cpu' or 'cuda'). + + Returns: + Data on the specified device. + """ if device is not None: if data_src.device is not device: return data_src.to(device) @@ -252,7 +349,16 @@ def check_device(data_src, device=None): def get_diff_weights(weights1, weights2): - """ Produce a direction from 'weights1' to 'weights2'.""" + """ + Calculate the difference between two sets of weights. + + Args: + weights1: First set of weights. + weights2: Second set of weights. + + Returns: + Difference between the two sets of weights. + """ if isinstance(weights1, list) and isinstance(weights2, list): return [w2 - w1 for (w1, w2) in zip(weights1, weights2)] elif isinstance(weights1, torch.Tensor) and isinstance(weights2, torch.Tensor): @@ -263,7 +369,14 @@ def get_diff_weights(weights1, weights2): def get_name_params_difference(named_parameters1, named_parameters2): """ - return named_parameters2 - named_parameters1 + Calculate the difference between two sets of named parameters. + + Args: + named_parameters1 (dict): First set of named parameters. + named_parameters2 (dict): Second set of named parameters. + + Returns: + dict: Dictionary containing the differences between common named parameters. """ common_names = list(set(named_parameters1.keys()).intersection(set(named_parameters2.keys()))) named_diff_parameters = {} From f6065fe02eefee8cdaec4728308f62345ceea36c Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 13:28:39 +0530 Subject: [PATCH 08/70] udpate `python\fedml\simulation\nccl\base_framework` server.py remaning --- python/fedml/__init__.py | 100 +++++----- .../nccl/base_framework/LocalAggregator.py | 131 ++++++++++++- .../nccl/base_framework/algorithm_api.py | 16 ++ .../simulation/nccl/base_framework/common.py | 183 +++++++++++++++++- .../simulation/nccl/base_framework/params.py | 167 ++++++++++++++-- 5 files changed, 524 insertions(+), 73 deletions(-) diff --git a/python/fedml/__init__.py b/python/fedml/__init__.py index 99382964bb..d74491d737 100644 --- a/python/fedml/__init__.py +++ b/python/fedml/__init__.py @@ -22,6 +22,56 @@ ) from .core.common.ml_engine_backend import MLEngineBackend +from fedml import device +from fedml import data +from fedml import model +from fedml import mlops + +from .arguments import load_arguments + +from .launch_simulation import run_simulation + +from .launch_cross_silo_horizontal import run_cross_silo_server +from .launch_cross_silo_horizontal import run_cross_silo_client + +from .launch_cross_silo_hi import run_hierarchical_cross_silo_server +from .launch_cross_silo_hi import run_hierarchical_cross_silo_client + +from .launch_cheeath import run_cheetah_server +from .launch_cheeath import run_cheetah_client + +from .launch_serving import run_model_serving_client +from .launch_serving import run_model_serving_server + +from .launch_cross_device import run_mnn_server + +from .core.common.ml_engine_backend import MLEngineBackend + +from .runner import FedMLRunner + +from fedml import api + +__all__ = [ + "MLEngineBackend", + "device", + "data", + "model", + "mlops", + "FedMLRunner", + "run_simulation", + "run_cross_silo_server", + "run_cross_silo_client", + "run_hierarchical_cross_silo_server", + "run_hierarchical_cross_silo_client", + "run_cheetah_server", + "run_cheetah_client", + "run_model_serving_client", + "run_model_serving_server", + "run_mnn_server", + "api" +] + + _global_training_type = None _global_comm_backend = None @@ -517,53 +567,3 @@ def run_distributed(): Placeholder function for running distributed training. """ pass - - -from fedml import device -from fedml import data -from fedml import model -from fedml import mlops - -from .arguments import load_arguments - -from .launch_simulation import run_simulation - -from .launch_cross_silo_horizontal import run_cross_silo_server -from .launch_cross_silo_horizontal import run_cross_silo_client - -from .launch_cross_silo_hi import run_hierarchical_cross_silo_server -from .launch_cross_silo_hi import run_hierarchical_cross_silo_client - -from .launch_cheeath import run_cheetah_server -from .launch_cheeath import run_cheetah_client - -from .launch_serving import run_model_serving_client -from .launch_serving import run_model_serving_server - -from .launch_cross_device import run_mnn_server - -from .core.common.ml_engine_backend import MLEngineBackend - -from .runner import FedMLRunner - -from fedml import api - -__all__ = [ - "MLEngineBackend", - "device", - "data", - "model", - "mlops", - "FedMLRunner", - "run_simulation", - "run_cross_silo_server", - "run_cross_silo_client", - "run_hierarchical_cross_silo_server", - "run_hierarchical_cross_silo_client", - "run_cheetah_server", - "run_cheetah_client", - "run_model_serving_client", - "run_model_serving_server", - "run_mnn_server", - "api" -] diff --git a/python/fedml/simulation/nccl/base_framework/LocalAggregator.py b/python/fedml/simulation/nccl/base_framework/LocalAggregator.py index e6bb9d7cd5..36057a0b6c 100644 --- a/python/fedml/simulation/nccl/base_framework/LocalAggregator.py +++ b/python/fedml/simulation/nccl/base_framework/LocalAggregator.py @@ -16,11 +16,53 @@ class BaseLocalAggregator(object): """ Used to manage and aggregate results from local trainers (clients). It needs to know all datasets. - device: indicates the device of this local aggregator. + device: indicates the device of this local aggregator + + Args: + args: The command-line arguments for the aggregator. + rank (int): The rank of this local aggregator. + worker_number (int): The total number of workers, including the server and clients. + comm: The communication state. + device: The device where the aggregator is located. + dataset: The dataset used for training and testing. + model: The model used for training. + trainer: The trainer responsible for training the model. + + Attributes: + device: Indicates the device of this local aggregator. + args: The command-line arguments for the aggregator. + trainer: The trainer responsible for training the model. + train_global: The global training dataset. + test_global: The global testing dataset. + val_global: The global validation dataset (if available). + train_data_num_in_total: The total number of training data points across all clients. + test_data_num_in_total: The total number of testing data points across all clients. + train_data_local_num_dict: A dictionary mapping client indices to the number of training data points for each client. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + comm: The communication state. + rank: The rank of this local aggregator. + device_rank: The rank of this local aggregator as a device (GPU). + worker_number: The total number of workers, including the server and clients. + device_number: The total number of devices (GPUs) used for training. + groups: A dictionary of communication groups, where each group is associated with a specific device. + + Methods: + measure_client_runtime(): Measures the runtime of client operations. + simulate_client(server_params, client_index, average_weight): Simulates a client's training process. + add_client_result(localAggregatorToServerParams, client_params): Adds client results to be aggregated and sent to the server. """ # def __init__(self, args, trainer, device, dataset, comm=None, rank=0, size=0, backend="NCCL"): def __init__(self, args, rank, worker_number, comm, device, dataset, model, trainer): + """ + Measure the runtime of client operations. + + This method measures the runtime of client operations and can be used for performance analysis. + + Returns: + None + """ self.device = device self.args = args self.trainer = trainer @@ -64,9 +106,29 @@ def __init__(self, args, rank, worker_number, comm, device, dataset, model, trai logging.info("self.trainer = {}".format(self.trainer)) def measure_client_runtime(self): + """ + Measure the runtime of client operations. + + This method measures the runtime of client operations and can be used for performance analysis. + + Returns: + None + """ pass def simulate_client(self, server_params, client_index, average_weight): + """ + Simulate a client's training process. + + Args: + server_params: Parameters received from the server. + client_index (int): The index of the simulated client. + average_weight: The average weight used in the simulation. + + Returns: + client_params: Parameters to be sent back to the server. + """ + # server_model_parameters = server_params.get("model_params") # self.trainer.set_model_params(server_model_parameters) self.trainer.id = client_index @@ -83,6 +145,16 @@ def simulate_client(self, server_params, client_index, average_weight): return client_params def add_client_result(self, localAggregatorToServerParams, client_params): + """ + Add client results to be aggregated and sent to the server. + + Args: + localAggregatorToServerParams: Parameters to be sent to the server. + client_params: Parameters received from a client. + + Returns: + None + """ # Add params that needed to be reduces from clients mean_sum_param_names = client_params.get_sum_reduce_param_names() for name in mean_sum_param_names: @@ -96,6 +168,15 @@ def add_client_result(self, localAggregatorToServerParams, client_params): ) def simulate_all_tasks(self, server_params): + """ + Simulate all tasks for this local aggregator. + + Args: + server_params: Parameters received from the server. + + Returns: + localAggregatorToServerParams: Parameters to be sent back to the server. + """ average_weight_dict = self.decode_average_weight_dict(server_params) client_indexes = server_params.get(f"client_schedule{self.device_rank}").numpy() simulated_client_indexes = [] @@ -124,7 +205,16 @@ def simulate_all_tasks(self, server_params): def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, server_params): """ - This is used for receiving server schedule client indexes. + Receive server's schedule of client indexes for this local aggregator. + + Args: + round_idx: The current round index. + client_num_in_total: The total number of clients. + client_num_per_round: The number of clients to be scheduled for this round. + server_params: Parameters received from the server. + + Returns: + None """ # scheduler(workloads, constraints, memory) for i in range(self.device_number): @@ -133,12 +223,31 @@ def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, return None, None def get_average_weight(self, client_indexes): + """ + Get average weight for a list of client indexes. + + Args: + client_indexes: A list of client indexes. + + Returns: + average_weight_dict: A dictionary mapping client indexes to their average weights. + """ average_weight_dict = {} for client_index in client_indexes: average_weight_dict[client_index] = 0.0 return average_weight_dict def encode_average_weight_dict(self, server_params, average_weight_dict): + """ + Encode and add the average weight dictionary to server parameters. + + Args: + server_params: Parameters to be sent to the server. + average_weight_dict: A dictionary mapping client indexes to their average weights. + + Returns: + None + """ server_params.add_broadcast_param( name="average_weight_dict_keys", param=torch.tensor(list(average_weight_dict.keys())) ) @@ -147,6 +256,15 @@ def encode_average_weight_dict(self, server_params, average_weight_dict): ) def decode_average_weight_dict(self, server_params): + """ + Decode the average weight dictionary from server parameters. + + Args: + server_params: Parameters received from the server. + + Returns: + average_weight_dict: A dictionary mapping client indexes to their average weights. + """ average_weight_dict_keys = server_params.get("average_weight_dict_keys").numpy() average_weight_dict_values = server_params.get("average_weight_dict_values").numpy() average_weight_dict = {} @@ -154,6 +272,15 @@ def decode_average_weight_dict(self, server_params): return average_weight_dict def train(self): + """ + Train the federated learning model. + + This method handles the federated learning training process, including communication with the server, + scheduling clients, and aggregating local client results. + + Returns: + None + """ server_params = ServerToClientParams() server_params.add_broadcast_param(name="broadcastTest", param=torch.tensor([0, 0, 0])) server_params.broadcast() diff --git a/python/fedml/simulation/nccl/base_framework/algorithm_api.py b/python/fedml/simulation/nccl/base_framework/algorithm_api.py index e656d220e1..127efb9459 100644 --- a/python/fedml/simulation/nccl/base_framework/algorithm_api.py +++ b/python/fedml/simulation/nccl/base_framework/algorithm_api.py @@ -3,6 +3,22 @@ def FedML_Base_NCCL(args, process_id, worker_number, comm, device, dataset, model, model_trainer=None): + """ + Create an instance of either the BaseServer or BaseLocalAggregator based on the process ID. + + Args: + args: The arguments for configuring the FedML engine. + process_id (int): The ID of the current process. + worker_number (int): The total number of workers in the simulation. + comm: The communication backend (e.g., MPI communicator). + device: The device on which the model should be placed. + dataset: The dataset used for training. + model: The model to be trained. + model_trainer: An optional trainer for the model. + + Returns: + BaseServer or BaseLocalAggregator: An instance of either the server or local aggregator based on the process ID. + """ if process_id == 0: return BaseServer(args, process_id, worker_number, comm, device, dataset, model, model_trainer) diff --git a/python/fedml/simulation/nccl/base_framework/common.py b/python/fedml/simulation/nccl/base_framework/common.py index 7011ebdda4..1b64f8a63a 100644 --- a/python/fedml/simulation/nccl/base_framework/common.py +++ b/python/fedml/simulation/nccl/base_framework/common.py @@ -11,8 +11,14 @@ def get_weights(state): """ - Returns list of weights from state_dict - """ + Returns a list of weights from the state dictionary. + + Args: + state (dict): The state dictionary containing model parameters. + + Returns: + list or None: A list of model weights or None if the state is None. + """" if state is not None: return list(state.values()) else: @@ -20,12 +26,25 @@ def get_weights(state): def set_model_params_with_list(model, new_model_params): + """ + Set the model parameters with a list of new parameters. + + Args: + model: The model whose parameters will be updated. + new_model_params (list): A list of new model parameters. + """ for model_param, model_update_param in zip(model.parameters(), new_model_params): print(f"model_param.shape: {model_param.shape}, model_update_param.shape: {model_update_param.shape}") # model_param.data = model_update_param def clear_optim_buffer(optimizer): + """ + Clear the optimization buffer for momentum. + + Args: + optimizer: The optimizer whose buffer will be cleared. + """ for group in optimizer.param_groups: for p in group["params"]: param_state = optimizer.state[p] @@ -38,6 +57,13 @@ def clear_optim_buffer(optimizer): def optimizer_to(optim, device): + """ + Move optimizer parameters to the specified device. + + Args: + optim: The optimizer whose parameters will be moved. + device (str): The target device (e.g., 'cpu' or 'cuda'). + """ for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): @@ -53,6 +79,16 @@ def optimizer_to(optim, device): def move_to_cpu(model, optimizer): + """ + Move the model and optimizer to the CPU. + + Args: + model: The model to be moved. + optimizer: The optimizer to be moved. + + Returns: + model: The model on the CPU. + """ if str(next(model.parameters()).device) == "cpu": pass else: @@ -64,6 +100,17 @@ def move_to_cpu(model, optimizer): def move_to_gpu(model, optimizer, device): + """ + Move the model and optimizer to the specified GPU device. + + Args: + model: The model to be moved. + optimizer: The optimizer to be moved. + device (str): The target GPU device (e.g., 'cuda'). + + Returns: + model: The model on the specified GPU device. + """ if str(next(model.parameters()).device) == "cpu": model = model.to(device) else: @@ -104,6 +151,16 @@ class CommState: def init_ddp(args): + """ + Initialize Distributed Data Parallel (DDP) for training. + + Args: + args: The arguments containing DDP configuration. + + Returns: + global_rank (int): The global rank of the current process. + world_size (int): The total number of processes in the world. + """ # use InfiniBand os.environ["NCCL_DEBUG"] = "INFO" os.environ["NCCL_SOCKET_IFNAME"] = "lo" @@ -127,6 +184,15 @@ def init_ddp(args): def FedML_NCCL_Similulation_init(args): + """ + Initialize NCCL-based simulation environment. + + Args: + args: The arguments containing simulation configuration. + + Returns: + args (object): The updated arguments. + """ # dist.init_process_group( # init_method='tcp://10.1.1.20:23456', # rank=args.rank, @@ -157,27 +223,74 @@ def FedML_NCCL_Similulation_init(args): def get_rank(): - return dist.get_rank() + """ + Get the rank of the current process in the distributed environment. + Returns: + int: The rank of the current process. + """ + return dist.get_rank() def get_server_rank(): - return CommState.server_rank + """ + Get the rank of the server process in the distributed environment. + Returns: + int: The rank of the server process. + """ + return CommState.server_rank def get_world_size(): - return dist.get_world_size() + """ + Get the total number of processes in the distributed environment. + Returns: + int: The total number of processes. + """ + return dist.get_world_size() def get_worker_number(): + """ + Get the number of worker processes (excluding the server) in the distributed environment. + + Returns: + int: The number of worker processes. + """ return CommState.device_size def new_group(ranks): + """ + Create a new process group with the specified ranks. + + Args: + ranks (list): A list of ranks to include in the new group. + + Returns: + dist.ProcessGroup: The new process group. + """ return dist.new_group(ranks=ranks) # dist.new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None) def fedml_nccl_send_to_server(tensor, src=0, group=None): + """ + Send a tensor from a device (GPU) to the server process. + + Args: + tensor (torch.Tensor): The tensor to send. + src (int): The source rank of the sending process. + group (dist.ProcessGroup, optional): The process group to use for communication. + + Note: + This function is used to send tensors from a device (GPU) to the server during communication. + + Example: + ```python + fedml_nccl_send_to_server(my_tensor, src=1, group=my_group) + ``` + + """ is_cuda = tensor.is_cuda # if not is_cuda: # logging.info("Warning: Tensor is not on GPU!!!") @@ -186,6 +299,22 @@ def fedml_nccl_send_to_server(tensor, src=0, group=None): def fedml_nccl_broadcast(tensor, src): + """ + Broadcast a tensor from the server process to all devices (GPUs). + + Args: + tensor (torch.Tensor): The tensor to broadcast. + src (int): The source rank of the broadcasting process. + + Note: + This function is used to broadcast tensors from the server to all devices during communication. + + Example: + ```python + fedml_nccl_broadcast(my_tensor, src=0) + ``` + + """ is_cuda = tensor.is_cuda # if not is_cuda: # logging.info("Warning: Tensor is not on GPU!!!") @@ -195,7 +324,21 @@ def fedml_nccl_broadcast(tensor, src): def fedml_nccl_reduce(tensor, dst, op: ReduceOp = ReduceOp.SUM): """ - :param op: Currently only supports SUM and MEAN reduction ops + Reduce a tensor across processes with the specified reduction operation. + + Args: + tensor (torch.Tensor): The tensor to reduce. + dst (int): The destination rank for the reduced tensor. + op (ReduceOp): The reduction operation (SUM or MEAN). Currently only supports SUM and MEAN reduction ops + + Note: + This function is used to perform reduction operations (SUM or MEAN) on tensors across processes. + + Example: + ```python + fedml_nccl_reduce(my_tensor, dst=0, op=ReduceOp.SUM) + ``` + """ is_cuda = tensor.is_cuda # if not is_cuda: @@ -216,10 +359,38 @@ def fedml_nccl_reduce(tensor, dst, op: ReduceOp = ReduceOp.SUM): def fedml_nccl_barrier(): + """ + Synchronize all processes in the distributed environment. + + Note: + This function is used to ensure that all processes reach a barrier and synchronize their execution. + + Example: + ```python + fedml_nccl_barrier() + ``` + + """ dist.barrier() def broadcast_model_state(state_dict, src): + """ + Broadcast the model's state dictionary from the server process to all devices (GPUs). + + Args: + state_dict (dict): The model's state dictionary to broadcast. + src (int): The source rank of the broadcasting process. + + Note: + This function is used to broadcast the model's state dictionary from the server to all devices during communication. + + Example: + ```python + broadcast_model_state(my_state_dict, src=0) + ``` + + """ # for name, param in state_dict.items(): # logging.info(f"name:{name}, param.shape: {param.shape}") for param in state_dict.values(): diff --git a/python/fedml/simulation/nccl/base_framework/params.py b/python/fedml/simulation/nccl/base_framework/params.py index 1e1c39777f..9b790a8ebf 100644 --- a/python/fedml/simulation/nccl/base_framework/params.py +++ b/python/fedml/simulation/nccl/base_framework/params.py @@ -10,15 +10,20 @@ class Params(Params): """ - Unified Parameter Object for passing arguments among APIs - from the algorithm frame (e.g., client_trainer.py and server aggregator.py). + Unified Parameter Object for passing arguments among APIs. - Usage:: + This class is used for passing arguments among different parts of the algorithm framework. + You can add parameters and retrieve them using attribute access. + + Example: >> my_params = Params() - >> # add parameter + >> # Add a parameter >> my_params.add(name="w", param=model_weights) - >> # get parameter - >> my_params.w + >> # Get a parameter + >> weight = my_params.w + + Attributes: + _params (dict): A dictionary to store parameter names and values. """ def __init__(self, **kwargs): @@ -27,7 +32,20 @@ def __init__(self, **kwargs): class ServerToClientParams(Params): """ - Normally, ServerToClient only broadcast parameters, hence all devices will receive same data from server. + Parameters sent from server to clients for broadcasting. + + This class represents parameters that are broadcasted from the server to all clients. + It allows adding broadcast parameters and performing the broadcasting operation. + + Example: + >> server_params = ServerToClientParams() + >> # Add a broadcast parameter + >> server_params.add_broadcast_param(name="w", param=model_weights) + >> # Broadcast the added parameters to all clients + >> server_params.broadcast() + + Attributes: + _broadcast_params (list): A list of parameter names to be broadcasted. """ def __init__(self, **kwargs): @@ -36,14 +54,26 @@ def __init__(self, **kwargs): # self._broadcast_params = {} def add_broadcast_param(self, name, param): + """ + Add a parameter to be broadcasted to all clients. + + Args: + name (str): The name of the parameter. + param (torch.Tensor or list of torch.Tensor): The parameter to be broadcasted. + + Returns: + None + """ self.__dict__.update({name: param}) self._broadcast_params.append(name) # self._broadcast_params.update({name: param}) - def broadcast(self): + def broadcast(self): """ - Perform communication of the added parameters. - Note that this is a collective operation and all processes (server and devices) must call this function. + Perform broadcasting of the added parameters to all clients. + + Note: + This is a collective operation, and all processes (server and devices) must call this function. """ for param_name in self._broadcast_params: @@ -56,13 +86,26 @@ def broadcast(self): class LocalAggregatorToServerParams(Params): + """ + Parameters sent from local aggregator to the server for aggregation. + + This class represents parameters that are sent from local aggregators to the server + for aggregation and communication between clients and the server. + + Attributes: + _reduce_params (dict): A dictionary containing lists of parameters to be reduced using different operations. + _gather_params (list): A list of parameter names to be gathered from clients. + client_indexes (list): List of client indexes for which this local aggregator has data. + """" # def __init__(self, client_indexes, rank, group, **kwargs): def __init__(self, client_indexes, **kwargs): """ - client_indexes and group are used to indicate client_indexes that are - simulated by currernt LocalAggregator, - This will be used for gathering data. + Initialize the LocalAggregatorToServerParams object. + + Args: + client_indexes (list): List of client indexes that are simulated by this LocalAggregator. """ + super().__init__(**kwargs) self._reduce_params = dict([(ReduceOp.SUM, []),]) self._gather_params = [] @@ -71,6 +114,17 @@ def __init__(self, client_indexes, **kwargs): # self.group = group def add_reduce_param(self, name, param, op=ReduceOp.SUM): + """ + Add a parameter to be reduced. + + Args: + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be reduced. + op (ReduceOp, optional): The reduction operation (default is ReduceOp.SUM). + + Returns: + None + """ if name in self.__dict__: if isinstance(self.__dict__[name], list): for i, tensor in enumerate(param): @@ -83,8 +137,15 @@ def add_reduce_param(self, name, param, op=ReduceOp.SUM): def add_gather_params(self, client_index, name, param): """ - Server needs to add all gather param of all clients, - Then the collective communication can work. + Add parameters to be gathered from clients. + + Args: + client_index (int): The client index for which the parameter is added. + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be gathered. + + Returns: + None """ # new_name = f"client{client_index}_name" # self.__dict__.update({new_name: param}) @@ -96,6 +157,17 @@ def add_gather_params(self, client_index, name, param): self.__dict__[name][client_index] = param def communicate(self, rank, groups, client_schedule=None): + """ + Perform communication between local aggregator and server. + + Args: + rank (int): The rank of the local aggregator. + groups (dict): Dictionary of communication groups. + client_schedule (list, optional): Schedule of client indexes (default is None). + + Returns: + None + """ for param_name in self._reduce_params[ReduceOp.SUM]: param = getattr(self, param_name) if isinstance(param, list): @@ -128,29 +200,94 @@ def communicate(self, rank, groups, client_schedule=None): class ClientToLocalAggregatorParams(Params): + """ + Parameters sent from a client to a local aggregator for aggregation. + + This class represents parameters that are sent from a client to a local aggregator + for aggregation and communication within a local group. + + Attributes: + client_index (int): The client index. + _reduce_params (dict): A dictionary containing lists of parameters to be reduced using different operations. + _gather_params (list): A list of parameter names to be gathered by the local aggregator. + """ def __init__(self, client_index, **kwargs): + """ + Initialize the ClientToLocalAggregatorParams object. + + Args: + client_index (int): The client index for which the parameters are intended. + """ super().__init__(**kwargs) self.client_index = client_index self._reduce_params = dict([(ReduceOp.MEAN, []), (ReduceOp.SUM, []),]) self._gather_params = [] def add_reduce_param(self, name, param, op=ReduceOp.SUM): + """ + Add a parameter to be reduced. + + Args: + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be reduced. + op (ReduceOp, optional): The reduction operation (default is ReduceOp.SUM). + + Returns: + None + """ self.__dict__.update({name: param}) self._reduce_params[op].append(name) def add_gather_params(self, name, param): + """ + Add parameters to be gathered by the local aggregator. + + Args: + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be gathered. + + Returns: + None + """ self.__dict__.update({name: param}) self._gather_params.append(name) def get_mean_reduce_param_names(self): + """ + Get the names of parameters to be reduced with the MEAN operation. + + Returns: + list: A list of parameter names. + """ return self._reduce_params[ReduceOp.MEAN] def get_sum_reduce_param_names(self): + """ + Get the names of parameters to be reduced with the SUM operation. + + Returns: + list: A list of parameter names. + """ return self._reduce_params[ReduceOp.SUM] def get_gather_param_names(self): + """ + Get the names of parameters to be gathered by the local aggregator. + + Returns: + list: A list of parameter names. + """ return self._gather_params def local_gather(local_gather_params): + """ + Perform local gathering of parameters. + + Args: + local_gather_params (ClientToLocalAggregatorParams): Parameters to be gathered. + + Returns: + None + """ pass From 6d6b355dfae714c8f4f696ea4dc19485a67e91d2 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 18:09:36 +0530 Subject: [PATCH 09/70] `python\fedml\simulation\nccl` python\fedml\simulation\nccl done --- .../simulation/nccl/base_framework/Server.py | 176 ++++++++++++++++++ .../fedml/simulation/nccl/fedavg/FedAvgAPI.py | 20 ++ 2 files changed, 196 insertions(+) diff --git a/python/fedml/simulation/nccl/base_framework/Server.py b/python/fedml/simulation/nccl/base_framework/Server.py index aaae29ad63..82c8fee9c4 100644 --- a/python/fedml/simulation/nccl/base_framework/Server.py +++ b/python/fedml/simulation/nccl/base_framework/Server.py @@ -17,9 +17,57 @@ class BaseServer: Used to manage and aggregate results from local aggregators. We hope users does not need to modify this code. """ + """ + Used to manage and aggregate results from local aggregators. + + Attributes: + device (str): The device associated with this server. + args: Command-line arguments. + trainer: The trainer used for training. + train_global: Global training data. + test_global: Global test data. + val_global: Global validation data. + train_data_num_in_total (int): The total number of training data points. + test_data_num_in_total (int): The total number of test data points. + train_data_local_num_dict: A dictionary containing local training data counts. + train_data_local_dict: A dictionary containing local training data. + test_data_local_dict: A dictionary containing local test data. + comm: Communication object. + rank (int): The rank of this server. + worker_number (int): The total number of workers (devices). + device_number (int): The total number of devices excluding the server. + groups (dict): A dictionary of communication groups. + client_runtime_history (dict): A history of client runtimes. + + Methods: + client_sampling(round_idx, client_num_in_total, client_num_per_round): + Randomly sample clients for communication in a federated round. + + simulate_all_tasks(server_params, client_indexes): + Simulate tasks for all selected clients and create localAggregatorToServerParams. + + workload_estimate(client_indexes, mode="simulate"): + Estimate the workload of clients in a federated round. + + memory_estimate(client_indexes, mode="simulate"): + Estimate the memory usage of clients in a federated round. + """ # def __init__(self, args, trainer, device, dataset, comm=None, rank=0, size=0, backend="NCCL"): def __init__(self, args, rank, worker_number, comm, device, dataset, model, trainer): + """ + Initialize the BaseServer object. + + Args: + args: Command-line arguments. + rank (int): The rank of this server. + worker_number (int): The total number of workers (devices). + comm: Communication object. + device (str): The device associated with this server. + dataset: Dataset information. + model: The model used for federated learning. + trainer: The trainer used for training. + """ self.device = device self.args = args self.trainer = trainer @@ -56,6 +104,17 @@ def __init__(self, args, rank, worker_number, comm, device, dataset, model, trai self.client_runtime_history = {} def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample clients for communication in a federated round. + + Args: + round_idx (int): The index of the federated round. + client_num_in_total (int): The total number of clients in the dataset. + client_num_per_round (int): The number of clients to be sampled in each round. + + Returns: + list: A list of client indexes sampled for communication. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -66,6 +125,16 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def simulate_all_tasks(self, server_params, client_indexes): + """ + Simulate tasks for all selected clients and create localAggregatorToServerParams. + + Args: + server_params: Server parameters. + client_indexes (list): List of client indexes selected for communication. + + Returns: + LocalAggregatorToServerParams: Parameters to be communicated to the local aggregators. + """ localAggregatorToServerParams = LocalAggregatorToServerParams(None) # model_update = [torch.zeros_like(v) for v in get_weights(self.trainer.get_model_params())] # localAggregatorToServerParams.add_reduce_param(name="model_params", @@ -81,6 +150,16 @@ def simulate_all_tasks(self, server_params, client_indexes): return localAggregatorToServerParams def workload_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the workload of clients in a federated round. + + Args: + client_indexes (list): List of client indexes. + mode (str, optional): The mode for workload estimation (default is "simulate"). + + Returns: + list: A list of estimated client workloads. + """ if mode == "simulate": client_samples = [self.train_data_local_num_dict[client_index] for client_index in client_indexes] workload = client_samples @@ -91,6 +170,16 @@ def workload_estimate(self, client_indexes, mode="simulate"): return workload def memory_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the memory usage of clients in a federated round. + + Args: + client_indexes (list): List of client indexes. + mode (str, optional): The mode for memory estimation (default is "simulate"). + + Returns: + np.ndarray: An array representing the estimated memory usage for each client. + """ if mode == "simulate": memory = np.ones(self.device_number) elif mode == "real": @@ -100,6 +189,15 @@ def memory_estimate(self, client_indexes, mode="simulate"): return memory def resource_estimate(self, mode="simulate"): + """ + Estimate the resource usage of clients in a federated round. + + Args: + mode (str, optional): The mode for resource estimation (default is "simulate"). + + Returns: + np.ndarray: An array representing the estimated resource usage for each client. + """ if mode == "simulate": resource = np.ones(self.device_number) elif mode == "real": @@ -109,6 +207,19 @@ def resource_estimate(self, mode="simulate"): return resource def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, server_params, mode="simulate"): + """ + Schedule clients for communication in a federated round. + + Args: + round_idx (int): The index of the federated round. + client_num_in_total (int): The total number of clients in the dataset. + client_num_per_round (int): The number of clients to be scheduled in each round. + server_params: Server parameters. + mode (str, optional): The mode for scheduling (default is "simulate"). + + Returns: + tuple: A tuple containing the selected client indexes and their schedule for communication. + """ # scheduler(workloads, constraints, memory) client_indexes = self.client_sampling(round_idx, client_num_in_total, client_num_per_round) # workload = self.workload_estimate(client_indexes, mode) @@ -129,6 +240,15 @@ def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, return client_indexes, client_schedule def get_average_weight(self, client_indexes): + """ + Calculate the average weight for each client based on their training data size. + + Args: + client_indexes (list): List of client indexes. + + Returns: + dict: A dictionary mapping client indexes to their average weights. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -139,6 +259,13 @@ def get_average_weight(self, client_indexes): return average_weight_dict def encode_average_weight_dict(self, server_params, average_weight_dict): + """ + Encode the average weight dictionary into server parameters. + + Args: + server_params: Server parameters. + average_weight_dict (dict): A dictionary mapping client indexes to their average weights. + """ server_params.add_broadcast_param( name="average_weight_dict_keys", param=torch.tensor(list(average_weight_dict.keys())) ) @@ -147,12 +274,49 @@ def encode_average_weight_dict(self, server_params, average_weight_dict): ) def decode_average_weight_dict(self, server_params): + """ + Decode the average weight dictionary received from the server. + + This method is used to decode the average weight dictionary that was previously encoded and broadcasted + by the server. The average weight dictionary represents the weights assigned to each client based on + their training data size. + + Args: + server_params (ServerToClientParams): The server parameters containing the average weight dictionary. + + Returns: + dict: The decoded average weight dictionary. + """ pass def record_client_runtime(self, client_runtimes): + """ + Record the runtime of each client during a training round. + + This method is used to record the runtime of each client during a training round. The client runtimes are + typically collected and communicated by the local aggregators. + + Args: + client_runtimes (list): A list of client runtimes for each client. + + Returns: + None + """ pass def train(self): + """ + Train the federated learning model using the server-client communication protocol. + + This method implements the federated learning training process by coordinating communication + between the server and clients for multiple rounds of training. + + Args: + None + + Returns: + None + """ server_params = ServerToClientParams() server_params.add_broadcast_param(name="broadcastTest", param=torch.tensor([1, 2, 3])) server_params.broadcast() @@ -198,6 +362,18 @@ def train(self): self.test_on_server_for_all_clients(round) def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients after a certain number of rounds. + + This method tests the federated learning model on both the training and test datasets + for all clients on the server side. + + Args: + round_idx (int): The current round index. + + Returns: + None + """ if self.trainer.test_on_the_server( self.train_data_local_dict, self.test_data_local_dict, self.device, self.args, ): diff --git a/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py b/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py index 094be26ee1..2e459ad72f 100644 --- a/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py +++ b/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py @@ -4,6 +4,26 @@ def FedML_FedAvg_NCCL(args, process_id, worker_number, comm, device, dataset, model, model_trainer=None): + """ + Create a FedAvgServer or FedAvgLocalAggregator object based on the process ID. + + This function is a factory function for creating either a FedAvgServer or a FedAvgLocalAggregator object + based on the value of the process ID. If the process ID is 0, it creates a FedAvgServer object; otherwise, + it creates a FedAvgLocalAggregator object. + + Args: + args (object): Arguments for the federated learning setup. + process_id (int): The process ID. + worker_number (int): The total number of worker processes. + comm (object): The communication backend. + device (object): The device on which the model is trained. + dataset (tuple): A tuple containing dataset-related information. + model (object): The machine learning model. + model_trainer (object, optional): The model trainer. If not provided, it will be created. + + Returns: + object: A FedAvgServer or FedAvgLocalAggregator object. + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) if process_id == 0: From 696f2d2b326f8bde01efbe1e01f0abd57d079fb7 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 19:44:57 +0530 Subject: [PATCH 10/70] `python\fedml\simulation\mpi python\fedml\simulation\mpi fedseg spilt_nn --- .../fedml/simulation/mpi/fedseg/FedSegAPI.py | 59 ++++ .../simulation/mpi/fedseg/FedSegAggregator.py | 96 +++++++ .../mpi/fedseg/FedSegClientManager.py | 92 +++++++ .../mpi/fedseg/FedSegServerManager.py | 86 ++++++ .../simulation/mpi/fedseg/FedSegTrainer.py | 78 ++++++ .../simulation/mpi/fedseg/MyModelTrainer.py | 49 ++++ python/fedml/simulation/mpi/fedseg/utils.py | 260 +++++++++++++++++- .../simulation/mpi/split_nn/SplitNNAPI.py | 47 ++++ .../fedml/simulation/mpi/split_nn/client.py | 30 ++ .../simulation/mpi/split_nn/client_manager.py | 86 ++++++ .../fedml/simulation/mpi/split_nn/server.py | 46 +++- .../simulation/mpi/split_nn/server_manager.py | 46 ++++ 12 files changed, 963 insertions(+), 12 deletions(-) diff --git a/python/fedml/simulation/mpi/fedseg/FedSegAPI.py b/python/fedml/simulation/mpi/fedseg/FedSegAPI.py index 4c07213060..fead0eaa4b 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegAPI.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegAPI.py @@ -10,6 +10,12 @@ def FedML_init(): + """ + Initialize the federated learning environment. + + Returns: + tuple: A tuple containing the MPI communicator (`comm`), process ID (`process_id`), and worker number (`worker_number`). + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -29,6 +35,25 @@ def FedML_FedSeg_distributed( args, model_trainer=None, ): + """ + Initialize and run the federated Segmentation training process. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of workers (including the server). + device: The device on which the model is trained. + comm: The MPI communicator. + model: The neural network model. + train_data_num: The number of training data samples. + train_data_local_num_dict: A dictionary containing the number of local training data samples for each worker. + train_data_local_dict: A dictionary containing the local training data for each worker. + test_data_local_dict: A dictionary containing the local testing data for each worker. + args: Additional arguments for the federated learning setup. + model_trainer: The model trainer for training the model (optional). + + Notes: + - If `process_id` is 0, it initializes the server. Otherwise, it initializes a client. + """ if process_id == 0: init_server(args, device, comm, process_id, worker_number, model, model_trainer) @@ -49,6 +74,21 @@ def FedML_FedSeg_distributed( def init_server(args, device, comm, rank, size, model, model_trainer): + """ + Initialize the federated learning server. + + Args: + args: Additional arguments for the server initialization. + device: The device on which the model is trained. + comm: The MPI communicator. + rank (int): The rank of the current process. + size (int): The total number of processes. + model: The neural network model. + model_trainer: The model trainer for training the model (optional). + + Notes: + This function initializes the server for federated Segmentation training. + """ logging.info("Initializing Server") if model_trainer is None: @@ -78,6 +118,25 @@ def init_client( test_data_local_dict, model_trainer, ): + """ + Initialize and run a federated learning client. + + Args: + args: Additional arguments for the client initialization. + device: The device on which the model is trained. + comm: The MPI communicator. + process_id (int): The ID of the current client process. + size (int): The total number of processes. + model: The neural network model. + train_data_num: The number of training data samples. + train_data_local_num_dict: A dictionary containing the number of local training data samples for each client. + train_data_local_dict: A dictionary containing the local training data for each client. + test_data_local_dict: A dictionary containing the local testing data for each client. + model_trainer: The model trainer for training the model (optional). + + Notes: + This function initializes and runs a federated learning client. + """ client_index = process_id - 1 logging.info("Initializing Client: {0}".format(client_index)) diff --git a/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py b/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py index bb126cae8e..a0e6e5f3d6 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py @@ -8,6 +8,44 @@ class FedSegAggregator(object): + """ + Federated Segmentation Aggregator for collecting and managing model updates and statistics from clients. + + Args: + worker_num (int): Number of worker (client) nodes. + device: The computing device (e.g., GPU) for training. + model: The segmentation model used in federated learning. + args: Additional configuration arguments. + model_trainer: Trainer for the segmentation model. + + Attributes: + trainer: The model trainer for training and evaluation. + worker_num (int): Number of worker (client) nodes. + device: The computing device for training. + args: Additional configuration arguments. + model_dict (dict): Dictionary to store model parameters received from clients. + sample_num_dict (dict): Dictionary to store the number of training samples from clients. + flag_client_model_uploaded_dict (dict): Dictionary to track whether each client has uploaded its model. + train_acc_client_dict (dict): Dictionary to store training accuracy for each client. + train_acc_class_client_dict (dict): Dictionary to store training class-wise accuracy for each client. + train_mIoU_client_dict (dict): Dictionary to store training mean Intersection over Union (mIoU) for each client. + train_FWIoU_client_dict (dict): Dictionary to store training frequency-weighted IoU (FWIoU) for each client. + train_loss_client_dict (dict): Dictionary to store training loss for each client. + test_acc_client_dict (dict): Dictionary to store test accuracy for each client. + test_acc_class_client_dict (dict): Dictionary to store test class-wise accuracy for each client. + test_mIoU_client_dict (dict): Dictionary to store test mean Intersection over Union (mIoU) for each client. + test_FWIoU_client_dict (dict): Dictionary to store test frequency-weighted IoU (FWIoU) for each client. + test_loss_client_dict (dict): Dictionary to store test loss for each client. + best_mIoU (float): Best mIoU value among all clients. + best_mIoU_clients (dict): Dictionary to store the clients with the best mIoU. + saver: Saver for saving experiment configurations and results. + + Methods: + get_global_model_params: Get the global model parameters. + set_global_model_params: Set the global model parameters. + add_local_trained_result: Add model parameters and sample count from a client. + check_whether_all_receive: Check if all clients have uploaded their models. + """ def __init__(self, worker_num, device, model, args, model_trainer): self.trainer = model_trainer self.worker_num = worker_num @@ -43,18 +81,44 @@ def __init__(self, worker_num, device, model, args, model_trainer): ) def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + Global model parameters. + """ return self.trainer.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters: Global model parameters to set. + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the model parameters and sample count from a client. + + Args: + index (int): Index or identifier of the client. + model_params: Model parameters trained by the client. + sample_num (int): Number of training samples used by the client. + """ logging.info("Add model index: {}".format(index)) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all clients have uploaded their models. + + Returns: + True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -63,6 +127,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate model updates from multiple clients. + + Returns: + Averaged model parameters after aggregation. + """ start_time = time.time() model_list = [] training_num = 0 @@ -93,6 +163,17 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select a subset of clients for federated learning. + + Args: + round_idx (int): Current federated learning round index. + client_num_in_total (int): Total number of available clients. + client_num_per_round (int): Number of clients to select for the current round. + + Returns: + List of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -115,6 +196,15 @@ def add_client_test_result( train_eval_metrics: EvaluationMetricsKeeper, test_eval_metrics: EvaluationMetricsKeeper, ): + """ + Add evaluation metrics and results from a client. + + Args: + round_idx (int): Current federated learning round index. + client_idx (int): Index or identifier of the client. + train_eval_metrics (EvaluationMetricsKeeper): Evaluation metrics for training data. + test_eval_metrics (EvaluationMetricsKeeper): Evaluation metrics for testing data. + """ logging.info("Adding client test result : {}".format(client_idx)) # Populating Training Dictionary @@ -176,6 +266,12 @@ def add_client_test_result( self.saver.save_checkpoint(saver_state, is_best, filename) def output_global_acc_and_loss(self, round_idx): + """ + Output global accuracy and loss statistics for the current federated learning round. + + Args: + round_idx (int): Current federated learning round index. + """ logging.info( "################## Output global accuracy and loss for round {} :".format( round_idx diff --git a/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py b/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py index de6ce9f19b..5370c557d4 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py @@ -7,16 +7,65 @@ class FedSegClientManager(FedMLCommManager): + """ + Client manager for federated segmentation. + + This class manages the client-side communication and training in a federated segmentation system. + + Args: + args: Additional configuration arguments. + trainer: Model trainer for federated segmentation. + comm: MPI communicator for distributed communication. + rank (int): Rank of the client. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + + Attributes: + args: Additional configuration arguments. + trainer: Model trainer for federated segmentation. + num_rounds (int): Number of communication rounds. + + Methods: + run(): Start the client manager. + register_message_receive_handlers(): Register message handlers for receiving initialization and model synchronization messages. + handle_message_init(msg_params): Handle the initialization message from the central server. + start_training(): Start the training process. + handle_message_receive_model_from_server(msg_params): Handle received model updates from the central server. + send_model_to_server(receive_id, weights, local_sample_num, train_evaluation_metrics, test_evaluation_metrics): Send trained model updates to the central server. + """ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the FedSegClientManager. + + Args: + args: Additional configuration arguments. + trainer: Model trainer for federated segmentation. + comm: MPI communicator for distributed communication. + rank (int): Rank of the client. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 def run(self): + """ + Start the client manager. + + Notes: + This function starts the client manager to handle communication and training. + """ super().run() def register_message_receive_handlers(self): + """ + Register message handlers for receiving initialization and model synchronization messages. + + Notes: + This function registers message handlers to process incoming messages from the central server. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +75,15 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the central server. + + Args: + msg_params (dict): Parameters included in the received message. + + Notes: + This function processes the initialization message from the central server, updates the model and dataset, and starts training. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) logging.info( @@ -39,10 +97,25 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the training process. + + Notes: + This function initiates the training process on the client side. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle received model updates from the central server. + + Args: + msg_params (dict): Parameters included in the received message. + + Notes: + This function processes received model updates from the central server, updates the model and dataset, and continues training. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -62,6 +135,19 @@ def send_model_to_server( train_evaluation_metrics, test_evaluation_metrics, ): + """ + Send trained model updates to the central server. + + Args: + receive_id (int): Receiver's ID. + weights: Trained model parameters. + local_sample_num (int): Number of local training samples. + train_evaluation_metrics: Evaluation metrics for training. + test_evaluation_metrics: Evaluation metrics for testing. + + Notes: + This function sends the trained model updates and evaluation metrics to the central server. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -78,6 +164,12 @@ def send_model_to_server( self.send_message(message) def __train(self): + """ + Perform training on the client side. + + Notes: + This method initiates the training process on the client side, including testing the global parameters, training the local model, and sending updates to the central server. + """ train_evaluation_metrics = test_evaluation_metrics = None logging.info( diff --git a/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py b/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py index d1382c6d77..677d17f757 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py @@ -7,7 +7,44 @@ class FedSegServerManager(FedMLCommManager): + """ + Server manager for federated segmentation. + + This class manages the server-side communication and aggregation of model updates in a federated segmentation system. + + Args: + args: Additional configuration arguments. + aggregator: Aggregator for federated segmentation models. + comm: MPI communicator for distributed communication. + rank (int): Rank of the server. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + + Attributes: + args: Additional configuration arguments. + aggregator: Aggregator for federated segmentation models. + round_num (int): Number of communication rounds. + + Methods: + run(): Start the server manager. + send_init_msg(): Send initial configuration messages to clients. + register_message_receive_handlers(): Register message handlers for receiving model updates from clients. + handle_message_receive_model_from_client(msg_params): Handle received model updates from clients. + send_message_init_config(receive_id, global_model_params, client_index): Send initial configuration messages to clients. + send_message_sync_model_to_client(receive_id, global_model_params, client_index): Send model synchronization messages to clients. + """ def __init__(self, args, aggregator, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the FedSegServerManager. + + Args: + args: Additional configuration arguments. + aggregator: Aggregator for federated segmentation models. + comm: MPI communicator for distributed communication. + rank (int): Rank of the server. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.args = args self.aggregator = aggregator @@ -16,9 +53,21 @@ def __init__(self, args, aggregator, comm=None, rank=0, size=0, backend="MPI"): logging.info("Initializing Server Manager") def run(self): + """ + Start the server manager. + + Notes: + This function starts the server manager to handle communication and aggregation. + """ super().run() def send_init_msg(self): + """ + Send initial configuration messages to clients. + + Notes: + This function sends initial configuration messages to clients, including global model parameters and client indexes. + """ # sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, @@ -32,12 +81,27 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message handlers for receiving model updates from clients. + + Notes: + This function registers message handlers to process incoming messages from clients. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model updates from clients. + + Args: + msg_params (dict): Parameters included in the received message. + + Notes: + This function processes received model updates from clients, aggregates them, and initiates the next round of communication. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -82,6 +146,17 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initial configuration messages to clients. + + Args: + receive_id (int): Receiver's ID. + global_model_params: Global model parameters. + client_index (int): Index of the client. + + Notes: + This function sends initial configuration messages to clients, including global model parameters and client indexes. + """ logging.info("Initial Configurations sent to client {0}".format(client_index)) message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id @@ -93,6 +168,17 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send model synchronization messages to clients. + + Args: + receive_id (int): Receiver's ID. + global_model_params: Global model parameters. + client_index (int): Index of the client. + + Notes: + This function sends model synchronization messages to clients, updating their models with the global parameters. + """ logging.info( "send_message_sync_model_to_client. receive_id {0}".format(receive_id) ) diff --git a/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py b/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py index f0bda08b47..d9b057f030 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py @@ -2,6 +2,40 @@ class FedSegTrainer(object): + """ + Trainer for federated segmentation models on a client. + + This class manages the training process of a federated segmentation model on a client. + + Args: + client_index (int): The index of the client within the federated system. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + train_data_num (int): Total number of training samples across all clients. + test_data_local_dict (dict): A dictionary containing local test data for each client. + device (torch.device): The device on which to perform training and evaluation. + model (nn.Module): The segmentation model to be trained. + args: Additional configuration arguments. + model_trainer: Trainer for the segmentation model. + + Attributes: + args: Additional configuration arguments. + trainer: Trainer for the segmentation model. + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data for the client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + test_data_local_dict (dict): A dictionary containing local test data for the client. + all_train_data_num (int): Total number of training samples across all clients. + train_local: Local training data for the client. + local_sample_number (int): The number of local training samples for the client. + test_local: Local test data for the client. + + Methods: + update_model(weights): Update the model with the provided weights. + update_dataset(client_index): Update the dataset for the client with the given index. + train(): Perform training on the local dataset and return trained weights and the number of local samples. + test(): Perform testing on the local test dataset and return evaluation metrics. + """ def __init__( self, client_index, @@ -14,6 +48,20 @@ def __init__( args, model_trainer, ): + """ + Initialize the FedSegTrainer for a client. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + train_data_num (int): Total number of training samples across all clients. + test_data_local_dict (dict): A dictionary containing local test data for each client. + device (torch.device): The device on which to perform training and evaluation. + model: The segmentation model to be trained. + args: Additional configuration arguments. + model_trainer: Trainer for the segmentation model. + """ self.args = args self.trainer = model_trainer @@ -30,15 +78,39 @@ def __init__( self.device = device def update_model(self, weights): + """ + Update the model with the provided weights. + + Args: + weights: Model weights to be set. + + Notes: + This function updates the model with the provided weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the dataset for the client with the given index. + + Args: + client_index (int): The index of the client. + + Notes: + This function updates the dataset and client-related attributes for the specified client index. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def train(self): + """ + Perform training on the local dataset and return trained weights and the number of local samples. + + Returns: + tuple: A tuple containing trained model weights and the number of local training samples. + """ self.trainer.train(self.train_local, self.device) weights = self.trainer.get_model_params() @@ -46,6 +118,12 @@ def train(self): return weights, self.local_sample_number def test(self): + """ + Perform testing on the local test dataset and return evaluation metrics. + + Returns: + tuple: A tuple containing evaluation metrics on the local test dataset. + """ train_evaluation_metrics = None if self.args.round_idx and self.args.round_idx % self.args.evaluation_frequency == 0: diff --git a/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py b/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py index a4230e4e29..abec3eb5a7 100644 --- a/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py +++ b/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py @@ -9,7 +9,28 @@ class MyModelTrainer(ClientTrainer): + """ + A custom model trainer for federated learning clients. + + This trainer is designed for training and evaluating a segmentation model in a federated learning setting. + + Attributes: + model (nn.Module): The segmentation model to be trained and evaluated. + args: Additional configuration arguments for training and evaluation. + + Methods: + get_model_params(): Get the model parameters for the current trainer. + set_model_params(model_parameters): Set the model parameters for the current trainer. + train(train_data, device, args): Train the model on the provided training data. + test(test_data, device, args): Evaluate the model on the provided test data. + """ def get_model_params(self): + """ + Get the model parameters for the current trainer. + + Returns: + dict: A dictionary containing the model parameters. + """ if self.args.backbone_freezed: logging.info("Initializing model; Backbone Freezed") return self.model.encoder_decoder.cpu().state_dict() @@ -18,6 +39,12 @@ def get_model_params(self): return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters for the current trainer. + + Args: + model_parameters (dict): A dictionary containing the model parameters to be set. + """ if self.args.backbone_freezed: logging.info("Updating Global model; Backbone Freezed") self.model.encoder_decoder.load_state_dict(model_parameters) @@ -26,6 +53,17 @@ def set_model_params(self, model_parameters): self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the provided training data. + + Args: + train_data (DataLoader): DataLoader containing the training data. + device (torch.device): The device on which to perform training. + args: Additional arguments for training. + + Notes: + This function trains the model using the provided data and updates its parameters. + """ model = self.model args = self.args @@ -100,6 +138,17 @@ def train(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Evaluate the model on the provided test data. + + Args: + test_data (DataLoader): DataLoader containing the test data. + device (torch.device): The device on which to perform evaluation. + args: Additional arguments for evaluation. + + Returns: + EvaluationMetricsKeeper: An object containing various evaluation metrics. + """ logging.info("Evaluation on trainer ID:{}".format(self.id)) model = self.model args = self.args diff --git a/python/fedml/simulation/mpi/fedseg/utils.py b/python/fedml/simulation/mpi/fedseg/utils.py index 74cff9d830..815d4ab674 100644 --- a/python/fedml/simulation/mpi/fedseg/utils.py +++ b/python/fedml/simulation/mpi/fedseg/utils.py @@ -16,6 +16,15 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list of NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary containing model parameters as NumPy arrays. + + Returns: + dict: A dictionary containing model parameters as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -24,27 +33,73 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to a list of NumPy arrays. + + Args: + model_params (dict): A dictionary containing model parameters as PyTorch tensors. + + Returns: + dict: A dictionary containing model parameters as NumPy arrays. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def save_as_pickle_file(path, data): + """ + Save data to a pickle file. + + Args: + path (str): The file path where the data will be saved. + data (any): The data to be saved. + """ with open(path, "wb") as f: pickle.dump(data, f) f.close() def load_from_pickle_file(path): + """ + Load data from a pickle file. + + Args: + path (str): The file path from which the data will be loaded. + + Returns: + any: The loaded data. + """ return pickle.load(open(path, "rb")) def count_parameters(model): + """ + Count the number of trainable parameters in a PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model. + + Returns: + float: The number of trainable parameters in millions (M). + """ params = sum(p.numel() for p in model.parameters() if p.requires_grad) return params / 1000000 def str_to_bool(s): + """ + Convert a string to a boolean value. + + Args: + s (str): The input string. + + Returns: + bool: The boolean value corresponding to the string ("True" or "False"). + + Raises: + ValueError: If the input string is neither "True" nor "False". + """ if s == "True": return True elif s == "False": @@ -54,6 +109,23 @@ def str_to_bool(s): class EvaluationMetricsKeeper: + """ + A class to store and manage evaluation metrics. + + Args: + accuracy (float): Accuracy metric. + accuracy_class (float): Accuracy per class metric. + mIoU (float): Mean Intersection over Union (mIoU) metric. + FWIoU (float): Frequency-Weighted Intersection over Union (FWIoU) metric. + loss (float): Loss metric. + + Attributes: + acc (float): Accuracy metric. + acc_class (float): Accuracy per class metric. + mIoU (float): Mean Intersection over Union (mIoU) metric. + FWIoU (float): Frequency-Weighted Intersection over Union (FWIoU) metric. + loss (float): Loss metric. + """ def __init__(self, accuracy, accuracy_class, mIoU, FWIoU, loss): self.acc = accuracy self.acc_class = accuracy_class @@ -64,13 +136,37 @@ def __init__(self, accuracy, accuracy_class, mIoU, FWIoU, loss): # Segmentation Loss class SegmentationLosses(object): + """ + A class for managing segmentation loss functions. + + Args: + size_average (bool): Whether to compute the size-average loss. + batch_average (bool): Whether to compute the batch-average loss. + ignore_index (int): The index to ignore in the loss computation. + + Attributes: + ignore_index (int): The index to ignore in the loss computation. + size_average (bool): Whether to compute the size-average loss. + batch_average (bool): Whether to compute the batch-average loss. + """ def __init__(self, size_average=True, batch_average=True, ignore_index=255): self.ignore_index = ignore_index self.size_average = size_average self.batch_average = batch_average def build_loss(self, mode="ce"): - """Choices: ['ce' or 'focal']""" + """ + Build a segmentation loss function based on the specified mode. + + Args: + mode (str): The mode of the loss function. Choices: ['ce' or 'focal'] + + Returns: + function: The selected segmentation loss function. + + Raises: + NotImplementedError: If an unsupported mode is specified. + """ if mode == "ce": return self.CrossEntropyLoss elif mode == "focal": @@ -79,6 +175,19 @@ def build_loss(self, mode="ce"): raise NotImplementedError def CrossEntropyLoss(self, logit, target): + """ + Compute the Cross Entropy loss. + + Args: + logit (torch.Tensor): The predicted logit tensor. + target (torch.Tensor): The target tensor. + + Returns: + torch.Tensor: The computed loss. + + Note: + This function uses the specified ignore_index and handles size and batch averaging. + """ n, c, h, w = logit.size() criterion = nn.CrossEntropyLoss( ignore_index=self.ignore_index, size_average=self.size_average @@ -91,6 +200,21 @@ def CrossEntropyLoss(self, logit, target): return loss def FocalLoss(self, logit, target, gamma=2, alpha=0.5): + """ + Compute the Focal loss. + + Args: + logit (torch.Tensor): The predicted logit tensor. + target (torch.Tensor): The target tensor. + gamma (float): The Focal loss gamma parameter. + alpha (float): The Focal loss alpha parameter. + + Returns: + torch.Tensor: The computed loss. + + Note: + This function uses the specified ignore_index and handles size and batch averaging. + """ n, c, h, w = logit.size() criterion = nn.CrossEntropyLoss( ignore_index=self.ignore_index, size_average=self.size_average @@ -109,16 +233,33 @@ def FocalLoss(self, logit, target, gamma=2, alpha=0.5): # LR Scheduler class LR_Scheduler(object): - """Learning Rate Scheduler + """ + Learning Rate Scheduler for adjusting the learning rate during training. + Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` + Args: - args: - :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`,`step`), - :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, - :attr:`args.lr_step` - iters_per_epoch: number of iterations per epoch + mode (str): The mode of the learning rate scheduler. + Choices: ['cos', 'poly', 'step'] + - 'cos': Cosine mode. + - 'poly': Polynomial mode. + - 'step': Step mode. + base_lr (float): The base learning rate. + num_epochs (int): The total number of training epochs. + iters_per_epoch (int): The number of iterations per epoch. + lr_step (int): The step size for the 'step' mode. + warmup_epochs (int): The number of warm-up epochs. + + Attributes: + mode (str): The mode of the learning rate scheduler. + lr (float): The current learning rate. + lr_step (int): The step size for the 'step' mode. + iters_per_epoch (int): The number of iterations per epoch. + N (int): The total number of iterations over all epochs. + epoch (int): The current epoch. + warmup_iters (int): The number of warm-up iterations. """ def __init__( @@ -136,6 +277,14 @@ def __init__( self.warmup_iters = warmup_epochs * iters_per_epoch def __call__(self, optimizer, i, epoch): + """ + Adjusts the learning rate based on the specified mode. + + Args: + optimizer: The optimizer whose learning rate will be adjusted. + i (int): The current iteration. + epoch (int): The current epoch. + """ T = epoch * self.iters_per_epoch + i if self.mode == "cos": lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) @@ -154,6 +303,13 @@ def __call__(self, optimizer, i, epoch): self._adjust_learning_rate(optimizer, lr) def _adjust_learning_rate(self, optimizer, lr): + """ + Adjusts the learning rate of the optimizer. + + Args: + optimizer: The optimizer whose learning rate will be adjusted. + lr (float): The new learning rate. + """ if len(optimizer.param_groups) == 1: optimizer.param_groups[0]["lr"] = lr else: @@ -165,7 +321,25 @@ def _adjust_learning_rate(self, optimizer, lr): # save model checkpoints (centralized) class Saver(object): + """ + Utility class for saving checkpoints and experiment configuration. + + Args: + args (argparse.Namespace): The command-line arguments. + + Attributes: + args (argparse.Namespace): The command-line arguments. + directory (str): The directory where experiments are stored. + runs (list): A list of existing experiment directories. + experiment_dir (str): The directory for the current experiment. + """ def __init__(self, args): + """ + Initializes a new Saver object for saving checkpoints and experiment configuration. + + Args: + args (argparse.Namespace): The command-line arguments. + """ self.args = args self.directory = os.path.join("run", args.dataset, args.model, args.checkname) self.runs = sorted(glob.glob(os.path.join(self.directory, "experiment_*"))) @@ -178,7 +352,14 @@ def __init__(self, args): os.makedirs(self.experiment_dir) def save_checkpoint(self, state, is_best, filename="checkpoint.pth.tar"): - """Saves checkpoint to disk""" + """ + Saves a checkpoint to disk. + + Args: + state (dict): The state to be saved. + is_best (bool): True if this is the best checkpoint, False otherwise. + filename (str, optional): The filename for the checkpoint. Defaults to "checkpoint.pth.tar". + """ filename = os.path.join(self.experiment_dir, filename) torch.save(state, filename) if is_best: @@ -211,6 +392,9 @@ def save_checkpoint(self, state, is_best, filename="checkpoint.pth.tar"): ) def save_experiment_config(self): + """ + Saves the experiment configuration to a text file. + """ logfile = os.path.join(self.experiment_dir, "parameters.txt") log_file = open(logfile, "w") @@ -251,20 +435,54 @@ def save_experiment_config(self): # Evaluation Metrics class Evaluator(object): + """ + Class for evaluating segmentation results. + + Args: + num_class (int): The number of classes in the segmentation task. + + Attributes: + num_class (int): The number of classes in the segmentation task. + confusion_matrix (numpy.ndarray): The confusion matrix for evaluating segmentation results. + """ def __init__(self, num_class): + """ + Initializes an Evaluator object for evaluating segmentation results. + + Args: + num_class (int): The number of classes in the segmentation task. + """ self.num_class = num_class self.confusion_matrix = np.zeros((self.num_class,) * 2) def Pixel_Accuracy(self): + """ + Computes the Pixel Accuracy for segmentation evaluation. + + Returns: + float: The Pixel Accuracy. + """ Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() return Acc def Pixel_Accuracy_Class(self): + """ + Computes the Pixel Accuracy per class for segmentation evaluation. + + Returns: + float: The mean Pixel Accuracy per class. + """ Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) Acc = np.nanmean(Acc) return Acc def Mean_Intersection_over_Union(self): + """ + Computes the Mean Intersection over Union (IoU) for segmentation evaluation. + + Returns: + float: The Mean IoU. + """ MIoU = np.diag(self.confusion_matrix) / ( np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) @@ -274,6 +492,12 @@ def Mean_Intersection_over_Union(self): return MIoU def Frequency_Weighted_Intersection_over_Union(self): + """ + Computes the Frequency Weighted Intersection over Union (IoU) for segmentation evaluation. + + Returns: + float: The Frequency Weighted IoU. + """ freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) iu = np.diag(self.confusion_matrix) / ( np.sum(self.confusion_matrix, axis=1) @@ -285,6 +509,16 @@ def Frequency_Weighted_Intersection_over_Union(self): return FWIoU def _generate_matrix(self, gt_image, pre_image): + """ + Generates a confusion matrix for segmentation evaluation. + + Args: + gt_image (numpy.ndarray): Ground truth segmentation image. + pre_image (numpy.ndarray): Predicted segmentation image. + + Returns: + numpy.ndarray: The confusion matrix. + """ mask = (gt_image >= 0) & (gt_image < self.num_class) label = self.num_class * gt_image[mask].astype("int") + pre_image[mask] count = np.bincount(label, minlength=self.num_class**2) @@ -292,8 +526,18 @@ def _generate_matrix(self, gt_image, pre_image): return confusion_matrix def add_batch(self, gt_image, pre_image): + """ + Adds a batch of ground truth and predicted images for evaluation. + + Args: + gt_image (numpy.ndarray): Batch of ground truth segmentation images. + pre_image (numpy.ndarray): Batch of predicted segmentation images. + """ assert gt_image.shape == pre_image.shape self.confusion_matrix += self._generate_matrix(gt_image, pre_image) def reset(self): + """ + Resets the confusion matrix to zero. + """ self.confusion_matrix = np.zeros((self.num_class,) * 2) diff --git a/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py b/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py index 76ab846d65..26834eee89 100644 --- a/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py +++ b/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py @@ -10,6 +10,21 @@ def SplitNN_distributed( process_id, worker_number, device, comm, model, dataset, args, ): + """ + Initialize and distribute a Split Neural Network for training. + + Args: + process_id (int): The ID of the current process. + worker_number (int): Total number of worker processes. + device: The computing device (e.g., GPU) for training. + comm: Communication backend for distributed training. + model: The neural network model to be trained. + dataset: Dataset information including data splits. + args: Additional training configuration arguments. + + Returns: + None + """ [ train_data_num, local_data_num, @@ -47,6 +62,20 @@ def SplitNN_distributed( def init_server(comm, server_model, process_id, worker_number, device, args): + """ + Initialize and run the server-side component of Split Neural Network training. + + Args: + comm: Communication backend for distributed training. + server_model: The server-side portion of the neural network model. + process_id (int): The ID of the current process. + worker_number (int): Total number of worker processes. + device: The computing device (e.g., GPU) for training. + args: Additional training configuration arguments. + + Returns: + none + """ arg_dict = { "comm": comm, "model": server_model, @@ -63,6 +92,24 @@ def init_server(comm, server_model, process_id, worker_number, device, args): def init_client( comm, client_model, worker_number, train_data_local, test_data_local, process_id, server_rank, epochs, device, args, ): + """ + Initialize and run the client-side component of Split Neural Network training. + + Args: + comm: Communication backend for distributed training. + client_model: The client-side portion of the neural network model. + worker_number (int): Total number of worker processes. + train_data_local: Local training data for the client. + test_data_local: Local testing data for the client. + process_id (int): The ID of the current process. + server_rank (int): The rank of the server process. + epochs: Number of training epochs for the client. + device: The computing device (e.g., GPU) for training. + args: Additional training configuration arguments. + + Returns: + None + """ client_ID = process_id - 1 arg_dict = { "client_index": client_ID, diff --git a/python/fedml/simulation/mpi/split_nn/client.py b/python/fedml/simulation/mpi/split_nn/client.py index 9fc8a816de..9a16af3a22 100644 --- a/python/fedml/simulation/mpi/split_nn/client.py +++ b/python/fedml/simulation/mpi/split_nn/client.py @@ -4,7 +4,19 @@ class SplitNN_client: + """ + SplitNNClient class represents a client in a Split Learning setup. + + Args: + args (dict): Dictionary containing client-specific configuration. + """ def __init__(self, args): + """ + Initialize a SplitNNClient instance. + + Args: + args (dict): Dictionary containing client-specific configuration. + """ self.client_idx = args['client_index'] self.comm = args["comm"] self.model = args["model"] @@ -26,6 +38,12 @@ def __init__(self, args): self.device = args["device"] def forward_pass(self): + """ + Perform a forward pass through the model. + + Returns: + tuple: Tuple containing model activations (outputs) and labels. + """ logging.info("forward_pass") inputs, labels = next(self.dataloader) inputs, labels = inputs.to(self.device), labels.to(self.device) @@ -40,16 +58,28 @@ def forward_pass(self): return self.acts, labels def backward_pass(self, grads): + """ + Perform a backward pass and update model parameters. + + Args: + grads: Gradients used for the backward pass. + """ logging.info("backward_pass") self.acts.backward(grads) self.optimizer.step() def eval_mode(self): + """ + Switch the model to evaluation mode and prepare the test data loader. + """ logging.info("eval_mode") self.dataloader = iter(self.testloader) self.model.eval() def train_mode(self): + """ + Switch the model to training mode and prepare the training data loader. + """ logging.info("train_mode") self.dataloader = iter(self.trainloader) self.model.train() diff --git a/python/fedml/simulation/mpi/split_nn/client_manager.py b/python/fedml/simulation/mpi/split_nn/client_manager.py index 1bb1e84580..73639f3fcb 100644 --- a/python/fedml/simulation/mpi/split_nn/client_manager.py +++ b/python/fedml/simulation/mpi/split_nn/client_manager.py @@ -6,7 +6,27 @@ class SplitNNClientManager(FedMLCommManager): + """ + Manages the client-side operations for Split Learning in a Federated Learning setting. + + Args: + arg_dict (dict): A dictionary containing necessary arguments. + trainer (Trainer): The trainer responsible for the client's model. + backend (str): The communication backend (e.g., "MPI"). + + Attributes: + trainer (Trainer): The trainer responsible for the client's model. + args (args): Arguments for the client manager. + """ def __init__(self, arg_dict, trainer, backend="MPI"): + """ + Initialize a SplitNNClientManager. + + Args: + arg_dict (dict): A dictionary containing necessary arguments. + trainer (Trainer): The trainer responsible for the client's model. + backend (str): The communication backend (e.g., "MPI"). + """ super().__init__( arg_dict["args"], arg_dict["comm"], @@ -19,12 +39,20 @@ def __init__(self, arg_dict, trainer, backend="MPI"): self.args.round_idx = 0 def run(self): + """ + Start the client manager. + + If the trainer's rank is 1, it starts the protocol by running the forward pass. + """ if self.trainer.rank == 1: logging.info("Starting protocol from rank 1 process") self.run_forward_pass() super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2C_SEMAPHORE, self.handle_message_semaphore ) @@ -33,12 +61,23 @@ def register_message_receive_handlers(self): ) def handle_message_semaphore(self, msg_params): + """ + Handle the semaphore message and start the training process. + + Args: + msg_params: Parameters of the semaphore message. + """ # no point in checking the semaphore message logging.info("Starting training at node {}".format(self.trainer.rank)) self.trainer.train_mode() self.run_forward_pass() def run_forward_pass(self): + """ + Run the forward pass of the trainer. + + Sends activations and labels to the server afterward. + """ acts, labels = self.trainer.forward_pass() self.send_activations_and_labels_to_server( acts, labels, self.trainer.SERVER_RANK @@ -46,6 +85,15 @@ def run_forward_pass(self): self.trainer.batch_idx += 1 def run_eval(self): + """ + Run the evaluation process for the client. + + This method sends a validation signal to the server, switches the trainer to evaluation mode, + and performs the forward pass for validation data. After validation, it sends a validation + completion signal to the server and updates the round index. If the maximum number of + epochs per node is reached, it sends a finish signal to the server. + + """ self.send_validation_signal_to_server(self.trainer.SERVER_RANK) self.trainer.eval_mode() for i in range(len(self.trainer.testloader)): @@ -69,6 +117,12 @@ def run_eval(self): self.finish() def handle_message_gradients(self, msg_params): + """ + Handle received gradients and initiate backward pass. + + Args: + msg_params: Parameters of the received gradients message. + """ grads = msg_params.get(MyMessage.MSG_ARG_KEY_GRADS) self.trainer.backward_pass(grads) if self.trainer.batch_idx == len(self.trainer.trainloader): @@ -79,6 +133,14 @@ def handle_message_gradients(self, msg_params): self.run_forward_pass() def send_activations_and_labels_to_server(self, acts, labels, receive_id): + """ + Send activations and labels to the server. + + Args: + acts: Activations to be sent. + labels: Labels corresponding to the activations. + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_ACTS, self.get_sender_id(), receive_id ) @@ -86,24 +148,48 @@ def send_activations_and_labels_to_server(self, acts, labels, receive_id): self.send_message(message) def send_semaphore_to_client(self, receive_id): + """ + Send a semaphore message to a client. + + Args: + receive_id: ID of the receiving client. + """ message = Message( MyMessage.MSG_TYPE_C2C_SEMAPHORE, self.get_sender_id(), receive_id ) self.send_message(message) def send_validation_signal_to_server(self, receive_id): + """ + Send a validation signal message to the server. + + Args: + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_VALIDATION_MODE, self.get_sender_id(), receive_id ) self.send_message(message) def send_validation_over_to_server(self, receive_id): + """ + Send a validation completion signal message to the server. + + Args: + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_VALIDATION_OVER, self.get_sender_id(), receive_id ) self.send_message(message) def send_finish_to_server(self, receive_id): + """ + Send a finish signal message to the server. + + Args: + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_PROTOCOL_FINISHED, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/split_nn/server.py b/python/fedml/simulation/mpi/split_nn/server.py index 1187d16fc4..cf9cff9c17 100644 --- a/python/fedml/simulation/mpi/split_nn/server.py +++ b/python/fedml/simulation/mpi/split_nn/server.py @@ -5,21 +5,39 @@ class SplitNN_server: + """ + SplitNN Server for managing communication and training. + """ + def __init__(self, args): + """ + Initialize the SplitNN Server. + + Args: + args (dict): A dictionary containing configuration arguments. + """ self.comm = args["comm"] self.model = args["model"] self.MAX_RANK = args["max_rank"] self.init_params() def init_params(self): + """ + Initialize training parameters and optimizer. + """ self.epoch = 0 self.log_step = 50 self.active_node = 1 self.train_mode() - self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + self.optimizer = optim.SGD( + self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4 + ) self.criterion = nn.CrossEntropyLoss() def reset_local_params(self): + """ + Reset local training parameters. + """ logging.info("reset_local_params") self.total = 0 self.correct = 0 @@ -28,25 +46,38 @@ def reset_local_params(self): self.batch_idx = 0 def train_mode(self): + """ + Switch to training mode. + """ logging.info("train_mode") self.model.train() self.phase = "train" self.reset_local_params() def eval_mode(self): + """ + Switch to evaluation mode. + """ logging.info("eval_mode") self.model.eval() self.phase = "validation" self.reset_local_params() def forward_pass(self, acts, labels): + """ + Perform a forward pass of the model. + + Args: + acts: Activations. + labels: Ground truth labels. + """ logging.info("forward_pass") self.acts = acts self.optimizer.zero_grad() self.acts.retain_grad() logits = self.model(acts) _, predictions = logits.max(1) - self.loss = self.criterion(logits, labels) # pylint: disable=E1102 + self.loss = self.criterion(logits, labels) self.total += labels.size(0) self.correct += predictions.eq(labels).sum().item() if self.step % self.log_step == 0 and self.phase == "train": @@ -61,18 +92,25 @@ def forward_pass(self, acts, labels): self.step += 1 def backward_pass(self): + """ + Perform a backward pass and update model weights. + """ logging.info("backward_pass") self.loss.backward() self.optimizer.step() return self.acts.grad def validation_over(self): + """ + Handle the end of validation and switch to the next node. + """ logging.info("validation_over") - # not precise estimation of validation loss self.val_loss /= self.step acc = self.correct / self.total logging.info( - "phase={} acc={} loss={} epoch={} and step={}".format(self.phase, acc, self.val_loss, self.epoch, self.step) + "phase={} acc={} loss={} epoch={} and step={}".format( + self.phase, acc, self.val_loss, self.epoch, self.step + ) ) self.epoch += 1 diff --git a/python/fedml/simulation/mpi/split_nn/server_manager.py b/python/fedml/simulation/mpi/split_nn/server_manager.py index cd7aa3ad52..683bb1e50f 100644 --- a/python/fedml/simulation/mpi/split_nn/server_manager.py +++ b/python/fedml/simulation/mpi/split_nn/server_manager.py @@ -4,7 +4,18 @@ class SplitNNServerManager(FedMLCommManager): + """ + Manager for the SplitNN server that handles communication. + """ def __init__(self, arg_dict, trainer, backend="MPI"): + """ + Initialize the SplitNNServerManager. + + Args: + arg_dict (dict): A dictionary containing configuration arguments. + trainer: The trainer instance for the server. + backend (str): The communication backend to use (default is "MPI"). + """ super().__init__( arg_dict["args"], arg_dict["comm"], @@ -19,6 +30,9 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_ACTS, self.handle_message_acts ) @@ -34,6 +48,12 @@ def register_message_receive_handlers(self): ) def send_grads_to_client(self, receive_id, grads): + """ + Handle a message containing activations. + + Args: + msg_params (dict): Parameters of the received message. + """ message = Message( MyMessage.MSG_TYPE_S2C_GRADS, self.get_sender_id(), receive_id ) @@ -41,6 +61,12 @@ def send_grads_to_client(self, receive_id, grads): self.send_message(message) def handle_message_acts(self, msg_params): + """ + Handle a message containing activations. + + Args: + msg_params (dict): Parameters of the received message. + """ acts, labels = msg_params.get(MyMessage.MSG_ARG_KEY_ACTS) self.trainer.forward_pass(acts, labels) if self.trainer.phase == "train": @@ -48,10 +74,30 @@ def handle_message_acts(self, msg_params): self.send_grads_to_client(self.trainer.active_node, grads) def handle_message_validation_mode(self, msg_params): + """ + Handle a message indicating validation mode. + + Args: + msg_params (dict): Parameters of the received message. + """ + self.trainer.eval_mode() def handle_message_validation_over(self, msg_params): + """ + Handle a message indicating the end of validation. + + Args: + msg_params (dict): Parameters of the received message. + """ + self.trainer.validation_over() def handle_message_finish_protocol(self): + """ + Handle a message indicating the protocol has finished. + + Args: + msg_params (dict): Parameters of the received message. + """ self.finish() From c4472b1e1ee081f88278ad2a3230557778807b29 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 21:28:52 +0530 Subject: [PATCH 11/70] python\fedml\simulation\mpi\async_fedavg python\fedml\simulation\mpi\async_fedavg --- .../mpi/async_fedavg/AsyncFedAVGAggregator.py | 138 ++++++++++++++++++ .../mpi/async_fedavg/AsyncFedAVGTrainer.py | 74 ++++++++++ .../async_fedavg/AsyncFedAvgClientManager.py | 108 ++++++++++++-- .../mpi/async_fedavg/AsyncFedAvgSeqAPI.py | 58 ++++++++ .../async_fedavg/AsyncFedAvgServerManager.py | 93 ++++++++++++ .../mpi/async_fedavg/MyModelTrainer.py | 60 ++++++++ .../mpi/async_fedavg/my_model_trainer.py | 60 ++++++++ .../my_model_trainer_classification.py | 60 ++++++++ .../mpi/async_fedavg/my_model_trainer_nwp.py | 60 ++++++++ .../my_model_trainer_tag_prediction.py | 60 ++++++++ .../simulation/mpi/async_fedavg/utils.py | 29 ++++ 11 files changed, 789 insertions(+), 11 deletions(-) diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py index ce87629513..50a14adb20 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py @@ -12,6 +12,67 @@ from ....core.schedule.runtime_estimate import t_sample_fit class AsyncFedAVGAggregator(object): + """ + Aggregator for the asynchronous Federated Averaging server in a federated learning system. + + Args: + train_global: Global training data. + test_global: Global testing data. + all_train_data_num: Total number of training data samples. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local testing data for each client. + train_data_local_num_dict: Dictionary containing the number of local training data samples for each client. + worker_num: Number of worker processes. + device: The computing device (e.g., CPU or GPU). + args: Command-line arguments and configurations. + model_trainer: Trainer for the federated learning model. + + Attributes: + trainer: Trainer for the federated learning model. + args: Command-line arguments and configurations. + train_global: Global training data. + test_global: Global testing data. + val_global: Global validation data generated from the global training data. + all_train_data_num: Total number of training data samples. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local testing data for each client. + train_data_local_num_dict: Dictionary containing the number of local training data samples for each client. + worker_num: Number of worker processes. + device: The computing device (e.g., CPU or GPU). + model_dict: Dictionary containing client models indexed by client ID. + sample_num_dict: Dictionary containing the number of samples trained by each client. + flag_client_model_uploaded_dict: Dictionary tracking whether client models have been uploaded. + runtime_history: Dictionary containing runtime information for clients. + model_weights: Global model weights updated during aggregation. + client_running_status: Array tracking the status of running clients. + + Methods: + get_global_model_params(): + Get the global model parameters. + + set_global_model_params(model_parameters): + Set the global model parameters. + + add_local_trained_result(index, model_params, local_sample_number, + current_round, client_round): + Add the locally trained model results to the aggregator and update the global model. + + client_schedule(round_idx, client_indexes, mode="simulate"): + Generate a schedule for clients based on runtime information. + + get_average_weight(client_indexes): + Calculate the average weight assigned to each client based on the number of training samples. + + client_sampling(round_idx, client_num_in_total, client_num_per_round): + Sample clients for communication in a round. + + _generate_validation_set(num_samples=10000): + Generate a validation set from the global testing data. + + test_on_server_for_all_clients(round_idx): + Perform testing on the server for all clients and log the results. + """ + def __init__( self, train_global, @@ -54,14 +115,42 @@ def __init__( def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ # return self.trainer.get_model_params() return self.model_weights def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters to be set. + + Returns: + None + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, local_sample_number, current_round, client_round): + """ + Add the locally trained model results to the aggregator and update the global model. + + Args: + index (int): Index of the client. + model_params (dict): Model parameters trained by the client. + local_sample_number (int): Number of local training data samples used by the client. + current_round (int): Current communication round. + client_round (int): Round index for the client. + + Returns: + None + """ logging.info("add_model. index = %d" % index) self.client_running_status = np.setdiff1d(self.client_running_status, @@ -76,6 +165,17 @@ def add_local_trained_result(self, index, model_params, local_sample_number, def client_schedule(self, round_idx, client_indexes, mode="simulate"): + """ + Generate a schedule for clients based on runtime information. + + Args: + round_idx (int): Current communication round. + client_indexes (list): List of client indexes. + mode (str): The scheduling mode ("simulate" or "release"). + + Returns: + list: List of client schedules. + """ self.runtime_history = {} for i in range(self.worker_num): self.runtime_history[i] = {} @@ -91,6 +191,15 @@ def client_schedule(self, round_idx, client_indexes, mode="simulate"): def get_average_weight(self, client_indexes): + """ + Calculate the average weight assigned to each client based on the number of training samples. + + Args: + client_indexes (list): List of client indexes. + + Returns: + dict: A dictionary mapping client indexes to their respective average weights. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -102,6 +211,17 @@ def get_average_weight(self, client_indexes): def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample clients for communication in a round. + + Args: + round_idx (int): Current communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of client indexes selected for communication in the current round. + """ num_clients = min(client_num_per_round, client_num_in_total) np.random.seed( round_idx @@ -116,6 +236,15 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set from the global testing data. + + Args: + num_samples (int): Number of samples to include in the validation set. + + Returns: + torch.utils.data.DataLoader: DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -130,6 +259,15 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients and log the results. + + Args: + round_idx (int): Current communication round. + + Returns: + None + """ if self.trainer.test_on_the_server( self.train_data_local_dict, self.test_data_local_dict, diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py index 1265fb298a..589480abe7 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py @@ -2,6 +2,46 @@ class AsyncFedAVGTrainer(object): + """ + An asynchronous Federated Averaging trainer for client nodes in a federated learning system. + + Args: + client_index (int): The index of the client node. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + test_data_local_dict (dict): A dictionary containing local testing data for each client. + train_data_num (int): The total number of training samples across all clients. + device (torch.device): The device (e.g., CPU or GPU) to perform training and testing on. + args (argparse.Namespace): Command-line arguments and configurations for training. + model_trainer (ClientTrainer): An instance of a client-side model trainer. + + Attributes: + trainer (ClientTrainer): The model trainer used for training and testing. + client_index (int): The index of the client node. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + test_data_local_dict (dict): A dictionary containing local testing data for each client. + all_train_data_num (int): The total number of training samples across all clients. + train_local (Dataset): The local training dataset for the current client. + local_sample_number (int): The number of local training samples for the current client. + test_local (Dataset): The local testing dataset for the current client. + device (torch.device): The device used for training and testing. + args (argparse.Namespace): Command-line arguments and configurations for training. + + Methods: + update_model(weights): + Update the model's weights with the provided weights. + + update_dataset(client_index): + Update the local training and testing datasets for the current client. + + train(round_idx=None): + Train the model on the local training dataset. + + test(): + Test the model on both the local training and testing datasets. + + """ def __init__( self, client_index, @@ -28,15 +68,42 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model's weights with the provided weights. + + Args: + weights (dict): The model parameters as a dictionary of tensors. + + Returns: + None + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local training and testing datasets for the current client. + + Args: + client_index (int): The index of the current client. + + Returns: + None + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def train(self, round_idx=None): + """ + Train the model on the local training dataset. + + Args: + round_idx (int, optional): The current round index. Defaults to None. + + Returns: + tuple: A tuple containing the trained model's weights and the number of local training samples. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) @@ -45,6 +112,13 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """ + Test the model on both the local training and testing datasets. + + Returns: + tuple: A tuple containing various metrics, including training accuracy, training loss, the number + of training samples, testing accuracy, testing loss, and the number of testing samples. + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py index 3e8fb9af07..d97c437f21 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py @@ -11,6 +11,45 @@ class AsyncFedAVGClientManager(FedMLCommManager): + """ + Client manager for asynchronous Federated Averaging in a federated learning system. + + Args: + args: Command-line arguments and configurations. + trainer: Trainer for the federated learning model. + comm: Communication backend. + rank: Rank of the client manager. + size: Total number of client managers. + backend: Communication backend type (default: "MPI"). + + Attributes: + trainer: Trainer for the federated learning model. + num_rounds: Total number of communication rounds. + round_idx: Current communication round index. + worker_id: Unique identifier for the client manager. + + Methods: + run(): + Run the client manager. + + register_message_receive_handlers(): + Register message receive handlers for communication. + + handle_message_init(msg_params): + Handle the initialization message from the server. + + start_training(): + Start the training process. + + handle_message_receive_model_from_server(msg_params): + Handle the received model from the server. + + send_result_to_server(receive_id, weights, local_sample_num, client_runtime_info): + Send training results to the server. + + __train(global_model_params, client_index): + Perform model training for a client. + """ def __init__( self, args, @@ -27,9 +66,21 @@ def __init__( self.worker_id = self.rank - 1 def run(self): + """ + Run the communication manager. + + Returns: + None + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -39,6 +90,15 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params (dict): Dictionary of message parameters. + + Returns: + None + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -46,10 +106,25 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_index) def start_training(self): + """ + Start the training process. + + Returns: + None + """ self.round_idx = 0 # self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Dictionary of message parameters. + + Returns: + None + """ logging.info("handle_message_receive_model_from_server.") global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -62,6 +137,18 @@ def handle_message_receive_model_from_server(self, msg_params): def send_result_to_server(self, receive_id, weights, local_sample_num, client_runtime_info): + """ + Send the training result to the server. + + Args: + receive_id (int): ID of the message receiver. + weights (dict): Model parameters. + local_sample_num (int): Number of local training samples. + client_runtime_info (dict): Dictionary of client runtime information. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -74,6 +161,16 @@ def send_result_to_server(self, receive_id, weights, local_sample_num, client_ru def __train(self, global_model_params, client_index): + """ + Perform the training process for a client. + + Args: + global_model_params (dict): Global model parameters. + client_index (int): Index of the client. + + Returns: + None + """ logging.info("#######training########### round_id = %d" % self.round_idx) local_agg_model_params = {} @@ -92,14 +189,3 @@ def __train(self, global_model_params, client_index): # diff_weights = get_name_params_difference(global_model_params, weights) # weights - global_model_params self.send_result_to_server(0, weights, local_sample_num, client_runtime_info) - - - - - - - - - - - diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py index 25c0f9bfd9..bb3424f773 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py @@ -11,6 +11,23 @@ def FedML_Async_distributed( args, process_id, worker_number, comm, device, dataset, model, model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Run the asynchronous federated learning process. + + Args: + args (object): An object containing the configuration parameters. + process_id (int): The unique ID of the current process. + worker_number (int): The total number of worker processes. + comm (object): The communication object. + device (object): The device to run the training on (e.g., GPU). + dataset (list): A list containing dataset-related information. + model (object): The federated learning model. + model_trainer (object, optional): The model trainer object. Defaults to None. + preprocessed_sampling_lists (list, optional): Preprocessed sampling lists for clients. Defaults to None. + + Returns: + None + """ [ train_data_num, test_data_num, @@ -75,6 +92,28 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for asynchronous federated learning. + + Args: + args (object): An object containing the configuration parameters. + device (object): The device to run the training on (e.g., GPU). + comm (object): The communication object. + rank (int): The rank of the current process. + size (int): The total number of processes. + model (object): The federated learning model. + train_data_num (int): The number of training data samples. + train_data_global (object): The global training dataset. + test_data_global (object): The global test dataset. + train_data_local_dict (dict): A dictionary containing local training data for clients. + test_data_local_dict (dict): A dictionary containing local test data for clients. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for clients. + model_trainer (object): The model trainer object. + preprocessed_sampling_lists (list, optional): Preprocessed sampling lists for clients. Defaults to None. + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(-1) @@ -126,6 +165,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for asynchronous federated learning. + + Args: + args (object): An object containing the configuration parameters. + device (object): The device to run the training on (e.g., GPU). + comm (object): The communication object. + process_id (int): The unique ID of the current process. + size (int): The total number of processes. + model (object): The federated learning model. + train_data_num (int): The number of training data samples. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for clients. + train_data_local_dict (dict): A dictionary containing local training data for clients. + test_data_local_dict (dict): A dictionary containing local test data for clients. + model_trainer (object, optional): The model trainer object. Defaults to None. + + Returns: + None + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py index 2a456df3e5..d7f89afdab 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py @@ -8,6 +8,49 @@ class AsyncFedAVGServerManager(FedMLCommManager): + """ + Manager for the asynchronous Federated Averaging server in a federated learning system. + + Args: + args (argparse.Namespace): Command-line arguments and configurations for the server. + aggregator: An instance of the aggregator responsible for aggregating client updates. + comm: The communication object for inter-process communication. + rank (int): The rank of the server process. + size (int): The total number of processes. + backend (str): The communication backend (e.g., "MPI"). + is_preprocessed (bool): Indicates whether the data is preprocessed. + preprocessed_client_lists (list): A list of preprocessed client data. + + Attributes: + args (argparse.Namespace): Command-line arguments and configurations for the server. + aggregator: An instance of the aggregator responsible for aggregating client updates. + round_num (int): The total number of communication rounds. + round_idx (int): The current round index. + is_preprocessed (bool): Indicates whether the data is preprocessed. + preprocessed_client_lists (list): A list of preprocessed client data. + client_round_dict (dict): A dictionary to track the round index for each client. + + Methods: + run(): + Start the server and begin the federated learning process. + + send_init_msg(): + Send initialization messages to client processes to start communication. + + register_message_receive_handlers(): + Register message handlers for receiving client updates. + + handle_message_receive_model_from_client(msg_params): + Handle the received client update message, record client runtime information, + aggregate the updates, and perform testing. + + send_message_init_config(receive_id, global_model_params, client_index): + Send initialization configuration messages to clients. + + send_message_sync_model_to_client(receive_id, global_model_params, client_index): + Send synchronized model updates to clients. + + """ def __init__( self, args, @@ -32,10 +75,22 @@ def __init__( def run(self): + """ + Start the server and begin the federated learning process. + + Returns: + None + """ super().run() def send_init_msg(self): + """ + Send initialization messages to client processes to start communication. + + Returns: + None + """ # sampling clients # client_indexes = self.aggregator.client_sampling( # self.round_idx, @@ -54,12 +109,28 @@ def send_init_msg(self): def register_message_receive_handlers(self): + """ + Register message handlers for receiving client updates. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received client update message, record client runtime information, + aggregate the updates, and perform testing. + + Args: + msg_params (dict): Message parameters containing client update information. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -107,6 +178,17 @@ def handle_message_receive_model_from_client(self, msg_params): def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initialization configuration messages to clients. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters to be sent to clients. + client_index (list): List of client indexes for the current communication round. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -117,6 +199,17 @@ def send_message_init_config(self, receive_id, global_model_params, def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index): + """ + Send synchronized model updates to clients. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters to be sent to clients. + client_index (list): List of client indexes for the current communication round. + + Returns: + None + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py b/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py index 008582c6d3..b0cd08e362 100644 --- a/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py +++ b/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py @@ -7,13 +7,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -50,6 +87,17 @@ def train(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.eval() @@ -91,4 +139,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py index 7ed01c3703..5941065ed7 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py @@ -6,13 +6,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -61,6 +98,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -115,4 +163,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py index 8aa72effea..ec9aeca3c8 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py @@ -6,13 +6,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -65,6 +102,17 @@ def train(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -92,4 +140,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py index 08d9b13f65..faf2f1690f 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py @@ -5,13 +5,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -56,6 +93,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -83,4 +131,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py index 50539b3ea5..08fd44d0ce 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py @@ -5,13 +5,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -56,6 +93,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -100,4 +148,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/utils.py b/python/fedml/simulation/mpi/async_fedavg/utils.py index aea2449590..63aa625d5f 100644 --- a/python/fedml/simulation/mpi/async_fedavg/utils.py +++ b/python/fedml/simulation/mpi/async_fedavg/utils.py @@ -5,6 +5,16 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary of model parameters with keys as parameter names + and values as NumPy arrays. + + Returns: + dict: A dictionary of model parameters with the same keys, but values as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -13,12 +23,31 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to NumPy arrays. + + Args: + model_params (dict): A dictionary of model parameters with keys as parameter names + and values as PyTorch tensors. + + Returns: + dict: A dictionary of model parameters with the same keys, but values as NumPy arrays. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a sweep process using a named pipe. + + Args: + args: Additional arguments or information to include in the completion message. + + Returns: + None + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 93231e356ee3de2aa30798ec34e6a2834e4ea9c9 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 8 Sep 2023 12:16:34 +0530 Subject: [PATCH 12/70] python\fedml\simulation\mpi base framework classical_vertical_fl decentralized framework fedavg --- .../mpi/base_framework/algorithm_api.py | 44 +++++++ .../mpi/base_framework/central_manager.py | 60 ++++++++- .../mpi/base_framework/central_worker.py | 51 ++++++++ .../mpi/base_framework/client_manager.py | 93 ++++++++++++++ .../mpi/base_framework/client_worker.py | 44 ++++++- .../mpi/base_framework/message_define.py | 2 +- .../classical_vertical_fl/guest_manager.py | 80 ++++++++++++ .../classical_vertical_fl/guest_trainer.py | 120 ++++++++++++++++++ .../mpi/classical_vertical_fl/host_manager.py | 34 ++++- .../mpi/classical_vertical_fl/host_trainer.py | 54 +++++++- .../decentralized_framework/algorithm_api.py | 18 ++- .../decentralized_worker.py | 29 +++++ .../decentralized_worker_manager.py | 57 ++++++++- .../simulation/mpi/fedavg/FedAVGAggregator.py | 85 ++++++++++++- .../simulation/mpi/fedavg/FedAVGTrainer.py | 66 ++++++++++ .../fedml/simulation/mpi/fedavg/FedAvgAPI.py | 49 +++++++ .../mpi/fedavg/FedAvgClientManager.py | 49 ++++++- .../mpi/fedavg/FedAvgServerManager.py | 87 ++++++++++++- python/fedml/simulation/mpi/fedavg/utils.py | 30 ++++- 19 files changed, 1023 insertions(+), 29 deletions(-) diff --git a/python/fedml/simulation/mpi/base_framework/algorithm_api.py b/python/fedml/simulation/mpi/base_framework/algorithm_api.py index 5df90e0187..65eb2023dc 100644 --- a/python/fedml/simulation/mpi/base_framework/algorithm_api.py +++ b/python/fedml/simulation/mpi/base_framework/algorithm_api.py @@ -7,6 +7,14 @@ def FedML_init(): + """ + Initialize the MPI communication and retrieve process information. + + Returns: + comm (object): MPI communication object. + process_id (int): Unique ID of the current process. + worker_number (int): Total number of worker processes. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -14,6 +22,18 @@ def FedML_init(): def FedML_Base_distributed(args, process_id, worker_number, comm): + """ + Run the base distributed federated learning process. + + Args: + args (object): An object containing the configuration parameters. + process_id (int): Unique ID of the current process. + worker_number (int): Total number of worker processes. + comm (object): MPI communication object. + + Returns: + None + """ if process_id == 0: init_central_worker(args, comm, process_id, worker_number) else: @@ -21,6 +41,18 @@ def FedML_Base_distributed(args, process_id, worker_number, comm): def init_central_worker(args, comm, process_id, size): + """ + Initialize the central worker for distributed federated learning. + + Args: + args (object): An object containing the configuration parameters. + comm (object): MPI communication object. + process_id (int): Unique ID of the current process. + size (int): Total number of processes. + + Returns: + None + """ # aggregator client_num = size - 1 aggregator = BaseCentralWorker(client_num, args) @@ -31,6 +63,18 @@ def init_central_worker(args, comm, process_id, size): def init_client_worker(args, comm, process_id, size): + """ + Initialize a client worker for distributed federated learning. + + Args: + args (object): An object containing the configuration parameters. + comm (object): MPI communication object. + process_id (int): Unique ID of the current process. + size (int): Total number of processes. + + Returns: + None + """ # trainer client_ID = process_id - 1 trainer = BaseClientWorker(client_ID) diff --git a/python/fedml/simulation/mpi/base_framework/central_manager.py b/python/fedml/simulation/mpi/base_framework/central_manager.py index dc192a187c..4ccaaaafe1 100644 --- a/python/fedml/simulation/mpi/base_framework/central_manager.py +++ b/python/fedml/simulation/mpi/base_framework/central_manager.py @@ -7,6 +7,19 @@ class BaseCentralManager(FedMLCommManager): def __init__(self, args, comm, rank, size, aggregator): + """ + Initialize the BaseCentralManager. + + Args: + args (object): An object containing configuration parameters. + comm (object): MPI communication object. + rank (int): The rank of the current process. + size (int): The total number of processes. + aggregator (object): The aggregator for aggregating results. + + Returns: + None + """ super().__init__(args, comm, rank, size) self.aggregator = aggregator @@ -14,17 +27,40 @@ def __init__(self, args, comm, rank, size, aggregator): self.args.round_idx = 0 def run(self): + """ + Run the central manager. + + This method initiates the communication with client processes and aggregates their results. + + Returns: + None + """ for process_id in range(1, self.size): self.send_message_init_config(process_id) super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for the central manager. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_INFORMATION, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle messages received from client processes. + + Args: + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) client_local_result = msg_params.get(MyMessage.MSG_ARG_KEY_INFORMATION) @@ -34,11 +70,12 @@ def handle_message_receive_model_from_client(self, msg_params): logging.info("b_all_received = " + str(b_all_received)) if b_all_received: logging.info( - "**********************************ROUND INDEX = " + str(self.args.round_idx) + "**********************************ROUND INDEX = " + + str(self.args.round_idx) ) global_result = self.aggregator.aggregate() - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: self.finish() @@ -48,12 +85,31 @@ def handle_message_receive_model_from_client(self, msg_params): self.send_message_to_client(receiver_id, global_result) def send_message_init_config(self, receive_id): + """ + Send initialization configuration message to a client process. + + Args: + receive_id (int): The ID of the receiving client process. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) self.send_message(message) def send_message_to_client(self, receive_id, global_result): + """ + Send a message to a client process containing global results. + + Args: + receive_id (int): The ID of the receiving client process. + global_result (object): The global result to be sent. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INFORMATION, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/base_framework/central_worker.py b/python/fedml/simulation/mpi/base_framework/central_worker.py index c9a862251b..982d90932a 100644 --- a/python/fedml/simulation/mpi/base_framework/central_worker.py +++ b/python/fedml/simulation/mpi/base_framework/central_worker.py @@ -2,7 +2,36 @@ class BaseCentralWorker(object): + """ + Base class representing a central worker in a distributed system. + + This class is responsible for managing client local results and aggregating them. + + Attributes: + client_num (int): The number of client processes. + args (object): An object containing configuration parameters. + client_local_result_list (dict): A dictionary to store client local results. + flag_client_model_uploaded_dict (dict): A dictionary to track whether each client has uploaded results. + + Methods: + add_client_local_result(index, client_local_result): + Add client's local result to the worker. + check_whether_all_receive(): + Check if all clients have uploaded their local results. + aggregate(): + Aggregate client local results. + """ def __init__(self, client_num, args): + """ + Initialize the BaseCentralWorker. + + Args: + client_num (int): The number of client processes. + args (object): An object containing configuration parameters. + + Returns: + None + """ self.client_num = client_num self.args = args @@ -13,11 +42,27 @@ def __init__(self, client_num, args): self.flag_client_model_uploaded_dict[idx] = False def add_client_local_result(self, index, client_local_result): + """ + Add client's local result to the worker. + + Args: + index (int): The index of the client. + client_local_result (object): The local result from the client. + + Returns: + None + """ logging.info("add_client_local_result. index = %d" % index) self.client_local_result_list[index] = client_local_result self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local results. + + Returns: + bool: True if all clients have uploaded their results, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -26,6 +71,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate client local results. + + Returns: + object: The aggregated global result. + """ global_result = 0 for k in self.client_local_result_list.keys(): global_result += self.client_local_result_list[k] diff --git a/python/fedml/simulation/mpi/base_framework/client_manager.py b/python/fedml/simulation/mpi/base_framework/client_manager.py index 00e5fc12ca..2ff3436cdb 100644 --- a/python/fedml/simulation/mpi/base_framework/client_manager.py +++ b/python/fedml/simulation/mpi/base_framework/client_manager.py @@ -4,16 +4,72 @@ class BaseClientManager(FedMLCommManager): + """ + Base class representing a client manager in a distributed system. + + This class handles the communication between clients and the central server. + + Attributes: + args (object): An object containing configuration parameters. + comm (object): A communication object for MPI communication. + rank (int): The rank of the current process. + size (int): The total number of processes. + trainer (object): An object responsible for client-side training. + num_rounds (int): The total number of communication rounds. + + Methods: + run(): + Start the client manager. + handle_message_init(msg_params): + Handle initialization message from the server. + handle_message_receive_model_from_server(msg_params): + Handle receiving model update from the server. + send_model_to_server(receive_id, client_gradient): + Send client-side model updates to the server. + __train(): + Perform training and send updates to the server. + """ def __init__(self, args, comm, rank, size, trainer): + """ + Initialize the BaseClientManager. + + Args: + args (object): An object containing configuration parameters. + comm (object): A communication object for MPI communication. + rank (int): The rank of the current process. + size (int): The total number of processes. + trainer (object): An object responsible for client-side training. + + Returns: + None + """ super().__init__(args, comm, rank, size) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 def run(self): + """ + Start the client manager. + + Args: + None + + Returns: + None + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + + Args: + None + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -23,11 +79,29 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle initialization message from the server. + + Args: + msg_params (dict): Parameters included in the message. + + Returns: + None + """ self.trainer.update(0) self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle receiving model update from the server. + + Args: + msg_params (dict): Parameters included in the message. + + Returns: + None + """ global_result = msg_params.get(MyMessage.MSG_ARG_KEY_INFORMATION) self.trainer.update(global_result) self.args.round_idx += 1 @@ -36,6 +110,16 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, client_gradient): + """ + Send client-side model updates to the server. + + Args: + receive_id (int): The ID of the recipient (server). + client_gradient (object): The client-side model update. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_INFORMATION, self.get_sender_id(), receive_id ) @@ -43,6 +127,15 @@ def send_model_to_server(self, receive_id, client_gradient): self.send_message(message) def __train(self): + """ + Perform training and send updates to the server. + + Args: + None + + Returns: + None + """ # do something here (e.g., training) training_interation_result = self.trainer.train() diff --git a/python/fedml/simulation/mpi/base_framework/client_worker.py b/python/fedml/simulation/mpi/base_framework/client_worker.py index bf607d8beb..f8673b7178 100644 --- a/python/fedml/simulation/mpi/base_framework/client_worker.py +++ b/python/fedml/simulation/mpi/base_framework/client_worker.py @@ -1,12 +1,54 @@ class BaseClientWorker(object): + """ + Base class representing a client worker in a distributed system. + + This class is responsible for client-side operations, such as training and updating information. + + Attributes: + client_index (int): The index of the client worker. + updated_information (int): Information that can be updated during training. + + Methods: + update(updated_information): + Update the information associated with the client. + train(): + Perform client-specific training or operation. + + """ + def __init__(self, client_index): + """ + Initialize the BaseClientWorker. + + Args: + client_index (int): The index of the client worker. + + Returns: + None + """ self.client_index = client_index self.updated_information = 0 def update(self, updated_information): + """ + Update the information associated with the client. + + Args: + updated_information (int): The new information to be associated with the client. + + Returns: + None + """ self.updated_information = updated_information print(self.updated_information) def train(self): - # complete your own algorithm operation here, as am example, we return the client_index + """ + Perform client-specific training or operation. + + Returns: + int: An example result (client_index in this case). + """ + # Complete your own algorithm operation here. + # As an example, we return the client_index. return self.client_index diff --git a/python/fedml/simulation/mpi/base_framework/message_define.py b/python/fedml/simulation/mpi/base_framework/message_define.py index 27ba9f14d4..b8b52f79d5 100644 --- a/python/fedml/simulation/mpi/base_framework/message_define.py +++ b/python/fedml/simulation/mpi/base_framework/message_define.py @@ -15,6 +15,6 @@ class MyMessage(object): MSG_ARG_KEY_RECEIVER = "receiver" """ - message payload keywords definition + message payload keywords definition """ MSG_ARG_KEY_INFORMATION = "information" diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py b/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py index fbcfafef43..b3f1c03469 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py @@ -4,7 +4,47 @@ class GuestManager(FedMLCommManager): + """ + Class representing the manager for a guest in a distributed system. + + This class is responsible for handling communication between the guest and other participants, + as well as coordinating training rounds. + + Attributes: + args: Arguments for the manager. + comm: The communication interface. + rank: The rank of the guest in the communication group. + size: The total number of participants in the communication group. + guest_trainer: The trainer responsible for guest-specific training. + + Methods: + run(): + Start the guest manager and run communication. + register_message_receive_handlers(): + Register message receive handlers for handling incoming messages. + handle_message_receive_logits_from_client(msg_params): + Handle the reception of logits and trigger training when all data is received. + send_message_init_config(receive_id): + Send an initialization message to a client. + send_message_to_client(receive_id, global_result): + Send a message containing global training results to a client. + + """ + def __init__(self, args, comm, rank, size, guest_trainer): + """ + Initialize the GuestManager. + + Args: + args: Arguments for the manager. + comm: The communication interface. + rank: The rank of the guest in the communication group. + size: The total number of participants in the communication group. + guest_trainer: The trainer responsible for guest-specific training. + + Returns: + None + """ super().__init__(args, comm, rank, size) self.guest_trainer = guest_trainer @@ -12,17 +52,38 @@ def __init__(self, args, comm, rank, size, guest_trainer): self.args.round_idx = 0 def run(self): + """ + Start the guest manager and run communication. + + Returns: + None + """ for process_id in range(1, self.size): self.send_message_init_config(process_id) super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_LOGITS, self.handle_message_receive_logits_from_client, ) def handle_message_receive_logits_from_client(self, msg_params): + """ + Handle the reception of logits and trigger training when all data is received. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) host_train_logits = msg_params.get(MyMessage.MSG_ARG_KEY_TRAIN_LOGITS) host_test_logits = msg_params.get(MyMessage.MSG_ARG_KEY_TEST_LOGITS) @@ -44,12 +105,31 @@ def handle_message_receive_logits_from_client(self, msg_params): self.finish() def send_message_init_config(self, receive_id): + """ + Send an initialization message to a client. + + Args: + receive_id: The ID of the client to receive the message. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) self.send_message(message) def send_message_to_client(self, receive_id, global_result): + """ + Send a message containing global training results to a client. + + Args: + receive_id: The ID of the client to receive the message. + global_result: The global training result to send. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_GRADIENT, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py b/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py index 9be9ee9bc2..a8610b46ce 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py @@ -8,6 +8,43 @@ class GuestTrainer(object): + """ + Class representing the trainer for a guest in a distributed system. + + This class handles training and gradient aggregation for the guest. + + Attributes: + client_num: The number of clients in the system. + device: The device (e.g., CPU or GPU) used for training. + X_train: The training data features. + y_train: The training data labels. + X_test: The test data features. + y_test: The test data labels. + model_feature_extractor: The feature extractor model. + model_classifier: The classifier model. + args: Arguments for the trainer. + + Methods: + get_batch_num(): + Get the number of batches for training. + add_client_local_result(index, host_train_logits, host_test_logits): + Add client local results to the trainer. + check_whether_all_receive(): + Check if all client local results have been received. + train(round_idx): + Perform training for a round and return gradients to hosts. + _bp_classifier(x, grads): + Backpropagate gradients through the classifier. + _bp_feature_extractor(x, grads): + Backpropagate gradients through the feature extractor. + _test(round_idx): + Perform testing and calculate evaluation metrics. + _sigmoid(x): + Compute the sigmoid function. + _compute_correct_prediction(y_targets, y_prob_preds, threshold): + Compute correct predictions and evaluation statistics. + + """ def __init__( self, client_num, @@ -72,15 +109,38 @@ def __init__( self.loss_list = list() def get_batch_num(self): + """ + Get the number of batches for training. + + Returns: + int: The number of batches. + """ return self.n_batches def add_client_local_result(self, index, host_train_logits, host_test_logits): + """ + Add client local results to the trainer. + + Args: + index: The index of the client. + host_train_logits: Logits from the client's local training data. + host_test_logits: Logits from the client's local test data. + + Returns: + None + """ # logging.info("add_client_local_result. index = %d" % index) self.host_local_train_logits_list[index] = host_train_logits self.host_local_test_logits_list[index] = host_test_logits self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all client local results have been received. + + Returns: + bool: True if all results have been received, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -89,6 +149,15 @@ def check_whether_all_receive(self): return True def train(self, round_idx): + """ + Perform training for a round and return gradients to hosts. + + Args: + round_idx: The index of the training round. + + Returns: + ndarray: Gradients to hosts. + """ batch_x = self.X_train[ self.batch_idx * self.batch_size : self.batch_idx * self.batch_size + self.batch_size @@ -137,6 +206,17 @@ def train(self, round_idx): return gradients_to_hosts def _bp_classifier(self, x, grads): + """ + Backpropagate gradients through the classifier. + + Args: + x: Input data. + grads: Gradients to be backpropagated. + + Returns: + ndarray: Gradients of the input data. + """ + x = x.clone().detach().requires_grad_(True) output = self.model_classifier(x) output.backward(gradient=grads) @@ -146,12 +226,31 @@ def _bp_classifier(self, x, grads): return x_grad def _bp_feature_extractor(self, x, grads): + """ + Backpropagate gradients through the feature extractor. + + Args: + x: Input data. + grads: Gradients to be backpropagated. + + Returns: + None + """ output = self.model_feature_extractor(x) output.backward(gradient=grads) self.optimizer_fe.step() self.optimizer_fe.zero_grad() def _test(self, round_idx): + """ + Perform testing and calculate evaluation metrics. + + Args: + round_idx: The index of the training round. + + Returns: + None + """ X_test = torch.tensor(self.X_test).float().to(self.device) y_test = self.y_test @@ -183,9 +282,30 @@ def _test(self, round_idx): ) def _sigmoid(self, x): + """ + Compute the sigmoid function. + + Args: + x: Input data. + + Returns: + ndarray: Sigmoid values. + """ return 1.0 / (1.0 + np.exp(-x)) def _compute_correct_prediction(self, y_targets, y_prob_preds, threshold=0.5): + """ + Compute correct predictions and evaluation statistics. + + Args: + y_targets: True labels. + y_prob_preds: Predicted probabilities. + threshold: Threshold for classification. + + Returns: + ndarray: Predicted labels. + list: Statistics (positive predictions, negative predictions, correct predictions). + """ y_hat_lbls = [] pred_pos_count = 0 pred_neg_count = 0 diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py b/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py index 82d31aa70d..cb37de2a54 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py @@ -2,18 +2,29 @@ from ....core.distributed.fedml_comm_manager import FedMLCommManager from ....core.distributed.communication.message import Message - class HostManager(FedMLCommManager): def __init__(self, args, comm, rank, size, trainer): + """ + Initialize a HostManager instance. + + Args: + args: Configuration arguments. + comm: MPI communication object. + rank: Rank of the process. + size: Number of processes in the communicator. + trainer: Trainer for host-specific tasks. + """ super().__init__(args, comm, rank, size) self.trainer = trainer self.num_rounds = args.comm_round self.round_idx = 0 def run(self): + """Start the HostManager.""" super().run() def register_message_receive_handlers(self): + """Register message receive handlers.""" self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -23,10 +34,22 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params: Parameters from the initialization message. + """ self.round_idx = 0 self.__train() def handle_message_receive_gradient_from_server(self, msg_params): + """ + Handle the gradient message received from the server. + + Args: + msg_params: Parameters from the gradient message. + """ gradient = msg_params.get(MyMessage.MSG_ARG_KEY_GRADIENT) self.trainer.update_model(gradient) self.round_idx += 1 @@ -35,6 +58,14 @@ def handle_message_receive_gradient_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, host_train_logits, host_test_logits): + """ + Send host training and test logits to the server. + + Args: + receive_id: ID of the receiver. + host_train_logits: Host's training logits. + host_test_logits: Host's test logits. + """ message = Message( MyMessage.MSG_TYPE_C2S_LOGITS, self.get_sender_id(), receive_id ) @@ -43,6 +74,7 @@ def send_model_to_server(self, receive_id, host_train_logits, host_test_logits): self.send_message(message) def __train(self): + """Perform host training and send logits to the server.""" host_train_logits, host_test_logits = self.trainer.computer_logits( self.round_idx ) diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py b/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py index 06af3bcd27..6d18e386e1 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py @@ -3,6 +3,20 @@ class HostTrainer(object): + """ + Trainer for host-specific tasks in a federated learning environment. + + This class manages the training and gradient update process for host-specific tasks in a federated learning system. + + Args: + client_index: Index of the host client. + device: Computing device (e.g., CPU or GPU) to perform training. + X_train: Training data for the host. + X_test: Test data for the host. + model_feature_extractor: Feature extractor model. + model_classifier: Classifier model. + args: Configuration arguments. + """ def __init__( self, client_index, @@ -13,6 +27,9 @@ def __init__( model_classifier, args, ): + """ + Initialize a HostTrainer instance. + """ # device information self.client_index = client_index self.device = device @@ -55,9 +72,19 @@ def __init__( self.cached_extracted_features = None def get_batch_num(self): + """Get the number of training batches.""" return self.n_batches def computer_logits(self, round_idx): + """ + Compute logits for host-specific tasks. + + Args: + round_idx: Current round index. + + Returns: + tuple: A tuple containing host training logits and host test logits. + """ batch_x = self.X_train[ self.batch_idx * self.batch_size : self.batch_idx * self.batch_size + self.batch_size @@ -65,13 +92,12 @@ def computer_logits(self, round_idx): self.batch_x = torch.tensor(batch_x).float().to(self.device) self.extracted_feature = self.model_feature_extractor.forward(self.batch_x) logits = self.model_classifier.forward(self.extracted_feature) - # copy to CPU host memory logits_train = logits.cpu().detach().numpy() self.batch_idx += 1 if self.batch_idx == self.n_batches: self.batch_idx = 0 - # for test + # For test if (round_idx + 1) % self.args.frequency_of_the_test == 0: X_test = torch.tensor(self.X_test).float().to(self.device) extracted_feature = self.model_feature_extractor.forward(X_test) @@ -83,12 +109,27 @@ def computer_logits(self, round_idx): return logits_train, logits_test def update_model(self, gradient): - # logging.info("#######################gradient = " + str(gradient)) + """ + Update the model using the received gradient. + + Args: + gradient: Gradient received from the server. + """ gradient = torch.tensor(gradient).float().to(self.device) back_grad = self._bp_classifier(self.extracted_feature, gradient) self._bp_feature_extractor(self.batch_x, back_grad) def _bp_classifier(self, x, grads): + """ + Backpropagate gradients through the classifier model. + + Args: + x: Input data. + grads: Gradients to backpropagate. + + Returns: + x_grad: Gradients of the input data. + """ x = x.clone().detach().requires_grad_(True) output = self.model_classifier(x) output.backward(gradient=grads) @@ -98,6 +139,13 @@ def _bp_classifier(self, x, grads): return x_grad def _bp_feature_extractor(self, x, grads): + """ + Backpropagate gradients through the feature extractor model. + + Args: + x: Input data. + grads: Gradients to backpropagate. + """ output = self.model_feature_extractor(x) output.backward(gradient=grads) self.optimizer_fe.step() diff --git a/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py b/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py index c72161fad4..c4b4c7bd62 100644 --- a/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py +++ b/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py @@ -4,16 +4,28 @@ from .decentralized_worker_manager import DecentralizedWorkerManager from ....core.distributed.topology.symmetric_topology_manager import SymmetricTopologyManager - def FedML_Decentralized_Demo_distributed(args, process_id, worker_number, comm): - # initialize the topology (ring) + """ + Run the decentralized federated learning demo on a distributed system. + + This function initializes the topology (ring) for decentralized federated learning, + initializes the decentralized worker (trainer), and runs the decentralized worker manager. + + Args: + args: Configuration arguments. + process_id: The unique ID of the current process. + worker_number: The total number of workers in the distributed system. + comm: MPI communication object for distributed communication. + """ + # Initialize the topology (ring) tpmgr = SymmetricTopologyManager(worker_number, 2) tpmgr.generate_topology() logging.info(tpmgr.topology) - # initialize the decentralized trainer (worker) + # Initialize the decentralized trainer (worker) worker_index = process_id trainer = DecentralizedWorker(worker_index, tpmgr) + # Initialize the decentralized worker manager client_manager = DecentralizedWorkerManager(args, comm, process_id, worker_number, trainer, tpmgr) client_manager.run() diff --git a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py index 03d64e0525..e09994ee78 100644 --- a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py +++ b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py @@ -1,5 +1,15 @@ class DecentralizedWorker(object): + """ + Represents a decentralized federated learning worker. + """ def __init__(self, worker_index, topology_manager): + """ + Represents a decentralized federated learning worker. + + Args: + worker_index: The index or ID of the worker. + topology_manager: The topology manager for communication with neighboring workers. + """ self.worker_index = worker_index self.in_neighbor_idx_list = topology_manager.get_in_neighbor_idx_list( self.worker_index @@ -11,10 +21,23 @@ def __init__(self, worker_index, topology_manager): self.flag_neighbor_result_received_dict[neighbor_idx] = False def add_result(self, worker_index, updated_information): + """ + Add the result received from a neighboring worker. + + Args: + worker_index: The index or ID of the neighboring worker. + updated_information: The updated information received from the neighboring worker. + """ self.worker_result_dict[worker_index] = updated_information self.flag_neighbor_result_received_dict[worker_index] = True def check_whether_all_receive(self): + """ + Check if results have been received from all neighboring workers. + + Returns: + bool: True if results have been received from all neighbors, False otherwise. + """ for neighbor_idx in self.in_neighbor_idx_list: if not self.flag_neighbor_result_received_dict[neighbor_idx]: return False @@ -23,5 +46,11 @@ def check_whether_all_receive(self): return True def train(self): + """ + Perform the training process for the decentralized worker. + + Returns: + int: A placeholder value (0 in this case) representing the result of the training iteration. + """ self.add_result(self.worker_index, 0) return 0 diff --git a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py index a782a5ea3b..97bac8ca67 100644 --- a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py +++ b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py @@ -4,9 +4,22 @@ from ....core.distributed.communication.message import Message from ....core.distributed.fedml_comm_manager import FedMLCommManager - class DecentralizedWorkerManager(FedMLCommManager): + """ + Class representing a decentralized federated learning worker in a distributed system. + """ def __init__(self, args, comm, rank, size, trainer, topology_manager): + """ + Manages decentralized federated learning workers in a distributed system. + + Args: + args: Configuration arguments. + comm: MPI communication object for distributed communication. + rank: The rank (ID) of the current worker. + size: The total number of workers in the distributed system. + trainer: The decentralized worker/trainer. + topology_manager: The topology manager for communication between workers. + """ super().__init__(args, comm, rank, size) self.worker_index = rank self.trainer = trainer @@ -15,21 +28,36 @@ def __init__(self, args, comm, rank, size, trainer, topology_manager): self.round_idx = 0 def run(self): + """ + Start the training process for decentralized federated learning workers. + """ self.start_training() super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + """ self.register_message_receive_handler(MyMessage.MSG_TYPE_SEND_MSG_TO_NEIGHBOR, self.handle_msg_from_neighbor) def start_training(self): + """ + Initialize and start the training process. + """ self.round_idx = 0 self.__train() def handle_msg_from_neighbor(self, msg_params): + """ + Handle messages received from neighboring workers. + + Args: + msg_params: Parameters included in the received message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - training_interation_result = msg_params.get(MyMessage.MSG_ARG_KEY_PARAMS_1) + training_iteration_result = msg_params.get(MyMessage.MSG_ARG_KEY_PARAMS_1) logging.info("handle_msg_from_neighbor. sender_id = " + str(sender_id)) - self.trainer.add_result(sender_id, training_interation_result) + self.trainer.add_result(sender_id, training_iteration_result) if self.trainer.check_whether_all_receive(): logging.info(">>>>>>>>>>>>>>>WORKER %d, ROUND %d finished!<<<<<<<<" % (self.worker_index, self.round_idx)) self.round_idx += 1 @@ -38,17 +66,34 @@ def handle_msg_from_neighbor(self, msg_params): self.__train() def __train(self): - # do something here (e.g., training) - training_interation_result = self.trainer.train() + """ + Perform the training process and communicate with neighboring workers. + """ + # Perform the training process here (e.g., training iteration) + training_iteration_result = self.trainer.train() + # Send the training iteration result to neighboring workers for neighbor_idx in self.topology_manager.get_out_neighbor_idx_list(self.worker_index): - self.send_result_to_neighbors(neighbor_idx, training_interation_result) + self.send_result_to_neighbors(neighbor_idx, training_iteration_result) def send_message_init_config(self, receive_id): + """ + Send an initialization message to a specified worker. + + Args: + receive_id: The ID of the receiving worker. + """ message = Message(MyMessage.MSG_TYPE_INIT, self.get_sender_id(), receive_id) self.send_message(message) def send_result_to_neighbors(self, receive_id, client_params1): + """ + Send training iteration results to neighboring workers. + + Args: + receive_id: The ID of the receiving worker. + client_params1: Parameters to be sent in the message. + """ logging.info("send_result_to_neighbors. receive_id = " + str(receive_id)) message = Message(MyMessage.MSG_TYPE_SEND_MSG_TO_NEIGHBOR, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_PARAMS_1, client_params1) diff --git a/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py b/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py index 0c99809001..893b2c4b03 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py +++ b/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py @@ -13,6 +13,21 @@ from ....core.security.fedml_defender import FedMLDefender class FedAVGAggregator(object): + """ + Represents a Federated Averaging (FedAVG) aggregator for federated learning. + + Args: + train_global: The global training dataset. + test_global: The global testing dataset. + all_train_data_num: The total number of training data samples. + train_data_local_dict: A dictionary mapping worker indices to their local training datasets. + test_data_local_dict: A dictionary mapping worker indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping worker indices to the number of local training samples. + worker_num: The number of worker nodes participating in the federated learning. + device: The device (e.g., 'cuda' or 'cpu') used for computations. + args: Additional configuration arguments. + server_aggregator: The server-side aggregator used for communication with workers. + """ def __init__( self, train_global, @@ -47,18 +62,44 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters from the aggregator. + + Returns: + dict: The global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters in the aggregator. + + Args: + model_parameters (dict): The global model parameters to set. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the locally trained model result from a worker. + + Args: + index: The index or ID of the worker. + model_params (dict): The model parameters trained by the worker. + sample_num (int): The number of training samples used by the worker. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if model results have been received from all workers. + + Returns: + bool: True if results have been received from all workers, False otherwise. + """ logging.debug("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -68,6 +109,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate the model updates from worker nodes using Federated Averaging (FedAVG). + + Returns: + dict: The averaged model parameters. + """ start_time = time.time() model_list = [] @@ -97,6 +144,15 @@ def aggregate(self): return averaged_params def _fedavg_aggregation_(self, model_list): + """ + Perform the FedAVG aggregation on a list of local model updates. + + Args: + model_list (list): A list of tuples containing local sample numbers and model parameters. + + Returns: + dict: The aggregated model parameters. + """ training_num = 0 for i in range(0, len(model_list)): local_sample_number, local_model_params = model_list[i] @@ -116,6 +172,17 @@ def _fedavg_aggregation_(self, model_list): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for a federated learning round. + + Args: + round_idx (int): The index of the current federated learning round. + client_num_in_total (int): The total number of clients available. + client_num_per_round (int): The number of clients to sample for the current round. + + Returns: + list: A list of client indexes selected for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -131,7 +198,17 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes + def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing purposes. + + Args: + num_samples (int): The number of samples to include in the validation set. + + Returns: + DataLoader: A DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -146,6 +223,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients in a federated learning round. + + Args: + round_idx (int): The index of the current federated learning round. + """ if self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -170,4 +253,4 @@ def test_on_server_for_all_clients(self, round_idx): metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) else: - mlops.log({"round_idx": round_idx}) \ No newline at end of file + mlops.log({"round_idx": round_idx}) diff --git a/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py b/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py index 6b1e271d09..d0cfbd3d3a 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py +++ b/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py @@ -2,6 +2,42 @@ class FedAVGTrainer(object): + """ + A class that handles training and testing on a local client in the FedAVG framework. + + This class is responsible for training and testing a local model using client-specific data in a federated learning setting. + + Args: + client_index: The index or ID of the client. + train_data_local_dict: A dictionary containing local training data. + train_data_local_num_dict: A dictionary containing the number of training samples for each client. + test_data_local_dict: A dictionary containing local testing data. + train_data_num: The total number of training samples. + device: The computing device (e.g., "cuda" or "cpu") to perform training and testing. + args: An object containing configuration parameters. + model_trainer: A model trainer object responsible for training and testing. + + Attributes: + trainer: A model trainer object responsible for training and testing. + client_index: The index or ID of the client. + train_data_local_dict: A dictionary containing local training data. + train_data_local_num_dict: A dictionary containing the number of training samples for each client. + test_data_local_dict: A dictionary containing local testing data. + all_train_data_num: The total number of training samples. + train_local: Local training data for the current client. + local_sample_number: The number of training samples for the current client. + test_local: Local testing data for the current client. + device: The computing device (e.g., "cuda" or "cpu") to perform training and testing. + args: An object containing configuration parameters. + + Methods: + update_model(weights): Update the model with new weights. + update_dataset(client_index): Update the local datasets and client index. + train(round_idx=None): Train the local model using the current client's data. + test(): Test the local model on both training and testing data. + + """ + def __init__( self, client_index, @@ -28,9 +64,19 @@ def __init__( self.args = args def update_model(self, weights): + """Update the model with new weights. + + Args: + weights: The new model weights to set. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """Update the local datasets and client index. + + Args: + client_index: The index or ID of the client. + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -49,6 +95,15 @@ def update_dataset(self, client_index): self.test_local = None def train(self, round_idx=None): + """Train the local model using the current client's data. + + Args: + round_idx: The current communication round index (optional). + + Returns: + weights: The trained model weights. + local_sample_number: The number of training samples used for training. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) @@ -57,6 +112,17 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """Test the local model on both training and testing data. + + Returns: + A tuple containing the following metrics: + - train_tot_correct: The total number of correct predictions on the training data. + - train_loss: The loss on the training data. + - train_num_sample: The total number of training samples. + - test_tot_correct: The total number of correct predictions on the testing data. + - test_loss: The loss on the testing data. + - test_num_sample: The total number of testing samples. + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py b/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py index bc5069ab6d..d88ce8e3d1 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py +++ b/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py @@ -21,6 +21,20 @@ def FedML_FedAvg_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Run Federated Averaging (FedAvg) in a distributed setting. + + Args: + args: The command-line arguments and configuration for the FedAvg process. + process_id (int): The unique identifier for the current process. + worker_number (int): The total number of worker processes. + comm: The communication backend for inter-process communication. + device: The target device (e.g., CPU or GPU) for training. + dataset: The dataset for training and testing. + model: The machine learning model to be trained. + client_trainer (ClientTrainer, optional): The client trainer responsible for local training. + server_aggregator (ServerAggregator, optional): The server aggregator for model aggregation. + """ [ train_data_num, test_data_num, @@ -83,6 +97,24 @@ def init_server( train_data_local_num_dict, server_aggregator ): + """ + Initialize the server for FedAvg. + + Args: + args: The command-line arguments and configuration for the FedAvg process. + device: The target device (e.g., CPU or GPU) for training. + comm: The communication backend for inter-process communication. + rank (int): The rank or identifier of the server process. + size (int): The total number of processes. + model: The machine learning model to be trained. + train_data_num (int): The number of training samples. + train_data_global: The global training dataset. + test_data_global: The global testing dataset. + train_data_local_dict: A dictionary mapping client IDs to their local training datasets. + test_data_local_dict: A dictionary mapping client IDs to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training samples. + server_aggregator: The server aggregator for model aggregation. + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -109,6 +141,7 @@ def init_server( server_manager.run() + def init_client( args, device, @@ -122,6 +155,22 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for FedAvg. + + Args: + args: The command-line arguments and configuration for the FedAvg process. + device: The target device (e.g., CPU or GPU) for training. + comm: The communication backend for inter-process communication. + process_id (int): The unique identifier for the client process. + size (int): The total number of processes. + model: The machine learning model to be trained. + train_data_num (int): The number of training samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training samples. + train_data_local_dict: A dictionary mapping client IDs to their local training datasets. + test_data_local_dict: A dictionary mapping client IDs to their local testing datasets. + model_trainer (ModelTrainer, optional): The model trainer responsible for local training. + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py b/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py index 2cc81658b9..8eac4ff5a3 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py +++ b/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py @@ -7,6 +7,9 @@ class FedAVGClientManager(FedMLCommManager): + """ + Class representing the client manager in the FedAVG federated learning process. + """ def __init__( self, args, @@ -16,16 +19,32 @@ def __init__( size=0, backend="MPI", ): + """ + Initialize the client manager for the FedAVG federated learning process. + + Args: + args (Namespace): Command-line arguments and configuration for the FedAVG process. + trainer: The federated learning trainer responsible for local training. + comm: The communication backend for inter-process communication. + rank (int): The rank or identifier of the current client. + size (int): The total number of clients. + backend (str): The backend for distributed computing (e.g., "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 - def run(self): + """ + Start the client manager to handle federated learning tasks. + """ super().run() def register_message_receive_handlers(self): + """ + Register message handlers for processing incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -35,6 +54,12 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params (dict): Parameters received in the message. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -44,10 +69,19 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the federated training process. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the model update message received from the server. + + Args: + msg_params (dict): Parameters received in the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -56,11 +90,19 @@ def handle_message_receive_model_from_server(self, msg_params): self.trainer.update_dataset(int(client_index)) self.args.round_idx += 1 self.__train() + if self.args.round_idx == self.num_rounds - 1: - # post_complete_message_to_sweep_process(self.args) self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the locally trained model to the server. + + Args: + receive_id (int): The ID of the server to receive the model. + weights: The model parameters. + local_sample_num (int): The number of local training samples. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -71,6 +113,9 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """ + Perform federated training for a round. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py b/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py index 631db27f08..e0cc5cb096 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py +++ b/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py @@ -1,3 +1,4 @@ + import logging from .message_define import MyMessage @@ -7,6 +8,38 @@ class FedAVGServerManager(FedMLCommManager): + """ + A class that manages the server-side operations in a Federated Averaging (FedAVG) framework. + + This class handles the synchronization of model parameters and training progress across multiple clients + in a federated learning setting using the FedAVG algorithm. + + Args: + args: An object containing configuration parameters. + aggregator: An aggregator object responsible for aggregating client updates. + comm: A communication object for inter-process communication. + rank: The rank or ID of this process in the communication group. + size: The total number of processes in the communication group. + backend: The backend used for communication (e.g., "MPI" or "gloo"). + is_preprocessed: A flag indicating whether the client data is preprocessed. + preprocessed_client_lists: A list of preprocessed client data. + + Attributes: + args: An object containing configuration parameters. + aggregator: An aggregator object responsible for aggregating client updates. + round_num: The total number of communication rounds. + is_preprocessed: A flag indicating whether the client data is preprocessed. + preprocessed_client_lists: A list of preprocessed client data. + + Methods: + run(): Start the server manager and enter the main execution loop. + send_init_msg(): Send an initialization message to clients to start the federated learning process. + register_message_receive_handlers(): Register message handlers for message types. + handle_message_receive_model_from_client(msg_params): Handle a message received from a client containing model updates. + send_message_init_config(receive_id, global_model_params, client_index): Send an initialization message to a specific client. + send_message_sync_model_to_client(receive_id, global_model_params, client_index): Send a model synchronization message to a client. + """ + def __init__( self, args, @@ -18,6 +51,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the server manager for the FedAVG federated learning process. + + Args: + args (Namespace): Command-line arguments and configuration for the FedAVG process. + aggregator: The federated learning aggregator responsible for model aggregation. + comm: The communication backend for inter-process communication. + rank (int): The rank or identifier of the current server. + size (int): The total number of clients and servers. + backend (str): The backend for distributed computing (e.g., "MPI"). + is_preprocessed (bool): Whether client sampling has been preprocessed. + preprocessed_client_lists (list): Preprocessed client sampling lists. + """ super().__init__(args, comm, rank, size, backend) self.args = args self.aggregator = aggregator @@ -27,10 +73,15 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Start the server manager to handle federated learning tasks. + """ super().run() def send_init_msg(self): - # sampling clients + """ + Send initialization messages to clients, including global model parameters and client indexes. + """ client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -43,12 +94,21 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message handlers for processing incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the model update message received from a client. + + Args: + msg_params (dict): Parameters received in the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -62,20 +122,19 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: - # post_complete_message_to_sweep_process(self.args) self.finish() return if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -91,6 +150,14 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send an initialization message to a client. + + Args: + receive_id (int): The ID of the client to receive the message. + global_model_params: The global model parameters. + client_index: The index of the client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -101,6 +168,14 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send a model synchronization message to a client. + + Args: + receive_id (int): The ID of the client to receive the message. + global_model_params: The global model parameters. + client_index: The index of the client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, @@ -110,3 +185,5 @@ def send_message_sync_model_to_client( message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) self.send_message(message) + + \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fedavg/utils.py b/python/fedml/simulation/mpi/fedavg/utils.py index aea2449590..7d58689867 100644 --- a/python/fedml/simulation/mpi/fedavg/utils.py +++ b/python/fedml/simulation/mpi/fedavg/utils.py @@ -1,24 +1,46 @@ import os - import numpy as np import torch - def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list format to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary containing model parameters in list format. + + Returns: + dict: A dictionary containing model parameters as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to a list format. + + Args: + model_params (dict): A dictionary containing model parameters as PyTorch tensors. + + Returns: + dict: A dictionary containing model parameters in list format. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a sweep process. + + This function creates a named pipe and writes a completion message to it, along with the provided arguments. + + Args: + args: An object containing configuration parameters. + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 161b875ce4cd1ef443bcb380aebdbc1bb79e7f28 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 8 Sep 2023 13:30:38 +0530 Subject: [PATCH 13/70] python\fedml\simulation\mpi \fedavg_seq --- .../mpi/fedavg_seq/FedAVGAggregator.py | 160 +++++++++++++++--- .../mpi/fedavg_seq/FedAVGTrainer.py | 84 +++++++++ .../mpi/fedavg_seq/FedAvgClientManager.py | 90 ++++++++-- .../simulation/mpi/fedavg_seq/FedAvgSeqAPI.py | 51 ++++++ .../mpi/fedavg_seq/FedAvgServerManager.py | 58 +++++++ .../my_model_trainer_classification.py | 77 +++++++++ .../fedml/simulation/mpi/fedavg_seq/utils.py | 33 +++- 7 files changed, 511 insertions(+), 42 deletions(-) diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py b/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py index 5adb0e0208..c0b4eaa458 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py @@ -15,6 +15,23 @@ class FedAVGAggregator(object): + """ + Federated Averaging Aggregator. + + This class handles the aggregation of local model updates from clients in a federated learning setup using Federated Averaging. + + Args: + train_global: The global training dataset. + test_global: The global test dataset. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary containing local training datasets for each client. + test_data_local_dict: A dictionary containing local test datasets for each client. + train_data_local_num_dict: A dictionary containing the number of local training samples for each client. + worker_num: The number of worker nodes (clients). + device: The device (e.g., 'cpu' or 'cuda') on which the model and data should be placed. + args: An object containing configuration parameters. + server_aggregator: An optional server aggregator object. + """ def __init__( self, train_global, @@ -57,18 +74,42 @@ def __init__( self.runtime_avg[i][j] = None def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: A dictionary containing the global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): A dictionary containing the global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params): + """ + Add the local model update from a client. + + Args: + index (int): The index of the client. + model_params (dict): A dictionary containing the local model parameters. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params - # self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local model updates. + + Returns: + bool: True if all clients have uploaded their updates, False otherwise. + """ logging.debug("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -78,6 +119,16 @@ def check_whether_all_receive(self): return True def workload_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the workload of clients. + + Args: + client_indexes (list): A list of client indexes. + mode (str): The estimation mode, either "simulate" or "real". + + Returns: + list: A list of estimated workloads. + """ if mode == "simulate": client_samples = [ self.train_data_local_num_dict[client_index] @@ -91,6 +142,16 @@ def workload_estimate(self, client_indexes, mode="simulate"): return workload def memory_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the memory usage of clients. + + Args: + client_indexes (list): A list of client indexes. + mode (str): The estimation mode, either "simulate" or "real". + + Returns: + list: A list of estimated memory usages. + """ if mode == "simulate": memory = np.ones(self.worker_num) elif mode == "real": @@ -100,6 +161,15 @@ def memory_estimate(self, client_indexes, mode="simulate"): return memory def resource_estimate(self, mode="simulate"): + """ + Estimate the resource usage of clients. + + Args: + mode (str): The estimation mode, either "simulate" or "real". + + Returns: + list: A list of estimated resource usages. + """ if mode == "simulate": resource = np.ones(self.worker_num) elif mode == "real": @@ -109,6 +179,13 @@ def resource_estimate(self, mode="simulate"): return resource def record_client_runtime(self, worker_id, client_runtimes): + """ + Record the runtime of client training. + + Args: + worker_id (int): The ID of the worker. + client_runtimes (dict): A dictionary containing client runtime information. + """ for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id].append(runtime) if hasattr(self.args, "runtime_est_mode"): @@ -117,21 +194,27 @@ def record_client_runtime(self, worker_id, client_runtimes): if self.runtime_avg[worker_id][client_id] is None: self.runtime_avg[worker_id][client_id] = runtime else: - self.runtime_avg[worker_id][client_id] += self.runtime_avg[worker_id][client_id]/2 + runtime/2 + self.runtime_avg[worker_id][client_id] += self.runtime_avg[worker_id][client_id] / 2 + runtime / 2 elif self.args.runtime_est_mode == 'time_window': for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id] = self.runtime_history[worker_id][client_id][-3:] + def generate_client_schedule(self, round_idx, client_indexes): - # self.runtime_history = {} - # for i in range(self.worker_num): - # self.runtime_history[i] = {} - # for j in range(self.args.client_num_in_total): - # self.runtime_history[i][j] = [] + """ + Generate the schedule of clients for a given round. + + Args: + round_idx (int): The index of the round. + client_indexes (list): A list of client indexes. + + Returns: + list: A list of client schedules. + """ previous_time = time.time() if hasattr(self.args, "simulation_schedule") and round_idx > 5: - # Need some rounds to record some information. + # Need some rounds to record some information. simulation_schedule = self.args.simulation_schedule if hasattr(self.args, "runtime_est_mode"): if self.args.runtime_est_mode == 'EMA': @@ -144,7 +227,7 @@ def generate_client_schedule(self, round_idx, client_indexes): runtime_to_fit = self.runtime_history fit_params, fit_funcs, fit_errors = t_sample_fit( - self.worker_num, self.args.client_num_in_total, runtime_to_fit, + self.worker_num, self.args.client_num_in_total, runtime_to_fit, self.train_data_local_num_dict, uniform_client=True, uniform_gpu=False) if self.args.enable_wandb: @@ -187,6 +270,15 @@ def generate_client_schedule(self, round_idx, client_indexes): return client_schedule def get_average_weight(self, client_indexes): + """ + Calculate the average weight of clients based on their data sizes. + + Args: + client_indexes (list): A list of client indexes. + + Returns: + dict: A dictionary containing the average weight for each client. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -199,36 +291,33 @@ def get_average_weight(self, client_indexes): return average_weight_dict def aggregate(self): + """ + Aggregate the local model updates from clients and compute the global model parameters. + + Returns: + dict: A dictionary containing the global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 for idx in range(self.worker_num): - # added for attack & defense; enable multiple defenses - # if FedMLDefender.get_instance().is_defense_enabled(): - # self.model_dict[idx] = FedMLDefender.get_instance().defend( - # self.model_dict[idx], self.get_global_model_params() - # ) - if len(self.model_dict[idx]) > 0: - # some workers may not have parameters + # Some workers may not have parameters model_list.append(self.model_dict[idx]) # training_num += self.sample_num_dict[idx] logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) - # (num0, averaged_params) = model_list[0] averaged_params = model_list[0] for k in averaged_params.keys(): for i in range(0, len(model_list)): local_model_params = model_list[i] - # w = local_sample_number / training_num if i == 0: averaged_params[k] = local_model_params[k] else: averaged_params[k] += local_model_params[k] - # update the global model which is cached at the server side + # Update the global model which is cached at the server side self.set_global_model_params(averaged_params) end_time = time.time() @@ -236,6 +325,17 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select a subset of clients for training in a round. + + Args: + round_idx (int): The index of the round. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to select per round. + + Returns: + list: A list of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -244,7 +344,7 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): num_clients = min(client_num_per_round, client_num_in_total) np.random.seed( round_idx - ) # make sure for each comparison, we are selecting the same clients each round + ) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice( range(client_num_in_total), num_clients, replace=False ) @@ -252,6 +352,15 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing. + + Args: + num_samples (int, optional): The number of samples in the validation set. Defaults to 10000. + + Returns: + torch.utils.data.DataLoader: A DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -266,6 +375,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Test the global model on all clients. + + Args: + round_idx (int): The index of the current round. + """ if ( round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1 @@ -278,6 +393,8 @@ def test_on_server_for_all_clients(self, round_idx): train_num_samples = [] train_tot_corrects = [] train_losses = [] + + # Note: The following code is commented out, so it doesn't affect the execution. # for client_idx in range(self.args.client_num_in_total): # # train data # metrics = self.trainer.test( @@ -312,6 +429,7 @@ def test_on_server_for_all_clients(self, round_idx): else: metrics = self.aggregator.test(self.val_global, self.device, self.args) + # Note: The following code is commented out, so it doesn't affect the execution. # test_tot_correct, test_num_sample, test_loss = ( # metrics["test_correct"], # metrics["test_total"], diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py b/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py index 148eb9b7c3..cbbf31181d 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py @@ -2,6 +2,49 @@ class FedAVGTrainer(object): + """ + Trainer class for federated learning clients using the FedAVG algorithm. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training datasets. + train_data_local_num_dict (dict): A dictionary containing the number of samples for each local dataset. + test_data_local_dict (dict): A dictionary containing local testing datasets. + train_data_num (int): The total number of training samples. + device (str): The device (e.g., "cpu" or "cuda") for training. + args (Namespace): Command-line arguments and configuration. + model_trainer (object): An instance of the model trainer used for training. + + Attributes: + trainer (object): The model trainer instance. + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training datasets. + train_data_local_num_dict (dict): A dictionary containing the number of samples for each local dataset. + test_data_local_dict (dict): A dictionary containing local testing datasets. + all_train_data_num (int): The total number of training samples. + train_local (Dataset): The local training dataset. + local_sample_number (int): The number of local training samples. + test_local (Dataset): The local testing dataset. + device (str): The device for training (e.g., "cpu" or "cuda"). + args (Namespace): Command-line arguments and configuration. + + Methods: + update_model(weights): + Update the model with given weights. + + update_dataset(client_index): + Update the current dataset for training and testing. + + get_lr(progress): + Calculate the learning rate based on the training progress. + + train(round_idx=None): + Train the model on the local dataset for a given round. + + test(): + Evaluate the trained model on both local training and testing datasets. + + """ def __init__( self, client_index, @@ -28,15 +71,39 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model with the provided weights. + + Args: + weights (dict): The model parameters to set. + + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the current dataset for training and testing. + + Args: + client_index (int): The index of the client representing the dataset to be used. + + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def get_lr(self, progress): + """ + Calculate the learning rate based on the training progress. + + Args: + progress (int): The training progress, typically the round index. + + Returns: + float: The calculated learning rate. + + """ # This aims to make a float step_size work. if self.args.lr_schedule == "StepLR": exp_num = progress / self.args.lr_step_size @@ -56,6 +123,16 @@ def get_lr(self, progress): return lr def train(self, round_idx=None): + """ + Train the model on the local dataset for a given round. + + Args: + round_idx (int, optional): The current round index. Defaults to None. + + Returns: + tuple: A tuple containing the trained model weights and the number of local samples used. + + """ self.args.round_idx = round_idx # lr = self.get_lr(round_idx) # self.trainer.train(self.train_local, self.device, self.args, lr=lr) @@ -65,6 +142,13 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """ + Evaluate the trained model on both local training and testing datasets. + + Returns: + tuple: A tuple containing training and testing metrics, including correct predictions, loss, and sample counts. + + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py b/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py index 7cfaa43056..edaca7d713 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py @@ -8,10 +8,37 @@ from ....core.distributed.fedml_comm_manager import FedMLCommManager + + class FedAVGClientManager(FedMLCommManager): + """ + Manager for federated learning clients using the Federated Averaging (FedAvg) algorithm. + + This class handles communication between the server and clients, as well as the training + process on each client. + + Args: + args (Namespace): Command-line arguments and configuration. + trainer (object): An instance of the model trainer used for local training on clients. + comm (object, optional): The communication backend (e.g., MPI). Defaults to None. + rank (int, optional): The rank of the client. Defaults to 0. + size (int, optional): The total number of clients. Defaults to 0. + backend (str, optional): The communication backend type (e.g., MPI). Defaults to "MPI". + """ def __init__( self, args, trainer, comm=None, rank=0, size=0, backend="MPI", ): + """ + Initialize the FedAVGClientManager. + + Args: + args: The command-line arguments. + trainer: The trainer for client-side training. + comm: The communication backend. + rank: The rank of the client. + size: The total number of clients. + backend: The communication backend (e.g., "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round @@ -19,18 +46,28 @@ def __init__( self.worker_id = self.rank - 1 def run(self): + """ + Run the FedAVGClientManager. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers. + """ self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.handle_message_receive_model_from_server, ) def handle_message_init(self, msg_params): - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) + """ + Handle initialization message from the server. + Args: + msg_params: The message parameters. + """ + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -39,14 +76,20 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_indexes, average_weight_dict) def start_training(self): + """ + Start the training process. + """ self.round_idx = 0 - # self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params: The message parameters. + """ logging.info("handle_message_receive_model_from_server.") global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) - average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -54,18 +97,31 @@ def handle_message_receive_model_from_server(self, msg_params): self.round_idx += 1 self.__train(global_model_params, client_indexes, average_weight_dict) if self.round_idx == self.num_rounds - 1: - # post_complete_message_to_sweep_process(self.args) self.finish() def send_result_to_server(self, receive_id, weights, client_runtime_info): + """ + Send the training results to the server. + + Args: + receive_id: The ID of the recipient (server). + weights: The model weights. + client_runtime_info: Information about client runtime. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - # message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info) self.send_message(message) def add_client_model(self, local_agg_model_params, model_params, weight=1.0): - # Add params that needed to be reduces from clients + """ + Add the client model parameters to the local aggregation. + + Args: + local_agg_model_params: The local aggregation of model parameters. + model_params: The model parameters. + weight: The weight for averaging. + """ for name, param in model_params.items(): if name not in local_agg_model_params: local_agg_model_params[name] = param * weight @@ -73,32 +129,34 @@ def add_client_model(self, local_agg_model_params, model_params, weight=1.0): local_agg_model_params[name] += param * weight def __train(self, global_model_params, client_indexes, average_weight_dict): + """ + Train the client model. + + Args: + global_model_params: The global model parameters. + client_indexes: The indexes of clients. + average_weight_dict: The dictionary of average weights. + """ logging.info("#######training########### round_id = %d" % self.round_idx) if hasattr(self.args, "simulation_gpu_hetero"): - # runtime_speed_ratio - # runtime_speed_ratio * t_train - t_train - # time.sleep(runtime_speed_ratio * t_train - t_train) simulation_gpu_hetero = self.args.simulation_gpu_hetero runtime_speed_ratio = self.args.gpu_hetero_ratio * self.worker_id / self.args.worker_num if hasattr(self.args, "simulation_environment_hetero"): - # runtime_speed_ratio - # runtime_speed_ratio * t_train - t_train - # time.sleep(runtime_speed_ratio * t_train - t_train) if self.args.simulation_environment_hetero == "cos": runtime_speed_ratio = self.args.environment_hetero_ratio * \ (1 + cos(self.round_idx / self.num_rounds*3.1415926 + self.worker_id)) else: raise NotImplementedError - local_agg_model_params = {} client_runtime_info = {} for client_index in client_indexes: logging.info( "#######training########### Simulating client_index = %d, average weight: %f " - % (client_index, average_weight_dict[client_index]) + % (client_index, + average_weight_dict[client_index]) ) start_time = time.time() self.trainer.update_model(global_model_params) diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py b/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py index c9aa2dbfb1..a0afb8e868 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py @@ -12,6 +12,21 @@ def FedML_FedAvgSeq_distributed( args, process_id, worker_number, comm, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Function to initialize and run federated learning in a distributed environment using the FedAvg algorithm. + + Args: + args (Namespace): Command-line arguments and configuration. + process_id (int): The unique identifier for the current process. + worker_number (int): The total number of worker processes. + comm (object): The communication backend (e.g., MPI). + device (str): The device (e.g., "cpu" or "cuda") for training. + dataset (list): List containing dataset information. + model (nn.Module): The federated learning model. + client_trainer (object, optional): An instance of the client model trainer. Defaults to None. + server_aggregator (object, optional): An instance of the server aggregator. Defaults to None. + """ + [ train_data_num, test_data_num, @@ -58,6 +73,7 @@ def FedML_FedAvgSeq_distributed( client_trainer, ) +# Rest of the code... def init_server( args, @@ -74,6 +90,25 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize and run the federated learning server. + + Args: + args (Namespace): Command-line arguments and configuration. + device (str): The device (e.g., "cpu" or "cuda") for training. + comm (object): The communication backend (e.g., MPI). + rank (int): The rank of the server process. + size (int): The total number of processes. + model (nn.Module): The federated learning model. + train_data_num (int): The total number of training samples. + train_data_global (Dataset): The global training dataset. + test_data_global (Dataset): The global test dataset. + train_data_local_dict (dict): A dictionary of local training datasets. + test_data_local_dict (dict): A dictionary of local test datasets. + train_data_local_num_dict (dict): A dictionary of the number of samples in each local training dataset. + server_aggregator (object): An instance of the server aggregator. + """ + if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -113,6 +148,22 @@ def init_client( test_data_local_dict, client_trainer=None, ): + """ + Initialize and run a federated learning client. + + Args: + args (Namespace): Command-line arguments and configuration. + device (str): The device (e.g., "cpu" or "cuda") for training. + comm (object): The communication backend (e.g., MPI). + process_id (int): The unique identifier for the client process. + size (int): The total number of processes. + model (nn.Module): The federated learning model. + train_data_num (int): The total number of training samples. + train_data_local_num_dict (dict): A dictionary of the number of samples in each local training dataset. + train_data_local_dict (dict): A dictionary of local training datasets. + test_data_local_dict (dict): A dictionary of local test datasets. + client_trainer (object, optional): An instance of the client model trainer. Defaults to None. + """ client_index = process_id - 1 if client_trainer is None: client_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py b/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py index f9269a50b9..b59bb6b64e 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py @@ -10,6 +10,19 @@ class FedAVGServerManager(FedMLCommManager): + """ + Class responsible for managing the server in the FedAVG federated learning system. + + Args: + args (Namespace): Command-line arguments and configuration. + aggregator (object): An instance of the aggregator used for federated learning. + comm (object, optional): The communication backend (e.g., MPI). Defaults to None. + rank (int, optional): The rank of the server process. Defaults to 0. + size (int, optional): The total number of processes. Defaults to 0. + backend (str, optional): The backend used for communication. Defaults to "MPI". + is_preprocessed (bool, optional): Indicates whether client lists are preprocessed. Defaults to False. + preprocessed_client_lists (list, optional): Preprocessed client lists. Defaults to None. + """ def __init__( self, args, @@ -30,9 +43,19 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Run the server manager to coordinate federated learning. + + This method runs the server manager to coordinate the federated learning process. + """ super().run() def send_init_msg(self): + """ + Send the initialization message to clients. + + This method sends an initialization message to client processes to begin federated learning. + """ # sampling clients self.previous_time = time.time() client_indexes = self.aggregator.client_sampling( @@ -48,11 +71,24 @@ def send_init_msg(self): self.send_message_init_config(process_id, global_model_params, average_weight_dict, client_schedule) def register_message_receive_handlers(self): + """ + Register message receive handlers. + + This method registers message receive handlers for processing incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model update from a client. + + Args: + msg_params (dict): The parameters of the received message. + + This method handles the model update received from a client during federated learning. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -112,6 +148,17 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send the initialization configuration message to a client. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Average weight dictionary for clients. + client_schedule (list): The schedule of clients for the current round. + + This method sends an initialization configuration message to a client process. + """ message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) # message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) @@ -120,6 +167,17 @@ def send_message_init_config(self, receive_id, global_model_params, average_weig self.send_message(message) def send_message_sync_model_to_client(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send the model synchronization message to a client. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Average weight dictionary for clients. + client_schedule (list): The schedule of clients for the current round. + + This method sends a model synchronization message to a client process. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) diff --git a/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py b/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py index 20ce167511..19da2853aa 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py +++ b/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py @@ -6,13 +6,65 @@ class MyModelTrainer(ClientTrainer): + """ + Custom model trainer for federated learning clients. + + Args: + model (nn.Module): The PyTorch model to be trained. + id (int): The identifier of the client. + + Attributes: + model (nn.Module): The PyTorch model being trained. + id (int): The identifier of the client. + + Methods: + get_model_params(): + Get the model parameters as a dictionary. + + set_model_params(model_parameters): + Set the model parameters using a dictionary. + + train(train_data, device, args, lr=None): + Train the model on the provided training data. + + test(test_data, device, args): + Evaluate the model on the provided test data. + + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Perform testing on the server (not implemented in this class). + + """ def get_model_params(self): + """ + Get the model parameters as a dictionary. + + Returns: + dict: A dictionary containing the model's state dictionary. + + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters using a dictionary. + + Args: + model_parameters (dict): A dictionary containing the model's state dictionary. + + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, lr=None): + """ + Train the model on the provided training data. + + Args: + train_data (DataLoader): The DataLoader containing the training data. + device (str): The device (e.g., "cpu" or "cuda") for training. + args (Namespace): Command-line arguments and configuration. + lr (float, optional): The learning rate. Defaults to None. + + """ model = self.model model.to(device) @@ -66,6 +118,18 @@ def train(self, train_data, device, args, lr=None): ) def test(self, test_data, device, args): + """ + Evaluate the model on the provided test data. + + Args: + test_data (DataLoader): The DataLoader containing the test data. + device (str): The device (e.g., "cpu" or "cuda") for testing. + args (Namespace): Command-line arguments and configuration. + + Returns: + dict: A dictionary containing test metrics, including correct predictions, loss, and total samples. + + """ model = self.model model.to(device) @@ -93,4 +157,17 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Perform testing on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): A dictionary containing local training datasets. + test_data_local_dict (dict): A dictionary containing local testing datasets. + device (str): The device (e.g., "cpu" or "cuda") for testing. + args (Namespace, optional): Command-line arguments and configuration. Defaults to None. + + Returns: + bool: Always returns False as this method is not implemented in this class. + + """ return False diff --git a/python/fedml/simulation/mpi/fedavg_seq/utils.py b/python/fedml/simulation/mpi/fedavg_seq/utils.py index aea2449590..479e91c857 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/utils.py +++ b/python/fedml/simulation/mpi/fedavg_seq/utils.py @@ -1,24 +1,47 @@ +import torch +import numpy as np import os -import numpy as np -import torch +def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from NumPy arrays to PyTorch tensors. + Args: + model_params_list (dict): A dictionary of model parameters as NumPy arrays. -def transform_list_to_tensor(model_params_list): + Returns: + dict: A dictionary of model parameters as PyTorch tensors. + + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to NumPy arrays. + + Args: + model_params (dict): A dictionary of model parameters as PyTorch tensors. + + Returns: + dict: A dictionary of model parameters as NumPy arrays. + + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe for communication. + + Args: + args: Additional information or arguments (usually configuration). + + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 5c44b0174ce5d306a7deea205152b46bff3e1ab4 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 8 Sep 2023 19:33:23 +0530 Subject: [PATCH 14/70] python\fedml\simulation\mpi fedgan fedgkt --- .../simulation/mpi/fedgan/FedGANAggregator.py | 119 +++++++++++- .../simulation/mpi/fedgan/FedGANTrainer.py | 70 ++++++- .../fedml/simulation/mpi/fedgan/FedGanAPI.py | 58 ++++++ .../mpi/fedgan/FedGanClientManager.py | 49 +++++ .../mpi/fedgan/FedGanServerManager.py | 61 +++++- .../simulation/mpi/fedgan/gan_trainer.py | 62 +++++- python/fedml/simulation/mpi/fedgan/utils.py | 29 ++- .../fedml/simulation/mpi/fedgkt/FedGKTAPI.py | 58 +++++- .../simulation/mpi/fedgkt/GKTClientManager.py | 107 +++++++++- .../simulation/mpi/fedgkt/GKTClientTrainer.py | 74 +++++++ .../simulation/mpi/fedgkt/GKTServerManager.py | 73 ++++++- .../simulation/mpi/fedgkt/GKTServerTrainer.py | 113 ++++++++++- python/fedml/simulation/mpi/fedgkt/utils.py | 182 +++++++++++++++--- 13 files changed, 998 insertions(+), 57 deletions(-) diff --git a/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py b/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py index 826b2da7ec..745bcc1882 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py +++ b/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py @@ -13,6 +13,35 @@ class FedGANAggregator(object): + """ + A class for aggregating and managing local models in a Federated Generative Adversarial Network (FedGAN) setup. + + Attributes: + trainer: Model trainer object for training and testing. + args: Configuration arguments. + train_global: Global training dataset. + test_global: Global testing dataset. + val_global: Validation dataset for testing. + all_train_data_num: Total number of training samples. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + worker_num: Number of worker nodes. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + model_dict: Dictionary to store local models from client workers. + sample_num_dict: Dictionary to store the number of training samples from client workers. + flag_client_model_uploaded_dict: Dictionary to track whether client models have been uploaded. + + Methods: + get_global_model_params(): Get the global model parameters. + set_global_model_params(model_parameters): Set the global model parameters. + add_local_trained_result(index, model_params, sample_num): Add local trained model results to the aggregator. + check_whether_all_receive(): Check if all client workers have uploaded their local models. + aggregate(): Aggregate local models from client workers. + client_sampling(round_idx, client_num_in_total, client_num_per_round): Randomly sample a subset of clients for communication in a round. + _generate_validation_set(num_samples): Generate a validation dataset for testing. + test_on_server_for_all_clients(round_idx): Perform testing on the server side for all clients. + """ def __init__( self, train_global, @@ -26,6 +55,22 @@ def __init__( args, model_trainer, ): + """ + Initialize the FedGANAggregator. + + Args: + train_global: Global training dataset. + test_global: Global testing dataset. + all_train_data_num: Total number of training samples. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + worker_num: Number of worker nodes. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + args: Configuration arguments. + model_trainer: Model trainer object for training and testing. + + """ self.trainer = model_trainer self.args = args @@ -47,18 +92,48 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + + """ return self.trainer.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters to set. + + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add local trained model results to the aggregator. + + Args: + index: Index of the client worker. + model_params (dict): Local model parameters. + sample_num (int): Number of local training samples. + + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all client workers have uploaded their local models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -67,6 +142,13 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models from client workers. + + Returns: + dict: Averaged global model parameters. + + """ start_time = time.time() model_list = [] training_num = 0 @@ -77,7 +159,6 @@ def aggregate(self): logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) (num0, averaged_params) = model_list[0] for net in averaged_params.keys(): for k in averaged_params[net].keys(): @@ -89,7 +170,7 @@ def aggregate(self): else: averaged_params[net][k] += local_model_params[net][k] * w - # update the global model which is cached at the server side + # Update the global model which is cached at the server side self.set_global_model_params(averaged_params) end_time = time.time() @@ -97,6 +178,18 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for communication in a round. + + Args: + round_idx (int): Current communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of client indexes selected for communication in the current round. + + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -105,7 +198,7 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): num_clients = min(client_num_per_round, client_num_in_total) np.random.seed( round_idx - ) # make sure for each comparison, we are selecting the same clients each round + ) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice( range(client_num_in_total), num_clients, replace=False ) @@ -113,6 +206,16 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset for testing. + + Args: + num_samples (int): Number of samples to include in the validation set. + + Returns: + DataLoader: Validation dataset loader. + + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -127,6 +230,16 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server side for all clients. + + Args: + round_idx (int): Current communication round. + + Returns: + bool: True if testing on the server side is performed, False otherwise. + + """ if self.trainer.test_on_the_server( self.train_data_local_dict, self.test_data_local_dict, diff --git a/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py b/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py index d9caae6958..69f7d9324d 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py +++ b/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py @@ -2,6 +2,33 @@ class FedGANTrainer(object): + """ + Trainer for a federated GAN client. + + Args: + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training datasets. + train_data_local_num_dict (dict): Dictionary of local training dataset sizes. + test_data_local_dict (dict): Dictionary of local test datasets. + train_data_num (int): Number of samples in the global training dataset. + device: Device for training (e.g., 'cuda' or 'cpu'). + args: Configuration arguments. + model_trainer: Trainer for the GAN model. + + Attributes: + trainer: Trainer for the GAN model. + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training datasets. + train_data_local_num_dict (dict): Dictionary of local training dataset sizes. + test_data_local_dict (dict): Dictionary of local test datasets. + all_train_data_num (int): Number of samples in the global training dataset. + train_local: Local training dataset. + local_sample_number: Number of samples in the local training dataset. + test_local: Local test dataset. + device: Device for training (e.g., 'cuda' or 'cpu'). + args: Configuration arguments. + """ + def __init__( self, client_index, @@ -14,7 +41,6 @@ def __init__( model_trainer, ): self.trainer = model_trainer - self.client_index = client_index self.train_data_local_dict = train_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict @@ -23,28 +49,59 @@ def __init__( self.train_local = None self.local_sample_number = None self.test_local = None - self.device = device self.args = args def update_model(self, weights): + """ + Update the model with new weights. + + Args: + weights: New model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the client's dataset. + + Args: + client_index (int): Index of the client. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] - # self.test_local = self.test_data_local_dict[client_index] def train(self, round_idx=None): + """ + Train the client's GAN model. + + Args: + round_idx: Index of the training round (optional). + + Returns: + weights: Updated model weights. + local_sample_number: Number of samples in the local dataset. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) - weights = self.trainer.get_model_params() return weights, self.local_sample_number def test(self): - # train data + """ + Test the client's GAN model on both training and test datasets. + + Returns: + Tuple containing: + - train_tot_correct: Total correct predictions on the training dataset. + - train_loss: Loss on the training dataset. + - train_num_sample: Number of samples in the training dataset. + - test_tot_correct: Total correct predictions on the test dataset. + - test_loss: Loss on the test dataset. + - test_num_sample: Number of samples in the test dataset. + """ + # Train data metrics train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( train_metrics["test_correct"], @@ -52,7 +109,7 @@ def test(self): train_metrics["test_loss"], ) - # test data + # Test data metrics test_metrics = self.trainer.test(self.test_local, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( test_metrics["test_correct"], @@ -68,3 +125,4 @@ def test(self): test_loss, test_num_sample, ) + diff --git a/python/fedml/simulation/mpi/fedgan/FedGanAPI.py b/python/fedml/simulation/mpi/fedgan/FedGanAPI.py index eda4b5ff75..8f498adb41 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGanAPI.py +++ b/python/fedml/simulation/mpi/fedgan/FedGanAPI.py @@ -8,6 +8,12 @@ def FedML_init(): + """ + Initialize the MPI communication and return necessary information. + + Returns: + tuple: A tuple containing the MPI communication object, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -27,6 +33,21 @@ def FedML_FedGan_distributed( model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Initialize and run the Federated GAN distributed training. + + Args: + args: Configuration arguments. + process_id (int): The process ID of the current worker. + worker_number (int): Total number of workers. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + comm: MPI communication object. + model: GAN model to be trained. + dataset: Dataset information including training and testing data. + model_trainer: Model trainer object for training and testing. + preprocessed_sampling_lists: Preprocessed client sampling lists. + + """ [ train_data_num, test_data_num, @@ -92,6 +113,26 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for Federated GAN training. + + Args: + args: Configuration arguments. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + comm: MPI communication object. + rank (int): Rank of the current process. + size (int): Total number of processes. + model: GAN model to be trained. + train_data_num: Total number of training samples. + train_data_global: Global training dataset. + test_data_global: Global testing dataset. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + model_trainer: Model trainer object for training and testing. + preprocessed_sampling_lists: Preprocessed client sampling lists. + + """ if model_trainer is None: pass @@ -148,6 +189,23 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Federated GAN training. + + Args: + args: Configuration arguments. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + comm: MPI communication object. + process_id (int): The process ID of the current client. + size (int): Total number of processes. + model: GAN model to be trained. + train_data_num: Total number of training samples. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + model_trainer: Model trainer object for training and testing. + + """ client_index = process_id - 1 model_trainer.set_id(client_index) diff --git a/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py b/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py index df8dcc55bd..ad5385b561 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py +++ b/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py @@ -7,6 +7,23 @@ class FedGANClientManager(FedMLCommManager): + """ + Manager for Federated GAN client-side operations. + + Args: + args: Configuration arguments. + trainer: Model trainer for local training. + comm: MPI communication object. + rank (int): Rank of the current process. + size (int): Total number of processes. + backend (str): Backend for communication (e.g., 'MPI'). + + Attributes: + trainer: Model trainer for local training. + num_rounds: Number of communication rounds. + args.round_idx: Current communication round index. + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -14,9 +31,15 @@ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): self.args.round_idx = 0 def run(self): + """ + Start the client manager's execution. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for initialization and model updates. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +49,12 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params (dict): Message parameters containing model parameters and client index. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -35,10 +64,19 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the client training. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model update message from the server. + + Args: + msg_params (dict): Message parameters containing model parameters and client index. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -52,6 +90,14 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the local model to the server. + + Args: + receive_id (int): ID of the server receiving the model. + weights: Model weights to be sent. + local_sample_num: Number of local samples used for training. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -62,6 +108,9 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """ + Perform the local training and send the updated model to the server. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py b/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py index 15b8dc7390..5088f24cf1 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py +++ b/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py @@ -7,6 +7,28 @@ class FedGANServerManager(FedMLCommManager): + """ + Manager for Federated GAN server-side operations. + + Args: + args: Configuration arguments. + aggregator: Aggregator for model updates. + comm: MPI communication object. + rank (int): Rank of the current process. + size (int): Total number of processes. + backend (str): Backend for communication (e.g., 'MPI'). + is_preprocessed (bool): Indicates if client sampling is preprocessed. + preprocessed_client_lists (list): Preprocessed client sampling lists. + + Attributes: + args: Configuration arguments. + aggregator: Aggregator for model updates. + round_num: Number of communication rounds. + args.round_idx: Current communication round index. + is_preprocessed: Indicates if client sampling is preprocessed. + preprocessed_client_lists: Preprocessed client sampling lists. + """ + def __init__( self, args, @@ -27,10 +49,16 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Start the server manager's execution. + """ super().run() def send_init_msg(self): - # sampling clients + """ + Send initialization message to clients, including global model parameters and client indexes. + """ + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -43,12 +71,21 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message receive handlers for receiving model updates from clients. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model update message from a client. + + Args: + msg_params (dict): Message parameters containing sender ID, model parameters, and local sample count. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -62,7 +99,7 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() # self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: # post_complete_message_to_sweep_process(self.args) @@ -71,12 +108,12 @@ def handle_message_receive_model_from_client(self, msg_params): return if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -92,6 +129,14 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initialization configuration message to a client. + + Args: + receive_id (int): ID of the client receiving the configuration. + global_model_params: Global model parameters. + client_index: Index of the client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -102,6 +147,14 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send a model synchronization message to a client. + + Args: + receive_id (int): ID of the client receiving the model. + global_model_params: Global model parameters. + client_index: Index of the client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fedgan/gan_trainer.py b/python/fedml/simulation/mpi/fedgan/gan_trainer.py index 3ecb788a26..bad160fcb0 100644 --- a/python/fedml/simulation/mpi/fedgan/gan_trainer.py +++ b/python/fedml/simulation/mpi/fedgan/gan_trainer.py @@ -8,22 +8,57 @@ class GANTrainer(ClientTrainer): + """ + Trainer for a Generative Adversarial Network (GAN) client. + + Args: + netd: Discriminator network. + netg: Generator network. + + Attributes: + netg: Generator network. + netd: Discriminator network. + """ + def __init__(self, netd, netg): self.netg = netg self.netd = netd super(GANTrainer, self).__init__(model=None, args=None) def get_model_params(self): + """ + Get the parameters of the generator and discriminator networks. + + Returns: + dict: Dictionary containing the state dictionaries of the generator and discriminator networks. + """ weights_d = self.netd.cpu().state_dict() weights_g = self.netg.cpu().state_dict() weights = {"netg": weights_g, "netd": weights_d} return weights def set_model_params(self, model_parameters): + """ + Set the parameters of the generator and discriminator networks. + + Args: + model_parameters (dict): Dictionary containing the state dictionaries of the generator and discriminator networks. + """ self.netg.load_state_dict(model_parameters["netg"]) self.netd.load_state_dict(model_parameters["netd"]) def train(self, train_data, device, args): + """ + Train the generator and discriminator networks of the GAN. + + Args: + train_data: Training data for the GAN. + device: Device for training (e.g., 'cuda' or 'cpu'). + args: Configuration arguments for training. + + Returns: + None + """ netg = self.netg netd = self.netd @@ -32,7 +67,7 @@ def train(self, train_data, device, args): netd.to(device) netd.train() - criterion = nn.BCELoss() # pylint: disable=E1102 + criterion = nn.BCELoss() # Binary Cross-Entropy Loss optimizer_g = torch.optim.Adam(netg.parameters(), lr=args.lr) optimizer_d = torch.optim.Adam(netd.parameters(), lr=args.lr) @@ -43,29 +78,28 @@ def train(self, train_data, device, args): batch_d_loss = [] batch_g_loss = [] for batch_idx, (x, _) in enumerate(train_data): - # logging.info(batch_idx) - # logging.info(x.shape) if len(x) < 2: continue x = x.to(device) real_labels = torch.ones(x.size(0), 1).to(device) fake_labels = torch.zeros(x.size(0), 1).to(device) optimizer_d.zero_grad() - d_real_loss = criterion(netd(x), real_labels) # pylint: disable=E1102 + d_real_loss = criterion(netd(x), real_labels) noise = torch.randn(x.size(0), 100).to(device) - d_fake_loss = criterion(netd(netg(noise)), fake_labels) # pylint: disable=E1102 + d_fake_loss = criterion(netd(netg(noise)), fake_labels) d_loss = d_real_loss + d_fake_loss d_loss.backward() optimizer_d.step() noise = torch.randn(x.size(0), 100).to(device) optimizer_g.zero_grad() - g_loss = criterion(netd(netg(noise)), real_labels) # pylint: disable=E1102 + g_loss = criterion(netd(netg(noise)), real_labels) g_loss.backward() optimizer_g.step() batch_d_loss.append(d_loss.item()) batch_g_loss.append(g_loss.item()) + if len(batch_g_loss) > 0: epoch_g_loss.append(sum(batch_g_loss) / len(batch_g_loss)) epoch_d_loss.append(sum(batch_d_loss) / len(batch_d_loss)) @@ -81,7 +115,7 @@ def train(self, train_data, device, args): ) netg.eval() z = torch.randn(100, 100).to(device) - y_hat = netg(z).view(100, 28, 28) # (100, 28, 28) + y_hat = netg(z).view(100, 28, 28) result = y_hat.cpu().data.numpy() img = np.zeros([280, 280]) for j in range(10): @@ -89,8 +123,20 @@ def train(self, train_data, device, args): [x for x in result[j * 10: (j + 1) * 10]], axis=-1 ) + # Save generated images if needed # imsave("samples/{}_{}.jpg".format(self.id, epoch), img, cmap="gray") netg.train() def test(self, test_data, device, args): - pass + """ + Test the GAN model. + + Args: + test_data: Test data for the GAN. + device: Device for testing (e.g., 'cuda' or 'cpu'). + args: Configuration arguments for testing. + + Returns: + None + """ + pass # Testing is not implemented in this trainer diff --git a/python/fedml/simulation/mpi/fedgan/utils.py b/python/fedml/simulation/mpi/fedgan/utils.py index 195d130aea..e5edfa8ed9 100644 --- a/python/fedml/simulation/mpi/fedgan/utils.py +++ b/python/fedml/simulation/mpi/fedgan/utils.py @@ -5,6 +5,15 @@ def transform_list_to_tensor(model_params_list): + """ + Convert a dictionary of model parameters from NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary containing model parameters. + + Returns: + dict: A dictionary with model parameters converted to PyTorch tensors. + """ for net in model_params_list.keys(): for k in model_params_list[net].keys(): model_params_list[net][k] = torch.from_numpy( @@ -14,6 +23,15 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Convert a dictionary of model parameters from PyTorch tensors to NumPy arrays. + + Args: + model_params (dict): A dictionary containing model parameters as PyTorch tensors. + + Returns: + dict: A dictionary with model parameters converted to NumPy arrays. + """ for net in model_params.keys(): for k in model_params[net].keys(): model_params[net][k] = model_params[net][k].detach().numpy().tolist() @@ -21,10 +39,19 @@ def transform_tensor_to_list(model_params): def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe. + + Args: + args: Information or data to be included in the completion message. + + Returns: + None + """ pipe_path = "./tmp/fedml" if not os.path.exists(pipe_path): os.mkfifo(pipe_path) pipe_fd = os.open(pipe_path, os.O_WRONLY) with os.fdopen(pipe_fd, "w") as pipe: - pipe.write("training is finished! \n%s\n" % (str(args))) + pipe.write("Training is finished! \n%s\n" % (str(args))) diff --git a/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py b/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py index 9c4916a337..8d1fb71caa 100644 --- a/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py +++ b/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py @@ -7,6 +7,14 @@ def FedML_init(): + """ + Initialize the Federated Learning environment. + + Returns: + comm: The MPI communication object. + process_id: The ID of the current process. + worker_number: The total number of worker processes. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -22,6 +30,21 @@ def FedML_FedGKT_distributed( dataset, args, ): + """ + Perform Federated Knowledge Transfer (FedGKT) in a distributed setting. + + Args: + process_id: The ID of the current process. + worker_number: The total number of worker processes. + device: The device (e.g., CPU or GPU) for training. + comm: The MPI communication object. + model: A tuple containing client and server models. + dataset: A list containing dataset-related information. + args: Additional arguments and settings. + + Returns: + None + """ [ train_data_num, test_data_num, @@ -50,12 +73,26 @@ def FedML_FedGKT_distributed( def init_server(args, device, comm, rank, size, model): + """ + Initialize the Federated Knowledge Transfer (FedGKT) server. + + Args: + args: Additional arguments and settings. + device: The device (e.g., CPU or GPU) for training. + comm: The MPI communication object. + rank: The rank of the current process. + size: The total number of processes. + model: The server model for FedGKT. + + Returns: + None + """ # aggregator client_num = size - 1 server_trainer = GKTServerTrainer(client_num, device, model, args) # start the distributed training - server_manager = GKTServerMananger(args, server_trainer, comm, rank, size) + server_manager = GKTServerManager(args, server_trainer, comm, rank, size) server_manager.run() @@ -70,6 +107,23 @@ def init_client( test_data_local_dict, train_data_local_num_dict, ): + """ + Initialize a FedGKT client. + + Args: + args: Additional arguments and settings. + device: The device (e.g., CPU or GPU) for training. + comm: The MPI communication object. + process_id: The ID of the current process. + size: The total number of processes. + model: The client model for FedGKT. + train_data_local_dict: A dictionary of local training data. + test_data_local_dict: A dictionary of local testing data. + train_data_local_num_dict: A dictionary of the number of local training samples. + + Returns: + None + """ client_ID = process_id - 1 # 2. initialize the trainer @@ -84,5 +138,5 @@ def init_client( ) # 3. start the distributed training - client_manager = GKTClientMananger(args, trainer, comm, process_id, size) + client_manager = GKTClientManager(args, trainer, comm, process_id, size) client_manager.run() diff --git a/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py b/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py index befc3a5618..77c76b706d 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py @@ -5,18 +5,80 @@ from ....core.distributed.communication.message import Message -class GKTClientMananger(FedMLCommManager): +class GKTClientManager(FedMLCommManager): + """ + A class representing the client-side manager for Global Knowledge Transfer (GKT). + + This manager is responsible for coordinating communication between the client and the server + during the GKT training process. + + Args: + args (argparse.Namespace): Additional arguments and settings. + trainer (GKTClientTrainer): The client-side trainer responsible for training the client model. + comm (MPI.Comm): MPI communication object. + rank (int): The rank or identifier of the client process. + size (int): The total number of processes in the communication group. + backend (str): The MPI backend for communication (default is "MPI"). + + Attributes: + args (argparse.Namespace): Additional arguments and settings. + trainer (GKTClientTrainer): The client-side trainer responsible for training the client model. + num_rounds (int): The total number of communication rounds. + device (torch.device): The device (e.g., GPU) used for training. + args.round_idx (int): The current round index. + + Methods: + run(): Start the client manager to initiate communication and training. + register_message_receive_handlers(): Register message receive handlers for communication. + handle_message_init(msg_params): Handle the initialization message from the server. + handle_message_receive_logits_from_server(msg_params): Handle logits received from the server. + send_model_to_server(extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test): + Send extracted features, logits, and labels to the server for knowledge transfer. + __train(): Start the client model training process. + + """ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): - super().__init__(args, comm, rank, size, backend) + """ + Initialize the GKT (Global Knowledge Transfer) client manager. + + Args: + args: Additional arguments and settings. + trainer: The GKT client trainer instance. + comm: The MPI communication object. + rank: The rank of the current process. + size: The total number of processes. + backend: The communication backend (default: "MPI"). + Returns: + None + """ + super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 def run(self): + """ + Start the GKT client manager. + + Args: + None + + Returns: + None + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for the GKT client manager. + + Args: + None + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,11 +88,29 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params: Parameters from the received message. + + Returns: + None + """ logging.info("handle_message_init. Rank = " + str(self.rank)) self.args.round_idx = 0 self.__train() def handle_message_receive_logits_from_server(self, msg_params): + """ + Handle the message containing logits from the server. + + Args: + msg_params: Parameters from the received message. + + Returns: + None + """ logging.info( "handle_message_receive_logits_from_server. Rank = " + str(self.rank) ) @@ -50,6 +130,20 @@ def send_model_to_server( extracted_feature_dict_test, labels_dict_test, ): + """ + Send extracted features, logits, and labels to the server. + + Args: + receive_id: The ID of the recipient (usually the server). + extracted_feature_dict: A dictionary of extracted features. + logits_dict: A dictionary of logits. + labels_dict: A dictionary of labels. + extracted_feature_dict_test: A dictionary of extracted features for testing. + labels_dict_test: A dictionary of labels for testing. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_FEATURE_AND_LOGITS, self.get_sender_id(), @@ -65,6 +159,15 @@ def send_model_to_server( self.send_message(message) def __train(self): + """ + Perform the training process for the GKT client. + + Args: + None + + Returns: + None + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) ( extracted_feature_dict, diff --git a/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py b/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py index 927cb0b63f..fc969119f5 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py @@ -7,6 +7,40 @@ class GKTClientTrainer(object): + """ + A class representing the client-side trainer for Global Knowledge Transfer (GKT). + + This trainer is responsible for training a client model and exchanging knowledge with the server. + + Args: + client_index (int): The index of the client. + local_training_data (list): Local training data for the client. + local_test_data (list): Local test data for the client. + local_sample_number (int): The number of local training samples. + device (torch.device): The device (e.g., GPU) on which the client model is located. + client_model (torch.nn.Module): The client model. + args (argparse.Namespace): Additional arguments and settings. + + Attributes: + client_index (int): The index of the client. + local_training_data (list): Local training data for the client. + local_test_data (list): Local test data for the client. + local_sample_number (int): The number of local training samples. + args (argparse.Namespace): Additional arguments and settings. + device (torch.device): The device (e.g., GPU) on which the client model is located. + client_model (torch.nn.Module): The client model. + model_params (iterable): The parameters of the client model. + master_params (iterable): The master parameters of the client model. + optimizer (torch.optim.Optimizer): The optimizer used for training. + criterion_CE (torch.nn.CrossEntropyLoss): The cross-entropy loss criterion. + criterion_KL (KL_Loss): The KL divergence loss criterion for knowledge distillation. + server_logits_dict (dict): A dictionary to store logits received from the server. + + Methods: + get_sample_number(): Get the number of local training samples. + update_large_model_logits(logits): Update the logits received from the server. + train(): Train the client model and return extracted features, logits, and labels for training and test data. + """ def __init__( self, client_index, @@ -17,6 +51,21 @@ def __init__( client_model, args, ): + """ + Initialize the GKT (Global Knowledge Transfer) client trainer. + + Args: + client_index (int): The index of the client. + local_training_data (list): Local training data for the client. + local_test_data (list): Local test data for the client. + local_sample_number (int): The number of local training samples. + device (torch.device): The device (e.g., GPU) on which the client model is located. + client_model (torch.nn.Module): The client model. + args (argparse.Namespace): Additional arguments and settings. + + Returns: + None + """ self.client_index = client_index self.local_training_data = local_training_data[client_index] self.local_test_data = local_test_data[client_index] @@ -60,12 +109,37 @@ def __init__( self.server_logits_dict = dict() def get_sample_number(self): + """ + Get the number of local training samples. + + Returns: + int: The number of local training samples. + """ return self.local_sample_number def update_large_model_logits(self, logits): + """ + Update the logits received from the server. + + Args: + logits (dict): Logits received from the server. + + Returns: + None + """ self.server_logits_dict = logits def train(self): + """ + Train the client model. + + Returns: + dict: Extracted features for training data. + dict: Logits for training data. + dict: Labels for training data. + dict: Extracted features for test data. + dict: Labels for test data. + """ # key: batch_index; value: extracted_feature_map extracted_feature_dict = dict() diff --git a/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py b/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py index bad4f0be1c..309771d803 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py @@ -5,13 +5,51 @@ class GKTServerMananger(FedMLCommManager): + """ + Manager class for the server in the Global Knowledge Transfer (GKT) framework. + + This class handles communication and coordination between the server and clients in the GKT framework. + + Args: + args: Additional arguments and settings. + server_trainer: The server trainer responsible for aggregating client updates. + comm: MPI communication object. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Backend used for communication. + + Attributes: + server_trainer: The server trainer instance. + round_num: The total number of communication rounds. + args: Additional arguments and settings. + count: A counter used for tracking communication rounds. + + Methods: + run(): Start the server manager to handle communication with clients. + register_message_receive_handlers(): Register message handlers for message types. + handle_message_receive_feature_and_logits_from_client(msg_params): Handle client messages containing feature maps, logits, and labels. + send_message_init_config(receive_id, global_model_params): Send an initialization message to a client. + send_message_sync_model_to_client(receive_id, global_logits): Send a synchronization message with global logits to a client. + """ def __init__(self, args, server_trainer, comm=None, rank=0, size=0, backend="MPI"): - super().__init__(args, comm, rank, size, backend) + """ + Initialize the GKT (Global Knowledge Transfer) server manager. + + Args: + args: Additional arguments and settings. + server_trainer: The server trainer. + comm: MPI communication object. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Backend used for communication. + Returns: + None + """ + super().__init__(args, comm, rank, size, backend) self.server_trainer = server_trainer self.round_num = args.comm_round self.args.round_idx = 0 - self.count = 0 def run(self): @@ -27,6 +65,15 @@ def register_message_receive_handlers(self): ) def handle_message_receive_feature_and_logits_from_client(self, msg_params): + """ + Handle the message received from a client containing feature maps, logits, and labels. + + Args: + msg_params: Parameters received in the message. + + Returns: + None + """ logging.info("handle_message_receive_feature_and_logits_from_client") sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) extracted_feature_dict = msg_params.get(MyMessage.MSG_ARG_KEY_FEATURE) @@ -48,7 +95,7 @@ def handle_message_receive_feature_and_logits_from_client(self, msg_params): if b_all_received: self.server_trainer.train(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: self.finish() @@ -59,6 +106,16 @@ def handle_message_receive_feature_and_logits_from_client(self, msg_params): self.send_message_sync_model_to_client(receiver_id, global_logits) def send_message_init_config(self, receive_id, global_model_params): + """ + Send an initialization message to a client. + + Args: + receive_id: ID of the client to receive the message. + global_model_params: Global model parameters. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -66,6 +123,16 @@ def send_message_init_config(self, receive_id, global_model_params): logging.info("send_message_init_config. Receive_id: " + str(receive_id)) def send_message_sync_model_to_client(self, receive_id, global_logits): + """ + Send a synchronization message with global logits to a client. + + Args: + receive_id: ID of the client to receive the message. + global_logits: Global logits. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_SYNC_TO_CLIENT, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py b/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py index 5bca610f5a..310cb648ba 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py @@ -10,6 +10,15 @@ class GKTServerTrainer(object): + """ + Server-side trainer for Global Knowledge Transfer (GKT) in federated learning. + + Args: + client_num (int): Number of client devices. + device (str): The device on which to perform training (e.g., 'cuda' or 'cpu'). + server_model (nn.Module): The global server model. + args (argparse.Namespace): Command-line arguments and configurations. + """ def __init__(self, client_num, device, server_model, args): self.client_num = client_num self.device = device @@ -97,6 +106,17 @@ def add_local_trained_result( extracted_feature_dict_test, labels_dict_test, ): + """ + Add local training results from a client. + + Args: + index (int): Index of the client. + extracted_feature_dict (dict): Extracted feature maps from the client model. + logits_dict (dict): Logits from the client model. + labels_dict (dict): Labels from the client model. + extracted_feature_dict_test (dict): Extracted feature maps from the client model for testing. + labels_dict_test (dict): Labels from the client model for testing. + """ logging.info("add_model. index = %d" % index) self.client_extracted_feauture_dict[index] = extracted_feature_dict self.client_logits_dict[index] = logits_dict @@ -107,6 +127,12 @@ def add_local_trained_result( self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all client models have uploaded updates. + + Returns: + bool: True if all clients have uploaded updates, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -115,9 +141,24 @@ def check_whether_all_receive(self): return True def get_global_logits(self, client_index): + """ + Get global logits for a specific client. + + Args: + client_index (int): Index of the client. + + Returns: + dict: Global logits for the client. + """ return self.server_logits_dict[client_index] def train(self, round_idx): + """ + Train the server model using client updates. + + Args: + round_idx (int): Current communication round index. + """ if self.args.sweep == 1: self.sweep(round_idx) else: @@ -127,6 +168,12 @@ def train(self, round_idx): self.do_not_train_on_client(round_idx) def train_and_distill_on_client(self, round_idx): + """ + Train the server model on the client using distillation from client logits. + + Args: + round_idx (int): Current communication round index. + """ if self.args.test: epochs_server, whether_distill_back = self.get_server_epoch_strategy_test() else: @@ -146,21 +193,48 @@ def train_and_distill_on_client(self, round_idx): self.scheduler.step(self.best_acc, epoch=round_idx) def do_not_train_on_client(self, round_idx): + """ + Perform no training on the client model, only evaluation. + + Args: + round_idx (int): Current communication round index. + """ self.train_and_eval(round_idx, 1) self.scheduler.step(self.best_acc, epoch=round_idx) def sweep(self, round_idx): + """ + Perform sweeping training on the client model. + + Args: + round_idx (int): Current communication round index. + """ # train according to the logits from the client self.train_and_eval(round_idx, self.args.epochs_server) self.scheduler.step(self.best_acc, epoch=round_idx) def get_server_epoch_strategy_test(self): + """ + Get the training strategy for server epoch in the test mode. + + Returns: + tuple: Tuple containing the number of epochs (1) and whether to distill back (True). + """ return 1, True # ResNet56 def get_server_epoch_strategy_reset56(self, round_idx): + """ + Get the training strategy for server epoch in the ResNet56 client model. + + Args: + round_idx (int): Current communication round index. + + Returns: + tuple: Tuple containing the number of epochs and whether to distill back (True/False). + """ whether_distill_back = True - # set the training strategy + # set the training strategy based on round index if round_idx < 20: epochs = 20 elif 20 <= round_idx < 30: @@ -183,6 +257,15 @@ def get_server_epoch_strategy_reset56(self, round_idx): # ResNet56-2 def get_server_epoch_strategy_reset56_2(self, round_idx): + """ + Get the training strategy for server epoch in the ResNet56-2 client model. + + Args: + round_idx (int): Current communication round index. + + Returns: + tuple: Tuple containing the number of epochs and whether to distill back (True/False). + """ whether_distill_back = True # set the training strategy epochs = self.args.epochs_server @@ -190,6 +273,15 @@ def get_server_epoch_strategy_reset56_2(self, round_idx): # not increase after 40 epochs def get_server_epoch_strategy2(self, round_idx): + """ + Determine the training strategy (number of epochs and distillation) for the server model. + + Args: + round_idx (int): Current communication round index. + + Returns: + tuple: Tuple containing the number of epochs and whether to distill back (True/False). + """ whether_distill_back = True # set the training strategy if round_idx < 20: @@ -213,6 +305,13 @@ def get_server_epoch_strategy2(self, round_idx): return epochs, whether_distill_back def train_and_eval(self, round_idx, epochs): + """ + Train and evaluate the server model for a specified number of epochs. + + Args: + round_idx (int): Current communication round index. + epochs (int): Number of epochs to train for. + """ for epoch in range(epochs): logging.info( "train_and_eval. round_idx = %d, epoch = %d" % (round_idx, epoch) @@ -295,6 +394,12 @@ def train_and_eval(self, round_idx, epochs): ) def train_large_model_on_the_server(self): + """ + Train the server model using client features and logits. + + Returns: + dict: Dictionary containing training metrics (loss, accuracy). + """ # clear the server side logits for key in self.server_logits_dict.keys(): @@ -371,6 +476,12 @@ def train_large_model_on_the_server(self): return train_metrics def eval_large_model_on_the_server(self): + """ + Evaluate the server model on the test dataset provided by clients. + + Returns: + dict: Dictionary containing test metrics (loss, accuracy). + """ # set model to evaluation mode self.model_global.eval() diff --git a/python/fedml/simulation/mpi/fedgkt/utils.py b/python/fedml/simulation/mpi/fedgkt/utils.py index fe2ae83878..5f7feb88de 100644 --- a/python/fedml/simulation/mpi/fedgkt/utils.py +++ b/python/fedml/simulation/mpi/fedgkt/utils.py @@ -7,6 +7,15 @@ def get_state_dict(file): + """ + Load a PyTorch state dictionary from a file. + + Args: + file (str): The path to the file containing the state dictionary. + + Returns: + dict: The loaded state dictionary. + """ try: pretrain_state_dict = torch.load(file) except AssertionError: @@ -15,8 +24,16 @@ def get_state_dict(file): ) return pretrain_state_dict - def get_flat_params_from(model): + """ + Get a flat tensor containing all the parameters of a PyTorch model. + + Args: + model (nn.Module): The PyTorch model. + + Returns: + torch.Tensor: A 1D tensor containing the flattened parameters. + """ params = [] for param in model.parameters(): params.append(param.data.view(-1)) @@ -24,8 +41,14 @@ def get_flat_params_from(model): flat_params = torch.cat(params) return flat_params - def set_flat_params_to(model, flat_params): + """ + Set the parameters of a PyTorch model using a flat tensor of parameters. + + Args: + model (nn.Module): The PyTorch model. + flat_params (torch.Tensor): A 1D tensor containing the flattened parameters. + """ prev_ind = 0 for param in model.parameters(): flat_size = int(np.prod(list(param.size()))) @@ -35,32 +58,59 @@ def set_flat_params_to(model, flat_params): prev_ind += flat_size + class RunningAverage: - """A simple class that maintains the running average of a quantity + """ + A simple class that maintains the running average of a quantity Example: - ``` - loss_avg = RunningAverage() - loss_avg.update(2) - loss_avg.update(4) - loss_avg() = 3 - ``` + ``` + loss_avg = RunningAverage() + loss_avg.update(2) + loss_avg.update(4) + loss_avg() = 3.0 + ``` + + Attributes: + steps (int): The number of updates made to the running average. + total (float): The cumulative sum of values for the running average. """ def __init__(self): + """ + Initialize a RunningAverage object. + """ self.steps = 0 self.total = 0 def update(self, val): + """Update the running average with a new value. + + Args: + val (float): The new value to update the running average. + """ self.total += val self.steps += 1 def value(self): - return self.total / float(self.steps) + """Get the current value of the running average. + Returns: + float: The current running average value. + """ + return self.total / float(self.steps) def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" + """Computes the precision@k for the specified values of k. + + Args: + output (torch.Tensor): The model's output tensor. + target (torch.Tensor): The target tensor. + topk (tuple): A tuple of integers specifying the top-k values to compute. + + Returns: + list: A list of accuracy values for each k in topk. + """ maxk = max(topk) batch_size = target.size(0) @@ -76,17 +126,44 @@ def accuracy(output, target, topk=(1,)): class KL_Loss(nn.Module): + """ + Kullback-Leibler (KL) Divergence Loss with Temperature Scaling. + + This class represents the KL divergence loss with an optional temperature + scaling parameter for softening the logits. It is commonly used in knowledge + distillation between a student and a teacher model. + + Args: + temperature (float, optional): The temperature parameter for softening + the logits (default is 1). + + Attributes: + T (float): The temperature parameter for temperature scaling. + + """ + def __init__(self, temperature=1): + """ + Initialize the KL Divergence Loss. + + Args: + temperature (float, optional): The temperature parameter for softening + the logits (default is 1). + + """ super(KL_Loss, self).__init__() self.T = temperature def forward(self, output_batch, teacher_outputs): - # output_batch -> B X num_classes - # teacher_outputs -> B X num_classes + """Compute the KL divergence loss between output_batch and teacher_outputs. - # loss_2 = -torch.sum(torch.sum(torch.mul(F.log_softmax(teacher_outputs,dim=1), F.softmax(teacher_outputs,dim=1)+10**(-7))))/teacher_outputs.size(0) - # print('loss H:',loss_2) + Args: + output_batch (torch.Tensor): The output tensor from the student model. + teacher_outputs (torch.Tensor): The output tensor from the teacher model. + Returns: + torch.Tensor: The computed KL divergence loss. + """ output_batch = F.log_softmax(output_batch / self.T, dim=1) teacher_outputs = F.softmax(teacher_outputs / self.T, dim=1) + 10 ** (-7) @@ -96,20 +173,48 @@ def forward(self, output_batch, teacher_outputs): * nn.KLDivLoss(reduction="batchmean")(output_batch, teacher_outputs) ) - # Same result KL-loss implementation - # loss = T * T * torch.sum(torch.sum(torch.mul(teacher_outputs, torch.log(teacher_outputs) - output_batch)))/teacher_outputs.size(0) return loss + class CE_Loss(nn.Module): + """ + Cross-Entropy Loss with Temperature Scaling. + + This class represents the cross-entropy loss with an optional temperature + scaling parameter for softening the logits. It is commonly used in knowledge + distillation between a student and a teacher model. + + Args: + temperature (float, optional): The temperature parameter for softening + the logits (default is 1). + + Attributes: + T (float): The temperature parameter for temperature scaling. + + """ + def __init__(self, temperature=1): + """ + Initialize the Cross-Entropy (CE) Loss. + + Args: + temperature (float): The temperature parameter for softening the logits (default is 1). + + """ super(CE_Loss, self).__init__() self.T = temperature def forward(self, output_batch, teacher_outputs): - # output_batch -> B X num_classes - # teacher_outputs -> B X num_classes + """Compute the cross-entropy loss between output_batch and teacher_outputs. + + Args: + output_batch (torch.Tensor): The output tensor from the student model. + teacher_outputs (torch.Tensor): The output tensor from the teacher model. + Returns: + torch.Tensor: The computed cross-entropy loss. + """ output_batch = F.log_softmax(output_batch / self.T, dim=1) teacher_outputs = F.softmax(teacher_outputs / self.T, dim=1) @@ -123,28 +228,51 @@ def forward(self, output_batch, teacher_outputs): return loss - def save_dict_to_json(d, json_path): - """Saves dict of floats in json file + """Saves a dictionary of floats in a JSON file. Args: - d: (dict) of float-castable values (np.float, int, float, etc.) - json_path: (string) path to json file + d (dict): A dictionary of float-castable values (np.float, int, float, etc.). + json_path (str): Path to the JSON file where the dictionary will be saved. """ with open(json_path, "w") as f: - # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) - d = {k: v for k, v in d.items()} + # We need to convert the values to float for JSON (it doesn't accept np.array, np.float, etc.) + d = {k: float(v) for k, v in d.items()} json.dump(d, f, indent=4) - # Filter out batch norm parameters and remove them from weight decay - gets us higher accuracy 93.2 -> 93.48 # https://arxiv.org/pdf/1807.11205.pdf def bnwd_optim_params(model, model_params, master_params): + """Split model parameters into two groups for optimization. + + This function separates model parameters into two groups: batch normalization parameters + and remaining parameters. It sets the weight decay for batch normalization parameters to 0. + + Args: + model (nn.Module): The neural network model. + model_params (list): List of model parameters. + master_params (list): List of master parameters. + + Returns: + list: List of dictionaries specifying parameter groups for optimization. + """ bn_params, remaining_params = split_bn_params(model, model_params, master_params) return [{"params": bn_params, "weight_decay": 0}, {"params": remaining_params}] - def split_bn_params(model, model_params, master_params): + """Split model parameters into batch normalization and remaining parameters. + + This function separates model parameters into two groups: batch normalization parameters + and remaining parameters. + + Args: + model (nn.Module): The neural network model. + model_params (list): List of model parameters. + master_params (list): List of master parameters. + + Returns: + tuple: Two lists containing batch normalization parameters and remaining parameters. + """ def get_bn_params(module): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): return module.parameters() From b3997692cc6af9fd95bb41c624ad60d55d3c2153 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 13:02:02 +0530 Subject: [PATCH 15/70] python\fedml\simulation\mpi\ fednas fednova fedopt --- .../fedml/simulation/mpi/fednas/FedNASAPI.py | 52 +++++- .../simulation/mpi/fednas/FedNASAggregator.py | 112 ++++++++++++ .../mpi/fednas/FedNASClientManager.py | 50 +++++- .../mpi/fednas/FedNASServerManager.py | 45 ++++- .../simulation/mpi/fednas/FedNASTrainer.py | 123 +++++++++++++ .../simulation/mpi/fednova/FedNovaAPI.py | 49 +++++ .../mpi/fednova/FedNovaAggregator.py | 167 ++++++++++++++++++ .../mpi/fednova/FedNovaClientManager.py | 113 ++++++++---- .../mpi/fednova/FedNovaServerManager.py | 79 ++++++++- .../simulation/mpi/fednova/FedNovaTrainer.py | 78 +++++++- .../my_model_trainer_classification.py | 66 ++++++- python/fedml/simulation/mpi/fednova/utils.py | 35 +++- .../fedml/simulation/mpi/fedopt/FedOptAPI.py | 63 ++++++- .../simulation/mpi/fedopt/FedOptAggregator.py | 138 +++++++++++++-- .../mpi/fedopt/FedOptClientManager.py | 31 ++++ .../mpi/fedopt/FedOptServerManager.py | 44 ++++- .../simulation/mpi/fedopt/FedOptTrainer.py | 48 ++++- python/fedml/simulation/mpi/fedopt/optrepo.py | 25 ++- python/fedml/simulation/mpi/fedopt/utils.py | 29 ++- .../mpi/fedopt_seq/FedOptAggregator.py | 167 ++++++++++++++++++ 20 files changed, 1419 insertions(+), 95 deletions(-) diff --git a/python/fedml/simulation/mpi/fednas/FedNASAPI.py b/python/fedml/simulation/mpi/fednas/FedNASAPI.py index d213473b75..1615f5e43f 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASAPI.py +++ b/python/fedml/simulation/mpi/fednas/FedNASAPI.py @@ -8,12 +8,17 @@ def FedML_init(): + """ + Initialize the Federated Machine Learning environment using MPI (Message Passing Interface). + + Returns: + Tuple: A tuple containing the MPI communicator (`comm`), process ID (`process_id`), and worker number (`worker_number`). + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() return comm, process_id, worker_number - def FedML_FedNAS_distributed( args, process_id, @@ -25,6 +30,20 @@ def FedML_FedNAS_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Initialize and run the Federated NAS (Neural Architecture Search) distributed training process. + + Args: + args: Command-line arguments and configurations. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + comm: The MPI communicator. + device: The device (e.g., GPU) to run the training on. + dataset: A list containing dataset information. + model: The neural network model. + client_trainer (ClientTrainer, optional): The client trainer instance. + server_aggregator (ServerAggregator, optional): The server aggregator instance. + """ [ train_data_num, test_data_num, @@ -53,10 +72,23 @@ def FedML_FedNAS_distributed( test_data_local_dict, ) - def init_server( args, device, comm, process_id, worker_number, model, train_data_num, train_data_global, test_data_global, ): + """ + Initialize and run the server component of the Federated NAS distributed training. + + Args: + args: Command-line arguments and configurations. + device: The device (e.g., GPU) to run the training on. + comm: The MPI communicator. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + model: The neural network model. + train_data_num: The number of training data samples. + train_data_global: The global training data. + test_data_global: The global testing data. + """ # aggregator client_num = worker_number - 1 aggregator = FedNASAggregator(train_data_global, test_data_global, train_data_num, client_num, model, device, args,) @@ -65,7 +97,6 @@ def init_server( server_manager = FedNASServerManager(args, comm, process_id, worker_number, aggregator) server_manager.run() - def init_client( args, device, @@ -78,6 +109,21 @@ def init_client( train_data_local, test_data_local, ): + """ + Initialize and run the client component of the Federated NAS distributed training. + + Args: + args: Command-line arguments and configurations. + device: The device (e.g., GPU) to run the training on. + comm: The MPI communicator. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + model: The neural network model. + train_data_num: The number of training data samples. + local_data_num: The number of local training data samples. + train_data_local: The local training data. + test_data_local: The local testing data. + """ # trainer client_ID = process_id - 1 trainer = FedNASTrainer( diff --git a/python/fedml/simulation/mpi/fednas/FedNASAggregator.py b/python/fedml/simulation/mpi/fednas/FedNASAggregator.py index 988ae59e23..301bb8b7e8 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASAggregator.py +++ b/python/fedml/simulation/mpi/fednas/FedNASAggregator.py @@ -7,6 +7,38 @@ class FedNASAggregator(object): + """ + A class responsible for aggregating model parameters and architectures from multiple clients. + + Args: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global testing dataset. + all_train_data_num (int): The total number of training data samples. + client_num (int): The number of clients participating in federated learning. + model (nn.Module): The neural network model to be aggregated. + device (str): The device (e.g., 'cuda' or 'cpu') on which the model is trained. + args (argparse.Namespace): Command-line arguments and configurations. + + Attributes: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global testing dataset. + all_train_data_num (int): The total number of training data samples. + client_num (int): The number of clients participating in federated learning. + device (str): The device (e.g., 'cuda' or 'cpu') on which the model is trained. + args (argparse.Namespace): Command-line arguments and configurations. + model (nn.Module): The neural network model to be aggregated. + model_dict (dict): A dictionary to store client model parameters. + arch_dict (dict): A dictionary to store client model architectures. + sample_num_dict (dict): A dictionary to store the number of samples from each client. + train_acc_dict (dict): A dictionary to store training accuracy from each client. + train_loss_dict (dict): A dictionary to store training loss from each client. + train_acc_avg (float): The average training accuracy. + test_acc_avg (float): The average testing accuracy. + test_loss_avg (float): The average testing loss. + flag_client_model_uploaded_dict (dict): A dictionary to track whether client models have been uploaded. + best_accuracy (float): The best accuracy achieved during aggregation. + best_accuracy_different_cnn_counts (dict): A dictionary to store the best accuracy with different CNN counts. + """ def __init__( self, train_global, @@ -17,6 +49,19 @@ def __init__( device, args, ): + """ + Initialize a FedNASAggregator object. + + Args: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global testing dataset. + all_train_data_num (int): The total number of training data samples. + client_num (int): The number of clients participating in federated learning. + model (nn.Module): The neural network model to be aggregated. + device (str): The device (e.g., 'cuda' or 'cpu') on which the model is trained. + args (argparse.Namespace): Command-line arguments and configurations. + """ + self.train_global = train_global self.test_global = test_global self.all_train_data_num = all_train_data_num @@ -43,11 +88,29 @@ def __init__( self.wandb_table = wandb.Table(columns=["Epoch", "Searched Architecture"]) def get_model(self): + """ + Get the aggregated model. + + Returns: + nn.Module: The aggregated neural network model. + """ return self.model def add_local_trained_result( self, index, model_params, arch_params, sample_num, train_acc, train_loss ): + """ + Add the results from a locally trained model to the aggregator. + + Args: + index (int): The index of the client. + model_params (dict): The model parameters from the client. + arch_params (dict): The model architecture parameters from the client. + sample_num (int): The number of samples used for training by the client. + train_acc (float): The training accuracy achieved by the client. + train_loss (torch.Tensor): The training loss from the client. + """ + logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.arch_dict[index] = arch_params @@ -57,6 +120,12 @@ def add_local_trained_result( self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all client models have been received by the aggregator. + + Returns: + bool: True if all client models have been received, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -65,6 +134,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate model parameters and architectures from multiple clients. + + Returns: + dict: The aggregated model parameters and architectures. + """ averaged_weights = self.__aggregate_weight() self.model.load_state_dict(averaged_weights) if self.args.stage == "search": @@ -75,11 +150,23 @@ def aggregate(self): return averaged_weights def __update_arch(self, alphas): + """ + Update the architecture parameters of the aggregator's model. + + Args: + alphas (list): A list of architecture parameters. + """ logging.info("update_arch. server.") for a_g, model_arch in zip(alphas, self.model.arch_parameters()): model_arch.data.copy_(a_g.data) def __aggregate_weight(self): + """ + Aggregate model weights from multiple clients. + + Returns: + dict: The aggregated model weights. + """ logging.info("################aggregate weights############") start_time = time.time() model_list = [] @@ -104,6 +191,12 @@ def __aggregate_weight(self): return averaged_params def __aggregate_alpha(self): + """ + Calculate and log statistics including training accuracy, training loss, validation accuracy, and validation loss. + + Args: + round_idx (int): The current round index. + """ logging.info("################aggregate alphas############") start_time = time.time() alpha_list = [] @@ -124,6 +217,12 @@ def __aggregate_alpha(self): return averaged_alphas def statistics(self, round_idx): + """ + Calculate and log statistics including training accuracy, training loss, validation accuracy, and validation loss. + + Args: + round_idx (int): The current round index. + """ # train acc train_acc_list = self.train_acc_dict.values() self.train_acc_avg = sum(train_acc_list) / len(train_acc_list) @@ -175,6 +274,12 @@ def statistics(self, round_idx): ) def infer(self, round_idx): + """ + Perform model inference and calculate test accuracy and loss. + + Args: + round_idx (int): The current round index. + """ self.model.eval() self.model.to(self.device) if ( @@ -217,6 +322,13 @@ def infer(self, round_idx): logging.info("server_infer time cost: %d" % (end_time - start_time)) def record_model_global_architecture(self, round_idx): + """ + Record and log the architecture information of the global model, including genotype, CNN count, + and best accuracy for different CNN structures. + + Args: + round_idx (int): The current round index. + """ # save the structure genotype, normal_cnn_count, reduce_cnn_count = self.model.genotype() cnn_count = normal_cnn_count + reduce_cnn_count diff --git a/python/fedml/simulation/mpi/fednas/FedNASClientManager.py b/python/fedml/simulation/mpi/fednas/FedNASClientManager.py index 369e7677fa..11716a4826 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASClientManager.py +++ b/python/fedml/simulation/mpi/fednas/FedNASClientManager.py @@ -7,6 +7,17 @@ class FedNASClientManager(FedMLCommManager): + """ + Manager class for the client in the Federated NAS (Neural Architecture Search) distributed training. + + Args: + args: Command-line arguments and configurations. + comm: The MPI communicator. + rank: The process rank of the current worker. + size: The total number of workers. + trainer: The client trainer instance. + """ + def __init__(self, args, comm, rank, size, trainer): super().__init__(args, comm, rank, size) @@ -15,9 +26,15 @@ def __init__(self, args, comm, rank, size, trainer): self.args.round_idx = 0 def run(self): + """ + Start the client manager. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.__handle_msg_client_receive_config ) @@ -27,6 +44,12 @@ def register_message_receive_handlers(self): ) def __handle_msg_client_receive_config(self, msg_params): + """ + Handle the received configuration message from the server. + + Args: + msg_params (dict): The message parameters containing model and architecture information. + """ logging.info("__handle_msg_client_receive_config") global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS) @@ -35,10 +58,16 @@ def __handle_msg_client_receive_config(self, msg_params): self.trainer.update_arch(arch_params) self.args.round_idx = 0 - # start to train + # Start training self.__train() def __handle_msg_client_receive_model_from_server(self, msg_params): + """ + Handle the received model message from the server. + + Args: + msg_params (dict): The message parameters containing model and architecture information. + """ process_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS) @@ -54,6 +83,9 @@ def __handle_msg_client_receive_model_from_server(self, msg_params): self.finish() def __train(self): + """ + Perform the local training for the client. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) start_time = time.time() if self.args.stage == "search": @@ -68,7 +100,7 @@ def __train(self): weights, local_sample_num, train_acc, train_loss = self.trainer.train() alphas = [] train_finished_time = time.time() - # for one epoch, the local searching time cost is: 75s (based on RTX2080Ti) + # For one epoch, the local searching time cost is approximately 75s (based on RTX2080Ti) logging.info( "local searching time cost: %d" % (train_finished_time - start_time) ) @@ -77,7 +109,7 @@ def __train(self): weights, alphas, local_sample_num, train_acc, train_loss ) communication_finished_time = time.time() - # for one epoch, the local communication time cost is: < 1s (based o n RTX2080Ti) + # For one epoch, the local communication time cost is less than 1s (based on RTX2080Ti) logging.info( "local communication time cost: %d" % (communication_finished_time - train_finished_time) @@ -86,10 +118,20 @@ def __train(self): def __send_msg_fedavg_send_model_to_server( self, weights, alphas, local_sample_num, valid_acc, valid_loss ): + """ + Send the model updates and training results to the server. + + Args: + weights: The updated model weights. + alphas: The updated architecture parameters (only in the search stage). + local_sample_num: The number of local training samples. + valid_acc: The local training accuracy. + valid_loss: The local training loss. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.rank, 0) message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, alphas) message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_ACC, valid_acc) message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_LOSS, valid_loss) - self.send_message(message) + self.send_message(message) \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fednas/FedNASServerManager.py b/python/fedml/simulation/mpi/fednas/FedNASServerManager.py index 9f5b1c94c4..921bf28bae 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASServerManager.py +++ b/python/fedml/simulation/mpi/fednas/FedNASServerManager.py @@ -8,6 +8,17 @@ class FedNASServerManager(FedMLCommManager): + """ + Manager class for the server in the Federated NAS (Neural Architecture Search) distributed training. + + Args: + args: Command-line arguments and configurations. + comm: The MPI communicator. + rank: The process rank of the current worker. + size: The total number of workers. + aggregator: The aggregator for collecting client updates. + """ + def __init__(self, args, comm, rank, size, aggregator): super().__init__(args, comm, rank, size) @@ -17,6 +28,9 @@ def __init__(self, args, comm, rank, size, aggregator): self.aggregator = aggregator def run(self): + """ + Start the server manager. + """ global_model = self.aggregator.get_model() global_model_params = global_model.state_dict() global_arch_params = None @@ -29,6 +43,9 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.__handle_msg_server_receive_model_from_client_opt_send, @@ -37,6 +54,14 @@ def register_message_receive_handlers(self): def __send_initial_config_to_client( self, process_id, global_model_params, global_arch_params ): + """ + Send the initial configuration to a client. + + Args: + process_id: The ID of the target client. + global_model_params: The global model parameters. + global_arch_params: The global architecture parameters (only in the search stage). + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), process_id ) @@ -46,6 +71,12 @@ def __send_initial_config_to_client( self.send_message(message) def __handle_msg_server_receive_model_from_client_opt_send(self, msg_params): + """ + Handle the received model message from a client and optionally send updated models to clients. + + Args: + msg_params (dict): The message parameters containing model and architecture information. + """ process_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS) @@ -69,15 +100,15 @@ def __handle_msg_server_receive_model_from_client_opt_send(self, msg_params): else: global_model_params = self.aggregator.aggregate() global_arch_params = [] - self.aggregator.infer(self.args.round_idx) # for NAS, it cost 151 seconds + self.aggregator.infer(self.args.round_idx) # For NAS, it takes approximately 151 seconds self.aggregator.statistics(self.args.round_idx) if self.args.stage == "search": self.aggregator.record_model_global_architecture(self.args.round_idx) - # free all teh GPU memory cache + # Free all GPU memory cache torch.cuda.empty_cache() - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: self.finish() @@ -91,6 +122,14 @@ def __handle_msg_server_receive_model_from_client_opt_send(self, msg_params): def __send_model_to_client_message( self, process_id, global_model_params, global_arch_params ): + """ + Send the updated model to a client. + + Args: + process_id: The ID of the target client. + global_model_params: The updated global model parameters. + global_arch_params: The updated global architecture parameters (only in the search stage). + """ message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, 0, process_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, global_arch_params) diff --git a/python/fedml/simulation/mpi/fednas/FedNASTrainer.py b/python/fedml/simulation/mpi/fednas/FedNASTrainer.py index 837162f24b..296744b2bd 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASTrainer.py +++ b/python/fedml/simulation/mpi/fednas/FedNASTrainer.py @@ -8,6 +8,44 @@ class FedNASTrainer(object): + """ + Federated NAS Trainer for local model training and inference. + + This class is responsible for performing local training and inference on client devices during federated NAS. + + Args: + client_index (int): Index of the client within the federated system. + train_data_local_dict (dict): Dictionary containing local training datasets for each client. + test_data_local_dict (dict): Dictionary containing local test/validation datasets for each client. + train_data_local_num (int): Number of training samples on the local client. + train_data_num (int): Total number of training samples across all clients. + model (nn.Module): The neural network model to be trained. + device: The computing device (e.g., GPU) to perform training and inference. + args: Additional configuration and hyperparameters for training and inference. + + Methods: + update_model(weights): + Update the model's weights with global model weights. + + update_arch(alphas): + Update the model's architecture with global architecture parameters. + + search(): + Perform local architecture search and training. + + train(): + Perform local training. + + local_train(train_queue, valid_queue, model, criterion, optimizer): + Perform local training on a batch of data. + + local_infer(valid_queue, model, criterion): + Perform local inference on a batch of data. + + infer(): + Perform inference using the trained model. + + """ def __init__( self, client_index, @@ -33,16 +71,39 @@ def __init__( self.test_local = test_data_local_dict[client_index] def update_model(self, weights): + """ + Update the model with new weights. + + Args: + weights (dict): The model weights to update. + """ logging.info("update_model. client_index = %d" % self.client_index) self.model.load_state_dict(weights) def update_arch(self, alphas): + """ + Update the model architecture parameters (only used in the search stage). + + Args: + alphas (list): The architecture parameters to update. + """ logging.info("update_arch. client_index = %d" % self.client_index) for a_g, model_arch in zip(alphas, self.model.arch_parameters()): model_arch.data.copy_(a_g.data) # local search def search(self): + """ + Perform local neural architecture search. + + Returns: + tuple: A tuple containing the following elements: + - weights (dict): The updated model weights. + - alphas (list): The updated architecture parameters (only in the search stage). + - local_sample_number (int): The number of local training samples. + - local_avg_train_acc (float): The average training accuracy. + - local_avg_train_loss (float): The average training loss. + """ self.model.to(self.device) self.model.train() @@ -108,6 +169,22 @@ def search(self): def local_search( self, train_queue, valid_queue, model, architect, criterion, optimizer ): + """ + Perform local neural architecture search. + + Args: + train_queue (DataLoader): DataLoader for the training dataset. + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model. + architect (Architect): The architect responsible for architecture search. + criterion: The loss criterion for optimization. + optimizer: The optimizer for weight updates. + + Returns: + tuple: A tuple containing the following elements: + - top1_accuracy (float): Top-1 accuracy achieved during local search. + - loss (float): Average loss during local search. + """ objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() @@ -168,6 +245,16 @@ def local_search( return top1.avg / 100.0, objs.avg / 100.0, loss def train(self): + """ + Perform local training. + + Returns: + tuple: A tuple containing the following elements: + - weights (dict): The updated model weights. + - local_sample_number (int): The number of local training samples. + - local_avg_train_acc (float): The average training accuracy. + - local_avg_train_loss (float): The average training loss. + """ self.model.to(self.device) self.model.train() @@ -213,6 +300,21 @@ def train(self): ) def local_train(self, train_queue, valid_queue, model, criterion, optimizer): + """ + Perform local training. + + Args: + train_queue (DataLoader): DataLoader for the training dataset. + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model. + criterion: The loss criterion for optimization. + optimizer: The optimizer for weight updates. + + Returns: + tuple: A tuple containing the following elements: + - top1_accuracy (float): Top-1 accuracy achieved during local training. + - loss (float): Average loss during local training. + """ objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() @@ -249,6 +351,19 @@ def local_train(self, train_queue, valid_queue, model, criterion, optimizer): return top1.avg, objs.avg, loss def local_infer(self, valid_queue, model, criterion): + """ + Perform local inference. + + Args: + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model. + criterion: The loss criterion for evaluation. + + Returns: + tuple: A tuple containing the following elements: + - top1_accuracy (float): Top-1 accuracy achieved during local inference. + - loss (float): Average loss during local inference. + """ objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() @@ -281,6 +396,14 @@ def local_infer(self, valid_queue, model, criterion): # after searching, infer() function is used to infer the searched architecture def infer(self): + """ + Perform inference using the trained model. + + Returns: + tuple: A tuple containing the following elements: + - test_accuracy (float): Test accuracy achieved using the trained model. + - test_loss (float): Test loss using the trained model. + """ self.model.to(self.device) self.model.eval() diff --git a/python/fedml/simulation/mpi/fednova/FedNovaAPI.py b/python/fedml/simulation/mpi/fednova/FedNovaAPI.py index 16b48a5e17..2dd230e1ef 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaAPI.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaAPI.py @@ -11,6 +11,20 @@ def FedML_FedNova_distributed( args, process_id, worker_number, comm, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize and run the FedNova distributed training process. + + Args: + args: Command-line arguments. + process_id (int): ID of the current process. + worker_number (int): Total number of worker processes. + comm: Communication backend for distributed training. + device: PyTorch device (CPU or GPU) to run computations. + dataset: Dataset information including data loaders and other data-related details. + model: The model used for training. + client_trainer: Client-specific trainer (if applicable). + server_aggregator: Server aggregator for model updates (if provided). + """ [ train_data_num, test_data_num, @@ -72,6 +86,25 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for FedNova federated learning. + + Args: + args: Command-line arguments. + device: PyTorch device (CPU or GPU) to run computations. + comm: Communication backend for distributed training. + rank (int): Rank of the current process. + size (int): Total number of processes. + model: The model used for training. + train_data_num: Total number of training samples. + train_data_global: Global training dataset. + test_data_global: Global test dataset. + train_data_local_dict: Dictionary of local training datasets for clients. + test_data_local_dict: Dictionary of local test datasets for clients. + train_data_local_num_dict: Dictionary of the number of local training samples for clients. + server_aggregator: Server aggregator for model updates. + """ + if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -111,6 +144,22 @@ def init_client( test_data_local_dict, client_trainer=None, ): + """ + Initialize a client for FedNova federated learning. + + Args: + args: Command-line arguments. + device: PyTorch device (CPU or GPU) to run computations. + comm: Communication backend for distributed training. + process_id (int): ID of the current client process. + size (int): Total number of processes. + model: The model used for training. + train_data_num: Total number of training samples. + train_data_local_num_dict: Dictionary of the number of local training samples for clients. + train_data_local_dict: Dictionary of local training datasets for clients. + test_data_local_dict: Dictionary of local test datasets for clients. + client_trainer: Client-specific trainer (if applicable). + """ client_index = process_id - 1 if client_trainer is None: # client_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py b/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py index 71fb4743c0..4748409d09 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py @@ -13,6 +13,63 @@ class FedNovaAggregator(object): + """ + Federated Nova Aggregator for aggregating local model updates in a federated learning setup. + + This class manages the aggregation of local model updates from multiple clients in a federated learning system + using the Federated Nova (FedNova) approach. + + Args: + train_global (Dataset): Global training dataset. + test_global (Dataset): Global test/validation dataset. + all_train_data_num (int): Total number of training samples across all clients. + train_data_local_dict (dict): Dictionary containing local training datasets for each client. + test_data_local_dict (dict): Dictionary containing local test/validation datasets for each client. + train_data_local_num_dict (dict): Dictionary containing the number of local training samples for each client. + worker_num (int): Number of worker nodes/clients. + device: The computing device (e.g., GPU) to perform aggregation and computations. + args: Additional configuration and hyperparameters. + server_aggregator: Server-side aggregator for aggregation methods. + + Methods: + get_global_model_params(): + Get the global model parameters. + + set_global_model_params(model_parameters): + Set the global model parameters. + + add_local_trained_result(index, local_result): + Add the local training results from a client. + + check_whether_all_receive(): + Check if all clients have uploaded their local models. + + record_client_runtime(worker_id, client_runtimes): + Record client runtime information for scheduling. + + generate_client_schedule(round_idx, client_indexes): + Generate a schedule for client training in the federated round. + + get_average_weight(client_indexes): + Calculate the average weight for client selection. + + fednova_aggregate(params, norm_grads, tau_effs, tau_eff=0): + Perform FedNova aggregation of local model updates. + + aggregate(): + Aggregate local model updates using the FedNova aggregation method. + + client_sampling(round_idx, client_num_in_total, client_num_per_round): + Perform client sampling for a federated round. + + _generate_validation_set(num_samples=10000): + Generate a validation dataset for testing. + + test_on_server_for_all_clients(round_idx): + Test the global model on all clients' datasets. + + """ + def __init__( self, train_global, @@ -26,6 +83,21 @@ def __init__( args, server_aggregator, ): + """ + Initialize the FedNova manager. + + Args: + train_global: Global training dataset. + test_global: Global test dataset. + all_train_data_num: Total number of training samples. + train_data_local_dict: Dictionary containing local training datasets for clients. + test_data_local_dict: Dictionary containing local test datasets for clients. + train_data_local_num_dict: Dictionary containing the number of local training samples for clients. + worker_num: Number of worker nodes (clients). + device: PyTorch device (CPU or GPU) to run computations. + args: Command-line arguments. + server_aggregator: Aggregator for model updates. + """ self.aggregator = server_aggregator self.args = args @@ -55,18 +127,43 @@ def __init__( self.global_momentum_buffer = dict() def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters to be set. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, local_result): + """ + Add the local training result for a client. + + Args: + index (int): Index of the client. + local_result (dict): Local training result. + """ logging.info("add_model. index = %d" % index) self.result_dict[index] = local_result # self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local model updates. + + Returns: + bool: True if all clients have uploaded, False otherwise. + """ logging.debug("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -76,10 +173,27 @@ def check_whether_all_receive(self): return True def record_client_runtime(self, worker_id, client_runtimes): + """ + Record client runtime information. + + Args: + worker_id (int): Index of the worker. + client_runtimes (dict): Dictionary containing client runtime information. + """ for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id].append(runtime) def generate_client_schedule(self, round_idx, client_indexes): + """ + Generate a schedule for selecting clients in the current round. + + Args: + round_idx (int): Current round index. + client_indexes (list): List of client indexes. + + Returns: + list: List of client schedules for each worker. + """ # self.runtime_history = {} # for i in range(self.worker_num): # self.runtime_history[i] = {} @@ -128,6 +242,15 @@ def generate_client_schedule(self, round_idx, client_indexes): return client_schedule def get_average_weight(self, client_indexes): + """ + Get the average weight for clients based on the number of local samples. + + Args: + client_indexes (list): List of client indexes. + + Returns: + dict: Dictionary mapping client index to average weight. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -138,6 +261,18 @@ def get_average_weight(self, client_indexes): return average_weight_dict def fednova_aggregate(self, params, norm_grads, tau_effs, tau_eff=0): + """ + Perform FedNova aggregation. + + Args: + params (dict): Model parameters to be aggregated. + norm_grads (list): List of normalized gradients from clients. + tau_effs (list): List of effective tau values. + tau_eff (int): Effective tau for aggregation (optional). + + Returns: + dict: Aggregated model parameters. + """ # get tau_eff if tau_eff == 0: tau_eff = sum(tau_effs) @@ -166,6 +301,12 @@ def fednova_aggregate(self, params, norm_grads, tau_effs, tau_eff=0): return params def aggregate(self): + """ + Aggregate model updates from clients. + + Returns: + dict: Aggregated model parameters. + """ start_time = time.time() grad_results = [] t_eff_results = [] @@ -191,6 +332,17 @@ def aggregate(self): return w_global def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample clients for the current round. + + Args: + round_idx (int): Current round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of sampled client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -201,6 +353,15 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing. + + Args: + num_samples (int): Number of samples to include in the validation set (optional). + + Returns: + DataLoader: DataLoader for the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -211,6 +372,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients. + + Args: + round_idx (int): Current round index. + """ if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) train_num_samples = [] diff --git a/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py b/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py index ad2a99b2ed..8ea484c125 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py @@ -9,6 +9,29 @@ class FedNovaClientManager(FedMLCommManager): + """ + Manager for the client-side of the FedNova federated learning process. + + Parameters: + args: Command-line arguments. + trainer: Client trainer responsible for local training. + comm: Communication backend for distributed training. + rank (int): Rank of the client process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + + Methods: + __init__: Initialize the FedNovaClientManager. + run: Start the client manager. + register_message_receive_handlers: Register message receive handlers for handling incoming messages. + handle_message_init: Handle the initialization message received from the server. + start_training: Start the training process. + handle_message_receive_model_from_server: Handle the received model from the server. + send_result_to_server: Send training results to the server. + add_client_model: Add client model parameters to the aggregation. + __train: Perform the training process for the specified clients. + """ + def __init__( self, args, @@ -18,6 +41,17 @@ def __init__( size=0, backend="MPI", ): + """ + Initialize the FedNovaClientManager. + + Args: + args: Command-line arguments. + trainer: Client trainer responsible for local training. + comm: Communication backend for distributed training. + rank (int): Rank of the client process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round @@ -25,9 +59,15 @@ def __init__( self.worker_id = self.rank - 1 def run(self): + """ + Start the client manager. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -37,9 +77,13 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) + """ + Handle the initialization message received from the server. + Args: + msg_params: Parameters included in the received message. + """ + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -48,14 +92,19 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_indexes, average_weight_dict) def start_training(self): + """ + Start the training process. + """ self.round_idx = 0 - # self.__train() def handle_message_receive_model_from_server(self, msg_params): - logging.info("handle_message_receive_model_from_server.") - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) + """ + Handle the received model from the server. + Args: + msg_params: Parameters included in the received message. + """ + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -63,40 +112,52 @@ def handle_message_receive_model_from_server(self, msg_params): self.round_idx += 1 self.__train(global_model_params, client_indexes, average_weight_dict) if self.round_idx == self.num_rounds - 1: - # post_complete_message_to_sweep_process(self.args) self.finish() - def send_result_to_server(self, receive_id, weights, client_runtime_info): + """ + Send training results to the server. + + Args: + receive_id: ID of the recipient (e.g., the server). + weights: Model weights or parameters. + client_runtime_info: Information about client runtime. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id, ) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - # message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info) self.send_message(message) - def add_client_model(self, local_agg_model_params, client_index, grad, t_eff, weight=1.0): - # Add params that needed to be reduces from clients - # for name, param in model_params.items(): - # if name not in local_agg_model_params: - # local_agg_model_params[name] = param * weight - # else: - # local_agg_model_params[name] += param * weight - # local_agg_model_params[client_index]["grad"] = grad - # local_agg_model_params[client_index]["t_eff"] = t_eff + """ + Add client model parameters to the aggregation. + + Args: + local_agg_model_params: Local aggregation of model parameters. + client_index: Index or ID of the client. + grad: Gradients computed during training. + t_eff: Efficiency factor. + weight: Weight assigned to the client's contribution. + """ local_agg_model_params.append({ "grad": grad, "t_eff": t_eff, }) - def __train(self, global_model_params, client_indexes, average_weight_dict): + """ + Perform the training process for the specified clients. + + Args: + global_model_params: Global model parameters. + client_indexes: Indexes of the clients to train. + average_weight_dict: Dictionary of average weights for clients. + """ logging.info("#######training########### round_id = %d" % self.round_idx) - # local_agg_model_params = {} local_agg_model_params = [] client_runtime_info = {} for client_index in client_indexes: @@ -105,7 +166,6 @@ def __train(self, global_model_params, client_indexes, average_weight_dict): start_time = time.time() self.trainer.update_model(global_model_params) self.trainer.update_dataset(int(client_index)) - # weights, local_sample_num = self.trainer.train(self.round_idx) loss, grad, t_eff = self.trainer.train(self.round_idx) self.add_client_model(local_agg_model_params, client_index, grad, t_eff, weight=average_weight_dict[client_index]) @@ -116,14 +176,3 @@ def __train(self, global_model_params, client_indexes, average_weight_dict): logging.info("#######training########### End Simulating client_index = %d, consuming time: %f" % \ (client_index, client_runtime)) self.send_result_to_server(0, local_agg_model_params, client_runtime_info) - - - - - - - - - - - diff --git a/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py b/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py index 97d257dbe2..d423d03d99 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py @@ -8,6 +8,28 @@ class FedNovaServerManager(FedMLCommManager): + """ + Manager for the server-side of the FedNova federated learning process. + + Methods: + __init__: Initialize the FedNovaServerManager. + run: Start the server manager. + send_init_msg: Send initialization messages to clients. + register_message_receive_handlers: Register message receive handlers for handling incoming messages. + handle_message_receive_model_from_client: Handle the received model from a client. + send_message_init_config: Send initialization configuration message to a client. + send_message_sync_model_to_client: Send model synchronization message to a client. + + Parameters: + args: Command-line arguments. + aggregator: Server aggregator responsible for aggregating client updates. + comm: Communication backend for distributed training. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + is_preprocessed (bool): Indicates whether clients have been preprocessed. + preprocessed_client_lists (list): Lists of preprocessed clients for each round. + """ def __init__( self, args, @@ -19,6 +41,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the FedNovaServerManager. + + Args: + args: Command-line arguments. + aggregator: Server aggregator responsible for aggregating client updates. + comm: Communication backend for distributed training. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + is_preprocessed (bool): Indicates whether clients have been preprocessed. + preprocessed_client_lists (list): Lists of preprocessed clients for each round. + """ super().__init__(args, comm, rank, size, backend) self.args = args self.aggregator = aggregator @@ -28,12 +63,18 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Start the server manager. + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + """ # sampling clients client_indexes = self.aggregator.client_sampling( self.round_idx, @@ -53,12 +94,18 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model from a client. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -112,8 +159,18 @@ def handle_message_receive_model_from_client(self, msg_params): average_weight_dict, client_schedule ) - def send_message_init_config(self, receive_id, global_model_params, - average_weight_dict, client_schedule): + def send_message_init_config( + self, receive_id, global_model_params, average_weight_dict, client_schedule + ): + """ + Send initialization configuration message to a client. + + Args: + receive_id: Receiver's process ID. + global_model_params: Global model parameters. + average_weight_dict: Dictionary of average weights for clients. + client_schedule: Schedule of clients for the current round. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -123,8 +180,22 @@ def send_message_init_config(self, receive_id, global_model_params, message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE, client_schedule) self.send_message(message) - def send_message_sync_model_to_client(self, receive_id, global_model_params, - average_weight_dict, client_schedule): + def send_message_sync_model_to_client( + self, + receive_id, + global_model_params, + average_weight_dict, + client_schedule + ): + """ + Send model synchronization message to a client. + + Args: + receive_id: Receiver's process ID. + global_model_params: Global model parameters. + average_weight_dict: Dictionary of average weights for clients. + client_schedule: Schedule of clients for the current round. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py b/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py index d55e6d9822..420be4e3ac 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py @@ -2,6 +2,27 @@ class FedNovaTrainer(object): + """ + Trainer class for FedNova federated learning. + + Methods: + __init__: Initialize the FedNovaTrainer. + update_model: Update the model with global weights. + update_dataset: Update the local dataset for training. + get_lr: Calculate the learning rate for the current round. + train: Train the model on the local dataset. + test: Evaluate the model on the local training and test datasets. + + Parameters: + client_index (int): Index of the client. + train_data_local_dict (dict): Local training dataset for each client. + train_data_local_num_dict (dict): Number of samples in the local training dataset for each client. + test_data_local_dict (dict): Local test dataset for each client. + train_data_num (int): Total number of training samples across all clients. + device: Device (e.g., GPU or CPU) for model training. + args: Command-line arguments. + model_trainer: Trainer for the machine learning model. + """ def __init__( self, client_index, @@ -13,6 +34,19 @@ def __init__( args, model_trainer, ): + """ + Initialize the FedNovaTrainer. + + Args: + client_index (int): Index of the client. + train_data_local_dict (dict): Local training dataset for each client. + train_data_local_num_dict (dict): Number of samples in the local training dataset for each client. + test_data_local_dict (dict): Local test dataset for each client. + train_data_num (int): Total number of training samples across all clients. + device: Device (e.g., GPU or CPU) for model training. + args: Command-line arguments. + model_trainer: Trainer for the machine learning model. + """ self.trainer = model_trainer self.client_index = client_index @@ -29,15 +63,36 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model with global weights. + + Args: + weights: Global model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local dataset for training. + + Args: + client_index (int): Index of the client. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def get_lr(self, progress): + """ + Calculate the learning rate for the current round. + + Args: + progress (int): Current round index. + + Returns: + float: Learning rate. + """ # This aims to make a float step_size work. if self.args.lr_schedule == "StepLR": exp_num = progress / self.args.lr_step_size @@ -57,18 +112,28 @@ def get_lr(self, progress): return lr def train(self, round_idx=None): + """ + Train the model on the local dataset. + + Args: + round_idx (int): Current round index. + + Returns: + tuple: A tuple containing average loss, normalized gradient, and effective tau. + """ self.args.round_idx = round_idx - # lr = self.get_lr(round_idx) - # self.trainer.train(self.train_local, self.device, self.args, lr=lr) avg_loss, norm_grad, tau_eff = self.trainer.train(self.train_local, self.device, self.args, ratio=self.local_sample_number / self.total_train_num) - # weights = self.trainer.get_model_params() - - # return weights, self.local_sample_number return avg_loss, norm_grad, tau_eff - def test(self): + """ + Evaluate the model on the local training and test datasets. + + Returns: + tuple: A tuple containing training accuracy, training loss, training sample count, + test accuracy, test loss, and test sample count. + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( @@ -93,3 +158,4 @@ def test(self): test_loss, test_num_sample, ) + \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py b/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py index bf56731b5b..698730fd07 100644 --- a/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py +++ b/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py @@ -1,18 +1,55 @@ import torch from torch import nn - from ....core.alg_frame.client_trainer import ClientTrainer import logging - class MyModelTrainer(ClientTrainer): + """ + Custom client trainer for federated learning using PyTorch. + + Methods: + get_model_params: Get the model parameters as a state dictionary. + set_model_params: Set the model parameters from a state dictionary. + train: Train the model on the given training data. + test: Evaluate the model on the given test data. + test_on_the_server: Perform server-side testing (not implemented). + + Parameters: + model: The PyTorch model to be trained. + id (int): The identifier of the client. + """ + def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: Model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): Model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, lr=None): + """ + Train the model on the given training data. + + Args: + train_data: Training data for the client. + device: Device (e.g., GPU or CPU) for model training. + args: Command-line arguments for training configuration. + lr (float): Learning rate for optimization (optional). + + Returns: + None + """ model = self.model model.to(device) @@ -44,7 +81,7 @@ def train(self, train_data, device, args, lr=None): loss = criterion(log_probs, labels) # pylint: disable=E1102 loss.backward() - # Uncommet this following line to avoid nan loss + # Uncomment this following line to avoid nan loss # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) optimizer.step() @@ -66,6 +103,17 @@ def train(self, train_data, device, args, lr=None): ) def test(self, test_data, device, args): + """ + Evaluate the model on the given test data. + + Args: + test_data: Test data for the client. + device: Device (e.g., GPU or CPU) for model evaluation. + args: Command-line arguments for evaluation configuration. + + Returns: + dict: Evaluation metrics, including test_correct, test_loss, and test_total. + """ model = self.model model.to(device) @@ -93,4 +141,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Perform server-side testing (not implemented). + + Args: + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + device: Device (e.g., GPU or CPU) for testing. + args: Command-line arguments for testing configuration (not used). + + Returns: + bool: Always returns False (not implemented). + """ return False diff --git a/python/fedml/simulation/mpi/fednova/utils.py b/python/fedml/simulation/mpi/fednova/utils.py index aea2449590..f19b0adbab 100644 --- a/python/fedml/simulation/mpi/fednova/utils.py +++ b/python/fedml/simulation/mpi/fednova/utils.py @@ -1,24 +1,47 @@ import os - -import numpy as np import torch - +import numpy as np def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list of NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): Dictionary of model parameters represented as NumPy arrays. + + Returns: + dict: Dictionary of model parameters with tensors as values. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to lists. + + Args: + model_params (dict): Dictionary of model parameters represented as PyTorch tensors. + + Returns: + dict: Dictionary of model parameters with lists as values. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe for communication with another process. + + Args: + args: Additional information or configuration to include in the message. + + Returns: + None + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): @@ -26,4 +49,4 @@ def post_complete_message_to_sweep_process(args): pipe_fd = os.open(pipe_path, os.O_WRONLY) with os.fdopen(pipe_fd, "w") as pipe: - pipe.write("training is finished! \n%s\n" % (str(args))) + pipe.write("Training is finished! \n%s\n" % (str(args))) diff --git a/python/fedml/simulation/mpi/fedopt/FedOptAPI.py b/python/fedml/simulation/mpi/fedopt/FedOptAPI.py index dd1ec50208..81b48950f4 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptAPI.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptAPI.py @@ -10,6 +10,11 @@ def FedML_init(): + """Initialize the Federated Learning environment using MPI. + + Returns: + tuple: A tuple containing MPI communication object, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -24,9 +29,23 @@ def FedML_FedOpt_distributed( device, dataset, model, - client_trainer: ClientTrainer = None, - server_aggregator: ServerAggregator = None, + client_trainer=None, + server_aggregator=None, ): + """Initialize and run the Federated Optimization process. + + Args: + args: A configuration object containing federated optimization parameters. + process_id: The process ID. + worker_number: The total number of workers. + comm: MPI communication object. + device: The device (e.g., CPU or GPU) for training. + dataset: A list containing dataset information. + model: The machine learning model. + client_trainer: An optional client trainer object. + server_aggregator: An optional server aggregator object. + + """ [ train_data_num, test_data_num, @@ -37,6 +56,7 @@ def FedML_FedOpt_distributed( test_data_local_dict, class_num, ] = dataset + if process_id == 0: init_server( args, @@ -84,10 +104,28 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """Initialize the server-side components for federated optimization. + + Args: + args: A configuration object containing server parameters. + device: The device (e.g., CPU or GPU) for training. + comm: MPI communication object. + rank: The rank of the server process. + size: The total number of processes. + model: The machine learning model. + train_data_num: The number of training data samples. + train_data_global: Global training data. + test_data_global: Global test data. + train_data_local_dict: Dictionary of local training data. + test_data_local_dict: Dictionary of local test data. + train_data_local_num_dict: Dictionary of the number of local training data samples. + server_aggregator: The server aggregator object. + + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) - # aggregator + worker_num = size - 1 aggregator = FedOptAggregator( train_data_global, @@ -102,7 +140,7 @@ def init_server( server_aggregator, ) - # start the distributed training + server_manager = FedOptServerManager(args, aggregator, comm, rank, size) server_manager.send_init_msg() server_manager.run() @@ -121,6 +159,22 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """Initialize the client-side components for federated optimization. + + Args: + args: A configuration object containing client parameters. + device: The device (e.g., CPU or GPU) for training. + comm: MPI communication object. + process_id: The process ID. + size: The total number of processes. + model: The machine learning model. + train_data_num: The number of training data samples. + train_data_local_num_dict: Dictionary of the number of local training data samples. + train_data_local_dict: Dictionary of local training data. + test_data_local_dict: Dictionary of local test data. + model_trainer: An optional client trainer object. + + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) @@ -129,5 +183,6 @@ def init_client( trainer = FedOptTrainer( client_index, train_data_local_dict, train_data_local_num_dict, train_data_num, device, args, model_trainer, ) + client_manager = FedOptClientManager(args, trainer, comm, process_id, size) client_manager.run() diff --git a/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py b/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py index e86172ec2c..5d589f1c76 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py @@ -12,6 +12,39 @@ class FedOptAggregator(object): + """Aggregator for Federated Optimization. + + This class manages the aggregation of model updates from client devices in a federated optimization setting. + + Args: + train_global: The global training dataset. + test_global: The global testing dataset. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + args: An argparse.Namespace object containing various configuration options. + server_aggregator: An optional ServerAggregator object used for model aggregation. + + Attributes: + aggregator: The server aggregator for model aggregation. + args: An argparse.Namespace object containing various configuration options. + train_global: The global training dataset. + test_global: The global testing dataset. + val_global: A subset of the testing dataset used for validation. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + model_dict: A dictionary mapping client indices to their local model updates. + sample_num_dict: A dictionary mapping client indices to the number of samples used for their local updates. + flag_client_model_uploaded_dict: A dictionary tracking whether each client has uploaded its local model update. + opt: The server optimizer used for model aggregation. + """ def __init__( self, train_global, @@ -25,6 +58,20 @@ def __init__( args, server_aggregator, ): + """Initialize the FedOptAggregator. + + Args: + train_global: Global training data. + test_global: Global test data. + all_train_data_num: Total number of training data samples. + train_data_local_dict: Dictionary of local training data. + test_data_local_dict: Dictionary of local test data. + train_data_local_num_dict: Dictionary of the number of local training data samples. + worker_num: Number of worker clients. + device: The device (e.g., CPU or GPU) for training. + args: A configuration object containing aggregator parameters. + server_aggregator: The server aggregator object. + """ self.aggregator = server_aggregator self.args = args @@ -47,6 +94,11 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def _instantiate_opt(self): + """Instantiate the optimizer. + + Returns: + torch.optim.Optimizer: The instantiated optimizer. + """ return OptRepo.name2cls(self.args.server_optimizer)( filter(lambda p: p.requires_grad, self.get_model_params()), lr=self.args.server_lr, @@ -54,23 +106,48 @@ def _instantiate_opt(self): ) def get_model_params(self): - # return model parameters in type of generator + """Get model parameters. + + Returns: + generator: Generator of model parameters. + """ return self.aggregator.model.parameters() def get_global_model_params(self): - # return model parameters in type of ordered_dict + """Get global model parameters. + + Returns: + OrderedDict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """Set global model parameters. + + Args: + model_parameters: New global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """Add locally trained model results. + + Args: + index: Index of the client. + model_params: Model parameters. + sample_num: Number of training samples. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """Check if all clients have uploaded their models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -79,6 +156,11 @@ def check_whether_all_receive(self): return True def aggregate(self): + """Aggregate locally trained models. + + Returns: + OrderedDict: Aggregated global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 @@ -89,8 +171,9 @@ def aggregate(self): logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) + (num0, averaged_params) = model_list[0] + for k in averaged_params.keys(): for i in range(0, len(model_list)): local_sample_number, local_model_params = model_list[i] @@ -100,14 +183,13 @@ def aggregate(self): else: averaged_params[k] += local_model_params[k] * w - # server optimizer - # save optimizer state + # Server optimizer self.opt.zero_grad() opt_state = self.opt.state_dict() - # set new aggregated grad + self.set_model_global_grads(averaged_params) self.opt = self._instantiate_opt() - # load optimizer state + self.opt.load_state_dict(opt_state) self.opt.step() @@ -116,30 +198,53 @@ def aggregate(self): return self.get_global_model_params() def set_model_global_grads(self, new_state): + """Set global model gradients. + + Args: + new_state: New global model parameters. + """ new_model = copy.deepcopy(self.aggregator.model) new_model.load_state_dict(new_state) with torch.no_grad(): for parameter, new_parameter in zip(self.aggregator.model.parameters(), new_model.parameters()): parameter.grad = parameter.data - new_parameter.data - # because we go to the opposite direction of the gradient + model_state_dict = self.aggregator.model.state_dict() new_model_state_dict = new_model.state_dict() for k in dict(self.aggregator.model.named_parameters()).keys(): new_model_state_dict[k] = model_state_dict[k] - # self.trainer.model.load_state_dict(new_model_state_dict) + self.set_global_model_params(new_model_state_dict) def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """Sample clients for communication. + + Args: + round_idx: The current communication round. + client_num_in_total: Total number of clients. + client_num_per_round: Number of clients to sample per round. + + Returns: + list: List of sampled client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """Generate a validation dataset. + + Args: + num_samples: Number of samples in the validation dataset. + + Returns: + DataLoader: DataLoader for the validation dataset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -150,6 +255,11 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """Test on the server for all clients. + + Args: + round_idx: The current communication round. + """ if self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -169,7 +279,7 @@ def test_on_server_for_all_clients(self, round_idx): train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): - # train data + # Train data metrics = self.aggregator.test( self.train_data_local_dict[client_idx], self.device, self.args ) @@ -182,7 +292,7 @@ def test_on_server_for_all_clients(self, round_idx): train_num_samples.append(copy.deepcopy(train_num_sample)) train_losses.append(copy.deepcopy(train_loss)) - # test on training dataset + # Test on training dataset train_acc = sum(train_tot_corrects) / sum(train_num_samples) train_loss = sum(train_losses) / sum(train_num_samples) if self.args.enable_wandb: @@ -191,7 +301,7 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - # test data + # Test data test_num_samples = [] test_tot_corrects = [] test_losses = [] @@ -210,7 +320,7 @@ def test_on_server_for_all_clients(self, round_idx): test_num_samples.append(copy.deepcopy(test_num_sample)) test_losses.append(copy.deepcopy(test_loss)) - # test on test dataset + # Test on test dataset test_acc = sum(test_tot_corrects) / sum(test_num_samples) test_loss = sum(test_losses) / sum(test_num_samples) if self.args.enable_wandb: diff --git a/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py b/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py index 63222972ea..bc48cd9040 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py @@ -7,6 +7,30 @@ class FedOptClientManager(FedMLCommManager): + """Manages client-side operations for federated optimization. + + This class is responsible for managing client-side operations during federated optimization. + It handles communication with the server, updates model parameters, and performs training rounds. + + Attributes: + args: A configuration object containing client parameters. + trainer: An instance of the federated optimizer trainer. + comm: The communication backend. + rank: The rank of the client in the communication group. + size: The total number of processes in the communication group. + backend: The communication backend (e.g., "MPI"). + + Methods: + run(): Runs the client manager to participate in federated optimization. + register_message_receive_handlers(): Registers message handlers for receiving updates from the server. + handle_message_init(msg_params): Handles initialization messages from the server. + start_training(): Starts the federated training process. + handle_message_receive_model_from_server(msg_params): Handles received model updates from the server. + send_model_to_server(receive_id, weights, local_sample_num): Sends updated model to the server. + __train(): Performs the training process. + + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -14,9 +38,11 @@ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): self.args.round_idx = 0 def run(self): + """Runs the client manager to participate in federated optimization.""" super().run() def register_message_receive_handlers(self): + """Registers message handlers for receiving updates from the server.""" self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +52,7 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """Handles initialization messages from the server.""" global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -35,10 +62,12 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """Starts the federated training process.""" self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """Handles received model updates from the server.""" logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -52,6 +81,7 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """Sends updated model to the server.""" message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -62,6 +92,7 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """Performs the training process.""" logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py b/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py index febbb4ac39..a1fa6e85d6 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py @@ -5,8 +5,32 @@ from ....core.distributed.fedml_comm_manager import FedMLCommManager from ....core.distributed.communication.message import Message - class FedOptServerManager(FedMLCommManager): + """Manages the server-side operations for federated optimization. + + This class is responsible for managing the server-side operations during federated optimization. + It handles communication with clients, aggregation of model updates, and coordination of training rounds. + + Attributes: + args: A configuration object containing server parameters. + aggregator: An aggregator for collecting and aggregating model updates from clients. + comm: The communication backend. + rank: The rank of the server in the communication group. + size: The total number of processes in the communication group. + backend: The communication backend (e.g., "MPI"). + is_preprocessed: A boolean flag indicating whether data preprocessing has been applied. + preprocessed_client_lists: A list of preprocessed client data (optional). + + Methods: + run(): Runs the server manager to coordinate federated optimization. + send_init_msg(): Sends initialization messages to clients at the start of each round. + register_message_receive_handlers(): Registers message handlers for receiving updates from clients. + handle_message_receive_model_from_client(msg_params): Handles received model updates from clients. + send_message_init_config(receive_id, global_model_params, client_index): Sends initialization messages to clients. + send_message_sync_model_to_client(receive_id, global_model_params, client_index): Sends updated models to clients. + + """ + def __init__( self, args, @@ -27,10 +51,12 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """Runs the server manager to coordinate federated optimization.""" super().run() def send_init_msg(self): - # sampling clients + """Sends initialization messages to clients at the start of each round.""" + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -43,12 +69,14 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """Registers message handlers for receiving updates from clients.""" self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """Handles received model updates from clients.""" sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -62,36 +90,35 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: post_complete_message_to_sweep_process(self.args) self.finish() return - # sampling clients + # Sampling clients if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, self.args.client_num_per_round, ) - print("size = %d" % self.size) - for receiver_id in range(1, self.size): self.send_message_sync_model_to_client( receiver_id, global_model_params, client_indexes[receiver_id - 1] ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """Sends initialization messages to clients.""" message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -102,6 +129,7 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """Sends updated models to clients.""" logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py b/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py index 00661f35b0..8f99915857 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py @@ -2,6 +2,28 @@ class FedOptTrainer(object): + """Trains a federated optimizer on a client's local data. + + This class is responsible for training a federated optimizer on a client's + local data. It updates the model using the federated optimization technique + and returns the updated model weights. + + Attributes: + trainer: The model trainer used for local training. + client_index: The index of the client. + train_data_local_dict: A dictionary containing local training data. + train_data_local_num_dict: A dictionary containing the number of samples for each client. + all_train_data_num: The total number of training samples across all clients. + device: The device (e.g., CPU or GPU) for training. + args: A configuration object containing training parameters. + + Methods: + update_model(weights): Updates the model with the provided weights. + update_dataset(client_index): Updates the dataset for the given client. + train(round_idx=None): Trains the federated optimizer on the local data. + + """ + def __init__( self, client_index, @@ -17,22 +39,44 @@ def __init__( self.client_index = client_index self.train_data_local_dict = train_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict + self.all_train_data_num = train_data_num - # self.train_local = self.train_data_local_dict[client_index] - # self.local_sample_number = self.train_data_local_num_dict[client_index] + self.device = device self.args = args def update_model(self, weights): + """Update the model with the provided weights. + + Args: + weights: The updated model weights. + + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """Update the dataset for the given client. + + Args: + client_index: The index of the client. + + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] def train(self, round_idx=None): + """Train the federated optimizer on the local data. + + Args: + round_idx: The index of the training round (optional). + + Returns: + weights: The updated model weights. + local_sample_number: The number of local training samples. + + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) diff --git a/python/fedml/simulation/mpi/fedopt/optrepo.py b/python/fedml/simulation/mpi/fedopt/optrepo.py index 50615227d7..14a7e1e135 100644 --- a/python/fedml/simulation/mpi/fedopt/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt/optrepo.py @@ -1,17 +1,29 @@ import logging from typing import List, Union - import torch - class OptRepo: - """Collects and provides information about the subclasses of torch.optim.Optimizer.""" + """Collects and provides information about the subclasses of torch.optim.Optimizer. + + This class allows you to access and retrieve information about different PyTorch + optimizer classes. + + Attributes: + repo (dict): A dictionary containing optimizer class names as keys and the + corresponding optimizer classes as values. + + Methods: + get_opt_names(): Returns a list of supported optimizer names. + name2cls(name: str): Returns the optimizer class based on its name. + supported_parameters(opt: Union[str, torch.optim.Optimizer]): Returns a list of + __init__ function parameters of an optimizer. + """ repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} @classmethod def get_opt_names(cls) -> List[str]: - """Returns a list of supported optimizers. + """Returns a list of supported optimizer names. Returns: List[str]: Names of optimizers. @@ -29,6 +41,9 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: Returns: torch.optim.Optimizer: The class corresponding to the name. + + Raises: + KeyError: If the provided optimizer name is invalid. """ try: return cls.repo[name.lower()] @@ -39,7 +54,7 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: @classmethod def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[str]: - """Returns a lost of __init__ function parametrs of an optimizer. + """Returns a list of __init__ function parameters of an optimizer. Args: opt (Union[str, torch.optim.Optimizer]): The name or class of the optimizer. diff --git a/python/fedml/simulation/mpi/fedopt/utils.py b/python/fedml/simulation/mpi/fedopt/utils.py index aea2449590..5bcbb1954a 100644 --- a/python/fedml/simulation/mpi/fedopt/utils.py +++ b/python/fedml/simulation/mpi/fedopt/utils.py @@ -5,6 +5,15 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list of NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): Dictionary of model parameters represented as NumPy arrays. + + Returns: + dict: Dictionary of model parameters with tensors as values. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -13,12 +22,30 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to lists. + + Args: + model_params (dict): Dictionary of model parameters represented as PyTorch tensors. + + Returns: + dict: Dictionary of model parameters with lists as values. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe for communication with another process. + + Args: + args: Additional information or configuration to include in the message. + + Returns: + None + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): @@ -26,4 +53,4 @@ def post_complete_message_to_sweep_process(args): pipe_fd = os.open(pipe_path, os.O_WRONLY) with os.fdopen(pipe_fd, "w") as pipe: - pipe.write("training is finished! \n%s\n" % (str(args))) + pipe.write("Training is finished! \n%s\n" % (str(args))) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py index dc91017c3b..70920e7688 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py @@ -15,6 +15,42 @@ class FedOptAggregator(object): + """Aggregator for Federated Optimization. + + This class manages the aggregation of model updates from client devices in a federated optimization setting. + + Args: + train_global: The global training dataset. + test_global: The global testing dataset. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + args: An argparse.Namespace object containing various configuration options. + server_aggregator: An optional ServerAggregator object used for model aggregation. + + Attributes: + aggregator: The server aggregator for model aggregation. + args: An argparse.Namespace object containing various configuration options. + train_global: The global training dataset. + test_global: The global testing dataset. + val_global: A subset of the testing dataset used for validation. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + model_dict: A dictionary mapping client indices to their local model updates. + sample_num_dict: A dictionary mapping client indices to the number of samples used for their local updates. + flag_client_model_uploaded_dict: A dictionary tracking whether each client has uploaded its local model update. + opt: The server optimizer used for model aggregation. + runtime_history: A dictionary to track the runtime history of clients. + runtime_avg: A dictionary to track the average runtime of clients. + """ + def __init__( self, train_global, @@ -28,6 +64,11 @@ def __init__( args, server_aggregator, ): + """Instantiate the server optimizer based on configuration options. + + Returns: + torch.optim.Optimizer: The server optimizer. + """ self.aggregator = server_aggregator self.args = args @@ -59,6 +100,12 @@ def __init__( def _instantiate_opt(self): + """ + Instantiate the server optimizer based on configuration options. + + Returns: + torch.optim.Optimizer: The server optimizer. + """ return OptRepo.name2cls(self.args.server_optimizer)( filter(lambda p: p.requires_grad, self.get_model_params()), lr=self.args.server_lr, @@ -66,23 +113,55 @@ def _instantiate_opt(self): ) def get_model_params(self): + """ + Get the model parameters in the form of a generator. + + Returns: + generator: A generator of model parameters. + """ # return model parameters in type of generator return self.aggregator.model.parameters() def get_global_model_params(self): + """ + Get the global model parameters as an ordered dictionary. + + Returns: + collections.OrderedDict: The global model parameters. + """ + # return model parameters in type of ordered_dict return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters based on a provided dictionary. + + Args: + model_parameters (dict): A dictionary containing global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params): + """ + Add the local trained model update for a client. + + Args: + index (int): The index of the client. + model_params (dict): The model parameters of the local trained model. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params # self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all clients have uploaded their local model updates. + + Returns: + bool: True if all clients have uploaded their updates, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -93,6 +172,16 @@ def check_whether_all_receive(self): def workload_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the workload for selected clients. + + Args: + client_indexes (list): The indices of selected clients. + mode (str): The mode for workload estimation ("simulate" or "real"). + + Returns: + list: Workload estimates for the selected clients. + """ if mode == "simulate": client_samples = [ self.train_data_local_num_dict[client_index] @@ -106,6 +195,16 @@ def workload_estimate(self, client_indexes, mode="simulate"): return workload def memory_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the memory requirements for selected clients. + + Args: + client_indexes (list): The indices of selected clients. + mode (str): The mode for memory estimation ("simulate" or "real"). + + Returns: + numpy.ndarray: Memory estimates for the selected clients. + """ if mode == "simulate": memory = np.ones(self.worker_num) elif mode == "real": @@ -115,6 +214,15 @@ def memory_estimate(self, client_indexes, mode="simulate"): return memory def resource_estimate(self, mode="simulate"): + """ + Estimate the resource requirements for clients. + + Args: + mode (str): The mode for resource estimation ("simulate" or "real"). + + Returns: + numpy.ndarray: Resource estimates for clients. + """ if mode == "simulate": resource = np.ones(self.worker_num) elif mode == "real": @@ -124,6 +232,13 @@ def resource_estimate(self, mode="simulate"): return resource def record_client_runtime(self, worker_id, client_runtimes): + """ + Record the runtime of clients. + + Args: + worker_id (int): The ID of the worker (client). + client_runtimes (dict): A dictionary mapping client IDs to their runtimes. + """ for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id].append(runtime) if hasattr(self.args, "runtime_est_mode"): @@ -140,6 +255,15 @@ def record_client_runtime(self, worker_id, client_runtimes): def generate_client_schedule(self, round_idx, client_indexes): + """Generate a schedule of clients for training. + + Args: + round_idx (int): The current communication round index. + client_indexes (list): The indices of selected clients. + + Returns: + list: A schedule of clients for training. + """ # self.runtime_history = {} # for i in range(self.worker_num): # self.runtime_history[i] = {} @@ -195,6 +319,14 @@ def generate_client_schedule(self, round_idx, client_indexes): def get_average_weight(self, client_indexes): + """Calculate the average weight for selected clients. + + Args: + client_indexes (list): The indices of selected clients. + + Returns: + dict: A dictionary mapping client indices to their average weights. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -208,6 +340,12 @@ def get_average_weight(self, client_indexes): def aggregate(self): + """ + Aggregate the model updates from clients. + + Returns: + collections.OrderedDict: The aggregated global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 @@ -246,6 +384,12 @@ def aggregate(self): return self.get_global_model_params() def set_model_global_grads(self, new_state): + """ + Set the global model gradients based on a provided dictionary. + + Args: + new_state (dict): A dictionary containing the new global model gradients. + """ new_model = copy.deepcopy(self.aggregator.model) new_model.load_state_dict(new_state) with torch.no_grad(): @@ -260,6 +404,16 @@ def set_model_global_grads(self, new_state): self.set_global_model_params(new_model_state_dict) def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """Randomly sample a subset of clients for a communication round. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to sample for the round. + + Returns: + list: A list of indices representing the selected clients for the round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -270,6 +424,14 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """Generate a subset of the testing dataset for validation. + + Args: + num_samples (int): The number of samples to include in the validation set. + + Returns: + torch.utils.data.DataLoader: A DataLoader containing the validation subset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -280,6 +442,11 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """Test the global model on all clients. + + Args: + round_idx (int): The current communication round index. + """ if ( round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1 From 6356bfbf6890fa96d0ccd7081e5f841ef5a53ada Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 13:14:08 +0530 Subject: [PATCH 16/70] g --- .../simulation/mpi/fedopt_seq/optrepo.py | 7 ++++- .../fedml/simulation/mpi/fedopt_seq/utils.py | 30 +++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py index 50615227d7..6942b78b85 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py @@ -3,6 +3,7 @@ import torch +from typing import List, Union class OptRepo: """Collects and provides information about the subclasses of torch.optim.Optimizer.""" @@ -29,6 +30,9 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: Returns: torch.optim.Optimizer: The class corresponding to the name. + + Raises: + KeyError: If the provided optimizer name is invalid. """ try: return cls.repo[name.lower()] @@ -39,7 +43,7 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: @classmethod def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[str]: - """Returns a lost of __init__ function parametrs of an optimizer. + """Returns a list of __init__ function parameters of an optimizer. Args: opt (Union[str, torch.optim.Optimizer]): The name or class of the optimizer. @@ -60,4 +64,5 @@ def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[st @classmethod def _update_repo(cls): + """Updates the optimizer repository with the latest subclasses.""" cls.repo = {x.__name__: x for x in torch.optim.Optimizer.__subclasses__()} diff --git a/python/fedml/simulation/mpi/fedopt_seq/utils.py b/python/fedml/simulation/mpi/fedopt_seq/utils.py index aea2449590..8c2ff95a19 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/utils.py +++ b/python/fedml/simulation/mpi/fedopt_seq/utils.py @@ -1,24 +1,44 @@ +import torch +import numpy as np import os -import numpy as np -import torch +def transform_list_to_tensor(model_params_list): + """ + Convert a dictionary of model parameters from NumPy arrays in a list to PyTorch tensors. + Args: + model_params_list (dict): A dictionary of model parameters, where values are lists of NumPy arrays. -def transform_list_to_tensor(model_params_list): + Returns: + dict: A dictionary of model parameters with values as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Convert a dictionary of model parameters from PyTorch tensors to lists of NumPy arrays. + + Args: + model_params (dict): A dictionary of model parameters, where values are PyTorch tensors. + + Returns: + dict: A dictionary of model parameters with values as lists of NumPy arrays. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Send a completion message to a sweep process using a named pipe. + + Args: + args (str): A string containing information about the training completion status or other relevant details. + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 73a4289c7fb1c2f707293369066422beaf3ad021 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 13:17:05 +0530 Subject: [PATCH 17/70] gg --- .../mpi/fedopt_seq/FedOptTrainer.py | 33 +++++++++++++++++-- .../simulation/mpi/fedopt_seq/optrepo.py | 17 +++++++--- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py index 00661f35b0..162fbf1ef9 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py @@ -2,6 +2,8 @@ class FedOptTrainer(object): + """Trains a federated learning model for a specific client.""" + def __init__( self, client_index, @@ -12,27 +14,54 @@ def __init__( args, model_trainer, ): + """Initialize the FedOptTrainer. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary mapping client indexes to their local training datasets. + train_data_local_num_dict (dict): A dictionary mapping client indexes to the number of samples in their local datasets. + train_data_num (int): The total number of training samples. + device (str): The device (e.g., 'cuda' or 'cpu') on which to perform training. + args (object): Configuration parameters for training. + model_trainer (object): An instance of the model trainer for this client. + """ self.trainer = model_trainer self.client_index = client_index self.train_data_local_dict = train_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict self.all_train_data_num = train_data_num - # self.train_local = self.train_data_local_dict[client_index] - # self.local_sample_number = self.train_data_local_num_dict[client_index] self.device = device self.args = args def update_model(self, weights): + """Update the model parameters. + + Args: + weights (dict): A dictionary containing the updated model parameters. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """Update the local dataset for the client. + + Args: + client_index (int): The index of the client whose dataset should be updated. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] def train(self, round_idx=None): + """Train the federated learning model for the client. + + Args: + round_idx (int, optional): The current federated learning round index. Defaults to None. + + Returns: + tuple: A tuple containing the updated model weights and the number of local samples used for training. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) diff --git a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py index 6942b78b85..df6ec80985 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py @@ -6,13 +6,16 @@ from typing import List, Union class OptRepo: - """Collects and provides information about the subclasses of torch.optim.Optimizer.""" + """ + Collects and provides information about the subclasses of torch.optim.Optimizer. + """ repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} @classmethod def get_opt_names(cls) -> List[str]: - """Returns a list of supported optimizers. + """ + Returns a list of supported optimizers. Returns: List[str]: Names of optimizers. @@ -23,7 +26,8 @@ def get_opt_names(cls) -> List[str]: @classmethod def name2cls(cls, name: str) -> torch.optim.Optimizer: - """Returns the optimizer class belonging to the name. + """ + Returns the optimizer class belonging to the name. Args: name (str): Name of the optimizer. @@ -43,7 +47,8 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: @classmethod def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[str]: - """Returns a list of __init__ function parameters of an optimizer. + """ + Returns a list of __init__ function parameters of an optimizer. Args: opt (Union[str, torch.optim.Optimizer]): The name or class of the optimizer. @@ -64,5 +69,7 @@ def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[st @classmethod def _update_repo(cls): - """Updates the optimizer repository with the latest subclasses.""" + """ + Updates the optimizer repository with the latest subclasses. + """ cls.repo = {x.__name__: x for x in torch.optim.Optimizer.__subclasses__()} From 6ae730dbdd13ba4de894fb0be867e3f59b4a1a2a Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 17:30:30 +0530 Subject: [PATCH 18/70] uodtae --- .../mpi/fedopt_seq/FedOptClientManager.py | 89 ++++++++++++++++- .../simulation/mpi/fedopt_seq/FedOptSeqAPI.py | 63 ++++++++++++ .../mpi/fedopt_seq/FedOptServerManager.py | 77 ++++++++++++++- .../simulation/mpi/fedopt_seq/optrepo.py | 3 +- .../simulation/mpi/fedprox/FedProxAPI.py | 54 ++++++++++ .../mpi/fedprox/FedProxAggregator.py | 99 ++++++++++++++++--- .../mpi/fedprox/FedProxClientManager.py | 48 +++++++++ .../mpi/fedprox/FedProxServerManager.py | 62 +++++++++++- .../simulation/mpi/fedprox/FedProxTrainer.py | 61 +++++++++++- .../simulation/mpi/fedprox/message_define.py | 11 +-- python/fedml/simulation/mpi/fedprox/utils.py | 24 +++++ 11 files changed, 554 insertions(+), 37 deletions(-) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py index 3ec4cdf370..531352fdf0 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py @@ -8,6 +8,24 @@ class FedOptClientManager(FedMLCommManager): + """ + Manager for Federated Optimization Clients. + + Args: + args (object): Arguments for configuration. + trainer (object): Trainer for client-side training. + comm (object, optional): Communication module (default: None). + rank (int, optional): Client's rank (default: 0). + size (int, optional): Number of clients (default: 0). + backend (str, optional): Backend for communication (default: "MPI"). + + Attributes: + trainer (object): Trainer for client-side training. + num_rounds (int): Number of communication rounds. + round_idx (int): Current communication round index. + worker_id (int): Worker's unique identifier within the communication group. + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -19,6 +37,9 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register handlers for receiving messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -28,6 +49,16 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle initialization message from the server. + + Args: + msg_params (dict): Message parameters. + + Notes: + This method handles the initialization message from the server, including + model parameters, average weights, and client schedule. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -39,10 +70,25 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_indexes, average_weight_dict) def start_training(self): + """ + Start the training process for a new round. + """ self.round_idx = 0 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Message parameters. + + Notes: + This method handles the received model from the server, including model + parameters, average weights, and client schedule. It triggers the training + process and completes communication rounds. + """ logging.info("handle_message_receive_model_from_server.") + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -57,28 +103,61 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, client_runtime_info): + """ + Send the client's model to the server. + + Args: + receive_id (int): Receiver's ID. + weights (dict): Model parameters. + client_runtime_info (dict): Information about client runtime. + + Notes: + This method constructs and sends a message containing the client's model + and runtime information to the server. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id, ) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - # message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info + ) self.send_message(message) def add_client_model(self, local_agg_model_params, model_params, weight=1.0): - # Add params that needed to be reduces from clients + """ + Add a client's model to the local aggregated model. + + Args: + local_agg_model_params (dict): Local aggregated model parameters. + model_params (dict): Client's model parameters. + weight (float, optional): Weight for the client's model (default: 1.0). + + Notes: + This method adds client model parameters to the local aggregated model. + """ for name, param in model_params.items(): if name not in local_agg_model_params: local_agg_model_params[name] = param * weight else: local_agg_model_params[name] += param * weight - - def __train(self, global_model_params, client_indexes, average_weight_dict): + """ + Train the client's model. + + Args: + global_model_params (dict): Global model parameters. + client_indexes (list): List of client indexes. + average_weight_dict (dict): Dictionary of average weights for clients. + + Notes: + This method simulates client-side training, updating the local aggregated + model with the client's contributions. + """ logging.info("#######training########### round_id = %d" % self.round_idx) local_agg_model_params = {} diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py index f771a78598..6383688af2 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py @@ -10,6 +10,12 @@ def FedML_init(): + """ + Initialize the Federated Learning environment. + + Returns: + tuple: A tuple containing the MPI communicator, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -27,6 +33,23 @@ def FedML_FedOptSeq_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Run the Federated Optimization (FedOpt) distributed training. + + Args: + args (object): Arguments for configuration. + process_id (int): Process ID or rank. + worker_number (int): Total number of workers. + comm (object): MPI communicator. + device (object): Device for computation. + dataset (list): List of dataset elements. + model (object): Model for training. + client_trainer (ClientTrainer, optional): Client trainer (default: None). + server_aggregator (ServerAggregator, optional): Server aggregator (default: None). + + Notes: + This function orchestrates the FedOpt distributed training process. + """ [ train_data_num, test_data_num, @@ -84,6 +107,27 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for FedOpt distributed training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): MPI communicator. + rank (int): Server's rank. + size (int): Total number of workers. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_global (object): Global training data. + test_data_global (object): Global test data. + train_data_local_dict (dict): Local training data per client. + test_data_local_dict (dict): Local test data per client. + train_data_local_num_dict (dict): Number of local training data per client. + server_aggregator (ServerAggregator, optional): Server aggregator (default: None). + + Notes: + This function initializes the server and starts distributed training. + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -121,6 +165,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for FedOpt distributed training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): MPI communicator. + process_id (int): Client's process ID. + size (int): Total number of workers. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_local_num_dict (dict): Number of local training data per client. + train_data_local_dict (dict): Local training data per client. + test_data_local_dict (dict): Local test data per client. + model_trainer (object, optional): Model trainer (default: None). + + Notes: + This function initializes a client and runs the training process. + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py index 207fcf37ed..a089018a96 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py @@ -8,6 +8,28 @@ class FedOptServerManager(FedMLCommManager): + """ + Manager for the Federated Optimization (FedOpt) Server. + + Args: + args (object): Arguments for configuration. + aggregator (object): Aggregator for Federated Optimization. + comm (object, optional): Communication module (default: None). + rank (int, optional): Server's rank (default: 0). + size (int, optional): Total number of workers (default: 0). + backend (str, optional): Backend for communication (default: "MPI"). + is_preprocessed (bool, optional): Flag indicating preprocessed data (default: False). + preprocessed_client_lists (list, optional): Preprocessed client lists (default: None). + + Attributes: + args (object): Arguments for configuration. + aggregator (object): Aggregator for Federated Optimization. + round_num (int): Number of communication rounds. + round_idx (int): Current communication round index. + is_preprocessed (bool): Flag indicating preprocessed data. + preprocessed_client_lists (list): Preprocessed client lists. + """ + def __init__( self, args, @@ -31,6 +53,13 @@ def run(self): super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + Notes: + This method initializes and sends configuration messages to clients for the + start of a new communication round. + """ # sampling clients client_indexes = self.aggregator.client_sampling( self.round_idx, @@ -49,12 +78,30 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register handlers for receiving messages. + + Notes: + This method registers message handlers for the server to process incoming + messages from clients. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model from a client. + + Args: + msg_params (dict): Message parameters. + + Notes: + This method handles the received model from a client, records client + runtime information, adds local trained results, and checks whether all + clients have sent their updates to proceed to the next round. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -106,6 +153,19 @@ def handle_message_receive_model_from_client(self, msg_params): def send_message_init_config(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send initialization configuration message to a client. + + Args: + receive_id (int): Receiver's ID. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Dictionary of average weights for clients. + client_schedule (list): Schedule of clients for the round. + + Notes: + This method constructs and sends an initialization configuration message to + a client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -117,6 +177,18 @@ def send_message_init_config(self, receive_id, global_model_params, def send_message_sync_model_to_client(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send model synchronization message to a client. + + Args: + receive_id (int): Receiver's ID. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Dictionary of average weights for clients. + client_schedule (list): Schedule of clients for the round. + + Notes: + This method constructs and sends a model synchronization message to a client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, @@ -127,7 +199,4 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, # message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS, average_weight_dict) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE, client_schedule) - self.send_message(message) - - - + self.send diff --git a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py index df6ec80985..a6c07959b3 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py @@ -1,3 +1,4 @@ +import torch import logging from typing import List, Union @@ -72,4 +73,4 @@ def _update_repo(cls): """ Updates the optimizer repository with the latest subclasses. """ - cls.repo = {x.__name__: x for x in torch.optim.Optimizer.__subclasses__()} + cls.repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} diff --git a/python/fedml/simulation/mpi/fedprox/FedProxAPI.py b/python/fedml/simulation/mpi/fedprox/FedProxAPI.py index 4ab1af38da..9be6f03cfe 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxAPI.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxAPI.py @@ -10,6 +10,12 @@ def FedML_init(): + """ + Initialize the Federated Machine Learning environment. + + Returns: + tuple: A tuple containing the MPI communication object, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -27,6 +33,20 @@ def FedML_FedProx_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Run the Federated Proximal training process. + + Args: + args (object): Arguments for configuration. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + comm (object): Communication object. + device (object): Device for computation. + dataset (list): List containing dataset information. + model (object): Model for training. + client_trainer (object): Trainer for client-side training (default: None). + server_aggregator (object): Server aggregator for aggregation (default: None). + """ [ train_data_num, test_data_num, @@ -84,6 +104,24 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): Communication object. + rank (int): Rank of the server. + size (int): Total number of participants. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_global (object): Global training data. + test_data_global (object): Global testing data. + train_data_local_dict (dict): Dictionary of local training data. + test_data_local_dict (dict): Dictionary of local testing data. + train_data_local_num_dict (dict): Dictionary of local training data sizes. + server_aggregator (object): Server aggregator for aggregation. + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -123,6 +161,22 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): Communication object. + process_id (int): Process ID of the client. + size (int): Total number of participants. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_local_num_dict (dict): Dictionary of local training data sizes. + train_data_local_dict (dict): Dictionary of local training data. + test_data_local_dict (dict): Dictionary of local testing data. + model_trainer (object): Trainer for the model (default: None). + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py b/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py index 026729e264..e5da91e8d4 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py @@ -10,6 +10,10 @@ class FedProxAggregator(object): + """ + Aggregator for Federated Proximal training. + """ + def __init__( self, train_global, @@ -23,39 +27,78 @@ def __init__( args, server_aggregator, ): + """ + Initialize the FedProxAggregator. + + Args: + train_global (object): Global training data. + test_global (object): Global testing data. + all_train_data_num (int): Number of training data samples. + train_data_local_dict (dict): Dictionary of local training data. + test_data_local_dict (dict): Dictionary of local testing data. + train_data_local_num_dict (dict): Dictionary of local training data sizes. + worker_num (int): Number of workers. + device (object): Device for computation. + args (object): Arguments for configuration. + server_aggregator (object): Server aggregator for aggregation. + """ self.aggregator = server_aggregator - self.args = args self.train_global = train_global self.test_global = test_global self.val_global = self._generate_validation_set() self.all_train_data_num = all_train_data_num - self.train_data_local_dict = train_data_local_dict self.test_data_local_dict = test_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict - self.worker_num = worker_num self.device = device self.model_dict = dict() self.sample_num_dict = dict() self.flag_client_model_uploaded_dict = dict() + for idx in range(self.worker_num): self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add local trained model results to the aggregator. + + Args: + index (int): Index of the client. + model_params (dict): Local model parameters. + sample_num (int): Number of local samples. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Returns: + bool: True if all models have been received, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -64,6 +107,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models from clients and calculate the global model. + + Returns: + dict: Averaged global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 @@ -74,7 +123,6 @@ def aggregate(self): logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) (num0, averaged_params) = model_list[0] for k in averaged_params.keys(): for i in range(0, len(model_list)): @@ -85,7 +133,7 @@ def aggregate(self): else: averaged_params[k] += local_model_params[k] * w - # update the global model which is cached at the server side + # Update the global model which is cached at the server side self.set_global_model_params(averaged_params) end_time = time.time() @@ -93,16 +141,36 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select clients for participation in each round of training. + + Args: + round_idx (int): Current training round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select per round. + + Returns: + list: List of client indexes selected for the current training round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Ensure consistent client selection for each round client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing. + + Args: + num_samples (int): Number of samples in the validation set. + + Returns: + object: Validation dataset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -113,6 +181,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients. + + Args: + round_idx (int): Current training round index. + """ if self.aggregator.test_all(self.train_data_local_dict, self.test_data_local_dict, self.device, self.args,): return @@ -122,7 +196,7 @@ def test_on_server_for_all_clients(self, round_idx): train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): - # train data + # Train data metrics = self.aggregator.test(self.train_data_local_dict[client_idx], self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( metrics["test_correct"], @@ -135,12 +209,13 @@ def test_on_server_for_all_clients(self, round_idx): """ Note: CI environment is CPU-based computing. - The training speed for RNN training is to slow in this setting, so we only test a client to make sure there is no programming error. + The training speed for RNN training is too slow in this setting, + so we only test a client to make sure there is no programming error. """ if self.args.ci == 1: break - # test on training dataset + # Test on training dataset train_acc = sum(train_tot_corrects) / sum(train_num_samples) train_loss = sum(train_losses) / sum(train_num_samples) # wandb.log({"Train/Acc": train_acc, "round": round_idx}) @@ -148,7 +223,7 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - # test data + # Test data test_num_samples = [] test_tot_corrects = [] test_losses = [] @@ -167,10 +242,10 @@ def test_on_server_for_all_clients(self, round_idx): test_num_samples.append(copy.deepcopy(test_num_sample)) test_losses.append(copy.deepcopy(test_loss)) - # test on test dataset + # Test on test dataset test_acc = sum(test_tot_corrects) / sum(test_num_samples) test_loss = sum(test_losses) / sum(test_num_samples) # wandb.log({"Test/Acc": test_acc, "round": round_idx}) # wandb.log({"Test/Loss": test_loss, "round": round_idx}) stats = {"test_acc": test_acc, "test_loss": test_loss} - logging.info(stats) + logging.info(stats) \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py b/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py index 860fe336b0..cc13142319 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py @@ -7,6 +7,18 @@ class FedProxClientManager(FedMLCommManager): + """ + Client manager for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + trainer (object): Trainer for the client. + comm (object): Communication object. + rank (int): Rank of the client. + size (int): Total number of participants. + backend (str): Backend for communication (default: "MPI"). + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -17,6 +29,16 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for the client manager. + + This method registers message handlers for receiving initialization + and model synchronization messages. + + Message Types: + - MyMessage.MSG_TYPE_S2C_INIT_CONFIG + - MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +48,12 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle initialization message from the server. + + Args: + msg_params (dict): Message parameters. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -37,10 +65,19 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the training process. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle model synchronization message from the server. + + Args: + msg_params (dict): Message parameters. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -56,6 +93,14 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the trained model to the server. + + Args: + receive_id (int): Receiver ID (typically the server). + weights (object): Model weights. + local_sample_num (int): Number of local training samples. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -66,6 +111,9 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """ + Execute the training process. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py b/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py index ccf9f087cf..3e20185317 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py @@ -7,6 +7,20 @@ class FedProxServerManager(FedMLCommManager): + """ + Server manager for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + aggregator (object): Aggregator for model updates. + comm (object): Communication object. + rank (int): Rank of the server. + size (int): Total number of participants. + backend (str): Backend for communication (default: "MPI"). + is_preprocessed (bool): Flag indicating if data is preprocessed (default: False). + preprocessed_client_lists (list): Preprocessed client lists (default: None). + """ + def __init__( self, args, @@ -30,7 +44,13 @@ def run(self): super().run() def send_init_msg(self): - # sampling clients + """ + Send initialization messages to clients. + + Initializes the communication with clients by sending initial model parameters + and client indexes. + """ + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -44,12 +64,30 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message receive handlers for the server manager. + + This method registers the message receive handler for receiving model updates from clients. + + Message Types: + - MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER + + Message Handler: + - self.handle_message_receive_model_from_client + + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model updates from clients. + + Args: + msg_params (dict): Message parameters. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -63,7 +101,7 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: post_complete_message_to_sweep_process(self.args) @@ -72,12 +110,12 @@ def handle_message_receive_model_from_client(self, msg_params): return if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -93,6 +131,14 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initialization configuration message to a client. + + Args: + receive_id (int): Receiver ID. + global_model_params (object): Global model parameters. + client_index (int): Index of the client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -103,6 +149,14 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send model synchronization message to a client. + + Args: + receive_id (int): Receiver ID. + global_model_params (object): Global model parameters. + client_index (int): Index of the client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py b/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py index e77096b452..18ababf006 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py @@ -2,6 +2,33 @@ class FedProxTrainer(object): + """ + Federated Proximal Trainer for model training. + + Args: + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training data. + train_data_local_num_dict (dict): Dictionary of local training data counts. + test_data_local_dict (dict): Dictionary of local testing data. + train_data_num (int): Total number of training data samples. + device (object): Device for training (e.g., CPU or GPU). + args (object): Arguments for configuration. + model_trainer (object): Model trainer for training. + + Attributes: + trainer (object): Model trainer for training. + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training data. + train_data_local_num_dict (dict): Dictionary of local training data counts. + test_data_local_dict (dict): Dictionary of local testing data. + all_train_data_num (int): Total number of training data samples. + train_local (object): Local training data for the client. + local_sample_number (int): Number of local training data samples. + test_local (object): Local testing data for the client. + device (object): Device for training. + args (object): Arguments for configuration. + """ + def __init__( self, client_index, @@ -20,9 +47,6 @@ def __init__( self.train_data_local_num_dict = train_data_local_num_dict self.test_data_local_dict = test_data_local_dict self.all_train_data_num = train_data_num - # self.train_local = self.train_data_local_dict[client_index] - # self.local_sample_number = self.train_data_local_num_dict[client_index] - # self.test_local = self.test_data_local_dict[client_index] self.train_local = None self.local_sample_number = None self.test_local = None @@ -31,15 +55,36 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model with new weights. + + Args: + weights (object): New model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the dataset for training and testing. + + Args: + client_index (int): Index of the client. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def train(self, round_idx=None): + """ + Train the model. + + Args: + round_idx (int, optional): Index of the training round (default: None). + + Returns: + tuple: Tuple containing trained model weights and local sample count. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) @@ -48,7 +93,13 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): - # train data + """ + Test the trained model. + + Returns: + tuple: Tuple containing training and testing metrics. + """ + # Train data metrics train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( train_metrics["test_correct"], @@ -56,7 +107,7 @@ def test(self): train_metrics["test_loss"], ) - # test data + # Test data metrics test_metrics = self.trainer.test(self.test_local, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( test_metrics["test_correct"], diff --git a/python/fedml/simulation/mpi/fedprox/message_define.py b/python/fedml/simulation/mpi/fedprox/message_define.py index 092e2ba618..57a51e3b1c 100644 --- a/python/fedml/simulation/mpi/fedprox/message_define.py +++ b/python/fedml/simulation/mpi/fedprox/message_define.py @@ -1,23 +1,22 @@ class MyMessage(object): """ - message type definition + Message type definition. """ - # server to client + # Server to client messages MSG_TYPE_S2C_INIT_CONFIG = 1 MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT = 2 - # client to server + # Client to server messages MSG_TYPE_C2S_SEND_MODEL_TO_SERVER = 3 MSG_TYPE_C2S_SEND_STATS_TO_SERVER = 4 + # Message argument keys MSG_ARG_KEY_TYPE = "msg_type" MSG_ARG_KEY_SENDER = "sender" MSG_ARG_KEY_RECEIVER = "receiver" - """ - message payload keywords definition - """ + # Message payload keywords MSG_ARG_KEY_NUM_SAMPLES = "num_samples" MSG_ARG_KEY_MODEL_PARAMS = "model_params" MSG_ARG_KEY_CLIENT_INDEX = "client_idx" diff --git a/python/fedml/simulation/mpi/fedprox/utils.py b/python/fedml/simulation/mpi/fedprox/utils.py index aea2449590..932ca053de 100644 --- a/python/fedml/simulation/mpi/fedprox/utils.py +++ b/python/fedml/simulation/mpi/fedprox/utils.py @@ -5,6 +5,15 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from lists to tensors. + + Args: + model_params_list (dict): Dictionary of model parameters with lists. + + Returns: + dict: Dictionary of model parameters with tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -13,12 +22,27 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from tensors to lists. + + Args: + model_params (dict): Dictionary of model parameters with tensors. + + Returns: + dict: Dictionary of model parameters with lists. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a sweep process. + + Args: + args (object): Arguments for configuration. + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 3e80913203392fb26b5608fc350b1513a8261edf Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 19:41:38 +0530 Subject: [PATCH 19/70] docs --- python/fedml/model/cv/vgg.py | 52 ++- python/fedml/model/finance/vfl_classifier.py | 29 ++ .../model/finance/vfl_feature_extractor.py | 36 ++ .../model/finance/vfl_models_standalone.py | 110 +++++- python/fedml/model/linear/lr.py | 34 ++ python/fedml/model/linear/lr_cifar10.py | 30 ++ python/fedml/model/mobile/mnn_lenet.py | 21 +- python/fedml/model/mobile/mnn_resnet.py | 180 ++++++---- python/fedml/model/mobile/torch_lenet.py | 37 ++ python/fedml/model/model_hub.py | 18 + python/fedml/model/nlp/model_args.py | 334 +++++++++++++++++- python/fedml/model/nlp/rnn.py | 74 +++- .../serving/client/client_initializer.py | 60 ++++ .../fedml/serving/client/client_launcher.py | 32 ++ .../client/fedml_client_master_manager.py | 99 ++++++ .../client/fedml_client_slave_manager.py | 38 +- python/fedml/serving/client/fedml_trainer.py | 48 ++- .../client/fedml_trainer_dist_adapter.py | 63 +++- .../serving/client/process_group_manager.py | 24 +- python/fedml/serving/client/utils.py | 40 ++- .../llm/src/app/pipe/instruct_pipeline.py | 98 ++++- .../serving/example/llm/src/main_entry.py | 58 ++- .../example/mnist/src/mnist_serve_main.py | 36 ++ .../example/mnist/src/model/minist_model.py | 45 ++- python/fedml/serving/fedml_client.py | 50 ++- .../fedml/serving/fedml_inference_runner.py | 47 ++- python/fedml/serving/fedml_predictor.py | 48 ++- python/fedml/serving/fedml_server.py | 30 ++ .../fedml/serving/server/fedml_aggregator.py | 160 +++++++-- .../serving/server/fedml_server_manager.py | 200 ++++++++++- python/fedml/serving/server/message_define.py | 19 +- .../serving/server/server_initializer.py | 24 +- 32 files changed, 1976 insertions(+), 198 deletions(-) diff --git a/python/fedml/model/cv/vgg.py b/python/fedml/model/cv/vgg.py index 303a804137..3a2088369b 100644 --- a/python/fedml/model/cv/vgg.py +++ b/python/fedml/model/cv/vgg.py @@ -18,6 +18,25 @@ class VGG(nn.Module): + """ + VGG model implementation. + + Args: + features (nn.Module): The feature extractor module. + num_classes (int): Number of output classes. + init_weights (bool): Whether to initialize the model weights. + + Attributes: + features (nn.Module): The feature extractor module. + avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer. + classifier (nn.Sequential): Classifier module. + + Methods: + forward(x): Forward pass of the VGG model. + _initialize_weights(): Initialize model weights. + + """ + def __init__( self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True ) -> None: @@ -37,6 +56,16 @@ def __init__( self._initialize_weights() def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the VGG model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + """ x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) @@ -44,6 +73,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x def _initialize_weights(self) -> None: + """ + Initialize model weights. + + """ for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") @@ -58,19 +91,34 @@ def _initialize_weights(self) -> None: def make_layers(cfg, batch_norm=False): + """ + Create a list of layers for a VGG network based on the provided configuration. + + Args: + cfg (list): List of layer configurations where each element represents + the number of filters or "M" for max-pooling. + batch_norm (bool): If True, apply batch normalization after convolution. + + Returns: + nn.Sequential: A sequential container of layers. + + """ layers = [] - in_channels = 3 + in_channels = 3 # Input channel for RGB images for v in cfg: if v == "M": + # Max-pooling layer layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: v = int(v) conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: + # Add convolution, batch normalization, and ReLU activation layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: + # Add convolution and ReLU activation layers += [conv2d, nn.ReLU(inplace=True)] - in_channels = v + in_channels = v # Update the input channels for the next layer return nn.Sequential(*layers) diff --git a/python/fedml/model/finance/vfl_classifier.py b/python/fedml/model/finance/vfl_classifier.py index 2359e42209..c4f80065e8 100644 --- a/python/fedml/model/finance/vfl_classifier.py +++ b/python/fedml/model/finance/vfl_classifier.py @@ -2,6 +2,25 @@ class VFLClassifier(nn.Module): + """ + Virtual Federated Learning (VFL) Classifier Model. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing class predictions or scores. + + Architecture: + - Linear Layer: + - Input: input_dim neurons + - Output: output_dim neurons (typically the number of classes) + + """ def __init__(self, input_dim, output_dim, bias=True): super(VFLClassifier, self).__init__() self.classifier = nn.Sequential( @@ -9,4 +28,14 @@ def __init__(self, input_dim, output_dim, bias=True): ) def forward(self, x): + """ + Forward pass of the VFL Classifier model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (Tensor): Output tensor of shape (batch_size, output_dim) with class predictions or scores. + + """ return self.classifier(x) diff --git a/python/fedml/model/finance/vfl_feature_extractor.py b/python/fedml/model/finance/vfl_feature_extractor.py index 95a17c171f..c1bcccee73 100644 --- a/python/fedml/model/finance/vfl_feature_extractor.py +++ b/python/fedml/model/finance/vfl_feature_extractor.py @@ -2,6 +2,25 @@ class VFLFeatureExtractor(nn.Module): + """ + Virtual Federated Learning (VFL) Feature Extractor Model. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the desired feature dimension. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing extracted features. + + Architecture: + - Linear Layer followed by Leaky ReLU activation: + - Input: input_dim neurons + - Output: output_dim neurons (representing feature dimension) + + """ def __init__(self, input_dim, output_dim): super(VFLFeatureExtractor, self).__init__() self.classifier = nn.Sequential( @@ -10,7 +29,24 @@ def __init__(self, input_dim, output_dim): self.output_dim = output_dim def forward(self, x): + """ + Forward pass of the VFL Feature Extractor model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + features (Tensor): Output tensor of shape (batch_size, output_dim) with extracted features. + + """ return self.classifier(x) def get_output_dim(self): + """ + Get the output dimension of the feature extractor. + + Returns: + int: The output dimension (feature dimension). + + """ return self.output_dim diff --git a/python/fedml/model/finance/vfl_models_standalone.py b/python/fedml/model/finance/vfl_models_standalone.py index 89640c8453..46ab393090 100644 --- a/python/fedml/model/finance/vfl_models_standalone.py +++ b/python/fedml/model/finance/vfl_models_standalone.py @@ -4,6 +4,27 @@ class DenseModel(nn.Module): + """ + Dense Model with Linear Classifier. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes or features. + learning_rate (float, optional): The learning rate for the optimizer. Default is 0.01. + bias (bool, optional): Whether to include bias terms in the linear layer. Default is True. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim) representing the model's predictions. + + Methods: + - forward(x): Forward pass of the model to make predictions. + - backward(x, grads): Backward pass to compute gradients and update model parameters. + + """ + def __init__(self, input_dim, output_dim, learning_rate=0.01, bias=True): super(DenseModel, self).__init__() self.classifier = nn.Sequential( @@ -15,20 +36,42 @@ def __init__(self, input_dim, output_dim, learning_rate=0.01, bias=True): ) def forward(self, x): + """ + Forward pass of the Dense Model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (Tensor): Output tensor of shape (batch_size, output_dim) with model predictions. + + """ if self.is_debug: print("[DEBUG] DenseModel.forward") x = torch.tensor(x).float() - return self.classifier(x).detach().numpy() + return self.classifier(x) def backward(self, x, grads): + """ + Backward pass of the Dense Model. + + Args: + x (array-like): Input data of shape (batch_size, input_dim). + grads (array-like): Gradients of the loss with respect to the model's output. + + Returns: + x_grad (array-like): Gradients of the loss with respect to the input data. + + """ if self.is_debug: print("[DEBUG] DenseModel.backward") x = torch.tensor(x, requires_grad=True).float() grads = torch.tensor(grads).float() output = self.classifier(x) - output.backward(gradient=grads) + loss = torch.sum(output * grads) # Compute dot product for backward pass + loss.backward() x_grad = x.grad.numpy() self.optimizer.step() @@ -38,6 +81,25 @@ def backward(self, x, grads): class LocalModel(nn.Module): + """ + Local Model with a Linear Classifier. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes or features. + learning_rate (float): The learning rate for the optimizer. + + Attributes: + output_dim (int): The output dimension of the model. + + Methods: + forward(x): Forward pass of the model to make predictions. + predict(x): Make predictions using the model. + backward(x, grads): Backward pass to compute gradients and update model parameters. + get_output_dim(): Get the output dimension of the model. + + """ + def __init__(self, input_dim, output_dim, learning_rate): super(LocalModel, self).__init__() self.classifier = nn.Sequential( @@ -51,30 +113,66 @@ def __init__(self, input_dim, output_dim, learning_rate): ) def forward(self, x): + """ + Forward pass of the Local Model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (array-like): Output predictions as a numpy array. + + """ if self.is_debug: - print("[DEBUG] DenseModel.forward") + print("[DEBUG] LocalModel.forward") x = torch.tensor(x).float() return self.classifier(x).detach().numpy() def predict(self, x): + """ + Make predictions using the Local Model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (array-like): Output predictions as a numpy array. + + """ if self.is_debug: - print("[DEBUG] DenseModel.predict") + print("[DEBUG] LocalModel.predict") x = torch.tensor(x).float() return self.classifier(x).detach().numpy() def backward(self, x, grads): + """ + Backward pass of the Local Model. + + Args: + x (array-like): Input data of shape (batch_size, input_dim). + grads (array-like): Gradients of the loss with respect to the model's output. + + """ if self.is_debug: - print("[DEBUG] DenseModel.backward") + print("[DEBUG] LocalModel.backward") x = torch.tensor(x).float() grads = torch.tensor(grads).float() output = self.classifier(x) - output.backward(gradient=grads) + loss = torch.sum(output * grads) # Compute dot product for backward pass + loss.backward() self.optimizer.step() self.optimizer.zero_grad() def get_output_dim(self): + """ + Get the output dimension of the Local Model. + + Returns: + output_dim (int): The output dimension of the model. + + """ return self.output_dim diff --git a/python/fedml/model/linear/lr.py b/python/fedml/model/linear/lr.py index 53b5ce0c09..d5bca7fde2 100644 --- a/python/fedml/model/linear/lr.py +++ b/python/fedml/model/linear/lr.py @@ -2,11 +2,45 @@ class LogisticRegression(torch.nn.Module): + """ + Logistic Regression Model. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes or a single output. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing class probabilities or a single output. + + Architecture: + - Linear Layer: + - Input: input_dim neurons + - Output: output_dim neurons + - Activation: Sigmoid (for binary classification) or Softmax (for multi-class classification) + + Note: + - For binary classification, output_dim is typically set to 1. + - For multi-class classification, output_dim is the number of classes. + + """ def __init__(self, input_dim, output_dim): super(LogisticRegression, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim) def forward(self, x): + """ + Forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities or a single output. + + """ # try: outputs = torch.sigmoid(self.linear(x)) # except: diff --git a/python/fedml/model/linear/lr_cifar10.py b/python/fedml/model/linear/lr_cifar10.py index 762b9c2c3a..87d593a547 100644 --- a/python/fedml/model/linear/lr_cifar10.py +++ b/python/fedml/model/linear/lr_cifar10.py @@ -2,11 +2,41 @@ class LogisticRegression_Cifar10(torch.nn.Module): + """ + Logistic Regression Model for CIFAR-10 Image Classification. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes in CIFAR-10. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing class probabilities for CIFAR-10 classes. + + Architecture: + - Linear Layer: + - Input: input_dim neurons (flattened image vectors) + - Output: output_dim neurons (class probabilities) + - Activation: Sigmoid (to produce class probabilities) + + """ def __init__(self, input_dim, output_dim): super(LogisticRegression_Cifar10, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim) def forward(self, x): + """ + Forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities. + + """ # Flatten images into vectors # print(f"size = {x.size()}") x = x.view(x.size(0), -1) diff --git a/python/fedml/model/mobile/mnn_lenet.py b/python/fedml/model/mobile/mnn_lenet.py index a803da7fea..2378fc9695 100644 --- a/python/fedml/model/mobile/mnn_lenet.py +++ b/python/fedml/model/mobile/mnn_lenet.py @@ -5,7 +5,17 @@ class Lenet5(nn.Module): - """construct a lenet 5 model""" + """ + LeNet-5 convolutional neural network model. + + This class defines the LeNet-5 architecture for image classification. + + Args: + None + + Returns: + torch.Tensor: Model predictions. + """ def __init__(self): super(Lenet5, self).__init__() @@ -15,6 +25,15 @@ def __init__(self): self.fc2 = nn.linear(500, 10) def forward(self, x): + """ + Forward pass of the LeNet-5 model. + + Args: + x (torch.Tensor): Input image tensor. + + Returns: + torch.Tensor: Model predictions. + """ x = F.relu(self.conv1(x)) x = F.max_pool(x, [2, 2], [2, 2]) x = F.relu(self.conv2(x)) diff --git a/python/fedml/model/mobile/mnn_resnet.py b/python/fedml/model/mobile/mnn_resnet.py index 9ae9703bb3..4f9cf53744 100644 --- a/python/fedml/model/mobile/mnn_resnet.py +++ b/python/fedml/model/mobile/mnn_resnet.py @@ -5,93 +5,126 @@ class ResBlock(nn.Module): + """ + Residual Block for a ResNet-like architecture. + + This class defines a basic residual block with two convolutional layers and batch normalization. + + Args: + in_planes (int): Number of input channels. + planes (int): Number of output channels (number of filters in the convolutional layers). + stride (int): Stride value for the first convolutional layer (default is 1). + + Returns: + torch.Tensor: Output tensor from the residual block. + """ + def __init__(self, in_planes, planes, stride=1): super(ResBlock, self).__init__() - self.conv1 = nn.conv( - in_planes, - planes, - kernel_size=[3, 3], - stride=[stride, stride], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) - self.bn1 = nn.batch_norm(planes) - self.conv2 = nn.conv( - planes, - planes, - kernel_size=[3, 3], - stride=[1, 1], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) - self.bn2 = nn.batch_norm(planes) + self.bn2 = nn.BatchNorm2d(planes) def forward(self, x): + """ + Forward pass of the Residual Block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after passing through the residual block. + """ + out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) - out += x + out += x # Skip connection out = F.relu(out) return out class ResBlock_conv_shortcut(nn.Module): + """ + Residual Block with Convolutional Shortcuts for a ResNet-like architecture. + + This class defines a residual block with convolutional shortcuts. It consists of two convolutional layers + with batch normalization and a convolutional shortcut connection. + + Args: + in_planes (int): Number of input channels. + planes (int): Number of output channels (number of filters in the convolutional layers). + stride (int): Stride value for the first convolutional layer (default is 1). + + Returns: + torch.Tensor: Output tensor from the residual block. + """ + def __init__(self, in_planes, planes, stride=1): super(ResBlock_conv_shortcut, self).__init__() - self.conv1 = nn.conv( - in_planes, - planes, - kernel_size=[3, 3], - stride=[stride, stride], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) - self.bn1 = nn.batch_norm(planes) - self.conv2 = nn.conv( - planes, - planes, - kernel_size=[3, 3], - stride=[1, 1], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) - self.bn2 = nn.batch_norm(planes) + self.bn2 = nn.BatchNorm2d(planes) - self.conv_shortcut = nn.conv( - in_planes, - planes, - kernel_size=[1, 1], - stride=[stride, stride], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.conv_shortcut = nn.Conv2d( + in_planes, planes, kernel_size=1, stride=stride, bias=False ) - self.bn_shortcut = nn.batch_norm(planes) + self.bn_shortcut = nn.BatchNorm2d(planes) def forward(self, x): + """ + Forward pass of the Residual Block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after passing through the residual block. + """ + out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) - out += self.bn_shortcut(self.conv_shortcut(x)) + shortcut = self.bn_shortcut(self.conv_shortcut(x)) + out += shortcut # Skip connection with convolutional shortcut out = F.relu(out) return out class Resnet20(nn.Module): + """ + ResNet-20 implementation for image classification. + + This class defines a ResNet-20 architecture with convolutional blocks and shortcuts. + It consists of four stages, each containing convolutional blocks. + + Args: + num_classes (int): Number of output classes. + + Returns: + torch.Tensor: Output tensor representing class probabilities. + """ + def __init__(self, num_classes=10): super(Resnet20, self).__init__() - self.conv1 = nn.conv( + self.conv1 = nn.Conv2d( 3, 16, - kernel_size=[3, 3], - stride=[1, 1], - padding=[1, 1], + kernel_size=3, + stride=1, + padding=1, bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, ) - self.bn1 = nn.batch_norm(16) + self.bn1 = nn.BatchNorm2d(16) self.layer1 = ResBlock(16, 16, 1) self.layer2 = ResBlock(16, 16, 1) @@ -105,28 +138,37 @@ def __init__(self, num_classes=10): self.layer8 = ResBlock(64, 64, 1) self.layer9 = ResBlock(64, 64, 1) - self.fc = nn.linear(64, num_classes) + self.fc = nn.Linear(64, num_classes) def forward(self, x): + """ + Forward pass of the ResNet-20 model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor representing class probabilities. + """ + x = F.relu(self.bn1(self.conv1(x))) - x = self.layer1.forward(x) - x = self.layer2.forward(x) - x = self.layer3.forward(x) - # print(x.shape) - x = self.layer4.forward(x) - x = self.layer5.forward(x) - x = self.layer6.forward(x) - # print(x.shape) - x = self.layer7.forward(x) - x = self.layer8.forward(x) - x = self.layer9.forward(x) - # print(x.shape) - x = F.avg_pool(x, kernel=[8, 8], stride=[8, 8]) - x = F.convert(x, F.NCHW) - x = F.reshape(x, [0, -1]) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + + x = self.layer7(x) + x = self.layer8(x) + x = self.layer9(x) + + x = F.avg_pool2d(x, kernel_size=8, stride=8) + x = x.view(x.size(0), -1) x = self.fc(x) - out = F.softmax(x, 1) + out = F.softmax(x, dim=1) return out diff --git a/python/fedml/model/mobile/torch_lenet.py b/python/fedml/model/mobile/torch_lenet.py index fc1a64b457..ee3f30241f 100644 --- a/python/fedml/model/mobile/torch_lenet.py +++ b/python/fedml/model/mobile/torch_lenet.py @@ -3,6 +3,43 @@ class LeNet(nn.Module): + """ + LeNet-5 Convolutional Neural Network model for image classification. + + Args: + None + + Input: + - Input tensor of shape (batch_size, 1, 32, 32), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, 10), representing class probabilities for 10 classes. + + Architecture: + - Convolutional Layer 1: + - Input: 1 channel (grayscale image) + - Output: 20 feature maps + - Kernel size: 5x5 + - Activation: ReLU + - Max Pooling: 2x2 + - Convolutional Layer 2: + - Input: 20 feature maps + - Output: 50 feature maps + - Kernel size: 5x5 + - Activation: ReLU + - Max Pooling: 2x2 + - Fully Connected Layer 1: + - Input: 800 neurons (flattened 50x4x4 from previous layer) + - Output: 500 neurons + - Activation: ReLU + - Dropout: 50% dropout rate + - Fully Connected Layer 2: + - Input: 500 neurons + - Output: 10 neurons (class probabilities) + - Activation: Softmax + + """ + def __init__(self): super(LeNet, self).__init__() self.fc2 = nn.Linear(500, 10) diff --git a/python/fedml/model/model_hub.py b/python/fedml/model/model_hub.py index e4ccc3acfc..fefa775055 100644 --- a/python/fedml/model/model_hub.py +++ b/python/fedml/model/model_hub.py @@ -17,6 +17,24 @@ def create(args, output_dim): + """ + Create a deep learning model based on the provided arguments and dataset. + + Args: + args (Namespace): Command-line arguments containing model and dataset information. + output_dim (int): Dimension of the model's output. + + Returns: + torch.nn.Module or Tuple[torch.nn.Module, torch.nn.Module] or None: The created model(s). + + Raises: + Exception: If the specified model or dataset is not supported. + + Example: + >>> import argparse + >>> args = argparse.Namespace(model="cnn", dataset="mnist") + >>> model = create(args, 10) + """ global model model_name = args.model logging.info("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim)) diff --git a/python/fedml/model/nlp/model_args.py b/python/fedml/model/nlp/model_args.py index 2aaaa7319f..871f56ccd4 100644 --- a/python/fedml/model/nlp/model_args.py +++ b/python/fedml/model/nlp/model_args.py @@ -9,18 +9,78 @@ def get_default_process_count(): + """ + Get the default number of processes to use for multi-processing tasks. + + Returns: + int: The default process count. + + Example: + >>> process_count = get_default_process_count() + """ process_count = int(cpu_count() / 2) if cpu_count() > 2 else 1 if sys.platform == "win32": process_count = min(process_count, 61) return process_count - def get_special_tokens(): + """ + Get a list of special tokens commonly used in natural language processing tasks. + + Returns: + List[str]: A list of special tokens. + + Example: + >>> special_tokens = get_special_tokens() + """ return ["", "", "", "", ""] @dataclass class ModelArgs: + """ + Configuration class for model training and evaluation. + + Attributes: + adam_epsilon (float): Epsilon value for Adam optimizer. Default is 1e-8. + best_model_dir (str): Directory to save the best model checkpoints. Default is "outputs/best_model". + cache_dir (str): Directory for caching data. Default is "cache_dir/". + config (dict): Additional configuration settings as a dictionary. Default is an empty dictionary. + custom_layer_parameters (list): List of custom layer parameters. Default is an empty list. + custom_parameter_groups (list): List of custom parameter groups. Default is an empty list. + dataloader_num_workers (int): Number of workers for data loading. Default is determined by `get_default_process_count`. + do_lower_case (bool): Whether to convert input text to lowercase. Default is False. + dynamic_quantize (bool): Whether to dynamically quantize the model. Default is False. + early_stopping_consider_epochs (bool): Whether to consider epochs for early stopping. Default is False. + early_stopping_delta (float): Minimum change in metric value to consider for early stopping. Default is 0. + early_stopping_metric (str): Metric to monitor for early stopping. Default is "eval_loss". + early_stopping_metric_minimize (bool): Whether to minimize the early stopping metric. Default is True. + early_stopping_patience (int): Number of epochs with no improvement to wait before early stopping. Default is 3. + encoding (str): Encoding for input text. Default is None. + eval_batch_size (int): Batch size for evaluation. Default is 8. + evaluate_during_training (bool): Whether to perform evaluation during training. Default is False. + evaluate_during_training_silent (bool): Whether to silence evaluation logs during training. Default is True. + evaluate_during_training_steps (int): Frequency of evaluation steps during training. Default is 2000. + evaluate_during_training_verbose (bool): Whether to print evaluation results during training. Default is False. + evaluate_each_epoch (bool): Whether to perform evaluation after each epoch. Default is True. + fp16 (bool): Whether to use mixed-precision training (FP16). Default is True. + gradient_accumulation_steps (int): Number of gradient accumulation steps. Default is 1. + learning_rate (float): Learning rate for training. Default is 4e-5. + local_rank (int): Local rank for distributed training. Default is -1. + logging_steps (int): Frequency of logging training steps. Default is 50. + manual_seed (int): Seed for random number generation. Default is None. + max_grad_norm (float): Maximum gradient norm for clipping gradients. Default is 1.0. + max_seq_length (int): Maximum sequence length for input data. Default is 128. + model_name (str): Name of the model being used. Default is None. + model_type (str): Type of the model being used. Default is None. + ... (other attributes) + + Methods: + update_from_dict(new_values): Update attribute values from a dictionary. + get_args_for_saving(): Get a dictionary of attributes suitable for saving. + save(output_dir): Save the model configuration to a JSON file in the specified output directory. + load(input_dir): Load the model configuration from a JSON file in the specified input directory. + """ adam_epsilon: float = 1e-8 best_model_dir: str = "outputs/best_model" cache_dir: str = "cache_dir/" @@ -84,6 +144,20 @@ class ModelArgs: skip_special_tokens: bool = True def update_from_dict(self, new_values): + """ + Update attributes of the ModelArgs instance from a dictionary. + + Args: + new_values (dict): A dictionary containing attribute-value pairs to update. + + Raises: + TypeError: If the input `new_values` is not a Python dictionary. + + Example: + model_args = ModelArgs() + new_values = {'learning_rate': 0.01, 'train_batch_size': 16} + model_args.update_from_dict(new_values) + """ if isinstance(new_values, dict): for key, value in new_values.items(): setattr(self, key, value) @@ -91,6 +165,16 @@ def update_from_dict(self, new_values): raise (TypeError(f"{new_values} is not a Python dict.")) def get_args_for_saving(self): + """ + Get a dictionary of model arguments suitable for saving. + + Returns: + dict: A dictionary containing model arguments, excluding those specified in `not_saved_args`. + + Example: + model_args = ModelArgs() + args_to_save = model_args.get_args_for_saving() + """ args_for_saving = { key: value for key, value in asdict(self).items() @@ -99,11 +183,31 @@ def get_args_for_saving(self): return args_for_saving def save(self, output_dir): + """ + Save the model configuration to a JSON file in the specified output directory. + + Args: + output_dir (str): The directory where the model configuration JSON file will be saved. + + Example: + model_args = ModelArgs() + model_args.save("output_directory") + """ os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "model_args.json"), "w") as f: json.dump(self.get_args_for_saving(), f) def load(self, input_dir): + """ + Load the model configuration from a JSON file in the specified input directory. + + Args: + input_dir (str): The directory where the model configuration JSON file is located. + + Example: + model_args = ModelArgs() + model_args.load("input_directory") + """ if input_dir: model_args_file = os.path.join(input_dir, "model_args.json") if os.path.isfile(model_args_file): @@ -120,41 +224,151 @@ class ClassificationArgs(ModelArgs): """ model_class: str = "ClassificationModel" + """ + (str) The name of the classification model class. Defaults to "ClassificationModel". + """ + labels_list: list = field(default_factory=list) + """ + (list) A list of labels used for classification. Defaults to an empty list. + """ + labels_map: dict = field(default_factory=dict) + """ + (dict) A dictionary that maps labels to their corresponding indices. Defaults to an empty dictionary. + """ + lazy_delimiter: str = "\t" + """ + (str) The delimiter used for lazy loading of data. Defaults to the tab character ("\t"). + """ + lazy_labels_column: int = 1 + """ + (int) The column index (1-based) containing labels when using lazy loading. Defaults to 1. + """ + lazy_loading: bool = False + """ + (bool) Whether to use lazy loading of data. Defaults to False. + """ + lazy_loading_start_line: int = 1 + """ + (int) The line number (1-based) to start reading data when using lazy loading. Defaults to 1. + """ + lazy_text_a_column: bool = None + """ + (bool) Whether the lazy loading data contains a text column for input "text_a". Defaults to None. + """ + lazy_text_b_column: bool = None + """ + (bool) Whether the lazy loading data contains a text column for input "text_b". Defaults to None. + """ + lazy_text_column: int = 0 + """ + (int) The column index (0-based) containing text data when using lazy loading. Defaults to 0. + """ + onnx: bool = False + """ + (bool) Whether to use ONNX format for the model. Defaults to False. + """ + regression: bool = False + """ + (bool) Whether the task is regression (True) or classification (False). Defaults to False. + """ + sliding_window: bool = False + """ + (bool) Whether to use a sliding window approach for long documents. Defaults to False. + """ + stride: float = 0.8 + """ + (float) The stride used in the sliding window approach. Defaults to 0.8. + """ + tie_value: int = 1 + """ + (int) The value used for tied tokens in the dataset. Defaults to 1. + """ + evaluate_during_training_steps: int = 20 + """ + (int) The number of steps between evaluations during training. Defaults to 20. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ @dataclass class SeqTaggingArgs(ModelArgs): """ - Model args for a SeqTaggingArgs + Model args for a SeqTaggingModel """ model_class: str = "SeqTaggingModel" + """ + (str) The name of the SeqTagging model class. Defaults to "SeqTaggingModel". + """ + labels_list: list = field(default_factory=list) + """ + (list) A list of labels used for sequence tagging. Defaults to an empty list. + """ + lazy_delimiter: str = "\t" + """ + (str) The delimiter used for lazy loading of data. Defaults to the tab character ("\t"). + """ + lazy_labels_column: int = 1 + """ + (int) The column index (1-based) containing labels when using lazy loading. Defaults to 1. + """ + lazy_loading: bool = False + """ + (bool) Whether to use lazy loading of data. Defaults to False. + """ + lazy_loading_start_line: int = 1 + """ + (int) The line number (1-based) to start reading data when using lazy loading. Defaults to 1. + """ + onnx: bool = False + """ + (bool) Whether to use ONNX format for the model. Defaults to False. + """ + evaluate_during_training_steps: int = 20 + """ + (int) The number of steps between evaluations during training. Defaults to 20. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ + classification_report: bool = True + """ + (bool) Whether to generate a classification report. Defaults to True. + """ + pad_token_label_id: int = CrossEntropyLoss().ignore_index + """ + (int) The ID of the pad token label used for padding. Defaults to CrossEntropyLoss().ignore_index. + """ @dataclass @@ -164,16 +378,60 @@ class SpanExtractionArgs(ModelArgs): """ model_class: str = "QuestionAnsweringModel" + """ + (str) The name of the SpanExtraction model class. Defaults to "QuestionAnsweringModel". + """ + doc_stride: int = 384 + """ + (int) The document stride for span extraction. Defaults to 384. + """ + early_stopping_metric: str = "correct" + """ + (str) The early stopping metric. Defaults to "correct". + """ + early_stopping_metric_minimize: bool = False + """ + (bool) Whether to minimize the early stopping metric. Defaults to False. + """ + lazy_loading: bool = False + """ + (bool) Whether to use lazy loading of data. Defaults to False. + """ + max_answer_length: int = 100 + """ + (int) The maximum answer length. Defaults to 100. + """ + max_query_length: int = 64 + """ + (int) The maximum query length. Defaults to 64. + """ + n_best_size: int = 20 + """ + (int) The number of best answers to consider. Defaults to 20. + """ + null_score_diff_threshold: float = 0.0 + """ + (float) The null score difference threshold. Defaults to 0.0. + """ + evaluate_during_training_steps: int = 20 + """ + (int) The number of steps between evaluations during training. Defaults to 20. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ + @dataclass @@ -183,20 +441,92 @@ class Seq2SeqArgs(ModelArgs): """ model_class: str = "Seq2SeqModel" + """ + (str) The name of the Seq2Seq model class. Defaults to "Seq2SeqModel". + """ + base_marian_model_name: str = None + """ + (str) The base Marian model name. Defaults to None. + """ + dataset_class: Dataset = None + """ + (Dataset) The dataset class. Defaults to None. + """ + do_sample: bool = False + """ + (bool) Whether to perform sampling during decoding. Defaults to False. + """ + early_stopping: bool = True + """ + (bool) Whether to use early stopping during training. Defaults to True. + """ + evaluate_generated_text: bool = False + """ + (bool) Whether to evaluate generated text. Defaults to False. + """ + length_penalty: float = 2.0 + """ + (float) The length penalty factor during decoding. Defaults to 2.0. + """ + max_length: int = 20 + """ + (int) The maximum length of generated text. Defaults to 20. + """ + max_steps: int = -1 + """ + (int) The maximum number of training steps. Defaults to -1 (unlimited). + """ + num_beams: int = 4 + """ + (int) The number of beams used during decoding. Defaults to 4. + """ + num_return_sequences: int = 1 + """ + (int) The number of generated sequences to return. Defaults to 1. + """ + repetition_penalty: float = 1.0 + """ + (float) The repetition penalty factor during decoding. Defaults to 1.0. + """ + top_k: float = None + """ + (float) The top-k value used during decoding. Defaults to None. + """ + top_p: float = None + """ + (float) The top-p value used during decoding. Defaults to None. + """ + use_multiprocessed_decoding: bool = False + """ + (bool) Whether to use multiprocessed decoding. Defaults to False. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ + src_lang: str = "en_XX" + """ + (str) The source language for translation. Defaults to "en_XX". + """ + tgt_lang: str = "ro_RO" + """ + (str) The target language for translation. Defaults to "ro_RO". + """ + diff --git a/python/fedml/model/nlp/rnn.py b/python/fedml/model/nlp/rnn.py index 7af13e4618..b46c57f018 100644 --- a/python/fedml/model/nlp/rnn.py +++ b/python/fedml/model/nlp/rnn.py @@ -3,17 +3,20 @@ class RNN_OriginalFedAvg(nn.Module): - """Creates a RNN model using LSTM layers for Shakespeare language models (next character prediction task). + """ + Creates a RNN model using LSTM layers for Shakespeare language models (next character prediction task). This replicates the model structure in the paper: Communication-Efficient Learning of Deep Networks from Decentralized Data - H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Agueray Arcas. AISTATS 2017. - https://arxiv.org/abs/1602.05629 + H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Agueray Arcas. AISTATS 2017. + https://arxiv.org/abs/1602.05629 This is also recommended model by "Adaptive Federated Optimization. ICML 2020" (https://arxiv.org/pdf/2003.00295.pdf) + Args: - vocab_size: the size of the vocabulary, used as a dimension in the input embedding. - sequence_length: the length of input sequences. + embedding_dim: The dimension of word embeddings. Default is 8. + vocab_size: The size of the vocabulary, used as a dimension in the input embedding. Default is 90. + hidden_size: The size of the hidden state in the LSTM layers. Default is 256. Returns: - An uncompiled `torch.nn.Module`. + An uncompiled `torch.nn.Module`. """ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): @@ -30,6 +33,14 @@ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): self.fc = nn.Linear(hidden_size, vocab_size) def forward(self, input_seq): + """ + Forward pass of the model. + + Args: + input_seq: Input sequence of character indices. + Returns: + output: Model predictions. + """ embeds = self.embeddings(input_seq) # Note that the order of mini-batch is random so there is no hidden relationship among batches. # So we do not input the previous batch's hidden state, @@ -45,6 +56,20 @@ def forward(self, input_seq): class RNN_FedShakespeare(nn.Module): + """ + RNN model for Shakespeare language modeling (next character prediction task). + + This class defines an RNN model for predicting the next character in a sequence of text, + specifically tailored for the "fed_shakespeare" task. + + Args: + embedding_dim (int): Dimension of the character embeddings. + vocab_size (int): Size of the vocabulary (number of unique characters). + hidden_size (int): Size of the hidden state of the LSTM layers. + + Returns: + torch.Tensor: The model's output predictions. + """ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): super(RNN_FedShakespeare, self).__init__() self.embeddings = nn.Embedding( @@ -59,6 +84,14 @@ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): self.fc = nn.Linear(hidden_size, vocab_size) def forward(self, input_seq): + """ + Forward pass of the model. + + Args: + input_seq: Input sequence of character indices. + Returns: + output: Model predictions. + """ embeds = self.embeddings(input_seq) # Note that the order of mini-batch is random so there is no hidden relationship among batches. # So we do not input the previous batch's hidden state, @@ -74,15 +107,22 @@ def forward(self, input_seq): class RNN_StackOverFlow(nn.Module): - """Creates a RNN model using LSTM layers for StackOverFlow (next word prediction task). - This replicates the model structure in the paper: + """ + RNN model for StackOverflow language modeling (next word prediction task). + + This class defines an RNN model for predicting the next word in a sequence of text, specifically tailored + for the "stackoverflow_nwp" task. "Adaptive Federated Optimization. ICML 2020" (https://arxiv.org/pdf/2003.00295.pdf) - Table 9 + Args: - vocab_size: the size of the vocabulary, used as a dimension in the input embedding. - sequence_length: the length of input sequences. + vocab_size (int): Size of the vocabulary (number of unique words). + num_oov_buckets (int): Number of out-of-vocabulary (OOV) buckets. + embedding_size (int): Dimension of the word embeddings. + latent_size (int): Size of the LSTM hidden state. + num_layers (int): Number of LSTM layers. + Returns: - An uncompiled `torch.nn.Module`. + torch.Tensor: The model's output predictions. """ def __init__( @@ -107,6 +147,16 @@ def __init__( self.fc2 = nn.Linear(embedding_size, extended_vocab_size) def forward(self, input_seq, hidden_state=None): + """ + Forward pass of the model. + + Args: + input_seq (torch.Tensor): Input sequence of word indices. + hidden_state (tuple): Initial hidden state of the LSTM. + + Returns: + torch.Tensor: Model predictions. + """ embeds = self.word_embeddings(input_seq) lstm_out, hidden_state = self.lstm(embeds, hidden_state) fc1_output = self.fc1(lstm_out[:, :]) diff --git a/python/fedml/serving/client/client_initializer.py b/python/fedml/serving/client/client_initializer.py index 37791b80de..b26e727937 100644 --- a/python/fedml/serving/client/client_initializer.py +++ b/python/fedml/serving/client/client_initializer.py @@ -16,6 +16,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize and run a federated learning client. + + Args: + args: Arguments and configuration for the client. + device: The device on which the client should run (e.g., 'cpu' or 'cuda'). + comm: The communication backend for distributed training. + client_rank: The rank or identifier of this client. + client_num: The total number of clients in the federated learning scenario. + model: The machine learning model to be trained. + train_data_num: The number of training data points. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data points. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: An optional custom model trainer. + + Returns: + None + """ backend = args.backend trainer_dist_adapter = get_trainer_dist_adapter( @@ -60,6 +79,23 @@ def get_trainer_dist_adapter( test_data_local_dict, model_trainer, ): + """ + Get a distributed trainer adapter for the federated learning client. + + Args: + args: Arguments and configuration for the client. + device: The device on which the client should run (e.g., 'cpu' or 'cuda'). + client_rank: The rank or identifier of this client. + model: The machine learning model to be trained. + train_data_num: The number of training data points. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data points. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: An optional custom model trainer. + + Returns: + TrainerDistAdapter: A distributed trainer adapter. + """ return TrainerDistAdapter( args, device, @@ -74,10 +110,34 @@ def get_trainer_dist_adapter( def get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend): + """ + Get a federated learning client manager for the master client in the hierarchical scenario. + + Args: + args: Arguments and configuration for the client. + trainer_dist_adapter: A distributed trainer adapter. + comm: The communication backend for distributed training. + client_rank: The rank or identifier of this client. + client_num: The total number of clients in the federated learning scenario. + backend: The backend for distributed training (e.g., 'nccl' or 'gloo'). + + Returns: + ClientMasterManager: A federated learning client manager for the master client. + """ return ClientMasterManager(args, trainer_dist_adapter, comm, client_rank, client_num, backend) def get_client_manager_salve(args, trainer_dist_adapter): + """ + Get a federated learning client manager for a slave client in the hierarchical scenario. + + Args: + args: Arguments and configuration for the client. + trainer_dist_adapter: A distributed trainer adapter. + + Returns: + ClientSlaveManager: A federated learning client manager for a slave client. + """ from .fedml_client_slave_manager import ClientSlaveManager return ClientSlaveManager(args, trainer_dist_adapter) diff --git a/python/fedml/serving/client/client_launcher.py b/python/fedml/serving/client/client_launcher.py index 1a4831b11e..1034ab2cb7 100644 --- a/python/fedml/serving/client/client_launcher.py +++ b/python/fedml/serving/client/client_launcher.py @@ -27,6 +27,16 @@ class CrossSiloLauncher: @staticmethod def launch_dist_trainers(torch_client_filename, inputs): + """ + Launch distributed trainers based on the specified scenario. + + Args: + torch_client_filename (str): The filename of the torch client script to be launched. + inputs (List[str]): List of input arguments to be passed to the torch client script. + + Returns: + None + """ # this is only used by the client (DDP or single process), so there is no need to specify the backend. args = load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -38,12 +48,34 @@ def launch_dist_trainers(torch_client_filename, inputs): @staticmethod def _run_cross_silo_horizontal(args, torch_client_filename, inputs): + """ + Run distributed training in a horizontal federated learning scenario. + + Args: + args: Arguments and configuration for the client. + torch_client_filename (str): The filename of the torch client script to be launched. + inputs (List[str]): List of input arguments to be passed to the torch client script. + + Returns: + None + """ python_path = subprocess.run(["which", "python"], capture_output=True, text=True).stdout.strip() process_arguments = [python_path, torch_client_filename] + inputs subprocess.run(process_arguments) @staticmethod def _run_cross_silo_hierarchical(args, torch_client_filename, inputs): + """ + Run distributed training in a hierarchical federated learning scenario. + + Args: + args: Arguments and configuration for the client. + torch_client_filename (str): The filename of the torch client script to be launched. + inputs (List[str]): List of input arguments to be passed to the torch client script. + + Returns: + None + """ def get_torchrun_arguments(node_rank): torchrun_path = subprocess.run(["which", "torchrun"], capture_output=True, text=True).stdout.strip() diff --git a/python/fedml/serving/client/fedml_client_master_manager.py b/python/fedml/serving/client/fedml_client_master_manager.py index 6e4d2b7495..c14720b214 100644 --- a/python/fedml/serving/client/fedml_client_master_manager.py +++ b/python/fedml/serving/client/fedml_client_master_manager.py @@ -19,6 +19,17 @@ class ClientMasterManager(FedMLCommManager): RUN_FINISHED_STATUS_FLAG = "FINISHED" def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the ClientMasterManager. + + Args: + args: Arguments and configuration for the client manager. + trainer_dist_adapter: Trainer distribution adapter for distributed training. + comm: Communication backend (MPI, etc.). + rank: Rank of the client. + size: Size of the client group. + backend: Backend for distributed training (MPI, etc.). + """ super().__init__(args, comm, rank, size, backend) self.trainer_dist_adapter = trainer_dist_adapter self.args = args @@ -35,6 +46,9 @@ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backen self.is_inited = False def register_message_receive_handlers(self): + """ + Register message receive handlers for handling various types of messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -53,6 +67,12 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the "connection ready" message. + + Args: + msg_params: Parameters of the message. + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -60,9 +80,21 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the "check client status" message. + + Args: + msg_params: Parameters of the message. + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the "initialize" message and prepare for training. + + Args: + msg_params: Parameters of the message. + """ if self.is_inited: return @@ -88,6 +120,12 @@ def handle_message_init(self, msg_params): self.round_idx += 1 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the "receive model from server" message. + + Args: + msg_params: Parameters of the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -108,15 +146,35 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def handle_message_finish(self, msg_params): + """ + Handle the "finish" message and perform cleanup. + + Args: + msg_params: Parameters of the message. + """ logging.info(" ====================cleanup ====================") self.cleanup() def cleanup(self): + """ + Perform cleanup operations at the end of training. + """ self.send_client_status(0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) mlops.log_training_finished_status() self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model to the server. + + Args: + receive_id: ID of the recipient (usually the server). + weights: Model weights to be sent. + local_sample_num: Number of local training samples. + + Note: + This method sends model parameters to the server for aggregation. + """ tick = time.time() mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id,) @@ -130,6 +188,17 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): + """ + Send the client status message to the specified recipient. + + Args: + receive_id: ID of the recipient. + status: Status flag to be sent (default is ONLINE_STATUS_FLAG). + + Note: + This method sends information about the client's status, including the operating system. + + """ logging.info("send_client_status") logging.info("self.client_real_id = {}".format(self.client_real_id)) message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) @@ -149,9 +218,32 @@ def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status: Training status to be reported. + + Note: + This method logs the training status using MLOps. + + """ mlops.log_training_status(status) def sync_process_group(self, round_idx, model_params=None, client_index=None, src=0): + """ + Synchronize the process group with information about the current training round. + + Args: + round_idx: The current training round index. + model_params: Model parameters (default is None). + client_index: Client index (default is None). + src: Source of the synchronization (default is 0). + + Note: + This method broadcasts information about the current training round to the process group. + + """ logging.info("sending round number to pg") round_number = [round_idx, model_params, client_index] dist.broadcast_object_list( @@ -160,6 +252,13 @@ def sync_process_group(self, round_idx, model_params=None, client_index=None, sr logging.info("round number %d broadcast to process group" % round_number[0]) def __train(self): + """ + Perform the training for the current round. + + Note: + This method initiates the training process and sends the updated model to the server. + + """ logging.info("#######training########### round_id = %d" % self.round_idx) mlops.event("train", event_started=True, event_value=str(self.round_idx)) diff --git a/python/fedml/serving/client/fedml_client_slave_manager.py b/python/fedml/serving/client/fedml_client_slave_manager.py index 48f30d8263..5b817e34fa 100644 --- a/python/fedml/serving/client/fedml_client_slave_manager.py +++ b/python/fedml/serving/client/fedml_client_slave_manager.py @@ -5,6 +5,14 @@ class ClientSlaveManager: def __init__(self, args, trainer_dist_adapter): + """ + Initialize the ClientSlaveManager. + + Args: + args: Command-line arguments. + trainer_dist_adapter: Trainer distributed adapter. + + """ self.trainer_dist_adapter = trainer_dist_adapter self.args = args self.round_idx = 0 @@ -12,6 +20,13 @@ def __init__(self, args, trainer_dist_adapter): self.finished = False def train(self): + """ + Perform training for the current round. + + This method synchronizes with the process group, updates the model and dataset if necessary, and initiates training + for the current round. + + """ [round_idx, model_params, client_index] = self.await_sync_process_group() if round_idx: self.round_idx = round_idx @@ -28,7 +43,12 @@ def train(self): self.trainer_dist_adapter.train(self.round_idx) def finish(self): - # pass + """ + Finish the training process. + + This method performs cleanup operations and logs the completion of training. + + """ self.trainer_dist_adapter.cleanup_pg() logging.info( "Training finished for slave client rank %s in silo %s" @@ -37,6 +57,16 @@ def finish(self): self.finished = True def await_sync_process_group(self, src=0): + """ + Wait for synchronization with the process group. + + Args: + src: Source rank for synchronization (default is 0). + + Returns: + List: A list containing round_idx, model_params, and client_index. + + """ logging.info("process %d waiting for round number" % dist.get_rank()) objects = [None, None, None] dist.broadcast_object_list( @@ -46,5 +76,11 @@ def await_sync_process_group(self, src=0): return objects def run(self): + """ + Start the client manager's main execution loop. + + This method continuously trains the client while it is not finished. + + """ while not self.finished: self.train() diff --git a/python/fedml/serving/client/fedml_trainer.py b/python/fedml/serving/client/fedml_trainer.py index 827644cc42..ae6d9e9a7f 100755 --- a/python/fedml/serving/client/fedml_trainer.py +++ b/python/fedml/serving/client/fedml_trainer.py @@ -17,8 +17,21 @@ def __init__( args, model_trainer, ): + """ + Initialize the Federated Learning Trainer. + + Args: + client_index: Index of the client. + train_data_local_dict: Dictionary mapping client IDs to local training datasets. + train_data_local_num_dict: Dictionary mapping client IDs to local training data counts. + test_data_local_dict: Dictionary mapping client IDs to local test datasets. + train_data_num: Number of training data samples. + device: Torch device for training. + args: Command-line arguments. + model_trainer: Trainer for the model. + + """ self.trainer = model_trainer - self.client_index = client_index if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -32,15 +45,28 @@ def __init__( self.train_local = None self.local_sample_number = None self.test_local = None - self.device = device self.args = args self.args.device = device def update_model(self, weights): + """ + Update the model with new parameters. + + Args: + weights: Updated model parameters. + + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local dataset for training. + + Args: + client_index: Index of the client to update the dataset for. + + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -64,6 +90,16 @@ def update_dataset(self, client_index): self.trainer.update_dataset(self.train_local, self.test_local, self.local_sample_number) def train(self, round_idx=None): + """ + Perform federated training for the specified round. + + Args: + round_idx (Optional): Index of the current training round (default is None). + + Returns: + Tuple: A tuple containing the updated model weights and the number of local training samples. + + """ self.args.round_idx = round_idx tick = time.time() @@ -77,6 +113,14 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """ + Test the model on local data. + + Returns: + Tuple: A tuple containing training accuracy, training loss, number of training samples, + test accuracy, test loss, and number of test samples. + + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/serving/client/fedml_trainer_dist_adapter.py b/python/fedml/serving/client/fedml_trainer_dist_adapter.py index 60383d31cf..36db5d3bb3 100644 --- a/python/fedml/serving/client/fedml_trainer_dist_adapter.py +++ b/python/fedml/serving/client/fedml_trainer_dist_adapter.py @@ -19,7 +19,21 @@ def __init__( test_data_local_dict, model_trainer, ): + """ + Initialize the TrainerDistAdapter. + Args: + args: Command-line arguments. + device: Torch device for training. + client_rank: Rank of the client. + model: The neural network model. + train_data_num: Number of training data samples. + train_data_local_num_dict: Dictionary mapping client IDs to local training data counts. + train_data_local_dict: Dictionary mapping client IDs to local training datasets. + test_data_local_dict: Dictionary mapping client IDs to local test datasets. + model_trainer: Trainer for the model. + + """ ml_engine_adapter.model_to_device(args, model, device) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -62,6 +76,23 @@ def get_trainer( args, model_trainer, ): + """ + Create and return a trainer for the federated learning process. + + Args: + client_index: Index of the client. + train_data_local_dict: Dictionary mapping client IDs to local training datasets. + train_data_local_num_dict: Dictionary mapping client IDs to local training data counts. + test_data_local_dict: Dictionary mapping client IDs to local test datasets. + train_data_num: Number of training data samples. + device: Torch device for training. + args: Command-line arguments. + model_trainer: Trainer for the model. + + Returns: + FedMLTrainer: Trainer instance for federated learning. + + """ return FedMLTrainer( client_index, train_data_local_dict, @@ -74,20 +105,50 @@ def get_trainer( ) def train(self, round_idx): + """ + Perform federated training for the specified round. + + Args: + round_idx: Index of the current training round. + + Returns: + Tuple: A tuple containing the updated model weights and the number of local training samples. + + """ weights, local_sample_num = self.trainer.train(round_idx) return weights, local_sample_num def update_model(self, model_params): + """ + Update the model with new parameters. + + Args: + model_params: Updated model parameters. + + """ self.trainer.update_model(model_params) def update_dataset(self, client_index=None): + """ + Update the local dataset for training. + + Args: + client_index (Optional): Index of the client to update the dataset for (default is None, uses client's index). + + """ _client_index = client_index or self.client_index self.trainer.update_dataset(int(_client_index)) def cleanup_pg(self): + """ + Clean up the process group if using distributed training. + + This method is called to clean up the process group when hierarchical federated learning is used. + + """ if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: logging.info( - "Cleaningup process group for client %s in silo %s" + "Cleaning up process group for client %s in silo %s" % (self.args.proc_rank_in_silo, self.args.rank_in_node) ) self.process_group_manager.cleanup() diff --git a/python/fedml/serving/client/process_group_manager.py b/python/fedml/serving/client/process_group_manager.py index 92519c6cc4..06da2e9738 100644 --- a/python/fedml/serving/client/process_group_manager.py +++ b/python/fedml/serving/client/process_group_manager.py @@ -7,6 +7,17 @@ class ProcessGroupManager: def __init__(self, rank, world_size, master_address, master_port, only_gpu): + """ + Initialize a process group manager for distributed training. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master process. + master_port (int): The port for communication with the master process. + only_gpu (bool): Whether to use NCCL backend for GPU communication. + + """ logging.info("Start process group") logging.info( "rank: %d, world_size: %d, master_address: %s, master_port: %s" @@ -31,7 +42,18 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Cleanup the process group. + + """ dist.destroy_process_group() def get_process_group(self): - return self.messaging_pg + """ + Get the messaging process group. + + Returns: + torch.distributed.ProcessGroup: The process group for communication. + + """ + return self.messaging_pg \ No newline at end of file diff --git a/python/fedml/serving/client/utils.py b/python/fedml/serving/client/utils.py index 38f4a169d1..4d8657fe1c 100644 --- a/python/fedml/serving/client/utils.py +++ b/python/fedml/serving/client/utils.py @@ -3,16 +3,42 @@ # ref: https://discuss.pytorch.org/t/failed-to-load-model-trained-by-ddp-for-inference/84841/2?u=amir_zsh def convert_model_params_from_ddp(ddp_model_params): - model_params = OrderedDict() - for k, v in ddp_model_params.items(): - name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel - model_params[name] = v - return model_params + """ + Convert model parameters from DistributedDataParallel (DDP) format to a regular format. + + Args: + ddp_model_params (OrderedDict): Model parameters in DDP format. + Returns: + OrderedDict: Model parameters in regular format. -def convert_model_params_to_ddp(ddp_model_params): + Example: + >>> ddp_params = OrderedDict([('module.conv1.weight', tensor), ('module.fc1.weight', tensor)]) + >>> regular_params = convert_model_params_from_ddp(ddp_params) + """ model_params = OrderedDict() for k, v in ddp_model_params.items(): - name = f"module.{k}" # add 'module.' of DataParallel/DistributedDataParallel + name = k[7:] # Remove 'module.' of DataParallel/DistributedDataParallel model_params[name] = v return model_params + + +def convert_model_params_to_ddp(model_params): + """ + Convert model parameters from a regular format to DistributedDataParallel (DDP) format. + + Args: + model_params (OrderedDict): Model parameters in regular format. + + Returns: + OrderedDict: Model parameters in DDP format. + + Example: + >>> regular_params = OrderedDict([('conv1.weight', tensor), ('fc1.weight', tensor)]) + >>> ddp_params = convert_model_params_to_ddp(regular_params) + """ + ddp_model_params = OrderedDict() + for k, v in model_params.items(): + name = f"module.{k}" # Add 'module.' for DataParallel/DistributedDataParallel + ddp_model_params[name] = v + return ddp_model_params diff --git a/python/fedml/serving/example/llm/src/app/pipe/instruct_pipeline.py b/python/fedml/serving/example/llm/src/app/pipe/instruct_pipeline.py index edcc1a643b..f3ae5a4089 100644 --- a/python/fedml/serving/example/llm/src/app/pipe/instruct_pipeline.py +++ b/python/fedml/serving/example/llm/src/app/pipe/instruct_pipeline.py @@ -2,10 +2,8 @@ Adapted from https://github.com/databrickslabs/dolly/blob/master/training/generate.py """ from typing import List, Optional, Tuple - import logging import re - import torch from transformers import ( AutoModelForCausalLM, @@ -27,13 +25,17 @@ def load_model_tokenizer_for_generate( pretrained_model_name_or_path: str, ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: - """Loads the model and tokenizer so that it can be used for generating responses. + """ + Load the model and tokenizer for generating responses. Args: - pretrained_model_name_or_path (str): name or path for model + pretrained_model_name_or_path (str): Name or path for the pretrained model. Returns: - Tuple[PreTrainedModel, PreTrainedTokenizer]: model and tokenizer + Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. + + Example: + model, tokenizer = load_model_tokenizer_for_generate("gpt2") """ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left") model = AutoModelForCausalLM.from_pretrained( @@ -41,22 +43,25 @@ def load_model_tokenizer_for_generate( ) return model, tokenizer - def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int: - """Gets the token ID for a given string that has been added to the tokenizer as a special token. + """ + Get the token ID for a given string that has been added to the tokenizer as a special token. - When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are - treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. + When training, we configure the tokenizer so that sequences like "### Instruction:" and "### End" are + treated specially and converted to a single, new token. This function retrieves the token ID for a given key. Args: - tokenizer (PreTrainedTokenizer): the tokenizer - key (str): the key to convert to a single token + tokenizer (PreTrainedTokenizer): The tokenizer. + key (str): The key to convert to a single token. Raises: - ValueError: if more than one ID was generated + ValueError: If more than one ID was generated for the key. Returns: - int: the token ID for the given key + int: The token ID for the given key. + + Example: + special_token_id = get_special_token_id(tokenizer, "### Instruction:") """ token_ids = tokenizer.encode(key) if len(token_ids) > 1: @@ -64,6 +69,7 @@ def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int: return token_ids[0] + class InstructionTextGenerationPipeline(Pipeline): def __init__( self, @@ -98,6 +104,18 @@ def _sanitize_parameters( return_full_text: bool = None, **generate_kwargs ): + """ + Sanitize and configure parameters for text generation. + + Args: + return_full_text (bool, optional): Whether to return the full text. Defaults to None. + + Returns: + Tuple[Dict, Dict, Dict]: A tuple containing preprocess_params, forward_params, and postprocess_params. + + Raises: + ValueError: If the response key token is not found. + """ preprocess_params = {} # newer versions of the tokenizer configure the response key as a special token. newer versions still may @@ -130,6 +148,18 @@ def _sanitize_parameters( return preprocess_params, forward_params, postprocess_params def preprocess(self, instruction_text, **generate_kwargs): + """ + Preprocess the input text for text generation. + + Args: + instruction_text (str): The instruction text. + + Returns: + Dict: Preprocessed inputs for text generation. + + Example: + inputs = preprocess("Write a summary of a book.") + """ prompt_text = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction_text) inputs = self.tokenizer( prompt_text, @@ -140,6 +170,15 @@ def preprocess(self, instruction_text, **generate_kwargs): return inputs def _forward(self, model_inputs, **generate_kwargs): + """ + Forward pass for text generation. + + Args: + model_inputs (Dict): Inputs for text generation. + + Returns: + Dict: Model outputs for text generation. + """ input_ids = model_inputs["input_ids"] attention_mask = model_inputs.get("attention_mask", None) @@ -173,6 +212,21 @@ def postprocess( end_key_token_id: Optional[int] = None, return_full_text: bool = False ): + """ + Postprocess the model outputs for text generation. + + Args: + model_outputs (Dict): Model outputs for text generation. + response_key_token_id (int, optional): Token ID for the response key. Defaults to None. + end_key_token_id (int, optional): Token ID for the end key. Defaults to None. + return_full_text (bool, optional): Whether to return the full text. Defaults to False. + + Returns: + List[Dict]: List of generated text records. + + Example: + generated_text = postprocess(model_outputs) + """ generated_sequence: torch.Tensor = model_outputs["generated_sequence"][0] instruction_text = model_outputs["instruction_text"] @@ -236,6 +290,9 @@ def postprocess( records.append(rec) return records + + + def generate_response( @@ -245,16 +302,21 @@ def generate_response( tokenizer: PreTrainedTokenizer, **kwargs, ) -> str: - """Given an instruction, uses the model and tokenizer to generate a response. This formats the instruction in + """ + Given an instruction, uses the model and tokenizer to generate a response. This formats the instruction in the instruction format that the model was fine-tuned on. Args: - instruction (str): _description_ - model (PreTrainedModel): the model to use - tokenizer (PreTrainedTokenizer): the tokenizer to use + instruction (str): The instruction for text generation. + model (PreTrainedModel): The pretrained model to use for text generation. + tokenizer (PreTrainedTokenizer): The tokenizer associated with the pretrained model. + **kwargs: Additional keyword arguments for text generation. Returns: - str: response + str: The generated response based on the provided instruction. + + Example: + response = generate_response("Write a summary of a book.", model=my_model, tokenizer=my_tokenizer) """ generation_pipeline = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer, **kwargs) diff --git a/python/fedml/serving/example/llm/src/main_entry.py b/python/fedml/serving/example/llm/src/main_entry.py index e90dd1b2a3..19869aa7fb 100644 --- a/python/fedml/serving/example/llm/src/main_entry.py +++ b/python/fedml/serving/example/llm/src/main_entry.py @@ -2,15 +2,37 @@ from fedml.serving import FedMLPredictor from fedml.serving import FedMLInferenceRunner -# DATA_CACHE_DIR is a LOCAL folder that contains the model and config files if -# you do NOT want to transfer the model and config files to MLOps -# Not to also metion DATA_CACHE_DIR in the fedml_model_config.yaml +# Define the local data cache directory for model and config files DATA_CACHE_DIR = "~/fedml_serving/model_and_config" -DATA_CACHE_DIR = os.path.expanduser(DATA_CACHE_DIR) # Use absolute path +DATA_CACHE_DIR = os.path.expanduser(DATA_CACHE_DIR) # Use an absolute path + +class Chatbot(FedMLPredictor): + """ + A chatbot powered by language models for generating text-based responses. + + This chatbot uses Hugging Face Transformers to generate text-based responses to user inputs. + + Attributes: + chatbot (LLMChain): The language model-based chatbot. + + Methods: + predict(request: dict) -> dict: + Generate a response to a user's input text. + + Example: + chatbot = Chatbot() + fedml_inference_runner = FedMLInferenceRunner(chatbot) + fedml_inference_runner.run() + """ -class Chatbot(FedMLPredictor): # Inherit FedMLClientPredictor def __init__(self): - super().__init__() # Will excecute the bootstrap shell script + """ + Initialize the Chatbot with a language model-based chatbot. + + This constructor initializes the chatbot by loading a pre-trained language model + and setting up the necessary components for text generation. + """ + super().__init__() # Executes the bootstrap shell script from langchain import PromptTemplate, LLMChain from langchain.llms import HuggingFacePipeline import torch @@ -21,7 +43,8 @@ def __init__(self): TextGenerationPipeline, ) - PROMPT_FOR_GENERATION_FORMAT = f""""Below is an instruction that describes a task. Write a response that appropriately completes the request." + PROMPT_FOR_GENERATION_FORMAT = """ + "Below is an instruction that describes a task. Write a response that appropriately completes the request." ### Instruction: {{instruction}} @@ -37,7 +60,7 @@ def __init__(self): config = AutoConfig.from_pretrained(DATA_CACHE_DIR) model = AutoModelForCausalLM.from_pretrained( DATA_CACHE_DIR, - torch_dtype=torch.float32, # float 16 not supported on CPU + torch_dtype=torch.float32, # float 16 not supported on CPU trust_remote_code=True, device_map="auto" ) @@ -56,8 +79,21 @@ def __init__(self): ) ) self.chatbot = LLMChain(llm=hf_pipeline, prompt=prompt, verbose=True) - - def predict(self, request:dict): + + def predict(self, request: dict) -> dict: + """ + Generate a response to a user's input text. + + Args: + request (dict): A dictionary containing user input text. + + Returns: + dict: A dictionary containing the generated text-based response. + + Example: + input_request = {"text": "Tell me a joke."} + response = chatbot.predict(input_request) + """ input_dict = request question: str = input_dict.get("text", "").strip() @@ -71,4 +107,4 @@ def predict(self, request:dict): if __name__ == "__main__": chatbot = Chatbot() fedml_inference_runner = FedMLInferenceRunner(chatbot) - fedml_inference_runner.run() \ No newline at end of file + fedml_inference_runner.run() diff --git a/python/fedml/serving/example/mnist/src/mnist_serve_main.py b/python/fedml/serving/example/mnist/src/mnist_serve_main.py index 6367ea487f..8efcbc5ca9 100644 --- a/python/fedml/serving/example/mnist/src/mnist_serve_main.py +++ b/python/fedml/serving/example/mnist/src/mnist_serve_main.py @@ -11,7 +11,24 @@ # DATA_CACHE_DIR = "" class MnistPredictor(FedMLPredictor): + """ + A custom predictor for MNIST digit classification using a logistic regression model. + + This class loads a pretrained logistic regression model and provides a predict method to make predictions + on input data. + + Args: + None + + Example: + predictor = MnistPredictor() + input_data = {"arr": [0.1, 0.2, 0.3, ..., 0.9]} + prediction = predictor.predict(input_data) + """ def __init__(self): + """ + Initialize the MnistPredictor by loading a pretrained logistic regression model. + """ import pickle import torch @@ -27,6 +44,25 @@ def __init__(self): self.list_to_tensor_func = torch.tensor def predict(self, request): + """ + Perform predictions on input data using the pretrained logistic regression model. + + Args: + request (dict): A dictionary containing input data for prediction. + The dictionary should have the following key: + - "arr" (list): A list of float values representing the input features for a MNIST digit image. + + Returns: + torch.Tensor: A tensor representing the model's prediction. + + Example: + predictor = MnistPredictor() + input_data = {"arr": [0.1, 0.2, 0.3, ..., 0.9]} + prediction = predictor.predict(input_data) + + Note: + The input data should be a list of float values with the same dimensionality as the model's input. + """ arr = request["arr"] input_tensor = self.list_to_tensor_func(arr) return self.model(input_tensor) diff --git a/python/fedml/serving/example/mnist/src/model/minist_model.py b/python/fedml/serving/example/mnist/src/model/minist_model.py index 25789d4e1c..1aed515cd6 100644 --- a/python/fedml/serving/example/mnist/src/model/minist_model.py +++ b/python/fedml/serving/example/mnist/src/model/minist_model.py @@ -1,11 +1,54 @@ import torch class LogisticRegression(torch.nn.Module): + """ + Logistic Regression model for binary classification. + + This class defines a logistic regression model with a single linear layer followed by a sigmoid activation function + for binary classification tasks. + + Args: + input_dim (int): The dimensionality of the input features. + output_dim (int): The number of output classes, which should be 1 for binary classification. + + Example: + # Create a logistic regression model for binary classification + input_dim = 10 + output_dim = 1 + model = LogisticRegression(input_dim, output_dim) + + Forward Method: + The forward method computes the output of the model for a given input. + + Example: + # Forward pass with input tensor 'x' + input_tensor = torch.tensor([0.1, 0.2, 0.3, ..., 0.9]) + output = model(input_tensor) + + Note: + - For binary classification, the `output_dim` should be set to 1. + - The `forward` method applies a sigmoid activation to the linear output, producing values in the range [0, 1]. + + """ + def __init__(self, input_dim, output_dim): super(LogisticRegression, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim) def forward(self, x): - import torch + """ + Forward pass of the logistic regression model. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, output_dim). + + Example: + # Forward pass with input tensor 'x' + input_tensor = torch.tensor([0.1, 0.2, 0.3, ..., 0.9]) + output = model(input_tensor) + """ outputs = torch.sigmoid(self.linear(x)) return outputs diff --git a/python/fedml/serving/fedml_client.py b/python/fedml/serving/fedml_client.py index 69ca57a923..0c88d1a4ee 100644 --- a/python/fedml/serving/fedml_client.py +++ b/python/fedml/serving/fedml_client.py @@ -3,9 +3,49 @@ class FedMLModelServingClient: + """ + Client for Federated Machine Learning Model Serving. + + This class is responsible for initializing and running the client for federated machine learning model serving. + + Args: + args: An instance of arguments containing configuration settings. + end_point_name: The name of the model serving endpoint. + model_name: The name of the machine learning model. + model_version: The version of the machine learning model. + inference_request: An optional inference request configuration. + device: The device (e.g., 'cuda:0') to run the client on. + dataset: The dataset used for training and testing the model. + model: The machine learning model to be used. + model_trainer: An optional client trainer for model training. + + Attributes: + end_point_name: The name of the model serving endpoint. + model_name: The name of the machine learning model. + model_version: The version of the machine learning model. + inference_request: An optional inference request configuration. + + Methods: + run(): Start the client for federated machine learning model serving. + """ + def __init__(self, args, end_point_name, model_name, model_version, inference_request=None, device=None, dataset=None, model=None, model_trainer: ClientTrainer = None): + """ + Initializes the FedMLModelServingClient. + + Args: + args: An instance of arguments containing configuration settings. + end_point_name: The name of the model serving endpoint. + model_name: The name of the machine learning model. + model_version: The version of the machine learning model. + inference_request: An optional inference request configuration. + device: The device (e.g., 'cuda:0') to run the client on. + dataset: The dataset used for training and testing the model. + model: The machine learning model to be used. + model_trainer: An optional client trainer for model training. + """ self.end_point_name = end_point_name self.model_name = model_name self.model_version = model_version @@ -36,7 +76,15 @@ def __init__(self, args, end_point_name, model_name, model_version, model_trainer, ) else: - raise Exception("Exception") + raise Exception("Unsupported federated optimizer") def run(self): + """ + Start the client for federated machine learning model serving. + + This method initializes and runs the client for federated machine learning model serving. + + Returns: + None + """ pass diff --git a/python/fedml/serving/fedml_inference_runner.py b/python/fedml/serving/fedml_inference_runner.py index 5257bf75e7..d87309f347 100644 --- a/python/fedml/serving/fedml_inference_runner.py +++ b/python/fedml/serving/fedml_inference_runner.py @@ -2,22 +2,65 @@ from fastapi import FastAPI, Request class FedMLInferenceRunner(ABC): + """ + Abstract base class for federated machine learning inference runners. + + Subclasses should implement the `predict` method for making predictions. + + Attributes: + client_predictor: An instance of a client predictor class that implements the `predict` method. + + Methods: + run(): Start the FastAPI server to handle prediction requests. + """ + def __init__(self, client_predictor): + """ + Initializes the FedMLInferenceRunner. + + Args: + client_predictor: An instance of a client predictor class that implements the `predict` method. + """ self.client_predictor = client_predictor def run(self): + """ + Start the FastAPI server to handle prediction requests. + + This method creates an HTTP server using FastAPI and defines two routes: '/predict' for making predictions + and '/ready' to check the server's readiness. + + Returns: + None + """ api = FastAPI() + @api.post("/predict") async def predict(request: Request): + """ + Handle POST requests to the '/predict' route for making predictions. + + Args: + request: The HTTP request object containing the input data. + + Returns: + dict: A JSON response containing the generated text. + """ input_json = await request.json() response_text = self.client_predictor.predict(input_json) - + return {"generated_text": str(response_text)} @api.get("/ready") async def ready(): + """ + Handle GET requests to the '/ready' route to check the server's readiness. + + Returns: + dict: A JSON response indicating the server's readiness status. + """ return {"status": "Success"} import uvicorn port = 2345 - uvicorn.run(api, host="0.0.0.0", port=port) \ No newline at end of file + uvicorn.run(api, host="0.0.0.0", port=port) diff --git a/python/fedml/serving/fedml_predictor.py b/python/fedml/serving/fedml_predictor.py index 4d435bbed8..84f26e138a 100644 --- a/python/fedml/serving/fedml_predictor.py +++ b/python/fedml/serving/fedml_predictor.py @@ -8,20 +8,60 @@ from ..computing.scheduler.comm_utils import sys_utils class FedMLPredictor(ABC): + """ + Abstract base class for federated machine learning predictors. + + Subclasses should implement the `predict` method for making predictions. + + Attributes: + None + + Methods: + predict(*args, **kwargs): Abstract method for making predictions. + """ + def __init__(self): + """ + Initializes the FedMLPredictor. + + This constructor can be extended by subclasses as needed. + """ build_dynamic_args() @abstractmethod def predict(self, *args, **kwargs): + """ + Abstract method for making predictions. + + Subclasses should implement this method to define the prediction logic. + + Args: + *args: Variable-length arguments. + **kwargs: Keyword arguments. + + Returns: + None + """ pass def build_dynamic_args(): + """ + Builds dynamic arguments based on environment variables. + + This function checks for environment variables related to a bootstrap script and executes it if found. + + Args: + None + + Returns: + bool: True if the bootstrap script runs successfully, False otherwise. + """ DEFAULT_BOOTSTRAP_FULL_DIR = os.environ.get("BOOTSTRAP_DIR", None) if DEFAULT_BOOTSTRAP_FULL_DIR is None or DEFAULT_BOOTSTRAP_FULL_DIR == "": return print("DEFAULT_BOOTSTRAP_FULL_DIR: {}".format(DEFAULT_BOOTSTRAP_FULL_DIR)) - + DEFAULT_BOOTSTRAP_SCRIPT_DIR = os.path.dirname(DEFAULT_BOOTSTRAP_FULL_DIR) DEFAULT_BOOTSTRAP_SCRIPT_PATH = os.path.dirname(DEFAULT_BOOTSTRAP_FULL_DIR) DEFAULT_BOOTSTRAP_SCRIPT_FILE = os.path.basename(DEFAULT_BOOTSTRAP_FULL_DIR) @@ -47,12 +87,12 @@ def build_dynamic_args(): bootstrap_scripts = "cd {}; sh {}".format(bootstrap_script_dir, # Use sh over ./ to avoid permission denied error os.path.basename(bootstrap_script_file)) bootstrap_scripts = str(bootstrap_scripts).replace('\\', os.sep).replace('/', os.sep) - + process = ClientConstants.exec_console_with_script(bootstrap_scripts, should_capture_stdout=True, should_capture_stderr=True) # ClientConstants.save_bootstrap_process(run_id, process.pid) ret_code, out, err = ClientConstants.get_console_pipe_out_err_results(process) - + if ret_code is None or ret_code <= 0: if out is not None: out_str = sys_utils.decode_our_err_result(out) @@ -75,4 +115,4 @@ def build_dynamic_args(): logging.error("Bootstrap script error: {}".format(traceback.format_exc())) is_bootstrap_run_ok = False - return is_bootstrap_run_ok \ No newline at end of file + return is_bootstrap_run_ok diff --git a/python/fedml/serving/fedml_server.py b/python/fedml/serving/fedml_server.py index 3663755102..1ba0cfc682 100644 --- a/python/fedml/serving/fedml_server.py +++ b/python/fedml/serving/fedml_server.py @@ -2,9 +2,36 @@ class FedMLModelServingServer: + """ + Represents a server for serving federated machine learning models. + + This class initializes and manages the server-side functionality for serving federated models + in a federated learning system. + + Args: + args (object): Configuration arguments for the server. + end_point_name (str): The name of the endpoint for serving the model. + model_name (str): The name of the federated model. + model_version (str): The version of the federated model. + inference_request (object, optional): An inference request object for making predictions. + device (str, optional): The hardware device to use for inference (e.g., 'cpu' or 'cuda'). + dataset (list, optional): A list containing dataset-related information. + model (object, optional): The federated machine learning model. + server_aggregator (ServerAggregator, optional): The server aggregator for model aggregation. + + Methods: + run(): Starts the server and serves the federated model for inference. + + Note: + This class is designed for serving federated models in a federated learning system. + """ + def __init__(self, args, end_point_name, model_name, model_version, inference_request=None, device=None, dataset=None, model=None, server_aggregator: ServerAggregator = None): + """ + Initializes a Federated Model Serving Server instance. + """ self.end_point_name = end_point_name self.model_name = model_name self.model_version = model_version @@ -42,4 +69,7 @@ def __init__(self, args, end_point_name, model_name, model_version, raise Exception("Exception") def run(self): + """ + Starts the server and serves the federated model for inference. + """ pass diff --git a/python/fedml/serving/server/fedml_aggregator.py b/python/fedml/serving/server/fedml_aggregator.py index 08f4ead226..9cb2caa5ec 100644 --- a/python/fedml/serving/server/fedml_aggregator.py +++ b/python/fedml/serving/server/fedml_aggregator.py @@ -11,6 +11,22 @@ class FedMLAggregator(object): + """ + A class for federated machine learning aggregation and related tasks. + + Args: + train_global: Global training data. + test_global: Global testing data. + all_train_data_num: Number of samples in the entire training dataset. + train_data_local_dict: Local training data dictionary. + test_data_local_dict: Local testing data dictionary. + train_data_local_num_dict: Number of local samples for each client. + client_num: Number of clients. + device: Device to run computations (e.g., 'cuda' or 'cpu'). + args: Additional configuration arguments. + server_aggregator: Aggregator for server-side operations. + """ + def __init__( self, train_global, @@ -49,15 +65,35 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add locally trained model parameters for aggregation. + + Args: + index (int): Index of the client. + model_params (dict): Local model parameters. + sample_num (int): Number of local samples used for training. + """ logging.info("add_model. index = %d" % index) - # for dictionary model_params, we let the user level code to control the device + # For dictionary model_params, let the user-level code control the device if type(model_params) is not dict: model_params = ml_engine_adapter.model_params_to_device(self.args, model_params, self.device) @@ -66,6 +102,12 @@ def add_local_trained_result(self, index, model_params, sample_num): self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.debug("client_num = {}".format(self.client_num)) for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -75,20 +117,29 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models from clients to obtain a global model. + + Returns: + tuple: A tuple containing: + - dict: Averaged global model parameters. + - list: List of model tuples before aggregation. + - list: List of indices corresponding to selected models for aggregation. + """ start_time = time.time() model_list = [] for idx in range(self.client_num): model_list.append((self.sample_num_dict[idx], self.model_dict[idx])) - # model_list is the list after outlier removal + # Model list is the list after outlier removal model_list, model_list_idxes = self.aggregator.on_before_aggregation(model_list) Context().add(Context.KEY_CLIENT_MODEL_LIST, model_list) averaged_params = self.aggregator.aggregate(model_list) if type(averaged_params) is dict: - if len(averaged_params) == self.client_num + 1: # aggregator pass extra {-1 : global_parms_dict} as global_params - itr_count = len(averaged_params) - 1 # do not apply on_after_aggregation to client -1 + if len(averaged_params) == self.client_num + 1: # Aggregator passes extra {-1: global_parms_dict} as global_params + itr_count = len(averaged_params) - 1 # Do not apply on_after_aggregation to client -1 else: itr_count = len(averaged_params) @@ -104,23 +155,24 @@ def aggregate(self): return averaged_params, model_list, model_list_idxes def assess_contribution(self): + """ + Assess the contribution of clients to the global model. + """ if hasattr(self.args, "enable_contribution") and \ self.args.enable_contribution is not None and self.args.enable_contribution: self.aggregator.assess_contribution() def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_round): """ + Select a subset of data silos (clients) for a federated learning round. Args: - round_idx: round index, starting from 0 - client_num_in_total: this is equal to the users in a synthetic data, - e.g., in synthetic_1_1, this value is 30 - client_num_per_round: the number of edge devices that can train + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select for the current round. Returns: - data_silo_index_list: e.g., when client_num_in_total = 30, client_num_in_total = 3, - this value is the form of [0, 11, 20] - + list: List of selected data silo indices. """ logging.info( "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) @@ -130,39 +182,59 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): """ + Select a subset of clients for a federated learning round. + Args: - round_idx: round index, starting from 0 - client_id_list_in_total: this is the real edge IDs. - In MLOps, its element is real edge ID, e.g., [64, 65, 66, 67]; - in simulated mode, its element is client index starting from 1, e.g., [1, 2, 3, 4] - client_num_per_round: + round_idx (int): Round index, starting from 0. + client_id_list_in_total (list): List of real edge IDs or client indices. + client_num_per_round (int): Number of clients to select for the current round. Returns: - client_id_list_in_this_round: sampled real edge ID list, e.g., [64, 66] + list: List of selected client IDs or indices. """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for a federated learning round. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample for the current round. + + Returns: + list: List of sampled client indices. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for model evaluation. + + Args: + num_samples (int, optional): Number of samples to include in the validation set. Defaults to 10000. + + Returns: + DataLoader: DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -173,6 +245,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform model testing on the server for all clients in the current round. + + Args: + round_idx (int): Round index. + """ if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) self.aggregator.test_all( @@ -183,7 +261,7 @@ def test_on_server_for_all_clients(self, round_idx): ) if round_idx == self.args.comm_round - 1: - # we allow to return four metrics, such as accuracy, AUC, loss, etc. + # Allow returning multiple metrics (e.g., accuracy, AUC, loss, etc.) in the final round metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) else: metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) @@ -200,25 +278,39 @@ def test_on_server_for_all_clients(self, round_idx): mlops.log({"round_idx": round_idx}) def get_dummy_input_tensor(self): + """ + Get a dummy input tensor from the test data. + + Returns: + list: List of dummy input tensors. + """ test_data = None if self.test_global: test_data = self.test_global - else: # if test_global is None, then we use the first non-empty test_data_local_dict + else: # If test_global is None, use the first non-empty test_data_local_dict for k, v in self.test_data_local_dict.items(): if v: test_data = v break with torch.no_grad(): - batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> dataloader obj + batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> DataLoader object dummy_list = [] for tensor in features_label_tensors: - dummy_tensor = tensor[:1] # only take the first element as dummy input + dummy_tensor = tensor[:1] # Only take the first element as dummy input dummy_list.append(dummy_tensor) - features = dummy_list[:-1] # Can adapt Process Multi-Label + features = dummy_list[:-1] # Can adapt to process multi-label data return features def get_input_shape_type(self): + """ + Get the shapes and types of input features in the test data. + + Returns: + tuple: A tuple containing: + - list: List of input feature shapes. + - list: List of input feature types ('int' or 'float'). + """ test_data = None if self.test_global: test_data = self.test_global @@ -248,10 +340,24 @@ def get_input_shape_type(self): return input_shape, input_type + def save_dummy_input_tensor(self): + """ + Save the dummy input tensor information to a file. + + This function saves the input shape and type information to a file named 'dummy_input_tensor.pkl'. + The saved file can be used for reference or documentation purposes. + + Note: To save the file to a specific location (e.g., S3), additional implementation is required. + + Example: + To save to a specific location (e.g., S3), you can modify this function to upload the file accordingly. + + """ import pickle - features = self.get_input_size_type() + features = self.get_input_shape_type() with open('dummy_input_tensor.pkl', 'wb') as handle: pickle.dump(features, handle) - # TODO: save the dummy_input_tensor.pkl to s3, and transfer when click "Create Model Card" + # TODO: Save the 'dummy_input_tensor.pkl' to S3 or another desired location, and transfer it when needed. + \ No newline at end of file diff --git a/python/fedml/serving/server/fedml_server_manager.py b/python/fedml/serving/server/fedml_server_manager.py index 9e871c6ff6..93eb32380a 100644 --- a/python/fedml/serving/server/fedml_server_manager.py +++ b/python/fedml/serving/server/fedml_server_manager.py @@ -13,11 +13,26 @@ class FedMLServerManager(FedMLCommManager): + """ + Manages the server-side operations for federated machine learning. + + This class handles communication with clients, aggregation of model updates, + and the overall server-side federated learning process. + + Args: + args: Configuration arguments for the server. + aggregator: Aggregator for model updates from clients. + comm: Communication backend (e.g., MQTT, S3). + client_rank: Rank of the client. + client_num: Total number of clients. + backend: Communication backend (default is "MQTT_S3"). + """ + ONLINE_STATUS_FLAG = "ONLINE" RUN_FINISHED_STATUS_FLAG = "FINISHED" def __init__( - self, args, aggregator, comm=None, client_rank=0, client_num=0, backend="MQTT_S3", + self, args, aggregator, comm=None, client_rank=0, client_num=0, backend="MQTT_S3", ): super().__init__(args, comm, client_rank, client_num, backend) self.args = args @@ -35,9 +50,15 @@ def __init__( self.data_silo_index_list = None def run(self): + """ + Start the federated server manager. + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients to start the training process. + """ global_model_params = self.aggregator.get_global_model_params() global_model_url = None @@ -54,26 +75,29 @@ def send_init_msg(self): mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) try: - # get input type and shape for inference + # Get input type and shape for inference dummy_input_tensor = self.aggregator.get_dummy_input_tensor() logging.info(f"dummy tensor: {dummy_input_tensor}") # sample tensor for ONNX if not getattr(self.args, "skip_log_model_net", False): model_net_url = mlops.log_training_model_net_info(self.aggregator.aggregator.model, dummy_input_tensor) - # type and shape for later configuration + # Type and shape for later configuration input_shape, input_type = self.aggregator.get_input_shape_type() logging.info(f"input shape: {input_shape}") # [torch.Size([1, 24]), torch.Size([1, 2])] logging.info(f"input type: {input_type}") # [torch.int64, torch.float32] - # Send output input size and type (saved as json) to s3, - # and transfer when click "Create Model Card" + # Send output input size and type (saved as json) to S3, + # and transfer when clicking "Create Model Card" model_input_url = mlops.log_training_model_input_info(list(input_shape), list(input_type)) except Exception as e: logging.info("exception when processing model net and model input info: {}".format( traceback.format_exc())) def register_message_receive_handlers(self): + """ + Register message handlers for different message types. + """ logging.info("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready @@ -88,17 +112,33 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handles the connection readiness message from clients and initiates the federated learning process. + + This method is called when the server receives a message indicating that clients are ready to connect. + It selects the clients for the current round and checks their status. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( - self.args.round_idx, self.client_real_ids, self.args.client_num_per_round + self.args.round_idx, + self.client_real_ids, + self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.args.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.args.round_idx, + self.args.client_num_in_total, + len(self.client_id_list_in_this_round), ) mlops.log_round_info(self.round_num, -1) - # check client status in case that some clients start earlier than the server + # Check client status in case that some clients start earlier than the server client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: try: @@ -111,6 +151,19 @@ def handle_message_connection_ready(self, msg_params): client_idx_in_this_round += 1 def process_online_status(self, client_status, msg_params): + """ + Processes online status messages from clients. + + This method is called when the server receives an online status message from a client. + It updates the client online mapping and checks if all clients are online. + + Args: + client_status: The client's online status. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_online_mapping[str(msg_params.get_sender_id())] = True logging.info("self.client_online_mapping = {}".format(self.client_online_mapping)) @@ -128,11 +181,24 @@ def process_online_status(self, client_status, msg_params): if all_client_is_online: mlops.log_aggregation_status(MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) - # send initialization message to all clients to start training + # Send initialization message to all clients to start training self.send_init_msg() self.is_initialized = True def process_finished_status(self, client_status, msg_params): + """ + Processes finished status messages from clients. + + This method is called when the server receives a finished status message from a client. + It updates the client finished mapping and checks if all clients have finished. + + Args: + client_status: The client's finished status. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_finished_mapping[str(msg_params.get_sender_id())] = True all_client_is_finished = True @@ -151,6 +217,18 @@ def process_finished_status(self, client_status, msg_params): self.finish() def handle_message_client_status_update(self, msg_params): + """ + Handles client status update messages. + + This method is called when the server receives a client status update message. + It processes the received client status and takes appropriate actions. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) logging.info(f"received client status {client_status}") if client_status == FedMLServerManager.ONLINE_STATUS_FLAG: @@ -159,6 +237,18 @@ def handle_message_client_status_update(self, msg_params): self.process_finished_status(client_status, msg_params) def handle_message_receive_model_from_client(self, msg_params): + """ + Handles messages containing trained models received from clients. + + This method is called when the server receives a message containing a trained model from a client. + It processes the received model, performs aggregation, and sends updated models to clients. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) @@ -191,7 +281,7 @@ def handle_message_receive_model_from_client(self, msg_params): mlops.event("server.agg_and_eval", event_started=False, event_value=str(self.args.round_idx)) - # send round info to the MQTT backend + # Send round info to the MQTT backend mlops.log_round_info(self.round_num, self.args.round_idx) self.client_id_list_in_this_round = self.aggregator.client_selection( @@ -211,7 +301,7 @@ def handle_message_receive_model_from_client(self, msg_params): for receiver_id in self.client_id_list_in_this_round: client_index = self.data_silo_index_list[client_idx_in_this_round] if type(global_model_params) is dict: - # compatible with the old version that, user did not give {-1 : global_parms_dict} + # Compatible with the old version that user did not give {-1 : global_parms_dict} global_model_url, global_model_key = self.send_message_diff_sync_model_to_client( receiver_id, global_model_params[client_index], client_index ) @@ -221,7 +311,7 @@ def handle_message_receive_model_from_client(self, msg_params): ) client_idx_in_this_round += 1 - # if user give {-1 : global_parms_dict}, then record global_model url separately + # If the user gives {-1 : global_parms_dict}, then record global_model url separately if type(global_model_params) is dict and (-1 in global_model_params.keys()): global_model_url, global_model_key = self.send_message_diff_sync_model_to_client( -1, global_model_params[-1], -1 @@ -230,13 +320,21 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx += 1 mlops.log_aggregated_model_info( self.args.round_idx, model_url=global_model_url, - ) + ) logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) if self.args.round_idx < self.round_num: mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) def cleanup(self): + """ + Cleans up after a round of federated learning. + + This method is called to clean up resources and send finish messages to clients. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -245,7 +343,23 @@ def cleanup(self): client_idx_in_this_round += 1 def send_message_init_config(self, receive_id, global_model_params, datasilo_index, - global_model_url=None, global_model_key=None): + global_model_url=None, global_model_key=None): + """ + Sends initialization configuration message to a client. + + This method constructs and sends an initialization configuration message to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + global_model_params: Global model parameters to be sent. + datasilo_index: Index of the data silo associated with the client. + global_model_url: URL of the global model (optional). + global_model_key: Key of the global model (optional). + + Returns: + global_model_url: URL of the global model (if provided). + global_model_key: Key of the global model (if provided). + """ tick = time.time() message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) if global_model_url is not None: @@ -262,11 +376,35 @@ def send_message_init_config(self, receive_id, global_model_params, datasilo_ind return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): + """ + Sends a message to check the status of a client. + + This method constructs and sends a message to check the status of a specified client. + + Args: + receive_id: The ID of the client to receive the message. + datasilo_index: Index of the data silo associated with the client. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): + """ + Sends a finish message to a client. + + This method constructs and sends a finish message to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + datasilo_index: Index of the data silo associated with the client. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) @@ -275,7 +413,23 @@ def send_message_finish(self, receive_id, datasilo_index): logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index, - global_model_url=None, global_model_key=None): + global_model_url=None, global_model_key=None): + """ + Sends a synchronized model to a client. + + This method constructs and sends a message containing synchronized model parameters to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + global_model_params: The synchronized global model parameters to be sent. + client_index: Index of the client associated with the model. + global_model_url: URL for the global model parameters (optional). + global_model_key: Key for the global model parameters (optional). + + Returns: + global_model_url: URL for the global model parameters. + global_model_key: Key for the global model parameters. + """ tick = time.time() logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) @@ -296,6 +450,20 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, cli return global_model_url, global_model_key def send_message_diff_sync_model_to_client(self, receive_id, client_model_params, client_index): + """ + Sends a differentiated synchronized model to a client. + + This method constructs and sends a message containing differentiated synchronized model parameters to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + client_model_params: The differentiated synchronized model parameters to be sent. + client_index: Index of the client associated with the model. + + Returns: + global_model_url: URL for the global model parameters. + global_model_key: Key for the global model parameters. + """ tick = time.time() logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) @@ -309,4 +477,4 @@ def send_message_diff_sync_model_to_client(self, receive_id, client_model_params global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) - return global_model_url, global_model_key \ No newline at end of file + return global_model_url, global_model_key diff --git a/python/fedml/serving/server/message_define.py b/python/fedml/serving/server/message_define.py index 1c1db66741..e10b68f760 100644 --- a/python/fedml/serving/server/message_define.py +++ b/python/fedml/serving/server/message_define.py @@ -1,30 +1,29 @@ class MyMessage(object): """ - message type definition + Defines message types and their associated constants for communication between server and clients. """ - # connection info + # Connection Info MSG_TYPE_CONNECTION_IS_READY = 0 - # server to client + # Server to Client Messages MSG_TYPE_S2C_INIT_CONFIG = 1 MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT = 2 MSG_TYPE_S2C_CHECK_CLIENT_STATUS = 6 MSG_TYPE_S2C_FINISH = 7 - # client to server + # Client to Server Messages MSG_TYPE_C2S_SEND_MODEL_TO_SERVER = 3 MSG_TYPE_C2S_SEND_STATS_TO_SERVER = 4 MSG_TYPE_C2S_CLIENT_STATUS = 5 MSG_TYPE_C2S_FINISHED = 8 + # Message Argument Keys MSG_ARG_KEY_TYPE = "msg_type" MSG_ARG_KEY_SENDER = "sender" MSG_ARG_KEY_RECEIVER = "receiver" - """ - message payload keywords definition - """ + # Message Payload Keywords MSG_ARG_KEY_NUM_SAMPLES = "num_samples" MSG_ARG_KEY_MODEL_PARAMS = "model_params" MSG_ARG_KEY_MODEL_PARAMS_URL = "model_params_url" @@ -41,14 +40,12 @@ class MyMessage(object): MSG_ARG_KEY_CLIENT_STATUS = "client_status" MSG_ARG_KEY_CLIENT_OS = "client_os" - + MSG_ARG_KEY_EVENT_NAME = "event_name" MSG_ARG_KEY_EVENT_VALUE = "event_value" MSG_ARG_KEY_EVENT_MSG = "event_msg" - """ - MLOps related message - """ + # MLOps Related Messages # Client Status MSG_MLOPS_CLIENT_STATUS_IDLE = "IDLE" MSG_MLOPS_CLIENT_STATUS_UPGRADING = "UPGRADING" diff --git a/python/fedml/serving/server/server_initializer.py b/python/fedml/serving/server/server_initializer.py index 5877d96fea..941e486fd8 100644 --- a/python/fedml/serving/server/server_initializer.py +++ b/python/fedml/serving/server/server_initializer.py @@ -16,13 +16,31 @@ def init_server( train_data_local_dict, test_data_local_dict, train_data_local_num_dict, - server_aggregator, + server_aggregator=None, ): + """ + Initialize and start the server for federated machine learning. + + Args: + args: Configuration arguments for the server. + device: The device (e.g., GPU) to be used for computation. + comm: Communication module for distributed computing. + rank: The rank of the server in the communication group. + worker_num: The number of worker nodes in the federated setup. + model: The machine learning model to be used. + train_data_num: The number of training data samples. + train_data_global: The global training dataset. + test_data_global: The global test dataset. + train_data_local_dict: Dictionary of local training datasets for workers. + test_data_local_dict: Dictionary of local test datasets for workers. + train_data_local_num_dict: Dictionary of the number of local training samples for workers. + server_aggregator: The aggregator responsible for aggregating model updates (default: None). + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(0) - # aggregator + # Create the aggregator aggregator = FedMLAggregator( train_data_global, test_data_global, @@ -36,7 +54,7 @@ def init_server( server_aggregator, ) - # start the distributed training + # Start the distributed training backend = args.backend server_manager = FedMLServerManager(args, aggregator, comm, rank, worker_num, backend) server_manager.run() From d817a675e3380bf42dcecd0aa4e4884811865a01 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 23:21:48 +0530 Subject: [PATCH 20/70] j --- python/fedml/model/cv/resnet.py | 123 ++++++++++++++++-- .../serving/server/fedml_server_manager.py | 1 + 2 files changed, 111 insertions(+), 13 deletions(-) diff --git a/python/fedml/model/cv/resnet.py b/python/fedml/model/cv/resnet.py index 17cf6a622c..d833e3b762 100644 --- a/python/fedml/model/cv/resnet.py +++ b/python/fedml/model/cv/resnet.py @@ -17,7 +17,18 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation factor for convolution. Default is 1. + + Returns: + nn.Conv2d: 3x3 convolutional layer. + """ return nn.Conv2d( in_planes, out_planes, @@ -31,11 +42,22 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: 1x1 convolutional layer. + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """Basic residual block used in ResNet architectures.""" + expansion = 1 def __init__( @@ -49,6 +71,18 @@ def __init__( dilation=1, norm_layer=None, ): + """Initialize a BasicBlock instance. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connections. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for the convolution. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + """ super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -56,16 +90,27 @@ def __init__( raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + + # Define the convolutional layers self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) + + # Downsample layer for shortcut connections self.downsample = downsample self.stride = stride def forward(self, x): + """Forward pass through the BasicBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -98,22 +143,47 @@ def __init__( dilation=1, norm_layer=None, ): + """Initialize a Bottleneck block used in ResNet architectures. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connections. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for the convolution. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + """ super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + + # Define the three convolutional layers self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) + + # ReLU activation function self.relu = nn.ReLU(inplace=True) + + # Downsample layer for shortcut connections self.downsample = downsample self.stride = stride def forward(self, x): + """Forward pass through the Bottleneck block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -135,7 +205,6 @@ def forward(self, x): return out - class ResNet(nn.Module): def __init__( self, @@ -149,6 +218,19 @@ def __init__( norm_layer=None, KD=False, ): + """Initialize a ResNet model. + + Args: + block (nn.Module): The residual block type, either BasicBlock or Bottleneck. + layers (list): List of integers indicating the number of blocks in each layer. + num_classes (int, optional): Number of output classes. Default is 10. + zero_init_residual (bool, optional): If True, zero-initialize the last BN in each residual branch. Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Base width for the convolution. Default is 64. + replace_stride_with_dilation (tuple, optional): Replace stride with dilation in certain stages. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + KD (bool, optional): Whether to perform Knowledge Distillation. Default is False. + """ super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -173,7 +255,6 @@ def __init__( ) self.bn1 = nn.BatchNorm2d(self.inplanes) self.relu = nn.ReLU(inplace=True) - # self.maxpool = nn.MaxPool2d() self.layer1 = self._make_layer(block, 16, layers[0]) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) @@ -197,6 +278,19 @@ def __init__( nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + """ + Create a layer in the ResNet model. + + Args: + block (nn.Module): The residual block type, either BasicBlock or Bottleneck. + planes (int): Number of output channels for the layer. + blocks (int): Number of residual blocks in the layer. + stride (int, optional): The stride for the convolutional layers. Default is 1. + dilate (bool, optional): Whether to apply dilation to the convolutional layers. Default is False. + + Returns: + nn.Sequential: A sequential layer containing the specified number of residual blocks. + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -238,6 +332,14 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """Forward pass through the ResNet model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) # B x 16 x 32 x 32 @@ -245,13 +347,8 @@ def forward(self, x): x = self.layer2(x) # B x 32 x 16 x 16 x = self.layer3(x) # B x 64 x 8 x 8 - x = self.avgpool(x) # B x 64 x 1 x 1 - x_f = x.view(x.size(0), -1) # B x 64 - x = self.fc(x_f) # B x num_classes - if self.KD == True: - return x_f, x - else: - return x + x = self.avgpool(x) # B + def resnet20(class_num, pretrained=False, path=None, **kwargs): diff --git a/python/fedml/serving/server/fedml_server_manager.py b/python/fedml/serving/server/fedml_server_manager.py index 93eb32380a..b2fb7cc9c0 100644 --- a/python/fedml/serving/server/fedml_server_manager.py +++ b/python/fedml/serving/server/fedml_server_manager.py @@ -249,6 +249,7 @@ def handle_message_receive_model_from_client(self, msg_params): Returns: None """ + sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) From e09341cb4fd364ab63e176e2c89ed9da008d89f1 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sun, 10 Sep 2023 11:06:07 +0530 Subject: [PATCH 21/70] qdd ds --- python/fedml/model/cv/common.py | 338 +++++++++++++++++++++++++- python/fedml/model/cv/resnet.py | 16 +- python/fedml/model/cv/resnet_torch.py | 165 ++++++++++++- 3 files changed, 498 insertions(+), 21 deletions(-) diff --git a/python/fedml/model/cv/common.py b/python/fedml/model/cv/common.py index bcd3e452ff..267bb4494d 100644 --- a/python/fedml/model/cv/common.py +++ b/python/fedml/model/cv/common.py @@ -37,60 +37,190 @@ def round_channels(channels, ------- int Weighted number of channels. + + Examples: + -------- + >>> channels = 64 + >>> rounded_channels = round_channels(channels) + >>> print(rounded_channels) + 64 + + >>> channels = 57 + >>> rounded_channels = round_channels(channels) + >>> print(rounded_channels) + 56 """ rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) if float(rounded_channels) < 0.9 * channels: rounded_channels += divisor - return rounded_channels + return rounded_channel class Identity(nn.Module): """ Identity block. + + This block represents the identity function, which means it does not perform any + operations on the input and simply returns it unchanged. It is commonly used in + residual neural networks (ResNets) to create skip connections. + + Attributes: + None + + Methods: + forward(x): Performs a forward pass of the identity block. + __repr__(): Returns a string representation of the Identity block. + + Examples: + >>> identity_block = Identity() + >>> x = torch.randn(1, 64, 32, 32) + >>> output = identity_block(x) + >>> assert torch.allclose(x, output) # The output should be the same as the input. + """ def __init__(self): super(Identity, self).__init__() def forward(self, x): + """ + Forward pass of the identity block. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The input tensor unchanged. + """ return x def __repr__(self): + """ + String representation of the Identity block. + + Returns: + str: A string representing the Identity block. + """ return '{name}()'.format(name=self.__class__.__name__) class BreakBlock(nn.Module): """ - Break coonnection block for hourglass. + Break connection block for hourglass network. + + This block serves as a break in the network's connections. It takes an input and returns None. + It is commonly used in hourglass-style networks to create skips in the network flow. + + Attributes: + ---------- + None + + Methods: + ------- + forward(x): + Forward pass through the block. + + __repr__(): + Returns a string representation of the block. """ def __init__(self): super(BreakBlock, self).__init__() def forward(self, x): + """ + Forward pass through the block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + None + The block returns None, effectively breaking the connection. + """ return None def __repr__(self): + """ + Returns a string representation of the block. + + Returns: + ------- + str + A string representation of the block, indicating its name. + """ return '{name}()'.format(name=self.__class__.__name__) + class Swish(nn.Module): """ Swish activation function from 'Searching for Activation Functions,' https://arxiv.org/abs/1710.05941. + + This activation function is defined as Swish(x) = x * sigmoid(x). + + Attributes: + ---------- + None + + Methods: + ------- + forward(x): + Forward pass through the Swish activation function. """ def forward(self, x): + """ + Forward pass through the Swish activation function. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor after applying the Swish activation function. + """ return x * torch.sigmoid(x) class HSigmoid(nn.Module): """ - Approximated sigmoid function, so-called hard-version of sigmoid from 'Searching for MobileNetV3,' + Approximated sigmoid function, the hard version of sigmoid, from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244. + + This activation function is defined as HSigmoid(x) = relu6(x + 3.0) / 6.0. + + Attributes: + ---------- + None + + Methods: + ------- + forward(x): + Forward pass through the HSigmoid activation function. """ def forward(self, x): + """ + Forward pass through the HSigmoid activation function. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor after applying the HSigmoid activation function. + """ return F.relu6(x + 3.0, inplace=True) / 6.0 @@ -98,10 +228,22 @@ class HSwish(nn.Module): """ H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244. + This activation function is defined as HSwish(x) = x * relu6(x + 3.0) / 6.0. + Parameters: ---------- + inplace : bool, optional (default=False) + Whether to use the inplace version of the module. + + Attributes: + ---------- inplace : bool - Whether to use inplace version of the module. + Indicates whether the inplace version is used. + + Methods: + ------- + forward(x): + Forward pass through the H-Swish activation function. """ def __init__(self, inplace=False): @@ -109,24 +251,42 @@ def __init__(self, inplace=False): self.inplace = inplace def forward(self, x): + """ + Forward pass through the H-Swish activation function. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor after applying the H-Swish activation function. + """ return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 def get_activation_layer(activation): """ - Create activation layer from string/function. + Create an activation layer from a string/function. Parameters: ---------- - activation : function, or str, or nn.Module - Activation function or name of activation function. + activation : function, str, or nn.Module + Activation function or name of the activation function. Returns: ------- nn.Module Activation layer. + + Raises: + ------- + NotImplementedError: + If the specified activation function is not supported. """ - assert (activation is not None) + assert activation is not None if isfunction(activation): return activation() elif isinstance(activation, str): @@ -145,9 +305,9 @@ def get_activation_layer(activation): elif activation == "identity": return Identity() else: - raise NotImplementedError() + raise NotImplementedError("Unsupported activation function: {}".format(activation)) else: - assert (isinstance(activation, nn.Module)) + assert isinstance(activation, nn.Module) return activation @@ -165,6 +325,21 @@ class SelectableDense(nn.Module): Whether the layer uses a bias vector. num_options : int, default 1 Number of selectable options. + + Attributes: + ---------- + in_features : int + Number of input features. + out_features : int + Number of output features. + use_bias : bool + Whether the layer uses a bias vector. + num_options : int + Number of selectable options. + weight : torch.nn.Parameter + Learnable weight parameter. + bias : torch.nn.Parameter + Learnable bias parameter (if use_bias=True). """ def __init__(self, @@ -185,6 +360,21 @@ def __init__(self, self.register_parameter("bias", None) def forward(self, x, indices): + """ + Forward pass through the SelectableDense layer. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + indices : torch.Tensor + Tensor containing the indices of the selected options. + + Returns: + ------- + torch.Tensor + Output tensor after applying the SelectableDense layer. + """ weight = torch.index_select(self.weight, dim=0, index=indices) x = x.unsqueeze(-1) x = weight.bmm(x) @@ -195,6 +385,14 @@ def forward(self, x, indices): return x def extra_repr(self): + """ + Extra representation of the SelectableDense layer. + + Returns: + ------- + str + String representation of the layer's attributes. + """ return "in_features={}, out_features={}, bias={}, num_options={}".format( self.in_features, self.out_features, self.use_bias, self.num_options) @@ -242,6 +440,19 @@ def __init__(self, self.activ = get_activation_layer(activation) def forward(self, x): + """ + Forward pass of the dense block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor. + """ x = self.fc(x) if self.use_bn: x = self.bn(x) @@ -313,6 +524,19 @@ def __init__(self, self.activ = get_activation_layer(activation) def forward(self, x): + """ + Forward pass of the 1D convolution block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor. + """ x = self.conv(x) if self.use_bn: x = self.bn(x) @@ -341,6 +565,11 @@ def conv1x1(in_channels, Number of groups. bias : bool, default False Whether the layer uses a bias vector. + + Returns: + ------- + nn.Conv2d + 1x1 convolutional layer. """ return nn.Conv2d( in_channels=in_channels, @@ -377,6 +606,11 @@ def conv3x3(in_channels, Number of groups. bias : bool, default False Whether the layer uses a bias vector. + + Returns: + ------- + nn.Conv2d + 3x3 convolutional layer. """ return nn.Conv2d( in_channels=in_channels, @@ -409,6 +643,11 @@ def depthwise_conv3x3(channels, Dilation value for convolution layer. bias : bool, default False Whether the layer uses a bias vector. + + Returns: + ------- + nn.Conv2d + Depthwise 3x3 convolutional layer. """ return nn.Conv2d( in_channels=channels, @@ -449,6 +688,15 @@ class ConvBlock(nn.Module): Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Examples: + -------- + An example of using the ConvBlock: + + >>> import torch + >>> x = torch.randn(1, 3, 64, 64) # Input tensor with shape (batch_size, channels, height, width) + >>> conv_block = ConvBlock(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) + >>> output = conv_block(x) # Forward pass through the ConvBlock """ def __init__(self, @@ -489,6 +737,19 @@ def __init__(self, self.activ = get_activation_layer(activation) def forward(self, x): + """ + Forward pass of the ConvBlock. + + Parameters: + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, height, width). + + Returns: + ------- + torch.Tensor + Output tensor after applying convolution, batch normalization, and activation. + """ if self.use_pad: x = self.pad(x) x = self.conv(x) @@ -531,6 +792,11 @@ def conv1x1_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 1x1 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -580,6 +846,11 @@ def conv3x3_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 3x3 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -594,7 +865,6 @@ def conv3x3_block(in_channels, bn_eps=bn_eps, activation=activation) - def conv5x5_block(in_channels, out_channels, stride=1, @@ -630,6 +900,11 @@ def conv5x5_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 5x5 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -664,7 +939,7 @@ def conv7x7_block(in_channels, Number of input channels. out_channels : int Number of output channels. - padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1 + stride : int or tuple/list of 2 int, default 1 Strides of the convolution. padding : int or tuple/list of 2 int, default 3 Padding value for convolution layer. @@ -680,6 +955,11 @@ def conv7x7_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 7x7 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -730,6 +1010,11 @@ def dwconv_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + Depthwise Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -774,6 +1059,11 @@ def dwconv3x3_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 3x3 Depthwise Convolutional Block. """ return dwconv_block( in_channels=in_channels, @@ -816,6 +1106,11 @@ def dwconv5x5_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 5x5 Depthwise Convolutional Block. """ return dwconv_block( in_channels=in_channels, @@ -829,6 +1124,7 @@ def dwconv5x5_block(in_channels, activation=activation) + class DwsConvBlock(nn.Module): """ Depthwise separable convolution block with BatchNorms and activations at each convolution layers. @@ -859,6 +1155,11 @@ class DwsConvBlock(nn.Module): Activation function after the depthwise convolution block. pw_activation : function or str or None, default nn.ReLU(inplace=True) Activation function after the pointwise convolution block. + + Returns: + ---------- + torch.Tensor + The output tensor after applying depthwise separable convolution block. """ def __init__(self, @@ -895,6 +1196,19 @@ def __init__(self, activation=pw_activation) def forward(self, x): + """ + Forward pass of the depthwise separable convolution block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ---------- + torch.Tensor + The output tensor after applying depthwise separable convolution block. + """ x = self.dw_conv(x) x = self.pw_conv(x) return x diff --git a/python/fedml/model/cv/resnet.py b/python/fedml/model/cv/resnet.py index d833e3b762..9a4106c6e3 100644 --- a/python/fedml/model/cv/resnet.py +++ b/python/fedml/model/cv/resnet.py @@ -17,7 +17,8 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding. + """ + 3x3 convolution with padding. Args: in_planes (int): Number of input channels. @@ -42,7 +43,8 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution. + """ + 1x1 convolution. Args: in_planes (int): Number of input channels. @@ -56,7 +58,9 @@ def conv1x1(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): - """Basic residual block used in ResNet architectures.""" + """ + Basic residual block used in ResNet architectures. + """ expansion = 1 @@ -71,7 +75,8 @@ def __init__( dilation=1, norm_layer=None, ): - """Initialize a BasicBlock instance. + """ + Initialize a BasicBlock instance. Args: inplanes (int): Number of input channels. @@ -103,7 +108,8 @@ def __init__( self.stride = stride def forward(self, x): - """Forward pass through the BasicBlock. + """ + Forward pass through the BasicBlock. Args: x (torch.Tensor): Input tensor. diff --git a/python/fedml/model/cv/resnet_torch.py b/python/fedml/model/cv/resnet_torch.py index 5008edf8c4..bdf50e468b 100644 --- a/python/fedml/model/cv/resnet_torch.py +++ b/python/fedml/model/cv/resnet_torch.py @@ -9,7 +9,16 @@ # from .._internally_replaced_utils import load_state_dict_from_url # from ..utils import _log_api_usage_once +""" +This module provides pre-trained ResNet models and their URLs for download. +- `ResNet`: The main ResNet model class. +- `resnet18`, `resnet34`, `resnet50`, `resnet101`, `resnet152`: Pre-trained ResNet models. +- `resnext50_32x4d`, `resnext101_32x8d`: Pre-trained ResNeXt models. +- `wide_resnet50_2`, `wide_resnet101_2`: Pre-trained Wide ResNet models. + +You can use these models for various computer vision tasks. +""" __all__ = [ "ResNet", "resnet18", @@ -38,7 +47,19 @@ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: - """3x3 convolution with padding""" + """ + 3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation rate for convolution. Default is 1. + + Returns: + nn.Conv2d: Convolutional layer. + """ return nn.Conv2d( in_planes, out_planes, @@ -52,11 +73,38 @@ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, d def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """1x1 convolution""" + """ + 1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: Convolutional layer. + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic ResNet block. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for convolution. Default is 1. + norm_layer (Callable[..., nn.Module], optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor. + """ + expansion: int = 1 def __init__( @@ -87,6 +135,15 @@ def __init__( self.stride = stride def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the BasicBlock. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -112,6 +169,23 @@ class Bottleneck(nn.Module): # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + """ + Bottleneck block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for convolution. Default is 1. + norm_layer (Callable[..., nn.Module], optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor. + """ + expansion: int = 4 def __init__( @@ -141,6 +215,15 @@ def __init__( self.stride = stride def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the Bottleneck block. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -177,6 +260,26 @@ def __init__( args = None, in_channels = 3, ) -> None: + """ + Residual Neural Network (ResNet) model. + + Args: + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + layers (List[int]): List specifying the number of blocks in each layer. + num_classes (int, optional): Number of output classes. Default is 1000. + zero_init_residual (bool, optional): If True, zero-initializes the last BN in each residual branch. + Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Base width for grouped convolution. Default is 64. + replace_stride_with_dilation (Optional[List[bool]], optional): List specifying if strides should be replaced + with dilations in each layer. Default is None. + norm_layer (Optional[Callable[..., nn.Module]], optional): Normalization layer. Default is None. + args: Additional arguments (not used in the model). + in_channels: Number of input channels. Default is 3. + + Attributes: + expansion (int): Expansion factor for bottleneck blocks. + """ super().__init__() # _log_api_usage_once(self) if norm_layer is None: @@ -240,6 +343,19 @@ def _make_layer( stride: int = 1, dilate: bool = False, ) -> nn.Sequential: + """ + Create a layer consisting of multiple blocks. + + Args: + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + planes (int): Number of output channels for the layer. + blocks (int): Number of blocks in the layer. + stride (int, optional): Stride for the first block. Default is 1. + dilate (bool, optional): If True, use dilation in the layer. Default is False. + + Returns: + nn.Sequential: A sequential module containing the blocks. + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -274,6 +390,15 @@ def _make_layer( return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: + """ + Forward pass of the model. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ # See note [TorchScript super()] if len(x.shape) < 4: x = torch.unsqueeze(x, 1) @@ -295,6 +420,15 @@ def _forward_impl(self, x: Tensor) -> Tensor: return x def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the model. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ return self._forward_impl(x) @@ -308,6 +442,20 @@ def _resnet( progress: bool, **kwargs: Any, ) -> ResNet: + """ + Constructs a ResNet model. + + Args: + arch (str): Architecture name. + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + layers (List[int]): List specifying the number of blocks in each layer. + pretrained (bool): If True, loads pre-trained weights. + progress (bool): If True, displays download progress for pre-trained weights. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + ResNet: ResNet model. + """ model = ResNet(block, layers, **kwargs) # if pretrained: # state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) @@ -318,9 +466,18 @@ def _resnet( def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. + Constructs a ResNet model. + Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr + arch (str): Architecture name. + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + layers (List[int]): List specifying the number of blocks in each layer. + pretrained (bool): If True, loads pre-trained weights. + progress (bool): If True, displays download progress for pre-trained weights. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + ResNet: ResNet model. """ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) From a3bc61a82e40cde80e63557f3333214a6fa53a5b Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 13 Sep 2023 13:30:47 +0530 Subject: [PATCH 22/70] model done --- python/fedml/model/cv/batchnorm_utils.py | 406 ++++++++++++++++-- python/fedml/model/cv/common.py | 3 +- python/fedml/model/cv/darts/architect.py | 185 ++++++++ python/fedml/model/cv/darts/model.py | 123 +++++- python/fedml/model/cv/darts/model_search.py | 315 +++++++++++++- .../fedml/model/cv/darts/model_search_gdas.py | 72 +++- python/fedml/model/cv/darts/operations.py | 100 +++++ python/fedml/model/cv/darts/train.py | 25 ++ python/fedml/model/cv/darts/train_search.py | 34 ++ python/fedml/model/cv/darts/utils.py | 128 +++++- python/fedml/model/cv/darts/visualize.py | 13 + python/fedml/model/cv/efficientnet_utils.py | 265 +++++++++--- python/fedml/model/cv/group_normalization.py | 173 +++++++- .../fedml/model/cv/resnet56/resnet_client.py | 143 +++++- .../model/cv/resnet56/resnet_pretrained.py | 153 ++++++- .../fedml/model/cv/resnet56/resnet_server.py | 178 +++++++- python/fedml/model/linear/lr.py | 39 +- python/fedml/model/linear/lr_cifar10.py | 23 +- python/fedml/model/mobile/mnn_lenet.py | 23 + python/fedml/model/mobile/mnn_resnet.py | 22 + python/fedml/model/mobile/torch_lenet.py | 12 +- 21 files changed, 2259 insertions(+), 176 deletions(-) diff --git a/python/fedml/model/cv/batchnorm_utils.py b/python/fedml/model/cv/batchnorm_utils.py index 876454b55a..7ecb309ebb 100644 --- a/python/fedml/model/cv/batchnorm_utils.py +++ b/python/fedml/model/cv/batchnorm_utils.py @@ -24,20 +24,84 @@ class FutureResult(object): - """A thread-safe future implementation. Used only as one-to-one pipe.""" + """A thread-safe future implementation used for one-to-one communication. + This class provides a thread-safe mechanism for transferring results between threads, + typically in a producer-consumer pattern. It is designed for one-to-one communication + and ensures that the result is safely passed from one thread to another. + + Args: + None + + Attributes: + _result: The result value stored in the future. + _lock: A lock to ensure thread safety. + _cond: A condition variable associated with the lock for waiting and notifying. + + Methods: + put(result): + Puts a result value into the future. If a result already exists, it raises an + assertion error. + + get(): + Retrieves the result value from the future. If the result is not available yet, + it blocks until the result is put into the future. + + Example: + Here's an example of using `FutureResult` for communication between two threads: + + ```python + import threading + + def producer(future): + result = 42 # Some computation or value to produce + future.put(result) + + def consumer(future): + result = future.get() + print(f"Received result: {result}") + + future = FutureResult() + + # Start the producer and consumer threads + producer_thread = threading.Thread(target=producer, args=(future,)) + consumer_thread = threading.Thread(target=consumer, args=(future,)) + + producer_thread.start() + consumer_thread.start() + + producer_thread.join() + consumer_thread.join() + ``` + + Note: + This class is intended for one-to-one communication between threads. + """ def __init__(self): self._result = None self._lock = threading.Lock() self._cond = threading.Condition(self._lock) def put(self, result): + """Put a result into the future. + + Args: + result: The result value to be stored in the future. + + Raises: + AssertionError: If a result is already present in the future. + """ with self._lock: assert self._result is None, "Previous result has't been fetched." self._result = result self._cond.notify() def get(self): + """Get the result from the future, blocking if necessary. + + Returns: + The result value stored in the future. + """ with self._lock: if self._result is None: self._cond.wait() @@ -54,9 +118,69 @@ def get(self): class SlavePipe(_SlavePipeBase): - """Pipe for master-slave communication.""" + """Pipe for master-slave communication in a multi-threaded environment. + + This class represents a pipe used for communication between a master thread and one + or more slave threads. It is designed for multi-threaded applications where the + master thread delegates tasks to the slave threads and waits for their results. + + Args: + queue (Queue): A queue for sending messages from the slave thread to the master. + result (FutureResult): A FutureResult object for receiving results from the slave. + identifier (int): An identifier for the slave thread. + + Attributes: + queue (Queue): A queue for sending messages from the slave thread to the master. + result (FutureResult): A FutureResult object for receiving results from the slave. + identifier (int): An identifier for the slave thread. + + Methods: + run_slave(msg): + Executes a task in the slave thread and sends a message to the master thread. + It waits for the master to acknowledge the completion of the task and returns + the result. + + Example: + Here's an example of using `SlavePipe` for master-slave communication: + + ```python + import threading + + def slave_function(pipe): + # Perform some computation and send the result to the master + result = 42 # Placeholder for the result + pipe.run_slave(result) + + # Create a SlavePipe for communication + slave_pipe = SlavePipe(queue, result, 1) + + # Start the slave thread + slave_thread = threading.Thread(target=slave_function, args=(slave_pipe,)) + slave_thread.start() + + # Master thread can send tasks and receive results using the slave_pipe + task_result = slave_pipe.run_slave(task_data) + # Wait for the slave thread to finish + slave_thread.join() + + # Use the task_result received from the slave + print(f"Received result from slave: {task_result}") + ``` + + Note: + This class is intended for use in multi-threaded applications where a master + thread communicates with one or more slave threads. + """ def run_slave(self, msg): + """Execute a task in the slave thread and communicate with the master. + + Args: + msg: The message or task to be sent to the master. + + Returns: + The result of the task received from the master. + """ self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) @@ -64,13 +188,67 @@ def run_slave(self, msg): class SyncMaster(object): - """An abstract `SyncMaster` object. + """An abstract `SyncMaster` object for coordinating communication between master and slave devices. + + In a data parallel setting, the `SyncMaster` object manages the communication between the master device + and multiple slave devices. It provides a mechanism for slave devices to register and communicate with + the master during forward and backward passes. + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. + + Args: + master_callback (callable): A callback function to be invoked after collecting messages from slave devices. + + Attributes: + _master_callback (callable): A callback function to be invoked after collecting messages from slave devices. + _queue (queue.Queue): A queue for exchanging messages between master and slave devices. + _registry (collections.OrderedDict): A registry of slave devices and their associated communication pipes. + _activated (bool): A flag indicating whether the SyncMaster is activated for communication. + + Methods: + register_slave(identifier): + Register a slave device and obtain a `SlavePipe` object for communication with the master device. + + run_master(master_msg): + Main entry for the master device during each forward pass. Collects messages from all devices, + invokes the master callback to compute a response message, and sends messages back to each device. + + nr_slaves: + Property that returns the number of registered slave devices. + + Example: + Here's an example of using `SyncMaster` for coordinating communication in a data parallel setting: + + ```python + def master_callback(messages): + # Compute the master message based on received messages + master_msg = messages[0][1] + return [(0, master_msg)] # Send the same message back to the master + + sync_master = SyncMaster(master_callback) + + # Register slave devices and obtain communication pipes + slave_pipe1 = sync_master.register_slave(1) + slave_pipe2 = sync_master.register_slave(2) + + # During the forward pass, master device runs run_master to coordinate communication + master_msg = "Hello from master" + response_msg = sync_master.run_master(master_msg) + + # Use the response_msg and coordinate further actions + + # Get the number of registered slave devices + num_slaves = sync_master.nr_slaves + ``` + + Note: + This class is intended for use in multi-device data parallel applications where a master device + coordinates communication with multiple slave devices. """ def __init__(self, master_callback): @@ -90,11 +268,27 @@ def __setstate__(self, state): self.__init__(state["master_callback"]) def register_slave(self, identifier): - """ - Register an slave device. + """Register a slave device with the SyncMaster. + Args: - identifier: an identifier, usually is the device id. - Returns: a `SlavePipe` object which can be used to communicate with the master device. + identifier (int): An identifier, usually the device ID. + + Returns: + SlavePipe: A `SlavePipe` object for communicating with the master device. + + Raises: + AssertionError: If the SyncMaster is already activated and the queue is not empty. + + Notes: + This method should be called by slave devices to register themselves with the SyncMaster. + The returned `SlavePipe` object can be used for communication with the master device. + + Example: + ```python + sync_master = SyncMaster(master_callback) + slave_pipe = sync_master.register_slave(1) + ``` + """ if self._activated: assert self._queue.empty(), "Queue is not clean before next initialization." @@ -105,15 +299,30 @@ def register_slave(self, identifier): return SlavePipe(identifier, self._queue, future) def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. + """Run the master device during each forward pass. + The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). + Args: - master_msg: the message that the master want to send to itself. This will be placed as the first - message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. - Returns: the message to be sent back to the master device. + master_msg: The message that the master wants to send to itself. + This message will be placed as the first message when calling `master_callback`. + + Returns: + Any: The message to be sent back to the master device. + + Notes: + This method is the main entry for the master device during each forward pass. + It collects messages from all devices, invokes the master callback to compute a response message, + and sends messages back to each device. + + Example: + ```python + master_msg = "Hello from master" + response_msg = sync_master.run_master(master_msg) + ``` + """ self._activated = True @@ -136,16 +345,57 @@ def run_master(self, master_msg): @property def nr_slaves(self): + """Get the number of registered slave devices. + + Returns: + int: The number of registered slave devices. + + Example: + ```python + num_slaves = sync_master.nr_slaves + ``` + + """ return len(self._registry) def _sum_ft(tensor): - """sum over the first and last dimention""" + """Sum over the first and last dimensions of a tensor. + + Args: + tensor (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: A tensor with the sum of values over the first and last dimensions. + + Example: + ```python + input_tensor = torch.tensor([[1, 2], [3, 4]]) + result = _sum_ft(input_tensor) + # Result: tensor([10]) + ``` + + """ return tensor.sum(dim=0).sum(dim=-1) def _unsqueeze_ft(tensor): - """add new dementions at the front and the tail""" + """Add new dimensions at the front and the tail of a tensor. + + Args: + tensor (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: A tensor with new dimensions added at the front and the tail. + + Example: + ```python + input_tensor = torch.tensor([1, 2, 3]) + result = _unsqueeze_ft(input_tensor) + # Result: tensor([[[1]], [[2]], [[3]]]) + ``` + + """ return tensor.unsqueeze(0).unsqueeze(-1) @@ -154,6 +404,21 @@ def _unsqueeze_ft(tensor): class _SynchronizedBatchNorm(_BatchNorm): + """Synchronized Batch Normalization for parallel computation. + + This class extends PyTorch's BatchNorm2d to support synchronization for data parallelism. + It uses a master-slave communication pattern to compute batch statistics efficiently. + + Args: + num_features (int): Number of features in the input tensor. + eps (float): Small constant added to the denominator for numerical stability. Default: 1e-5 + momentum (float): Momentum factor for the running statistics. Default: 0.1 + affine (bool): If True, apply learned affine transformation. Default: True + + Note: + This class is typically used in a data parallel setup where multiple GPUs work together. + + """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): super(_SynchronizedBatchNorm, self).__init__( num_features, eps=eps, momentum=momentum, affine=affine @@ -166,6 +431,15 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): self._slave_pipe = None def forward(self, input): + """Forward pass through the synchronized batch normalization layer. + + Args: + input (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized and optionally affine-transformed tensor. + + """ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( @@ -221,7 +495,15 @@ def __data_parallel_replicate__(self, ctx, copy_id): self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + """Replicate the synchronized batch normalization layer for data parallelism. + + This method is called during data parallel replication to prepare the layer for parallel computation. + + Args: + ctx: The context object. + copy_id (int): Identifier for the replica. + + """ # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) @@ -244,8 +526,17 @@ def _data_parallel_master(self, intermediates): return outputs def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with sum and square-sum. This method - also maintains the moving average on the master device.""" + """Compute the mean and standard-deviation with sum and square-sum. + + Args: + sum_ (torch.Tensor): Sum of values. + ssum (torch.Tensor): Sum of squared values. + size (int): Size of the input batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Mean and standard-deviation. + + """ assert ( size > 1 ), "BatchNorm computes unbiased standard-deviation, which requires size > 1." @@ -288,25 +579,30 @@ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + or Instance Norm. + + Note: + This layer behaves like the built-in PyTorch BatchNorm1d when used on a single GPU or CPU. + Args: - num_features: num_features from an expected input of size - `batch_size x num_features [x width]` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` + num_features (int): Number of features in the input tensor. `batch_size x num_features [x width]` + eps (float): A small constant added to the denominator for numerical stability. Default: 1e-5 + momentum (float): The momentum factor used for computing running statistics. Default: 0.1 + affine (bool): If True, learnable affine parameters (gamma and beta) are applied. Default: True + Shape: - - Input: :math:`(N, C)` or :math:`(N, C, L)` - - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + - Input: (N, C) or (N, C, L) + - Output: (N, C) or (N, C, L) (same shape as input) + Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm1d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm1d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> input = torch.randn(20, 100) # 2D input >>> output = m(input) + >>> input_3d = torch.randn(20, 100, 30) # 3D input + >>> output_3d = m(input_3d) """ def _check_input_dim(self, input): @@ -426,14 +722,24 @@ class CallbackContext(object): def execute_replication_callbacks(modules): """ - Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - Note that, as all modules are isomorphism, we assign each sub-module with a context + Execute a replication callback `__data_parallel_replicate__` on each module created by original replication. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`. + Note that, as all modules are isomorphic, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. + + Args: + modules (list): List of replicated modules. + + Examples: + >>> # Replicate a module and execute replication callbacks + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> replicated_sync_bn = DataParallelWithCallback(replicate(sync_bn, device_ids=[0, 1])) + >>> # sync_bn.__data_parallel_replicate__ will be invoked. """ + master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] @@ -447,13 +753,19 @@ def execute_replication_callbacks(modules): class DataParallelWithCallback(DataParallel): """ Data Parallel with a replication callback. - An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by - original `replicate` function. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + A replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + the original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`. + + Args: + module (Module): The module to be parallelized. + device_ids (list): List of device IDs to use for parallelization. + Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - # sync_bn.__data_parallel_replicate__ will be invoked. + >>> # Parallelize a module with a replication callback + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> replicated_sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + >>> # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): @@ -466,13 +778,21 @@ def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. + + Args: + data_parallel (DataParallel): The existing DataParallel object to be patched. + Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) - > patch_replication_callback(sync_bn) - # this is equivalent to - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + >>> patch_replication_callback(sync_bn) + # This is equivalent to: + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + Note: + This function monkey-patches the `DataParallel` object to add the replication callback + without the need to create a new `DataParallelWithCallback` object. """ assert isinstance(data_parallel, DataParallel) diff --git a/python/fedml/model/cv/common.py b/python/fedml/model/cv/common.py index 267bb4494d..1f01c89022 100644 --- a/python/fedml/model/cv/common.py +++ b/python/fedml/model/cv/common.py @@ -1811,8 +1811,7 @@ def __repr__(self): groups=self.groups) -def channel_shuffle2(x, - groups): +def channel_shuffle2(x, groups): """ Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices,' https://arxiv.org/abs/1707.01083. The alternative version. diff --git a/python/fedml/model/cv/darts/architect.py b/python/fedml/model/cv/darts/architect.py index 27cca11c19..25fcf2883d 100644 --- a/python/fedml/model/cv/darts/architect.py +++ b/python/fedml/model/cv/darts/architect.py @@ -11,6 +11,64 @@ def _concat(xs): class Architect(object): + """ + The Architect class is responsible for architecture optimization in neural architecture search (NAS). + It adapts the architecture of a neural network to improve its performance on a specific task using gradient-based methods. + + Attributes: + network_momentum (float): The momentum term for the network weights. + network_weight_decay (float): The weight decay term for the network weights. + model (nn.Module): The neural network model for which the architecture is optimized. + criterion (nn.Module): The loss criterion used for training. + optimizer (torch.optim.Optimizer): The optimizer for architecture parameters. + device (torch.device): The device on which the operations are performed. + is_multi_gpu (bool): Flag indicating if the model is trained on multiple GPUs. + + Args: + model (nn.Module): The neural network model being optimized. + criterion (nn.Module): The loss criterion for training. + args (object): A configuration object containing hyperparameters. + device (torch.device): The device (e.g., 'cuda' or 'cpu') on which to perform computations. + + Methods: + step(input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled): + Perform a single step of architecture optimization. + + step_v2(input_train, target_train, input_valid, target_valid, lambda_train_regularizer, lambda_valid_regularizer): + Perform a single step of architecture optimization with custom regularization terms. + + step_single_level(input_train, target_train): + Perform a single step of architecture optimization for a single level. + + step_wa(input_train, target_train, input_valid, target_valid, lambda_regularizer): + Perform a single step of architecture optimization with weight adaptation. + + step_AOS(input_train, target_train, input_valid, target_valid): + Perform a single step of architecture optimization using the AOS method. + + _backward_step(input_valid, target_valid): + Perform the backward step during optimization. + + _backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer): + Perform the unrolled backward step during optimization. + + _construct_model_from_theta(theta): + Construct a new model using architecture parameters. + + _hessian_vector_product(vector, input, target, r=1e-2): + Compute the product of the Hessian matrix and a vector. + + _compute_unrolled_model(input, target, eta, network_optimizer): + Compute the unrolled model with updated weights. + + Example: + # Create an Architect instance + architect = Architect(model, criterion, args, device) + + # Perform architecture optimization + architect.step(input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled=True) + """ + def __init__(self, model, criterion, args, device): self.network_momentum = args.momentum self.network_weight_decay = args.weight_decay @@ -34,6 +92,18 @@ def __init__(self, model, criterion, args, device): # W_j = V_j + W_jx x # https://www.youtube.com/watch?v=k8fTYJPd3_I def _compute_unrolled_model(self, input, target, eta, network_optimizer): + """ + Compute the unrolled model with respect to the architecture parameters. + + Args: + input: Input data. + target: Target data. + eta (float): Learning rate. + network_optimizer: The network optimizer. + + Returns: + unrolled_model: The unrolled model. + """ logits = self.model(input) loss = self.criterion(logits, target) # pylint: disable=E1102 @@ -65,6 +135,18 @@ def step( network_optimizer, unrolled, ): + """ + Perform one optimization step for architecture search. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta (float): Learning rate. + network_optimizer: The network optimizer. + unrolled (bool): Whether to compute an unrolled model. + """ self.optimizer.zero_grad() if unrolled: # logging.info("first order") @@ -91,6 +173,17 @@ def step_v2( lambda_train_regularizer, lambda_valid_regularizer, ): + """ + Perform one optimization step for architecture search (variant 2). + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + lambda_train_regularizer (float): Regularization weight for training. + lambda_valid_regularizer (float): Regularization weight for validation. + """ self.optimizer.zero_grad() # grads_alpha_with_train_dataset @@ -143,6 +236,13 @@ def step_v2( # ours def step_single_level(self, input_train, target_train): + """ + Perform one optimization step for architecture search (single level). + + Args: + input_train: Training input data. + target_train: Training target data. + """ self.optimizer.zero_grad() # grads_alpha_with_train_dataset @@ -174,6 +274,16 @@ def step_single_level(self, input_train, target_train): def step_wa( self, input_train, target_train, input_valid, target_valid, lambda_regularizer ): + """ + Perform one optimization step for architecture search (weighted average). + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + lambda_regularizer (float): Regularization weight. + """ self.optimizer.zero_grad() # grads_alpha_with_train_dataset @@ -220,6 +330,15 @@ def step_wa( self.optimizer.step() def step_AOS(self, input_train, target_train, input_valid, target_valid): + """ + Perform one optimization step for architecture search (AOS). + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + """ self.optimizer.zero_grad() output_search = self.model(input_valid) arch_loss = self.criterion(output_search, target_valid) # pylint: disable=E1102 @@ -227,6 +346,13 @@ def step_AOS(self, input_train, target_train, input_valid, target_valid): self.optimizer.step() def _backward_step(self, input_valid, target_valid): + """ + Perform a backward step for the architecture optimization. + + Args: + input_valid: Validation input data. + target_valid: Validation target data. + """ logits = self.model(input_valid) loss = self.criterion(logits, target_valid) # pylint: disable=E1102 @@ -241,6 +367,17 @@ def _backward_step_unrolled( eta, network_optimizer, ): + """ + Perform a backward step for the architecture optimization with unrolled training. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta: Learning rate for unrolled training. + network_optimizer: The optimizer for the network weights. + """ # calculate w' in equation (7): # approximate w(*) by adapting w using only a single training step and enable momentum. unrolled_model = self._compute_unrolled_model( @@ -277,6 +414,15 @@ def _backward_step_unrolled( v.grad.data.copy_(g.data) def _construct_model_from_theta(self, theta): + """ + Construct a new model from the given theta. + + Args: + theta: A flattened parameter tensor. + + Returns: + model_new: A new model constructed using the provided theta. + """ model_new = self.model.new() model_dict = self.model.state_dict() @@ -311,6 +457,18 @@ def _construct_model_from_theta(self, theta): return model_new.to(self.device) def _hessian_vector_product(self, vector, input, target, r=1e-2): + """ + Calculate the Hessian-vector product. + + Args: + vector: A list of gradient vectors. + input: Input data. + target: Target data. + r: Regularization term. + + Returns: + List of Hessian-vector products. + """ # vector is (gradient of w' on validation dataset) R = r / _concat(vector).norm() parameters = ( @@ -374,6 +532,19 @@ def step_v2_2ndorder( lambda_train_regularizer, lambda_valid_regularizer, ): + """ + Perform a step for architecture optimization using the second-order method. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta: Learning rate for unrolled training. + network_optimizer: The optimizer for the network weights. + lambda_train_regularizer: Regularization term for training dataset. + lambda_valid_regularizer: Regularization term for validation dataset. + """ self.optimizer.zero_grad() # approximate w(*) by adapting w using only a single training step and enable momentum. @@ -465,6 +636,20 @@ def step_v2_2ndorder2( lambda_train_regularizer, lambda_valid_regularizer, ): + """ + Perform a step for architecture optimization using the second-order method with modifications. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta: Learning rate for unrolled training. + network_optimizer: The optimizer for the network weights. + lambda_train_regularizer: Regularization term for training dataset. + lambda_valid_regularizer: Regularization term for validation dataset. + """ + self.optimizer.zero_grad() # approximate w(*) by adapting w using only a single training step and enable momentum. diff --git a/python/fedml/model/cv/darts/model.py b/python/fedml/model/cv/darts/model.py index 62c11388b1..5f5f3badb8 100644 --- a/python/fedml/model/cv/darts/model.py +++ b/python/fedml/model/cv/darts/model.py @@ -6,9 +6,29 @@ class Cell(nn.Module): + """ + Cell in a neural architecture described by a genotype. + + Args: + genotype (Genotype): Genotype describing the cell's architecture. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels. + reduction (bool): Whether the cell is a reduction cell. + reduction_prev (bool): Whether the previous cell was a reduction cell. + + Input: + - s0 (Tensor): Input tensor from two steps back, shape (batch_size, C_prev_prev, H, W). + - s1 (Tensor): Input tensor from the previous step, shape (batch_size, C_prev, H, W). + - drop_prob (float): Dropout probability for drop path regularization during training. + + Output: + - Output tensor of the cell, shape (batch_size, C, H, W). + + """ + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): super(Cell, self).__init__() - print(C_prev_prev, C_prev, C) if reduction_prev: self.preprocess0 = FactorizedReduce(C_prev_prev, C) @@ -25,6 +45,17 @@ def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): self._compile(C, op_names, indices, concat, reduction) def _compile(self, C, op_names, indices, concat, reduction): + """ + Compiles the operations for the cell based on the given genotype. + + Args: + C (int): Number of output channels for the cell. + op_names (list of str): Names of the operations for each edge in the cell. + indices (list of int): Indices of the operations for each edge in the cell. + concat (list of int): Concatenation points for the cell. + reduction (bool): Whether the cell is a reduction cell. + + """ assert len(op_names) == len(indices) self._steps = len(op_names) // 2 self._concat = concat @@ -38,6 +69,18 @@ def _compile(self, C, op_names, indices, concat, reduction): self._indices = indices def forward(self, s0, s1, drop_prob): + """ + Forward pass through the cell. + + Args: + s0 (Tensor): Input tensor from two steps back, shape (batch_size, C_prev_prev, H, W). + s1 (Tensor): Input tensor from the previous step, shape (batch_size, C_prev, H, W). + drop_prob (float): Dropout probability for drop path regularization during training. + + Returns: + Tensor: Output tensor of the cell, shape (batch_size, C, H, W). + + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -59,15 +102,40 @@ def forward(self, s0, s1, drop_prob): return torch.cat([states[i] for i in self._concat], dim=1) + class AuxiliaryHeadCIFAR(nn.Module): + """ + Auxiliary head for CIFAR classification in the DARTS model. + + Args: + C (int): Number of input channels. + num_classes (int): Number of classes for classification. + + Input: + - Input tensor of shape (batch_size, C, 8, 8), assuming an input size of 8x8. + + Output: + - Output tensor of shape (batch_size, num_classes), representing class scores. + + Architecture: + - ReLU activation + - Average pooling with 5x5 kernel and stride 3 (resulting in an image size of 2x2) + - 1x1 convolution with 128 output channels + - Batch normalization + - ReLU activation + - 2x2 convolution with 768 output channels + - Batch normalization + - ReLU activation + - Linear layer with num_classes output units for classification. + + """ + def __init__(self, C, num_classes): - """assuming input size 8x8""" + super(AuxiliaryHeadCIFAR, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), - nn.AvgPool2d( - 5, stride=3, padding=0, count_include_pad=False - ), # image size = 2 x 2 + nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), @@ -108,6 +176,32 @@ def forward(self, x): class NetworkCIFAR(nn.Module): + """ + DARTS network architecture for CIFAR dataset. + + Args: + C (int): Initial number of channels. + num_classes (int): Number of classes for classification. + layers (int): Number of layers. + auxiliary (bool): Whether to use auxiliary heads. + genotype (Genotype): Genotype specifying the cell structure. + + Input: + - Input tensor of shape (batch_size, 3, 32, 32), where 3 is for RGB channels. + + Output: + - Main network output tensor of shape (batch_size, num_classes). + - Auxiliary head output tensor if auxiliary is True and during training. + + Architecture: + - Stem: Initial convolution layer followed by batch normalization. + - Cells: Stack of cells with specified genotype. + - Auxiliary Head: Optional auxiliary head for training stability. + - Global Pooling: Adaptive average pooling to 1x1 size. + - Classifier: Linear layer for classification. + + """ + def __init__(self, C, num_classes, layers, auxiliary, genotype): super(NetworkCIFAR, self).__init__() self._layers = layers @@ -158,6 +252,25 @@ def forward(self, input): class NetworkImageNet(nn.Module): + """ + Network architecture for ImageNet dataset. + + Args: + C (int): Initial number of channels. + num_classes (int): Number of classes for classification. + layers (int): Number of layers. + auxiliary (bool): Whether to include an auxiliary head. + genotype (Genotype): Genotype specifying the cell structure. + + Input: + - Input tensor of shape (batch_size, 3, height, width). + + Output: + - Main classifier logits tensor of shape (batch_size, num_classes). + - Auxiliary classifier logits tensor if auxiliary is True, otherwise None. + + """ + def __init__(self, C, num_classes, layers, auxiliary, genotype): super(NetworkImageNet, self).__init__() self._layers = layers diff --git a/python/fedml/model/cv/darts/model_search.py b/python/fedml/model/cv/darts/model_search.py index 75c5a504dd..dba356e6ea 100644 --- a/python/fedml/model/cv/darts/model_search.py +++ b/python/fedml/model/cv/darts/model_search.py @@ -7,7 +7,44 @@ from .utils import count_parameters_in_MB +import torch.nn as nn + class MixedOp(nn.Module): + """ + Mixed Operation Module for Neural Architecture Search (NAS). + + This module represents a mixture of different operations and allows for dynamic selection of one + of these operations based on a set of weights. + + Args: + C (int): Number of input channels. + stride (int): The stride for the operations. + + Input: + - Input tensor `x` of shape (batch_size, C, H, W), where `C` is the number of input channels, + and `H` and `W` are the spatial dimensions. + + Output: + - Output tensor of shape (batch_size, C, H', W'), where `C` is the number of output channels, + and `H'` and `W'` are the spatial dimensions after applying the selected operation. + + Attributes: + - _ops (nn.ModuleList): A list of operations to be mixed based on weights. + + Note: + - This module is typically used in Neural Architecture Search (NAS) to create a mixed operation + that combines different operations (e.g., convolution, pooling) and allows the architecture + search algorithm to learn which operations to use. + + Example: + To create an instance of the MixedOp module and use it in a NAS cell: + >>> mixed_op = MixedOp(C=64, stride=1) + >>> input_tensor = torch.randn(1, 64, 32, 32) # Example input tensor + >>> weights = torch.randn(5) # Example operation mixing weights + >>> output = mixed_op(input_tensor, weights) # Apply the mixed operation to the input + + """ + def __init__(self, C, stride): super(MixedOp, self).__init__() self._ops = nn.ModuleList() @@ -18,11 +55,67 @@ def __init__(self, C, stride): self._ops.append(op) def forward(self, x, weights): - # w is the operation mixing weights. see equation 2 in the original paper. + """ + Forward pass of the MixedOp module. + + Args: + x (Tensor): Input tensor of shape (batch_size, C, H, W). + weights (Tensor): Operation mixing weights of shape (num_operations,). + + Returns: + output (Tensor): Output tensor of shape (batch_size, C, H', W'). + + """ + # Apply the selected operation based on the given weights return sum(w * op(x) for w, op in zip(weights, self._ops)) class Cell(nn.Module): + """ + Cell Module for Neural Architecture Search (NAS). + + This module represents a cell in a neural network architecture designed for NAS. It contains a sequence + of mixed operations and is used to create the architecture search space. + + Args: + steps (int): The number of steps (operations) in the cell. + multiplier (int): The multiplier for the number of output channels. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels. + reduction (bool): Whether the cell performs reduction (downsampling). + reduction_prev (bool): Whether the previous cell performs reduction. + + Input: + - Two input tensors `s0` and `s1` of shape (batch_size, C_prev_prev, H, W) and (batch_size, C_prev, H, W), + where `C_prev_prev` is the number of input channels from two steps back, `C_prev` is the number of input + channels from the previous step, and `H` and `W` are the spatial dimensions. + + Output: + - Output tensor of shape (batch_size, C, H', W'), where `C` is the number of output channels, + and `H'` and `W'` are the spatial dimensions after applying the cell operations. + + Attributes: + - preprocess0 (nn.Module): Preprocessing layer for input `s0`. + - preprocess1 (nn.Module): Preprocessing layer for input `s1`. + - _steps (int): The number of steps (operations) in the cell. + - _multiplier (int): The multiplier for the number of output channels. + - _ops (nn.ModuleList): List of mixed operations to be applied in the cell. + + Note: + - This module is typically used in Neural Architecture Search (NAS) to create cells with different + combinations of operations, which are then combined to form a complete neural network architecture. + + Example: + To create an instance of the Cell module and use it in an NAS network: + >>> cell = Cell(steps=4, multiplier=4, C_prev_prev=48, C_prev=48, C=192, reduction=False, reduction_prev=True) + >>> input_s0 = torch.randn(1, 48, 32, 32) # Example input tensor s0 + >>> input_s1 = torch.randn(1, 48, 32, 32) # Example input tensor s1 + >>> weights = torch.randn(14) # Example operation mixing weights + >>> output = cell(input_s0, input_s1, weights) # Apply the cell operations to the inputs + + """ + def __init__( self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev ): @@ -38,7 +131,7 @@ def __init__( self._multiplier = multiplier self._ops = nn.ModuleList() - self._bns = nn.ModuleList() + for i in range(self._steps): for j in range(2 + i): stride = 2 if reduction and j < 2 else 1 @@ -46,6 +139,18 @@ def __init__( self._ops.append(op) def forward(self, s0, s1, weights): + """ + Forward pass of the Cell module. + + Args: + s0 (Tensor): Input tensor s0 of shape (batch_size, C_prev_prev, H, W). + s1 (Tensor): Input tensor s1 of shape (batch_size, C_prev, H, W). + weights (Tensor): Operation mixing weights of shape (num_operations,). + + Returns: + output (Tensor): Output tensor of shape (batch_size, C, H', W'). + + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -63,6 +168,52 @@ def forward(self, s0, s1, weights): class InnerCell(nn.Module): + """ + InnerCell Module for Neural Architecture Search (NAS). + + This module represents an inner cell in a neural network architecture designed for NAS. It contains a sequence + of mixed operations and is used to create the architecture search space. + + Args: + steps (int): The number of steps (operations) in the inner cell. + multiplier (int): The multiplier for the number of output channels. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels. + reduction (bool): Whether the inner cell performs reduction (downsampling). + reduction_prev (bool): Whether the previous cell performs reduction. + weights (Tensor): Operation mixing weights for the inner cell. + + Input: + - Two input tensors `s0` and `s1` of shape (batch_size, C_prev_prev, H, W) and (batch_size, C_prev, H, W), + where `C_prev_prev` is the number of input channels from two steps back, `C_prev` is the number of input + channels from the previous step, and `H` and `W` are the spatial dimensions. + + Output: + - Output tensor of shape (batch_size, C, H', W'), where `C` is the number of output channels, + and `H'` and `W'` are the spatial dimensions after applying the inner cell operations. + + Attributes: + - preprocess0 (nn.Module): Preprocessing layer for input `s0`. + - preprocess1 (nn.Module): Preprocessing layer for input `s1`. + - _steps (int): The number of steps (operations) in the inner cell. + - _multiplier (int): The multiplier for the number of output channels. + - _ops (nn.ModuleList): List of mixed operations to be applied in the inner cell. + + Note: + - This module is typically used in Neural Architecture Search (NAS) to create inner cells with different + combinations of operations, which are then combined to form a complete neural network architecture. + + Example: + To create an instance of the InnerCell module and use it in an NAS network: + >>> inner_cell = InnerCell(steps=4, multiplier=4, C_prev_prev=48, C_prev=48, C=192, reduction=False, + ... reduction_prev=True, weights=weights) + >>> input_s0 = torch.randn(1, 48, 32, 32) # Example input tensor s0 + >>> input_s1 = torch.randn(1, 48, 32, 32) # Example input tensor s1 + >>> output = inner_cell(input_s0, input_s1) # Apply the inner cell operations to the inputs + + """ + def __init__( self, steps, @@ -86,8 +237,7 @@ def __init__( self._multiplier = multiplier self._ops = nn.ModuleList() - self._bns = nn.ModuleList() - # len(self._ops)=2+3+4+5=14 + offset = 0 keys = list(OPS.keys()) for i in range(self._steps): @@ -102,6 +252,17 @@ def __init__( offset += i + 2 def forward(self, s0, s1): + """ + Forward pass of the InnerCell module. + + Args: + s0 (Tensor): Input tensor s0 of shape (batch_size, C_prev_prev, H, W). + s1 (Tensor): Input tensor s1 of shape (batch_size, C_prev, H, W). + + Returns: + output (Tensor): Output tensor of shape (batch_size, C, H', W'). + + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -117,16 +278,51 @@ def forward(self, s0, s1): class ModelForModelSizeMeasure(nn.Module): """ - This class is used only for calculating the size of the generated model. - The choices of opeartions are made using the current alpha value of the DARTS model. - The main difference between this model and DARTS model are the following: - 1. The __init__ takes one more parameter "alphas_normal" and "alphas_reduce" - 2. The new Cell module is rewriten to contain the functionality of both Cell and MixedOp - 3. To be more specific, MixedOp is replaced with a fixed choice of operation based on - the argmax(alpha_values) - 4. The new Cell class is redefined as an Inner Class. The name is the same, so please be - very careful when you change the code later - 5. + Model used solely for measuring the size of the generated model. + + This class is designed to calculate the size of a model based on specific choices of operations determined by + the alpha values of the DARTS model. It serves the purpose of estimating the model size without performing + actual training or inference. + + Differences from the DARTS model: + 1. Additional parameters "alphas_normal" and "alphas_reduce" are required in the constructor. + 2. The Cell module combines the functionality of both Cell and MixedOp. + 3. MixedOp is replaced with a fixed choice of operation based on the argmax(alpha_values). + 4. The Cell class is redefined as an inner class with the same name. + + Args: + C (int): The number of channels in the input data. + num_classes (int): The number of output classes. + layers (int): The number of layers in the model. + criterion: The loss criterion used for training. + alphas_normal (Tensor): Alpha values for normal cells. + alphas_reduce (Tensor): Alpha values for reduction cells. + steps (int, optional): The number of steps (operations) in each cell. Default is 4. + multiplier (int, optional): The multiplier for the number of output channels. Default is 4. + stem_multiplier (int, optional): The multiplier for the number of channels in the stem. Default is 3. + + Input: + - Input tensor of shape (batch_size, 3, H, W), where `batch_size` is the number of input samples, + `H` and `W` are the spatial dimensions, and `3` represents the RGB channels. + + Output: + - Output tensor of shape (batch_size, num_classes), representing class predictions. + + Attributes: + - stem (nn.Sequential): Stem layer consisting of a convolutional layer and batch normalization. + - cells (nn.ModuleList): List of inner cells that make up the model. + - global_pooling (nn.AdaptiveAvgPool2d): Global pooling layer for spatial aggregation. + - classifier (nn.Linear): Fully connected layer for class prediction. + + Note: + - This class is primarily used for measuring the size of a model and does not perform training or inference. + + Example: + To create an instance of the ModelForModelSizeMeasure and use it to measure the model size: + >>> model = ModelForModelSizeMeasure(C=16, num_classes=10, layers=8, criterion=nn.CrossEntropyLoss(), + ... alphas_normal=alphas_normal, alphas_reduce=alphas_reduce) + >>> input_data = torch.randn(1, 3, 32, 32) # Example input tensor + >>> model_size = get_model_size(model, input_data) # Get the estimated model size """ @@ -159,7 +355,7 @@ def __init__( self.cells = nn.ModuleList() reduction_prev = False - # for layers = 8, when layer_i = 2, 5, the cell is reduction cell. + for i in range(layers): if i in [layers // 3, 2 * layers // 3]: C_curr *= 2 @@ -207,6 +403,52 @@ def forward(self, input_data): class Network(nn.Module): + """ + DARTS-based neural network model for image classification. + + Args: + C (int): The number of channels in the input data. + num_classes (int): The number of output classes. + layers (int): The number of layers in the model. + criterion: The loss criterion used for training. + steps (int, optional): The number of steps (operations) in each cell. Default is 4. + multiplier (int, optional): The multiplier for the number of output channels. Default is 4. + stem_multiplier (int, optional): The multiplier for the number of channels in the stem. Default is 3. + + Input: + - Input tensor of shape (batch_size, 3, H, W), where `batch_size` is the number of input samples, + `H` and `W` are the spatial dimensions, and `3` represents the RGB channels. + + Output: + - Output tensor of shape (batch_size, num_classes), representing class predictions. + + Attributes: + - stem (nn.Sequential): Stem layer consisting of a convolutional layer and batch normalization. + - cells (nn.ModuleList): List of inner cells that make up the model. + - global_pooling (nn.AdaptiveAvgPool2d): Global pooling layer for spatial aggregation. + - classifier (nn.Linear): Fully connected layer for class prediction. + - alphas_normal (nn.Parameter): Learnable alpha values for normal cells. + - alphas_reduce (nn.Parameter): Learnable alpha values for reduction cells. + + Methods: + - new(self): Create a new instance of the network with the same architecture and initialize alpha values. + - new_arch_parameters(self): Generate new architecture parameters (alphas) for the network. + - arch_parameters(self): Get the current architecture parameters (alphas) of the network. + - genotype(self): Get the genotype of the network, which describes the architecture. + - get_current_model_size(self): Estimate the current model size in megabytes. + + Note: + - This class is based on the DARTS (Differentiable Architecture Search) architecture and is used for + neural architecture search (NAS) experiments. + + Example: + To create an instance of the Network class and use it for architecture search: + >>> model = Network(C=16, num_classes=10, layers=8, criterion=nn.CrossEntropyLoss()) + >>> input_data = torch.randn(1, 3, 32, 32) # Example input tensor + >>> genotype, normal_count, reduce_count = model.genotype() # Get the architecture genotype + >>> model_size = model.get_current_model_size() # Get the estimated model size + """ + def __init__( self, C, @@ -263,6 +505,12 @@ def __init__( self._initialize_alphas() def new(self): + """ + Create a new instance of the network with the same architecture and initialize alpha values. + + Returns: + Network: A new instance of the Network class with the same architecture. + """ model_new = Network( self._C, self._num_classes, self._layers, self._criterion, self.device ).to(self.device) @@ -271,6 +519,16 @@ def new(self): return model_new def forward(self, input): + """ + Forward pass of the neural network. + + Args: + input (Tensor): Input tensor of shape (batch_size, 3, H, W), where `batch_size` is the number of + input samples, `H` and `W` are the spatial dimensions, and `3` represents the RGB channels. + + Returns: + Tensor: Output tensor of shape (batch_size, num_classes), representing class predictions. + """ s0 = s1 = self.stem(input) for i, cell in enumerate(self.cells): if cell.reduction: @@ -283,6 +541,9 @@ def forward(self, input): return logits def _initialize_alphas(self): + """ + Initialize alpha values for normal and reduction cells. + """ k = sum(1 for i in range(self._steps) for n in range(2 + i)) num_ops = len(PRIMITIVES) @@ -294,6 +555,12 @@ def _initialize_alphas(self): ] def new_arch_parameters(self): + """ + Generate new architecture parameters (alphas) for the network. + + Returns: + List[nn.Parameter]: List of architecture parameters (alphas). + """ k = sum(1 for i in range(self._steps) for n in range(2 + i)) num_ops = len(PRIMITIVES) @@ -306,9 +573,21 @@ def new_arch_parameters(self): return _arch_parameters def arch_parameters(self): + """ + Get the current architecture parameters (alphas) of the network. + + Returns: + List[nn.Parameter]: List of architecture parameters (alphas). + """ return self._arch_parameters def genotype(self): + """ + Get the genotype of the network, which describes the architecture. + + Returns: + Genotype: The genotype of the network. + """ def _isCNNStructure(k_best): return k_best >= 4 @@ -360,6 +639,12 @@ def _parse(weights): return genotype, cnn_structure_count_normal, cnn_structure_count_reduce def get_current_model_size(self): + """ + Estimate the current model size in megabytes. + + Returns: + float: The estimated model size in megabytes. + """ model = ModelForModelSizeMeasure( self._C, self._num_classes, diff --git a/python/fedml/model/cv/darts/model_search_gdas.py b/python/fedml/model/cv/darts/model_search_gdas.py index 144c4af567..756c477516 100644 --- a/python/fedml/model/cv/darts/model_search_gdas.py +++ b/python/fedml/model/cv/darts/model_search_gdas.py @@ -8,6 +8,17 @@ class MixedOp(nn.Module): def __init__(self, C, stride): + """ + Initialize a MixedOp module. + + Args: + C (int): The number of input channels. + stride (int): The stride for the operation. + + Note: + PRIMITIVES: a list of operation primitives. + OPS: a dictionary mapping operation primitives to corresponding operation classes. + """ super(MixedOp, self).__init__() self._ops = nn.ModuleList() for primitive in PRIMITIVES: @@ -17,6 +28,17 @@ def __init__(self, C, stride): self._ops.append(op) def forward(self, x, weights, cpu_weights): + """ + Perform a forward pass through the MixedOp module. + + Args: + x (Tensor): Input tensor. + weights (Tensor): Weights for the operations. + cpu_weights (list): Weights converted to CPU. + + Returns: + Tensor: Output tensor after applying the mixed operations. + """ clist = [] for j, cpu_weight in enumerate(cpu_weights): if abs(cpu_weight) > 1e-10: @@ -31,6 +53,18 @@ class Cell(nn.Module): def __init__( self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev ): + """ + Initialize a Cell module. + + Args: + steps (int): The number of steps in the cell. + multiplier (int): Multiplier for the number of output channels. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels for the cell. + reduction (bool): Whether it's a reduction cell. + reduction_prev (bool): Whether the previous cell was a reduction cell. + """ super(Cell, self).__init__() self.reduction = reduction @@ -51,6 +85,17 @@ def __init__( self._ops.append(op) def forward(self, s0, s1, weights): + """ + Perform a forward pass through the Cell module. + + Args: + s0 (Tensor): Input tensor from two steps back. + s1 (Tensor): Input tensor from the previous step. + weights (Tensor): Weights for the operations. + + Returns: + Tensor: Output tensor after applying the cell operations. + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -64,7 +109,7 @@ def forward(self, s0, s1, weights): ) offset += len(states) states.append(s) - # logging.info(states) + return torch.cat(states[-self._multiplier :], dim=1) @@ -80,6 +125,19 @@ def __init__( multiplier=4, stem_multiplier=3, ): + """ + Initialize a Network_GumbelSoftmax model. + + Args: + C (int): Number of initial channels. + num_classes (int): Number of output classes. + layers (int): Number of layers. + criterion: Loss criterion. + device: The device to run the model on. + steps (int): Number of steps in each cell. + multiplier (int): Multiplier for the number of output channels. + stem_multiplier (int): Multiplier for the number of initial channels in the stem. + """ super(Network_GumbelSoftmax, self).__init__() self._C = C self._num_classes = num_classes @@ -89,7 +147,7 @@ def __init__( self._multiplier = multiplier self.device = device - C_curr = stem_multiplier * C # 3*16 + C_curr = stem_multiplier * C self.stem = nn.Sequential( nn.Conv2d(3, C_curr, 3, padding=1, bias=False), nn.BatchNorm2d(C_curr) ) @@ -98,7 +156,7 @@ def __init__( self.cells = nn.ModuleList() reduction_prev = False - # for layers = 8, when layer_i = 2, 5, the cell is reduction cell. + for i in range(layers): if i in [layers // 3, 2 * layers // 3]: C_curr *= 2 @@ -166,6 +224,14 @@ def arch_parameters(self): return self._arch_parameters def genotype(self): + """ + Get the architecture genotype of the model. + + Returns: + Genotype: The architecture genotype. + cnn_structure_count_normal (int): Count of CNN structures in normal cells. + cnn_structure_count_reduce (int): Count of CNN structures in reduction cells. + """ def _isCNNStructure(k_best): return k_best >= 4 diff --git a/python/fedml/model/cv/darts/operations.py b/python/fedml/model/cv/darts/operations.py index 1827b2c7d1..5a8cd9ab49 100644 --- a/python/fedml/model/cv/darts/operations.py +++ b/python/fedml/model/cv/darts/operations.py @@ -35,6 +35,25 @@ class ReLUConvBN(nn.Module): + """ + A composite module that applies ReLU activation, followed by a 2D convolution, and then batch normalization. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution operation. + padding (int): Padding for the convolution operation. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super(ReLUConvBN, self).__init__() self.op = nn.Sequential( @@ -50,6 +69,26 @@ def forward(self, x): class DilConv(nn.Module): + """ + A composite module that applies dilated convolution followed by batch normalization. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution operation. + padding (int): Padding for the convolution operation. + dilation (int): Dilation factor for the convolution operation. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__( self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True ): @@ -75,6 +114,25 @@ def forward(self, x): class SepConv(nn.Module): + """ + A composite module that applies separable convolution followed by batch normalization. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution operation. + padding (int): Padding for the convolution operation. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super(SepConv, self).__init__() self.op = nn.Sequential( @@ -108,7 +166,19 @@ def forward(self, x): return self.op(x) + class Identity(nn.Module): + """ + A module that represents the identity operation (no change). + + Input: + - Input tensor of any shape. + + Output: + - Output tensor with the same shape as the input. + + """ + def __init__(self): super(Identity, self).__init__() @@ -117,6 +187,20 @@ def forward(self, x): class Zero(nn.Module): + """ + A module that represents the zero operation (sets the tensor to zero). + + Args: + stride (int): Stride for selecting elements in the tensor. + + Input: + - Input tensor of any shape. + + Output: + - Output tensor with the same shape as the input, but with selected elements set to zero. + + """ + def __init__(self, stride): super(Zero, self).__init__() self.stride = stride @@ -128,6 +212,22 @@ def forward(self, x): class FactorizedReduce(nn.Module): + """ + A module that applies factorized reduction to reduce spatial dimensions. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__(self, C_in, C_out, affine=True): super(FactorizedReduce, self).__init__() assert C_out % 2 == 0 diff --git a/python/fedml/model/cv/darts/train.py b/python/fedml/model/cv/darts/train.py index 95ef17d38b..9d5acd7da2 100644 --- a/python/fedml/model/cv/darts/train.py +++ b/python/fedml/model/cv/darts/train.py @@ -184,6 +184,19 @@ def main(): def train(train_queue, model, criterion, optimizer): + """ + Perform training on the training dataset. + + Args: + train_queue (DataLoader): DataLoader for the training dataset. + model (nn.Module): The neural network model. + criterion (nn.Module): The loss function. + optimizer (Optimizer): The optimizer for updating model parameters. + + Returns: + float: Top-1 accuracy on the training dataset. + float: Average loss on the training dataset. + """ global is_multi_gpu objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() @@ -221,6 +234,18 @@ def train(train_queue, model, criterion, optimizer): def infer(valid_queue, model, criterion): + """ + Perform inference on the validation dataset using the trained model. + + Args: + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The trained neural network model. + criterion (nn.Module): The loss function used for validation. + + Returns: + float: Top-1 accuracy on the validation dataset. + float: Average loss on the validation dataset. + """ global is_multi_gpu objs = utils.AvgrageMeter() diff --git a/python/fedml/model/cv/darts/train_search.py b/python/fedml/model/cv/darts/train_search.py index 1bdc7e8d90..52f43dabf6 100644 --- a/python/fedml/model/cv/darts/train_search.py +++ b/python/fedml/model/cv/darts/train_search.py @@ -352,6 +352,26 @@ def main(): def train(epoch, train_queue, valid_queue, model, architect, criterion, optimizer, lr): + """ + Train the neural network for one epoch. + + Args: + epoch (int): Current epoch number. + train_queue (DataLoader): DataLoader for the training dataset. + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model to be trained. + architect (Architect): The architect responsible for updating architecture weights. + criterion (nn.Module): The loss function used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model weights. + lr (float): Learning rate. + + Returns: + float: Top-1 accuracy on the training dataset. + float: Average loss on the training dataset. + float: Loss value. + + """ + global is_multi_gpu objs = utils.AvgrageMeter() @@ -407,6 +427,20 @@ def train(epoch, train_queue, valid_queue, model, architect, criterion, optimize def infer(valid_queue, model, criterion): + """ + Perform inference on the validation dataset using the trained model. + + Args: + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The trained neural network model. + criterion (nn.Module): The loss function used for validation. + + Returns: + float: Top-1 accuracy on the validation dataset. + float: Average loss on the validation dataset. + float: Loss value. + + """ global is_multi_gpu objs = utils.AvgrageMeter() diff --git a/python/fedml/model/cv/darts/utils.py b/python/fedml/model/cv/darts/utils.py index 0f024b4614..696f121d30 100644 --- a/python/fedml/model/cv/darts/utils.py +++ b/python/fedml/model/cv/darts/utils.py @@ -7,23 +7,60 @@ from torch.autograd import Variable -class AvgrageMeter(object): +class AverageMeter(object): + """ + Computes and stores the average and sum of values over time. + + Attributes: + avg (float): The current average value. + sum (float): The current sum of values. + cnt (int): The current count of values. + + Methods: + reset(): Reset the average, sum, and count to zero. + update(val, n=1): Update the meter with a new value and count. + + """ def __init__(self): + """ + Initializes an AverageMeter object with initial values of zero. + """ self.reset() def reset(self): + """ + Reset the average, sum, and count to zero. + """ self.avg = 0 self.sum = 0 self.cnt = 0 def update(self, val, n=1): + """ + Update the meter with a new value and count. + + Args: + val (float): The new value to update the meter with. + n (int): The count associated with the new value. Default is 1. + """ self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt def accuracy(output, target, topk=(1,)): + """ + Computes the accuracy of model predictions given the output and target labels. + + Args: + output (Tensor): The model's output predictions. + target (Tensor): The ground truth labels. + topk (tuple of int): The top-k accuracy values to compute. Default is (1,). + + Returns: + list of float: A list of top-k accuracy values. + """ maxk = max(topk) batch_size = target.size(0) @@ -39,10 +76,35 @@ def accuracy(output, target, topk=(1,)): class Cutout(object): + """ + Apply cutout augmentation to an image. + + Args: + length (int): The size of the cutout square region. + + """ + def __init__(self, length): + """ + Initializes the Cutout object with a specified cutout length. + + Args: + length (int): The size of the cutout square region. + + """ self.length = length def __call__(self, img): + """ + Apply cutout augmentation to an image. + + Args: + img (PIL.Image): The input image. + + Returns: + PIL.Image: The augmented image with cutout applied. + + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -61,6 +123,16 @@ def __call__(self, img): def _data_transforms_cifar10(args): + """ + Define data transformations for CIFAR-10 dataset. + + Args: + args (argparse.Namespace): Command line arguments. + + Returns: + tuple: A tuple of train and validation data transforms. + + """ CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] @@ -81,10 +153,29 @@ def _data_transforms_cifar10(args): def count_parameters_in_MB(model): + """ + Count the number of parameters in a model in megabytes (MB). + + Args: + model (nn.Module): The model for which to count parameters. + + Returns: + float: The number of parameters in megabytes (MB). + + """ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 def save_checkpoint(state, is_best, save): + """ + Save a checkpoint of the model's state. + + Args: + state (dict): The model's state dictionary. + is_best (bool): True if this is the best checkpoint, False otherwise. + save (str): The directory where the checkpoint will be saved. + + """ filename = os.path.join(save, 'checkpoint.pth.tar') torch.save(state, filename) if is_best: @@ -93,14 +184,41 @@ def save_checkpoint(state, is_best, save): def save(model, model_path): + """ + Save the model's state dictionary to a file. + + Args: + model (nn.Module): The PyTorch model to be saved. + model_path (str): The path to the file where the model state will be saved. + + """ torch.save(model.state_dict(), model_path) def load(model, model_path): + """ + Load a model's state dictionary from a file into the model. + + Args: + model (nn.Module): The PyTorch model to which the state will be loaded. + model_path (str): The path to the file containing the model state. + + """ model.load_state_dict(torch.load(model_path)) def drop_path(x, drop_prob): + """ + Apply dropout to a tensor. + + Args: + x (Tensor): The input tensor to which dropout will be applied. + drop_prob (float): The probability of dropping out a value. + + Returns: + Tensor: The tensor after dropout has been applied. + + """ if drop_prob > 0.: keep_prob = 1. - drop_prob mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) @@ -110,6 +228,14 @@ def drop_path(x, drop_prob): def create_exp_dir(path, scripts_to_save=None): + """ + Create an experiment directory and optionally save scripts. + + Args: + path (str): The directory path for the experiment. + scripts_to_save (list of str, optional): List of script file paths to save in the directory. + + """ if not os.path.exists(path): os.mkdir(path) print('Experiment dir : {}'.format(path)) diff --git a/python/fedml/model/cv/darts/visualize.py b/python/fedml/model/cv/darts/visualize.py index df539289e2..79ffbf5970 100644 --- a/python/fedml/model/cv/darts/visualize.py +++ b/python/fedml/model/cv/darts/visualize.py @@ -4,6 +4,19 @@ def plot(genotype, filename): + """ + Generate a visualization of a given genotype and save it as a PDF file. + + Args: + genotype (list of tuples): The genotype to visualize, specifying operations and connections. + filename (str): The name of the PDF file to save the visualization. + + Example usage: + ```python + >>> genotype = [("conv3x3", 0), ("conv3x3", 1), ("maxpool3x3", 0), ("conv1x1", 2), ...] + >>> plot(genotype, "genotype_visualization.pdf") + ``` + """ g = Digraph( format="pdf", edge_attr=dict(fontsize="20", fontname="times"), diff --git a/python/fedml/model/cv/efficientnet_utils.py b/python/fedml/model/cv/efficientnet_utils.py index c95de26259..da000d7fa1 100644 --- a/python/fedml/model/cv/efficientnet_utils.py +++ b/python/fedml/model/cv/efficientnet_utils.py @@ -76,6 +76,15 @@ # An ordinary implementation of Swish function class Swish(nn.Module): def forward(self, x): + """ + Applies the Swish activation function to the input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the Swish activation. + """ return x * torch.sigmoid(x) @@ -83,12 +92,32 @@ def forward(self, x): class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): + """ + Forward pass for the memory-efficient Swish function. + + Args: + ctx: Context object to save tensors for backward pass. + i (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the memory-efficient Swish activation. + """ result = i * torch.sigmoid(i) ctx.save_for_backward(i) return result @staticmethod def backward(ctx, grad_output): + """ + Backward pass for the memory-efficient Swish function. + + Args: + ctx: Context object containing saved tensors from forward pass. + grad_output (torch.Tensor): Gradient of the loss with respect to the output. + + Returns: + torch.Tensor: Gradient of the loss with respect to the input tensor. + """ i = ctx.saved_tensors[0] sigmoid_i = torch.sigmoid(i) return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) @@ -96,17 +125,33 @@ def backward(ctx, grad_output): class MemoryEfficientSwish(nn.Module): def forward(self, x): + """ + Applies the memory-efficient Swish activation function to the input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the memory-efficient Swish activation. + """ + return SwishImplementation.apply(x) def round_filters(filters, global_params): - """Calculate and round number of filters based on width multiplier. - Use width_coefficient, depth_divisor and min_depth of global_params. + """ + Calculate and round the number of filters based on the width multiplier. + Args: - filters (int): Filters number to be calculated. - global_params (namedtuple): Global params of the model. + filters (int): Number of filters to be calculated. + global_params (namedtuple): Global parameters of the model. + Returns: - new_filters: New filters number after calculating. + int: New number of filters after rounding. + + Example: + # Calculate and round filters based on width multiplier and global parameters. + new_filters = round_filters(64, global_params) """ multiplier = global_params.width_coefficient if not multiplier: @@ -126,13 +171,19 @@ def round_filters(filters, global_params): def round_repeats(repeats, global_params): - """Calculate module's repeat number of a block based on depth multiplier. - Use depth_coefficient of global_params. + """ + Calculate module's repeat number of a block based on the depth multiplier. + Args: - repeats (int): num_repeat to be calculated. - global_params (namedtuple): Global params of the model. + repeats (int): Number of repeats to be calculated. + global_params (namedtuple): Global parameters of the model. + Returns: - new repeat: New repeat number after calculating. + int: New number of repeats after calculation. + + Example: + # Calculate repeats based on depth multiplier and global parameters. + new_repeats = round_repeats(5, global_params) """ multiplier = global_params.depth_coefficient if not multiplier: @@ -142,13 +193,20 @@ def round_repeats(repeats, global_params): def drop_connect(inputs, p, training): - """Drop connect. + """ + Apply drop connect to the input tensor. + Args: - input (tensor: BCWH): Input of this structure. - p (float: 0.0~1.0): Probability of drop connection. - training (bool): The running mode. + inputs (torch.Tensor): Input tensor to which drop connect will be applied. + p (float): Probability of drop connection (0.0 <= p <= 1.0). + training (bool): The running mode (True for training, False for inference). + Returns: - output: Output after drop connection. + torch.Tensor: Output tensor after applying drop connect. + + Example: + # Apply drop connect with a probability of 0.5 during training. + output = drop_connect(inputs, 0.5, training=True) """ assert 0 <= p <= 1, "p must be in range of [0,1]" @@ -170,11 +228,22 @@ def drop_connect(inputs, p, training): def get_width_and_height_from_size(x): - """Obtain height and width from x. + """ + Obtain height and width from a size value. + Args: - x (int, tuple or list): Data size. + x (int, tuple, or list): Data size. + Returns: - size: A tuple or list (H,W). + tuple: A tuple (height, width). + + Raises: + TypeError: If the input is not an int, tuple, or list. + + Example: + # Get height and width from an integer size. + size = get_width_and_height_from_size(32) + # Result: (32, 32) """ if isinstance(x, int): return x, x @@ -185,13 +254,20 @@ def get_width_and_height_from_size(x): def calculate_output_image_size(input_image_size, stride): - """Calculates the output image size when using Conv2dSamePadding with a stride. - Necessary for static padding. Thanks to mannatsingh for pointing this out. + """ + Calculate the output image size when using Conv2dSamePadding with a given stride. + Args: - input_image_size (int, tuple or list): Size of input image. - stride (int, tuple or list): Conv2d operation's stride. + input_image_size (int, tuple, or list): Size of the input image. + stride (int, tuple, or list): Conv2d operation's stride. + Returns: - output_image_size: A list [H,W]. + list: A list [height, width] representing the output image size. + + Example: + # Calculate the output size for an input image of size 128x128 with a stride of 2. + output_size = calculate_output_image_size((128, 128), 2) + # Result: [64, 64] """ if input_image_size is None: return None @@ -209,12 +285,18 @@ def calculate_output_image_size(input_image_size, stride): def get_same_padding_conv2d(image_size=None): - """Chooses static padding if you have specified an image size, and dynamic padding otherwise. - Static padding is necessary for ONNX exporting of models. + """ + Choose dynamic padding if no image size is specified, otherwise choose static padding. + Args: image_size (int or tuple): Size of the image. + Returns: - Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + Conv2dDynamicSamePadding or Conv2dStaticSamePadding: The appropriate Conv2d class. + + Example: + # Get the Conv2d class with dynamic padding based on image size. + conv2d_class = get_same_padding_conv2d((128, 128)) """ if image_size is None: return Conv2dDynamicSamePadding @@ -223,8 +305,9 @@ def get_same_padding_conv2d(image_size=None): class Conv2dDynamicSamePadding(nn.Conv2d): - """2D Convolutions like TensorFlow, for a dynamic image size. - The padding is operated in forward function by calculating dynamically. + """ + 2D Convolution with dynamic padding based on the input image size. + The padding is calculated dynamically during the forward pass. """ # Tips for 'SAME' mode padding. @@ -279,8 +362,23 @@ def forward(self, x): class Conv2dStaticSamePadding(nn.Conv2d): - """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. - The padding mudule is calculated in construction function, then used in forward. + """ + 2D Convolutions with static padding similar to TensorFlow's 'SAME' mode, + using the provided input image size for padding calculation. + + This module calculates the padding during construction and applies it during the forward pass. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int or tuple): Size of the convolutional kernel. + stride (int or tuple, optional): Stride of the convolution. Default is 1. + image_size (int or tuple, optional): Size of the input image. Must be provided for padding calculation. + **kwargs: Additional arguments for nn.Conv2d. + + Example: + # Create a Conv2dStaticSamePadding layer with an input image size of 128x128. + conv_layer = Conv2dStaticSamePadding(in_channels=3, out_channels=64, kernel_size=3, image_size=(128, 128)) """ # With the same calculation as Conv2dDynamicSamePadding @@ -327,12 +425,15 @@ def forward(self, x): def get_same_padding_maxPool2d(image_size=None): - """Chooses static padding if you have specified an image size, and dynamic padding otherwise. - Static padding is necessary for ONNX exporting of models. + """ + Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: - image_size (int or tuple): Size of the image. + image_size (int or tuple, optional): Size of the image. If provided, static padding will be used. + Returns: - MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding: A MaxPooling layer with the chosen padding. """ if image_size is None: return MaxPool2dDynamicSamePadding @@ -341,8 +442,21 @@ def get_same_padding_maxPool2d(image_size=None): class MaxPool2dDynamicSamePadding(nn.MaxPool2d): - """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. - The padding is operated in forward function by calculating dynamically. + """ + 2D MaxPooling with dynamic padding, similar to TensorFlow's 'SAME' mode, for a dynamic image size. + The padding is calculated dynamically during the forward pass. + + Args: + kernel_size (int or tuple): Size of the max-pooling kernel. + stride (int or tuple): Stride of the max-pooling operation. + padding (int or tuple, optional): Padding to be added. Default is 0. + dilation (int or tuple, optional): Dilation rate. Default is 1. + return_indices (bool, optional): Whether to return the indices. Default is False. + ceil_mode (bool, optional): Whether to use 'ceil' mode for output size. Default is False. + + Example: + # Create a MaxPool2dDynamicSamePadding layer. + maxpool_layer = MaxPool2dDynamicSamePadding(kernel_size=3, stride=2) """ def __init__( @@ -390,8 +504,19 @@ def forward(self, x): class MaxPool2dStaticSamePadding(nn.MaxPool2d): - """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. - The padding mudule is calculated in construction function, then used in forward. + """ + 2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding module is calculated during construction and then applied in the forward pass. + + Args: + kernel_size (int or tuple): Size of the max-pooling kernel. + stride (int or tuple): Stride of the max-pooling operation. + image_size (int or tuple): Size of the input image. Required to calculate static padding. + **kwargs: Additional keyword arguments for MaxPool2d. + + Example: + # Create a MaxPool2dStaticSamePadding layer with a specified image size. + maxpool_layer = MaxPool2dStaticSamePadding(kernel_size=3, stride=2, image_size=(224, 224)) """ def __init__(self, kernel_size, stride, image_size=None, **kwargs): @@ -448,16 +573,22 @@ def forward(self, x): class BlockDecoder(object): - """Block Decoder for readability, - straight from the official TensorFlow repository. + """ + Block Decoder for readability, straight from the official TensorFlow repository. + + This class provides methods to decode and encode block configurations represented as strings. + These strings define the arguments of each block in a neural network architecture. """ @staticmethod def _decode_block_string(block_string): - """Get a block through a string notation of arguments. + """ + Get a block through a string notation of arguments. + Args: block_string (str): A string notation of arguments. Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + Returns: BlockArgs: The namedtuple defined at the top of this file. """ @@ -489,9 +620,12 @@ def _decode_block_string(block_string): @staticmethod def _encode_block_string(block): - """Encode a block to a string. + """ + Encode a block to a string. + Args: block (namedtuple): A BlockArgs type argument. + Returns: block_string: A String form of BlockArgs. """ @@ -511,9 +645,12 @@ def _encode_block_string(block): @staticmethod def decode(string_list): - """Decode a list of string notations to specify blocks inside the network. + """ + Decode a list of string notations to specify blocks inside the network. + Args: string_list (list[str]): A list of strings, each string is a notation of block. + Returns: blocks_args: A list of BlockArgs namedtuples of block args. """ @@ -525,12 +662,16 @@ def decode(string_list): @staticmethod def encode(blocks_args): - """Encode a list of BlockArgs to a list of strings. + """ + Encode a list of BlockArgs to a list of strings. + Args: blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + Returns: block_strings: A list of strings, each string is a notation of block. """ + block_strings = [] for block in blocks_args: block_strings.append(BlockDecoder._encode_block_string(block)) @@ -538,9 +679,12 @@ def encode(blocks_args): def efficientnet_params(model_name): - """Map EfficientNet model name to parameter coefficients. + """ + Map EfficientNet model name to parameter coefficients. + Args: model_name (str): Model name to be queried. + Returns: params_dict[model_name]: A (width,depth,res,dropout) tuple. """ @@ -569,7 +713,9 @@ def efficientnet( num_classes=1000, include_top=True, ): - """Create BlockArgs and GlobalParams for efficientnet model. + """ + Create BlockArgs and GlobalParams for the EfficientNet model. + Args: width_coefficient (float) depth_coefficient (float) @@ -577,7 +723,8 @@ def efficientnet( dropout_rate (float) drop_connect_rate (float) num_classes (int) - Meaning as the name suggests. + include_top (bool) + Returns: blocks_args, global_params. """ @@ -613,10 +760,13 @@ def efficientnet( def get_model_params(model_name, override_params): - """Get the block args and global params for a given model name. + """ + Get the block args and global params for a given model name. + Args: model_name (str): Model's name. override_params (dict): A dict to modify global_params. + Returns: blocks_args, global_params """ @@ -669,16 +819,17 @@ def get_model_params(model_name, override_params): def load_pretrained_weights( model, model_name, weights_path=None, load_fc=True, advprop=False ): - """Loads pretrained weights from weights path or download using url. + """ + Loads pretrained weights from weights path or download using URL. + Args: - model (Module): The whole model of efficientnet. - model_name (str): Model name of efficientnet. - weights_path (None or str): - str: path to pretrained weights file on the local disk. - None: use pretrained weights downloaded from the Internet. - load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. - advprop (bool): Whether to load pretrained weights - trained with advprop (valid when weights_path is None). + model (Module): The whole model of EfficientNet. + model_name (str): Model name of EfficientNet. + weights_path (str or None): + - str: Path to pretrained weights file on the local disk. + - None: Use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for the fully connected (fc) layer at the end of the model. + advprop (bool): Whether to load pretrained weights trained with advprop (valid when weights_path is None). """ if isinstance(weights_path, str): state_dict = torch.load(weights_path) diff --git a/python/fedml/model/cv/group_normalization.py b/python/fedml/model/cv/group_normalization.py index 0081e37e3f..5444fcdafe 100644 --- a/python/fedml/model/cv/group_normalization.py +++ b/python/fedml/model/cv/group_normalization.py @@ -15,11 +15,39 @@ def group_norm( momentum=0.1, eps=1e-5, ): - """Applies Group Normalization for channels in the same group in each data sample in a - batch. - See :class:`~torch.nn.GroupNorm1d`, :class:`~torch.nn.GroupNorm2d`, - :class:`~torch.nn.GroupNorm3d` for details. """ + Applies Group Normalization for channels in the same group in each data sample in a batch. + + Args: + input (Tensor): The input tensor of shape (N, C, *), where N is the batch size, + C is the number of channels, and * represents any number of additional dimensions. + group (int): The number of groups to divide the channels into. + running_mean (Tensor or None): A tensor of running means for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + running_var (Tensor or None): A tensor of running variances for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + weight (Tensor or None): A tensor to scale the normalized values for each channel. + bias (Tensor or None): A tensor to add an offset to the normalized values for each channel. + use_input_stats (bool): If True, batch statistics (mean and variance) are computed + from the input tensor for normalization. If False, `running_mean` and `running_var` + are used for normalization. + momentum (float): The momentum factor for updating running statistics. + eps (float): A small value added to the denominator for numerical stability. + + Returns: + Tensor: The normalized output tensor with the same shape as the input. + + Note: + Group Normalization is applied to the channels of the input tensor separately within each group. + If `use_input_stats` is True, running statistics (mean and variance) will not be used for + normalization, and batch statistics will be computed from the input tensor. + + See Also: + - :class:`~torch.nn.GroupNorm1d` for 1D input (sequence data). + - :class:`~torch.nn.GroupNorm2d` for 2D input (image data). + - :class:`~torch.nn.GroupNorm3d` for 3D input (volumetric data). + """ + if not use_input_stats and (running_mean is None or running_var is None): raise ValueError( "Expected running_mean and running_var to be not None when use_input_stats=False" @@ -42,6 +70,38 @@ def _instance_norm( momentum=None, eps=None, ): + """ + Applies Instance Normalization for channels within each group in the input tensor. + + Args: + input (Tensor): The input tensor of shape (N, C, *), where N is the batch size, + C is the number of channels, and * represents any number of additional dimensions. + group (int): The number of groups to divide the channels into. + running_mean (Tensor or None): A tensor of running means for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + running_var (Tensor or None): A tensor of running variances for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + weight (Tensor or None): A tensor to scale the normalized values for each channel. + bias (Tensor or None): A tensor to add an offset to the normalized values for each channel. + use_input_stats (bool or None): If True, batch statistics (mean and variance) are computed + from the input tensor for normalization. If False, `running_mean` and `running_var` + are used for normalization. If None, it defaults to True during training and False during inference. + momentum (float): The momentum factor for updating running statistics. + eps (float): A small value added to the denominator for numerical stability. + + Returns: + Tensor: The normalized output tensor with the same shape as the input. + + Note: + Instance Normalization is applied to the channels of the input tensor separately within each group. + If `use_input_stats` is True, running statistics (mean and variance) will not be used for + normalization, and batch statistics will be computed from the input tensor. + + See Also: + - :class:`~torch.nn.InstanceNorm1d` for 1D input (sequence data). + - :class:`~torch.nn.InstanceNorm2d` for 2D input (image data). + - :class:`~torch.nn.InstanceNorm3d` for 3D input (volumetric data). + """ # Repeat stored stats and affine transform params if necessary if running_mean is not None: running_mean_orig = running_mean @@ -94,6 +154,36 @@ def _instance_norm( class _GroupNorm(_BatchNorm): + """ + Applies Group Normalization over a mini-batch of inputs. + + Group Normalization divides the channels into groups and computes statistics + (mean and variance) separately for each group, normalizing each group independently. + It can be used as a normalization layer in various neural network architectures. + + Args: + num_features (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float): A small value added to the denominator for numerical stability. + momentum (float): The momentum factor for updating running statistics. + affine (bool): If True, learnable affine parameters (weight and bias) are applied to + the normalized output. Default is False. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. Default is False. + + Attributes: + num_groups (int): Number of groups the channels are divided into. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. + + Note: + The input tensor should have shape (N, C, *), where N is the batch size, C is the + number of channels, and * represents any number of additional dimensions. + + See Also: + - :class:`~torch.nn.GroupNorm` for a user-friendly interface. + - :class:`~torch.nn.BatchNorm2d` for standard Batch Normalization. + """ def __init__( self, num_features, @@ -129,25 +219,27 @@ def forward(self, input): class GroupNorm2d(_GroupNorm): - r"""Applies Group Normalization over a 4D input (a mini-batch of 2D inputs - with additional channel dimension) as described in the paper - https://arxiv.org/pdf/1803.08494.pdf - `Group Normalization`_ . + """Applies Group Normalization over a 4D input (a mini-batch of 2D inputs + with an additional channel dimension) as described in the paper + "Group Normalization" (https://arxiv.org/pdf/1803.08494.pdf). + Args: - num_features: :math:`C` from an expected input of size - :math:`(N, C, H, W)` - num_groups: - eps: a value added to the denominator for numerical stability. Default: 1e-5 - momentum: the value used for the running_mean and running_var computation. Default: 0.1 - affine: a boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - track_running_stats: a boolean value that when set to ``True``, this - module tracks the running mean and variance, and when set to ``False``, - this module does not track such statistics and always uses batch - statistics in both training and eval modes. Default: ``False`` + num_features (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float): A small value added to the denominator for numerical stability. + Default: 1e-5. + momentum (float): The value used for computing running statistics (mean and variance). + Default: 0.1. + affine (bool): If True, learnable affine parameters (weight and bias) are applied to + the normalized output. Default: True. + track_running_stats (bool): If True, this module tracks running statistics + (mean and variance) during training. If False, it uses batch statistics in both + training and evaluation modes. Default: False. + Shape: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) + - Input: (N, C, H, W) + - Output: (N, C, H, W) (same shape as input) + Examples: >>> # Without Learnable Parameters >>> m = GroupNorm2d(100, 4) @@ -155,8 +247,17 @@ class GroupNorm2d(_GroupNorm): >>> m = GroupNorm2d(100, 4, affine=True) >>> input = torch.randn(20, 100, 35, 45) >>> output = m(input) + + Note: + The input tensor should have shape (N, C, H, W), where N is the batch size, + C is the number of channels, H is the height, and W is the width. + + See Also: + - :class:`~torch.nn.GroupNorm` for a user-friendly interface. + - :class:`~torch.nn.BatchNorm2d` for standard Batch Normalization for 2D data. """ + def _check_input_dim(self, input): if input.dim() != 4: raise ValueError("expected 4D input (got {}D input)".format(input.dim())) @@ -164,7 +265,35 @@ def _check_input_dim(self, input): class GroupNorm3d(_GroupNorm): """ - Assume the data format is (B, C, D, H, W) + Applies 3D Group Normalization over a mini-batch of 3D inputs. + + Group Normalization divides the channels into groups and computes statistics + (mean and variance) separately for each group, normalizing each group independently. + It is designed for 3D data with the format (B, C, D, H, W), where B is the batch size, + C is the number of channels, D is the depth, H is the height, and W is the width. + + Args: + num_features (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float): A small value added to the denominator for numerical stability. + momentum (float): The momentum factor for updating running statistics. + affine (bool): If True, learnable affine parameters (weight and bias) are applied to + the normalized output. Default is False. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. Default is False. + + Attributes: + num_groups (int): Number of groups the channels are divided into. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. + + Note: + The input tensor should have shape (N, C, D, H, W), where N is the batch size, C is the + number of channels, D is the depth, H is the height, and W is the width. + + See Also: + - :class:`~torch.nn.GroupNorm` for a user-friendly interface. + - :class:`~torch.nn.BatchNorm3d` for standard Batch Normalization for 3D data. """ def _check_input_dim(self, input): diff --git a/python/fedml/model/cv/resnet56/resnet_client.py b/python/fedml/model/cv/resnet56/resnet_client.py index 7e26488005..37d5dbe311 100644 --- a/python/fedml/model/cv/resnet56/resnet_client.py +++ b/python/fedml/model/cv/resnet56/resnet_client.py @@ -16,7 +16,19 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """ + 3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int): Stride for the convolution operation. + groups (int): Number of groups for grouped convolution. + dilation (int): Dilation factor for the convolution operation. + + Returns: + nn.Conv2d: A 3x3 convolutional layer. + """ return nn.Conv2d( in_planes, out_planes, @@ -30,11 +42,41 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """ + 1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int): Stride for the convolution operation. + + Returns: + nn.Conv2d: A 1x1 convolutional layer. + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Width of each group. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the block. + + Example: + block = BasicBlock(64, 128, stride=2) + """ + expansion = 1 def __init__( @@ -65,6 +107,15 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass through the BasicBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -84,6 +135,26 @@ def forward(self, x): class Bottleneck(nn.Module): + """ + Bottleneck building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Width of each group. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the block. + + Example: + block = Bottleneck(256, 512, stride=2) + """ + expansion = 4 def __init__( @@ -113,6 +184,15 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass through the Bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -135,6 +215,7 @@ def forward(self, x): return out + class ResNet(nn.Module): def __init__( self, @@ -148,6 +229,29 @@ def __init__( norm_layer=None, KD=False, ): + """ + ResNet model architecture. + + Args: + block (nn.Module): The block type to use for constructing layers (e.g., BasicBlock or Bottleneck). + layers (list of int): List specifying the number of blocks in each layer. + num_classes (int, optional): Number of output classes. Default is 10. + zero_init_residual (bool, optional): Whether to initialize the last BN in each residual branch to zero. Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Width of each group. Default is 64. + replace_stride_with_dilation (list of bool, optional): List indicating if stride should be replaced with dilation. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + KD (bool, optional): Knowledge distillation flag. Default is False. + + Attributes: + expansion (int): Expansion factor for the blocks. + + Example: + # Example architecture for a ResNet-18 model with 2 blocks in each layer. + model = ResNet(BasicBlock, [2, 2, 2, 2]) + # Alternatively, for a ResNet-50 model with 3 blocks in each layer. + model = ResNet(Bottleneck, [3, 4, 6, 3]) + """ super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -242,6 +346,15 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """ + Forward pass through the ResNet model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output logits and extracted features. + """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) # B x 16 x 32 x 32 @@ -260,12 +373,22 @@ def forward(self, x): def resnet5_56(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-32 model. + Constructs a ResNet-5-56 model. Args: + c (int): Number of output classes. pretrained (bool): If True, returns a model pre-trained. - """ + path (str, optional): Path to a pre-trained checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + nn.Module: A ResNet-5-56 model. + Example: + # Create a ResNet-5-56 model with 10 output classes. + model = resnet5_56(10) + """ + model = ResNet(BasicBlock, [1, 2, 2], num_classes=c, **kwargs) if pretrained: checkpoint = torch.load(path) @@ -285,10 +408,20 @@ def resnet5_56(c, pretrained=False, path=None, **kwargs): def resnet8_56(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-32 model. + Constructs a ResNet-8-56 model. Args: + c (int): Number of output classes. pretrained (bool): If True, returns a model pre-trained. + path (str, optional): Path to a pre-trained checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + nn.Module: A ResNet-8-56 model. + + Example: + # Create a ResNet-8-56 model with 10 output classes. + model = resnet8_56(10) """ model = ResNet(Bottleneck, [2, 2, 2], num_classes=c, **kwargs) diff --git a/python/fedml/model/cv/resnet56/resnet_pretrained.py b/python/fedml/model/cv/resnet56/resnet_pretrained.py index b1c6d93666..356db9fa5b 100644 --- a/python/fedml/model/cv/resnet56/resnet_pretrained.py +++ b/python/fedml/model/cv/resnet56/resnet_pretrained.py @@ -15,7 +15,23 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """ + Create a 3x3 convolution layer with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 3x3 convolution layer. + + Example: + # Create a 3x3 convolution layer with 64 input channels and 128 output channels. + conv_layer = conv3x3(64, 128) + """ return nn.Conv2d( in_planes, out_planes, @@ -29,11 +45,45 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """ + Create a 1x1 convolution layer. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 1x1 convolution layer. + + Example: + # Create a 1x1 convolution layer with 64 input channels and 128 output channels. + conv_layer = conv1x1(64, 128) + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): The expansion factor of the block. + + Example: + # Create a BasicBlock with 64 input channels and 128 output channels. + block = BasicBlock(64, 128) + """ expansion = 1 def __init__( @@ -83,6 +133,26 @@ def forward(self, x): class Bottleneck(nn.Module): + """ + Bottleneck building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): The expansion factor of the block (default is 4). + + Example: + # Create a Bottleneck block with 64 input channels and 128 output channels. + block = Bottleneck(64, 128) + """ expansion = 4 def __init__( @@ -135,6 +205,29 @@ def forward(self, x): class ResNet(nn.Module): + """ + ResNet model architecture for image classification. + + Args: + block (nn.Module): The building block for the network (e.g., BasicBlock or Bottleneck). + layers (list): List of integers specifying the number of blocks in each layer. + num_classes (int, optional): Number of classes for classification. Default is 10. + zero_init_residual (bool, optional): If True, zero-initialize the last BN in each residual branch. + Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Base width for grouped convolution. Default is 64. + replace_stride_with_dilation (list or None, optional): List of booleans specifying if the 2x2 stride + should be replaced with dilated convolution in each layer. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + KD (bool, optional): Knowledge distillation flag. Default is False. + + Attributes: + expansion (int): The expansion factor of the building block (default is 4). + + Example: + # Create a ResNet-56 model with 10 output classes. + model = ResNet(Bottleneck, [6, 6, 6], num_classes=10) + """ def __init__( self, block, @@ -200,6 +293,23 @@ def __init__( nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + """ + Create a layer of blocks for the ResNet model. + + Args: + block (nn.Module): The building block for the layer (e.g., BasicBlock or Bottleneck). + planes (int): The number of output channels for the layer. + blocks (int): The number of blocks to stack in the layer. + stride (int, optional): The stride for the layer's convolutional operations. Default is 1. + dilate (bool, optional): If True, apply dilated convolutions in the layer. Default is False. + + Returns: + nn.Sequential: A sequential container of blocks representing the layer. + + Example: + # Create a layer of 2 Bottleneck blocks with 64 output channels and stride 1. + layer = self._make_layer(Bottleneck, 64, 2, stride=1) + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -241,6 +351,17 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """ + Forward pass of the ResNet model. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is the batch size, C is the number of + channels, H is the height, and W is the width. + + Returns: + torch.Tensor: The output tensor of shape (B, num_classes) representing class logits. + torch.Tensor: Extracted features before the classification layer, of shape (B, C, H, W). + """ x = self.conv1(x) x = self.bn1(x) @@ -260,10 +381,20 @@ def forward(self, x): def resnet32_pretrained(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-32 model. + Constructs a pre-trained ResNet-32 model. Args: - pretrained (bool): If True, returns a model pre-trained. + c (int): The number of output classes. + pretrained (bool): If True, returns a model pre-trained on a given path. + path (str, optional): The path to the pre-trained model checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet model. + + Returns: + nn.Module: A pre-trained ResNet-32 model. + + Example: + # Create a pre-trained ResNet-32 model with 10 output classes. + model = resnet32_pretrained(10, pretrained=True, path='pretrained_resnet32.pth') """ model = ResNet(BasicBlock, [5, 5, 5], num_classes=c, **kwargs) @@ -285,10 +416,20 @@ def resnet32_pretrained(c, pretrained=False, path=None, **kwargs): def resnet56_pretrained(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-110 model. + Constructs a pre-trained ResNet-56 model. Args: - pretrained (bool): If True, returns a model pre-trained. + c (int): The number of output classes. + pretrained (bool): If True, returns a model pre-trained on a given path. + path (str, optional): The path to the pre-trained model checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet model. + + Returns: + nn.Module: A pre-trained ResNet-56 model. + + Example: + # Create a pre-trained ResNet-56 model with 10 output classes. + model = resnet56_pretrained(10, pretrained=True, path='pretrained_resnet56.pth') """ logging.info("path = " + str(path)) model = ResNet(Bottleneck, [6, 6, 6], num_classes=c, **kwargs) diff --git a/python/fedml/model/cv/resnet56/resnet_server.py b/python/fedml/model/cv/resnet56/resnet_server.py index a481461b1a..7ca1bf738c 100644 --- a/python/fedml/model/cv/resnet56/resnet_server.py +++ b/python/fedml/model/cv/resnet56/resnet_server.py @@ -17,7 +17,24 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """ + 3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 3x3 convolutional layer. + + Example: + # Create a 3x3 convolution with 64 input channels, 128 output channels, and a stride of 2. + conv_layer = conv3x3(64, 128, stride=2) + """ + return nn.Conv2d( in_planes, out_planes, @@ -31,11 +48,47 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """ + 1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 1x1 convolutional layer. + + Example: + # Create a 1x1 convolution with 64 input channels and 128 output channels. + conv_layer = conv1x1(64, 128) + """ + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic building block for a ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for width calculation. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the number of output channels. + + Example: + # Create a BasicBlock with 64 input channels and 128 output channels. + block = BasicBlock(64, 128) + """ + expansion = 1 def __init__( @@ -58,6 +111,19 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass of the BasicBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + Example: + # Forward pass through a BasicBlock. + output = block(input_tensor) + """ identity = x out = self.conv1(x) @@ -77,6 +143,27 @@ def forward(self, x): class Bottleneck(nn.Module): + """ + Bottleneck building block for a ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for width calculation. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the number of output channels. + + Example: + # Create a Bottleneck with 64 input channels and 128 output channels. + bottleneck = Bottleneck(64, 128) + """ + expansion = 4 def __init__( @@ -98,6 +185,19 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass of the Bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + Example: + # Forward pass through a Bottleneck block. + output = bottleneck(input_tensor) + """ identity = x out = self.conv1(x) @@ -133,6 +233,42 @@ def __init__( norm_layer=None, KD=False, ): + """ + ResNet model implementation. + + Args: + block (nn.Module): The type of block to use in the network (e.g., BasicBlock or Bottleneck). + layers (list of int): Number of blocks in each layer of the network. + num_classes (int): Number of output classes. Default is 10. + zero_init_residual (bool): Whether to zero-init the last BN in each residual branch. Default is False. + groups (int): Number of groups for grouped convolution. Default is 1. + width_per_group (int): Number of channels per group for grouped convolution. Default is 64. + replace_stride_with_dilation (list of bool): List indicating whether to replace 2x2 stride with dilation. + norm_layer (nn.Module): Normalization layer. Default is None. + KD (bool): Whether to enable knowledge distillation. Default is False. + + Attributes: + block (nn.Module): The type of block used in the network. + layers (list of int): Number of blocks in each layer of the network. + num_classes (int): Number of output classes. + zero_init_residual (bool): Whether to zero-init the last BN in each residual branch. + groups (int): Number of groups for grouped convolution. + base_width (int): Base width for width calculation. + dilation (int): Dilation rate for the convolution. + conv1 (nn.Conv2d): The initial convolutional layer. + bn1 (nn.BatchNorm2d): Batch normalization layer after the initial convolution. + relu (nn.ReLU): ReLU activation function. + layer1 (nn.Sequential): The first layer of the network. + layer2 (nn.Sequential): The second layer of the network. + layer3 (nn.Sequential): The third layer of the network. + avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer. + fc (nn.Linear): Fully connected layer for classification. + KD (bool): Whether knowledge distillation is enabled. + + Example: + # Create a ResNet-18 model with 10 output classes. + resnet = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10) + """ super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -179,6 +315,19 @@ def __init__( nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + """ + Helper function to create a layer of blocks. + + Args: + block (nn.Module): The type of block to use. + planes (int): Number of output channels for the layer. + blocks (int): Number of blocks in the layer. + stride (int, optional): Stride for the convolution. Default is 1. + dilate (bool, optional): Whether to use dilation. Default is False. + + Returns: + nn.Sequential: A sequential container of blocks. + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -212,6 +361,19 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """ + Forward pass of the ResNet model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + Example: + # Forward pass through a ResNet model. + output = resnet(input_tensor) + """ # x = self.conv1(x) # x = self.bn1(x) # x = self.relu(x) # B x 16 x 32 x 32 @@ -230,8 +392,20 @@ def resnet56_server(c, pretrained=False, path=None, **kwargs): """ Constructs a ResNet-110 model. + This function creates a ResNet-110 model for server-side applications with the specified number of output classes. + Args: + c (int): Number of output classes. pretrained (bool): If True, returns a model pre-trained. + path (str, optional): Path to a pre-trained model checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet model constructor. + + Returns: + nn.Module: A ResNet-110 model. + + Example: + # Create a ResNet-110 model with 10 output classes. + model = resnet56_server(10) """ logging.info("path = " + str(path)) model = ResNet(Bottleneck, [6, 6, 6], num_classes=c, **kwargs) diff --git a/python/fedml/model/linear/lr.py b/python/fedml/model/linear/lr.py index d5bca7fde2..d22e3dc4af 100644 --- a/python/fedml/model/linear/lr.py +++ b/python/fedml/model/linear/lr.py @@ -5,25 +5,42 @@ class LogisticRegression(torch.nn.Module): """ Logistic Regression Model. + This class implements a simple logistic regression model for binary or multi-class classification tasks. + Args: - input_dim (int): The input dimension, typically the number of features in each input sample. - output_dim (int): The output dimension, representing the number of classes or a single output. + input_dim (int): The input dimension, typically representing the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes (for multi-class) or 1 (for binary). Input: - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. Output: - - Output tensor of shape (batch_size, output_dim), representing class probabilities or a single output. + - Output tensor of shape (batch_size, output_dim), representing class probabilities (for multi-class) + or a single output (for binary). Architecture: - Linear Layer: - Input: input_dim neurons - Output: output_dim neurons - Activation: Sigmoid (for binary classification) or Softmax (for multi-class classification) - + Note: - - For binary classification, output_dim is typically set to 1. - - For multi-class classification, output_dim is the number of classes. + - For binary classification, set output_dim to 1. + - For multi-class classification, output_dim should be set to the number of classes. + + Example: + To create a binary logistic regression model with 10 input features: + >>> model = LogisticRegression(input_dim=10, output_dim=1) + + Forward Method: + The forward method computes the forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities (for multi-class) + or a single output (for binary). """ def __init__(self, input_dim, output_dim): @@ -38,13 +55,11 @@ def forward(self, x): x (Tensor): Input tensor of shape (batch_size, input_dim). Returns: - outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities or a single output. + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities (for multi-class) + or a single output (for binary). """ - # try: + outputs = torch.sigmoid(self.linear(x)) - # except: - # print(x.size()) - # import pdb - # pdb.set_trace() + return outputs diff --git a/python/fedml/model/linear/lr_cifar10.py b/python/fedml/model/linear/lr_cifar10.py index 87d593a547..3c99818edb 100644 --- a/python/fedml/model/linear/lr_cifar10.py +++ b/python/fedml/model/linear/lr_cifar10.py @@ -5,8 +5,11 @@ class LogisticRegression_Cifar10(torch.nn.Module): """ Logistic Regression Model for CIFAR-10 Image Classification. + This class implements a logistic regression model for classifying images in the CIFAR-10 dataset. + Args: - input_dim (int): The input dimension, typically the number of features in each input sample. + input_dim (int): The input dimension, typically representing the number of features in each input sample + (flattened image vectors). output_dim (int): The output dimension, representing the number of classes in CIFAR-10. Input: @@ -21,6 +24,19 @@ class LogisticRegression_Cifar10(torch.nn.Module): - Output: output_dim neurons (class probabilities) - Activation: Sigmoid (to produce class probabilities) + Example: + To create a CIFAR-10 logistic regression model with 3072 input features (32x32x3 images): + >>> model = LogisticRegression_Cifar10(input_dim=3072, output_dim=10) + + Forward Method: + The forward method computes the forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities. + """ def __init__(self, input_dim, output_dim): super(LogisticRegression_Cifar10, self).__init__() @@ -39,10 +55,13 @@ def forward(self, x): """ # Flatten images into vectors # print(f"size = {x.size()}") + x = x.view(x.size(0), -1) - outputs = torch.sigmoid(self.linear(x)) + # except: # print(x.size()) # import pdb # pdb.set_trace() + + outputs = torch.sigmoid(self.linear(x)) return outputs diff --git a/python/fedml/model/mobile/mnn_lenet.py b/python/fedml/model/mobile/mnn_lenet.py index 2378fc9695..e8527a67f5 100644 --- a/python/fedml/model/mobile/mnn_lenet.py +++ b/python/fedml/model/mobile/mnn_lenet.py @@ -48,7 +48,30 @@ def forward(self, x): def create_mnn_lenet5_model(mnn_file_path): + """ + Create and save a LeNet-5 model in the MNN format. + + Args: + mnn_file_path (str): The path to save the MNN model file. + + Note: + This function assumes you have a LeNet-5 model class defined in a 'lenet5' module. + The LeNet-5 model class should have a 'forward' method that takes an input tensor and returns predictions. + + Example: + To create and save a LeNet-5 model to 'lenet5.mnn': + >>> create_mnn_lenet5_model('lenet5.mnn') + + """ + # Create an instance of the LeNet-5 model net = Lenet5() + + # Define an input tensor with the desired shape (1 batch, 1 channel, 28x28) input_var = MNN.expr.placeholder([1, 1, 28, 28], MNN.expr.NCHW) + + # Perform a forward pass to generate predictions predicts = net.forward(input_var) + + # Save the model to the specified file path F.save([predicts], mnn_file_path) + \ No newline at end of file diff --git a/python/fedml/model/mobile/mnn_resnet.py b/python/fedml/model/mobile/mnn_resnet.py index 4f9cf53744..d265ddf94c 100644 --- a/python/fedml/model/mobile/mnn_resnet.py +++ b/python/fedml/model/mobile/mnn_resnet.py @@ -173,7 +173,29 @@ def forward(self, x): def create_mnn_resnet20_model(mnn_file_path): + """ + Create and save a ResNet-20 model in the MNN format. + + Args: + mnn_file_path (str): The path to save the MNN model file. + + Note: + This function assumes you have a ResNet-20 model class defined in a 'resnet20' module. + The ResNet-20 model class should have a 'forward' method that takes an input tensor and returns predictions. + + Example: + To create and save a ResNet-20 model to 'resnet20.mnn': + >>> create_mnn_resnet20_model('resnet20.mnn') + + """ + # Create an instance of the ResNet-20 model net = Resnet20() + + # Define an input tensor with the desired shape (1 batch, 3 channels, 32x32) input_var = MNN.expr.placeholder([1, 3, 32, 32], MNN.expr.NCHW) + + # Perform a forward pass to generate predictions predicts = net.forward(input_var) + + # Save the model to the specified file path F.save([predicts], mnn_file_path) diff --git a/python/fedml/model/mobile/torch_lenet.py b/python/fedml/model/mobile/torch_lenet.py index ee3f30241f..f72f8bee65 100644 --- a/python/fedml/model/mobile/torch_lenet.py +++ b/python/fedml/model/mobile/torch_lenet.py @@ -29,7 +29,7 @@ class LeNet(nn.Module): - Activation: ReLU - Max Pooling: 2x2 - Fully Connected Layer 1: - - Input: 800 neurons (flattened 50x4x4 from previous layer) + - Input: 800 neurons (flattened 50x4x4 from the previous layer) - Output: 500 neurons - Activation: ReLU - Dropout: 50% dropout rate @@ -38,6 +38,16 @@ class LeNet(nn.Module): - Output: 10 neurons (class probabilities) - Activation: Softmax + Note: + - LeNet-5 is a classic convolutional neural network architecture designed for image classification tasks. + - This implementation follows the original LeNet-5 architecture. + + Example: + To create an instance of the LeNet model: + >>> model = LeNet() + >>> input_tensor = torch.randn(1, 1, 32, 32) # Example input tensor + >>> output = model(input_tensor) # Forward pass to obtain class probabilities + """ def __init__(self): From 0884292c14b95468e44e93df0654a2ff8a6c8532 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 13 Sep 2023 14:29:35 +0530 Subject: [PATCH 23/70] 23 --- python/fedml/ml/engine/ml_engine_adapter.py | 252 +++++++++++++++++- .../ml/engine/torch_process_group_manager.py | 23 ++ .../fedml/ml/trainer/feddyn_trainer copy.py | 78 +++++- python/fedml/ml/trainer/mime_trainer.py | 64 +++++ python/fedml/ml/trainer/my_model_trainer.py | 75 +++++- .../my_model_trainer_classification.py | 58 ++++ .../fedml/ml/trainer/my_model_trainer_nwp.py | 53 +++- .../my_model_trainer_tag_prediction.py | 52 +++- python/fedml/ml/trainer/scaffold_trainer.py | 57 +++- 9 files changed, 687 insertions(+), 25 deletions(-) diff --git a/python/fedml/ml/engine/ml_engine_adapter.py b/python/fedml/ml/engine/ml_engine_adapter.py index dbae852142..4ec919964c 100644 --- a/python/fedml/ml/engine/ml_engine_adapter.py +++ b/python/fedml/ml/engine/ml_engine_adapter.py @@ -5,11 +5,23 @@ from .torch_process_group_manager import TorchProcessGroupManager from ...core.common.ml_engine_backend import MLEngineBackend +import tensorflow as tf +import numpy as np +from mxnet import np as mx_np def convert_numpy_to_torch_data_format(args, batched_x, batched_y): - import torch - import numpy as np - + """ + Convert batched data from NumPy format to PyTorch format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + torch.Tensor: Batched input data in PyTorch format. + torch.Tensor: Batched output data in PyTorch format. + """ if args.model == "cnn": batched_x = torch.from_numpy(np.asarray(batched_x)).float().reshape(-1, 28, 28) # CNN_MINST else: @@ -20,10 +32,18 @@ def convert_numpy_to_torch_data_format(args, batched_x, batched_y): def convert_numpy_to_tf_data_format(args, batched_x, batched_y): - # https://www.tensorflow.org/api_docs/python/tf/convert_to_tensor - import tensorflow as tf - import numpy as np - + """ + Convert batched data from NumPy format to TensorFlow format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + tf.Tensor: Batched input data in TensorFlow format. + tf.Tensor: Batched output data in TensorFlow format. + """ if args.model == "cnn": batched_x = tf.convert_to_tensor(np.asarray(batched_x), dtype=tf.float32) # CNN_MINST batched_x = tf.reshape(batched_x, [-1, 28, 28]) @@ -35,8 +55,18 @@ def convert_numpy_to_tf_data_format(args, batched_x, batched_y): def convert_numpy_to_jax_data_format(args, batched_x, batched_y): - import numpy as np - + """ + Convert batched data from NumPy format to JAX format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + numpy.ndarray: Batched input data in JAX format. + numpy.ndarray: Batched output data in JAX format. + """ if args.model == "cnn": batched_x = np.asarray(batched_x, dtype=np.float32) # CNN_MINST batched_x = np.reshape(batched_x, [-1, 28, 28]) @@ -48,8 +78,18 @@ def convert_numpy_to_jax_data_format(args, batched_x, batched_y): def convert_numpy_to_mxnet_data_format(args, batched_x, batched_y): - from mxnet import np as mx_np - + """ + Convert batched data from NumPy format to MXNet format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + mxnet.numpy.ndarray: Batched input data in MXNet format. + mxnet.numpy.ndarray: Batched output data in MXNet format. + """ if args.model == "cnn": batched_x = mx_np.array(batched_x) batched_x = mx_np.reshape(batched_x, [-1, 28, 28]) # pylint: disable=E1101 @@ -61,6 +101,17 @@ def convert_numpy_to_mxnet_data_format(args, batched_x, batched_y): def convert_numpy_to_ml_engine_data_format(args, batched_x, batched_y): + """ + Convert batched data from NumPy format to the format required by a specified machine learning engine. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + Data in the format required by the specified machine learning engine. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: return convert_numpy_to_tf_data_format(args, batched_x, batched_y) @@ -75,6 +126,16 @@ def convert_numpy_to_ml_engine_data_format(args, batched_x, batched_y): def is_torch_device_available(args, device_type): + """ + Check if a Torch device of the specified type is available. + + Args: + args: Model-specific arguments or configuration. + device_type (str): The type of Torch device to check (e.g., "gpu", "mps", "cpu"). + + Returns: + bool: True if the Torch device is available, False otherwise. + """ if device_type == MLEngineBackend.ml_device_type_gpu: if torch.cuda.is_available(): return True @@ -99,6 +160,16 @@ def is_torch_device_available(args, device_type): def is_mxnet_device_available(args, device_type): + """ + Check if a MXNet device of the specified type is available. + + Args: + args: Model-specific arguments or configuration. + device_type (str): The type of MXNet device to check (e.g., "cpu", "gpu"). + + Returns: + bool: True if the MXNet device is available, False otherwise. + """ if device_type == MLEngineBackend.ml_device_type_cpu: return True elif device_type == MLEngineBackend.ml_device_type_gpu: @@ -116,6 +187,16 @@ def is_mxnet_device_available(args, device_type): def is_device_available(args, device_type=MLEngineBackend.ml_device_type_gpu): + """ + Check if a specified device type is available based on the provided arguments and ML engine. + + Args: + args: Model-specific arguments or configuration. + device_type (str): The type of device to check (e.g., "gpu", "mps", "cpu"). + + Returns: + bool: True if the device is available, False otherwise. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: import tensorflow as tf @@ -144,6 +225,19 @@ def is_device_available(args, device_type=MLEngineBackend.ml_device_type_gpu): def get_torch_device(args, using_gpu, device_id, device_type): + """ + Get a Torch device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + torch.device: The Torch device. + """ + logging.info( "args = {}, using_gpu = {}, device_id = {}, device_type = {}".format(args, using_gpu, device_id, device_type) ) @@ -165,6 +259,18 @@ def get_torch_device(args, using_gpu, device_id, device_type): def get_tf_device(args, using_gpu, device_id, device_type): + """ + Get a TensorFlow device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + tf.device: The TensorFlow device. + """ import tensorflow as tf if using_gpu: @@ -174,6 +280,18 @@ def get_tf_device(args, using_gpu, device_id, device_type): def get_jax_device(args, using_gpu, device_id, device_type): + """ + Get a JAX device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + jax.devices.Device: The JAX device. + """ import jax devices = jax.devices(None) @@ -187,6 +305,18 @@ def get_jax_device(args, using_gpu, device_id, device_type): def get_mxnet_device(args, using_gpu, device_id, device_type): + """ + Get an MXNet device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + mxnet.context.Context: The MXNet device. + """ import mxnet as mx if using_gpu: @@ -196,6 +326,17 @@ def get_mxnet_device(args, using_gpu, device_id, device_type): def get_device(args, device_id=None, device_type="cpu"): + """ + Get the appropriate device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + device_id (int, optional): The ID of the GPU device. Defaults to None. + device_type (str, optional): The type of device (e.g., "cpu"). Defaults to "cpu". + + Returns: + torch.device, tf.device, jax.devices.Device, mxnet.context.Context: The selected device. + """ using_gpu = True if (hasattr(args, "using_gpu") and args.using_gpu is True) else False if hasattr(args, MLEngineBackend.ml_engine_args_flag): @@ -212,6 +353,17 @@ def get_device(args, device_id=None, device_type="cpu"): def dict_to_device(args, dict_obj, device): + """ + Move a dictionary of objects to the specified device. + + Args: + args: Model-specific arguments or configuration. + dict_obj (dict): A dictionary of objects. + device (torch.device, tf.device, jax.devices.Device, mxnet.context.Context): The target device. + + Returns: + dict: The dictionary with objects on the target device. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: with device: @@ -232,6 +384,17 @@ def dict_to_device(args, dict_obj, device): def model_params_to_device(args, params_obj, device): + """ + Move model parameters to the specified device. + + Args: + args: Model-specific arguments or configuration. + params_obj (dict): A dictionary of model parameters. + device (torch.device, tf.device, jax.devices.Device, mxnet.context.Context): The target device. + + Returns: + dict: The dictionary of model parameters on the target device. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: with device: @@ -255,6 +418,17 @@ def model_params_to_device(args, params_obj, device): def model_to_device(args, model_obj, device): + """ + Move a model to the specified device. + + Args: + args: Model-specific arguments or configuration. + model_obj: The model to be moved to the device. + device: The target device (e.g., torch.device, tf.device, jax.devices.Device, mxnet.context.Context). + + Returns: + The model on the target device. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: with device: @@ -271,6 +445,17 @@ def model_to_device(args, model_obj, device): def torch_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for PyTorch. + + Args: + args: Model-specific arguments or configuration. + model_obj: The PyTorch model. + device: The target device (e.g., torch.device). + + Returns: + TorchProcessGroupManager, torch.nn.parallel.DistributedDataParallel: The process group manager and DDP model. + """ from torch.nn.parallel import DistributedDataParallel as DDP only_gpu = args.using_gpu @@ -283,23 +468,68 @@ def torch_model_ddp(args, model_obj, device): # Todo: add tf ddp def tf_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for TensorFlow. + + Args: + args: Model-specific arguments or configuration. + model_obj: The TensorFlow model. + device: The target device (e.g., tf.device). + + Returns: + None, Model: The process group manager (None for TensorFlow) and DDP model. + """ process_group_manager, model = None, model_obj return process_group_manager, model # Todo: add jax ddp def jax_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for JAX. + + Args: + args: Model-specific arguments or configuration. + model_obj: The JAX model. + device: The target device (e.g., jax.devices.Device). + + Returns: + None, Model: The process group manager (None for JAX) and DDP model. + """ process_group_manager, model = None, model_obj return process_group_manager, model # Todo: add mxnet ddp def mxnet_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for MXNet. + + Args: + args: Model-specific arguments or configuration. + model_obj: The MXNet model. + device: The target device (e.g., mxnet.context.Context). + + Returns: + None, Model: The process group manager (None for MXNet) and DDP model. + """ process_group_manager, model = None, model_obj return process_group_manager, model def model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model based on the selected ML engine. + + Args: + args: Model-specific arguments or configuration. + model_obj: The model to be wrapped with DDP. + device: The target device (e.g., torch.device, tf.device, jax.devices.Device, mxnet.context.Context). + + Returns: + TorchProcessGroupManager, torch.nn.parallel.DistributedDataParallel or + None, Model: The process group manager and DDP model (or None for non-Torch engines). + """ process_group_manager, model = None, model_obj if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: diff --git a/python/fedml/ml/engine/torch_process_group_manager.py b/python/fedml/ml/engine/torch_process_group_manager.py index 931ff386c6..fb8fa534ef 100644 --- a/python/fedml/ml/engine/torch_process_group_manager.py +++ b/python/fedml/ml/engine/torch_process_group_manager.py @@ -7,6 +7,18 @@ class TorchProcessGroupManager: def __init__(self, rank, world_size, master_address, master_port, only_gpu): + """ + Initialize the TorchProcessGroupManager. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the distributed training. + master_address (str): The address of the master process for communication. + master_port (int): The port for communication with the master process. + only_gpu (bool): Flag indicating whether only GPUs are used for communication. + + Initializes the process group and creates a messaging process group for communication. + """ logging.info("Start process group") logging.info( "rank: %d, world_size: %d, master_address: %s, master_port: %s" @@ -38,7 +50,18 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Clean up the process group. + + Destroys the process group and performs cleanup. + """ dist.destroy_process_group() def get_process_group(self): + """ + Get the messaging process group. + + Returns: + torch.distributed.ProcessGroup: The messaging process group for communication. + """ return self.messaging_pg diff --git a/python/fedml/ml/trainer/feddyn_trainer copy.py b/python/fedml/ml/trainer/feddyn_trainer copy.py index bfabace0bd..f3f63fce06 100644 --- a/python/fedml/ml/trainer/feddyn_trainer copy.py +++ b/python/fedml/ml/trainer/feddyn_trainer copy.py @@ -7,24 +7,84 @@ +import torch + def model_parameter_vector(model): - param = [p.view(-1) for p in model.parameters()] - return torch.concat(param, dim=0) + """ + Flatten the parameters of a PyTorch model into a single 1D tensor. + Args: + model (torch.nn.Module): The PyTorch model. + + Returns: + torch.Tensor: A 1D tensor containing all the flattened model parameters. + """ + param = [p.view(-1) for p in model.parameters()] + return torch.cat(param, dim=0) # Use torch.cat to concatenate tensors def parameter_vector(parameters): + """ + Flatten a dictionary of PyTorch parameters into a single 1D tensor. + + Args: + parameters (dict): A dictionary of PyTorch parameters. + + Returns: + torch.Tensor: A 1D tensor containing all the flattened parameters. + """ param = [p.view(-1) for p in parameters.values()] - return torch.concat(param, dim=0) + return torch.cat(param, dim=0) # Use torch.cat to concatenate tensors + class FedDynModelTrainer(ClientTrainer): + """ + A federated dynamic model trainer that implements training and testing methods. + + Args: + ClientTrainer: The base class for client trainers. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + id (int): The identifier of the client. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args, old_grad): Train the model with federated dynamic regularization. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, old_grad): + """ + Train the model with federated dynamic regularization. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + old_grad (torch.Tensor): The previous gradient. + + Returns: + torch.Tensor: The updated gradient. + """ model = self.model for params in model.parameters(): params.requires_grad = True @@ -137,6 +197,18 @@ def train(self, train_data, device, args, old_grad): def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy and test loss. + """ + model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/mime_trainer.py b/python/fedml/ml/trainer/mime_trainer.py index 8f5f9d4078..a4fe013be2 100644 --- a/python/fedml/ml/trainer/mime_trainer.py +++ b/python/fedml/ml/trainer/mime_trainer.py @@ -13,6 +13,19 @@ def clip_norm(tensors, device, max_norm=1.0, norm_type=2.): + """ + Clip the gradients of a list of tensors to have a maximum norm. + + Args: + tensors (list of torch.Tensor): The list of tensors whose gradients need to be clipped. + device (torch.device): The device (CPU or GPU) on which the tensors are located. + max_norm (float): The maximum norm value for gradient clipping. + norm_type (float): The type of norm to use for computing the gradient norm. + + Returns: + float: The total gradient norm after clipping. + + """ total_norm = torch.norm(torch.stack( [torch.norm(p.detach(), norm_type).to(device) for p in tensors]), norm_type) clip_coef = max_norm / (total_norm + 1e-6) @@ -23,14 +36,55 @@ def clip_norm(tensors, device, max_norm=1.0, norm_type=2.): class MimeModelTrainer(ClientTrainer): + """ + A custom model trainer for Mime-based federated learning. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + accumulate_data_grad(train_data, device, args): Accumulate the gradients of the local data. + train(train_data, device, args, grad_global, global_named_states): Train the model with Mime-based federated learning. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def accumulate_data_grad(self, train_data, device, args): + """ + Accumulate the gradients of the local data. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for gradient computation. + args: Training arguments. + + Returns: + dict: A dictionary containing the accumulated gradients for each parameter. + """ model = self.model model.to(device) @@ -58,6 +112,16 @@ def accumulate_data_grad(self, train_data, device, args): def train(self, train_data, device, args, grad_global, global_named_states): + """ + Train the model with Mime-based federated learning. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + grad_global: Global gradients. + global_named_states: Global model states. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/my_model_trainer.py b/python/fedml/ml/trainer/my_model_trainer.py index f746a72379..e353018db6 100644 --- a/python/fedml/ml/trainer/my_model_trainer.py +++ b/python/fedml/ml/trainer/my_model_trainer.py @@ -6,23 +6,70 @@ class MyModelTrainer(ClientTrainer): + """ + A custom model trainer that implements training and testing methods. + + Args: + ClientTrainer: The base class for client trainers. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + __init__(self, model, args): Initialize the trainer. + get_model_params(self): Get the model parameters as a state dictionary. + set_model_params(self, model_parameters): Set the model parameters from a state dictionary. + on_before_local_training(self, train_data, device, args): Perform actions before local training (optional). + train(self, train_data, device, args): Train the model. + test(self, test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def __init__(self, model, args): super().__init__(model, args) - self.cpu_transfer = False if not hasattr(self.args, "cpu_transfer") else self.args.cpu_transfer + self.cpu_transfer = False if not hasattr( + self.args, "cpu_transfer") else self.args.cpu_transfer def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ if self.cpu_transfer: return self.model.cpu().state_dict() return self.model.state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def on_before_local_training(self, train_data, device, args): + """ + Execute code before local training (optional). + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ pass def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) @@ -31,7 +78,8 @@ def train(self, train_data, device, args): # train and update criterion = nn.CrossEntropyLoss().to(device) # pylint: disable=E1102 if args.client_optimizer == "sgd": - optimizer = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate) + optimizer = torch.optim.SGD( + self.model.parameters(), lr=args.learning_rate) else: optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.model.parameters()), @@ -71,6 +119,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, precision, and recall (if applicable). + """ model = self.model model.to(device) @@ -91,7 +150,8 @@ def test(self, test_data, device, args): https://github.com/google-research/federated/blob/49a43456aa5eaee3e1749855eed89c0087983541/optimization/stackoverflow_lr/federated_stackoverflow_lr.py#L131 """ if args.dataset == "stackoverflow_lr": - criterion = nn.BCELoss(reduction="sum").to(device) # pylint: disable=E1102 + criterion = nn.BCELoss(reduction="sum").to( + device) # pylint: disable=E1102 else: criterion = nn.CrossEntropyLoss().to(device) # pylint: disable=E1102 @@ -104,9 +164,12 @@ def test(self, test_data, device, args): if args.dataset == "stackoverflow_lr": predicted = (pred > 0.5).int() - correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum() - true_positive = ((target * predicted) > 0.1).int().sum(axis=-1) - precision = true_positive / (predicted.sum(axis=-1) + 1e-13) + correct = predicted.eq(target).sum( + axis=-1).eq(target.size(1)).sum() + true_positive = ((target * predicted) > + 0.1).int().sum(axis=-1) + precision = true_positive / \ + (predicted.sum(axis=-1) + 1e-13) recall = true_positive / (target.sum(axis=-1) + 1e-13) metrics["test_precision"] += precision.sum().item() metrics["test_recall"] += recall.sum().item() diff --git a/python/fedml/ml/trainer/my_model_trainer_classification.py b/python/fedml/ml/trainer/my_model_trainer_classification.py index a4251b4c3c..0aa9beea25 100644 --- a/python/fedml/ml/trainer/my_model_trainer_classification.py +++ b/python/fedml/ml/trainer/my_model_trainer_classification.py @@ -12,13 +12,51 @@ class ModelTrainerCLS(ClientTrainer): + """ + A custom model trainer for classification tasks. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args): Train the model. + train_iterations(train_data, device, args): Train the model for a specified number of iterations. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) @@ -77,6 +115,15 @@ def train(self, train_data, device, args): ) def train_iterations(self, train_data, device, args): + """ + Train the model for a specified number of iterations. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ + model = self.model model.to(device) @@ -137,6 +184,17 @@ def train_iterations(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, and + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/my_model_trainer_nwp.py b/python/fedml/ml/trainer/my_model_trainer_nwp.py index a613e1d987..1077ff5e7b 100644 --- a/python/fedml/ml/trainer/my_model_trainer_nwp.py +++ b/python/fedml/ml/trainer/my_model_trainer_nwp.py @@ -8,20 +8,58 @@ class ModelTrainerNWP(ClientTrainer): + """ + A custom model trainer for Next Word Prediction (NWP) tasks. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args): Train the model. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) model.train() # train and update - criterion = nn.CrossEntropyLoss(ignore_index=0).to(device) # pylint: disable=E1102 + criterion = nn.CrossEntropyLoss(ignore_index=0).to( + device) # pylint: disable=E1102 if args.client_optimizer == "sgd": optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, self.model.parameters()), @@ -34,7 +72,7 @@ def train(self, train_data, device, args): weight_decay=args.weight_decay, amsgrad=True, ) - + epoch_loss = [] for epoch in range(args.epochs): # begin_time = time.time() @@ -66,6 +104,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, and the total number of test samples. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py b/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py index 6cb07a7274..4a367cf679 100644 --- a/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py +++ b/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py @@ -5,13 +5,51 @@ class ModelTrainerTAGPred(ClientTrainer): + """ + A custom model trainer for TAG prediction tasks. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args): Train the model. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ + def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) @@ -56,6 +94,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, precision, and recall (if applicable). + """ model = self.model model.to(device) @@ -85,7 +134,8 @@ def test(self, test_data, device, args): loss = criterion(pred, target) # pylint: disable=E1102 predicted = (pred > 0.5).int() - correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum() + correct = predicted.eq(target).sum( + axis=-1).eq(target.size(1)).sum() true_positive = ((target * predicted) > 0.1).int().sum(axis=-1) precision = true_positive / (predicted.sum(axis=-1) + 1e-13) recall = true_positive / (target.sum(axis=-1) + 1e-13) diff --git a/python/fedml/ml/trainer/scaffold_trainer.py b/python/fedml/ml/trainer/scaffold_trainer.py index ea1f592064..e552946725 100644 --- a/python/fedml/ml/trainer/scaffold_trainer.py +++ b/python/fedml/ml/trainer/scaffold_trainer.py @@ -7,13 +7,55 @@ class ScaffoldModelTrainer(ClientTrainer): + """ + A scaffold model trainer that implements training and testing methods. + + Args: + ClientTrainer: The base class for client trainers. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + id (int): The identifier of the client. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args, c_model_global_params, c_model_local_params): Train the model. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ + def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, c_model_global_params, c_model_local_params): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + c_model_global_params (dict): Global model parameters. + c_model_local_params (dict): Local model parameters. + + Returns: + int: The number of training iterations. + """ model = self.model model.to(device) @@ -63,7 +105,8 @@ def train(self, train_data, device, args, c_model_global_params, c_model_local_p # logging.debug(f"c_model_global[name].device : {c_model_global[name].device}, \ # c_model_global_params[name].device : {c_model_local_params[name].device}") param.data = param.data - current_lr * \ - check_device((c_model_global_params[name] - c_model_local_params[name]), param.data.device) + check_device( + (c_model_global_params[name] - c_model_local_params[name]), param.data.device) iteration_cnt += 1 batch_loss.append(loss.item()) if len(batch_loss) == 0: @@ -77,8 +120,18 @@ def train(self, train_data, device, args, c_model_global_params, c_model_local_p ) return iteration_cnt - def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy and test loss. + """ model = self.model model.to(device) From d8ae5616cd9316c15c915e0b6125a85fb28a7874 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 15 Sep 2023 22:04:25 +0530 Subject: [PATCH 24/70] n --- python/fedml/device/device.py | 30 ++- python/fedml/device/gpu_mapping_cross_silo.py | 27 ++- python/fedml/device/gpu_mapping_mpi.py | 45 +++- python/fedml/device/ip_config_utils.py | 9 + python/fedml/fa/__init__.py | 48 +++- python/fedml/fa/aggregator/avg_aggregator.py | 34 +++ .../frequency_estimation_aggregator.py | 46 +++- .../fa/aggregator/global_analyzer_creator.py | 10 + .../heavy_hitter_triehh_aggregator.py | 49 +++- .../fa/aggregator/intersection_aggregator.py | 71 ++++-- .../k_percentile_element_aggregator.py | 43 +++- .../fedml/fa/aggregator/union_aggregator.py | 48 +++- python/fedml/fa/base_frame/client_analyzer.py | 83 +++++++ .../fedml/fa/base_frame/server_aggregator.py | 60 ++++- .../cross_silo/client/client_initializer.py | 54 +++++ .../fa/cross_silo/client/client_launcher.py | 37 ++- .../fa/cross_silo/client/fa_local_analyzer.py | 110 ++++++++- .../client/fedml_client_master_manager.py | 177 ++++++++++++++- .../client/fedml_client_slave_manager.py | 63 +++++- .../client/fedml_trainer_dist_adapter.py | 108 ++++++++- .../client/process_group_manager.py | 50 ++++- python/fedml/fa/cross_silo/fa_client.py | 38 ++++ python/fedml/fa/cross_silo/fa_server.py | 38 ++++ .../fa/cross_silo/server/fedml_aggregator.py | 134 +++++++++-- .../cross_silo/server/fedml_server_manager.py | 172 ++++++++++++++ .../cross_silo/server/server_initializer.py | 16 ++ python/fedml/fa/data/data_loader.py | 24 +- .../fa/data/fake_numeric_data/data_loader.py | 32 ++- .../fa/data/self_defined_data/data_loader.py | 30 ++- .../data/twitter_Sentiment140/data_loader.py | 34 ++- .../twitter_data_processing.py | 57 ++++- python/fedml/fa/data/utils.py | 50 ++++- python/fedml/fa/local_analyzer/avg.py | 25 ++- .../local_analyzer/client_analyzer_creator.py | 11 +- .../fa/local_analyzer/frequency_estimation.py | 32 ++- .../fa/local_analyzer/heavy_hitter_triehh.py | 81 ++++++- .../fedml/fa/local_analyzer/intersection.py | 24 +- .../fa/local_analyzer/k_percentage_element.py | 25 ++- python/fedml/fa/local_analyzer/union.py | 23 +- python/fedml/fa/runner.py | 63 +++++- python/fedml/fa/simulation/sp/client.py | 65 +++++- python/fedml/fa/simulation/sp/simulator.py | 45 ++++ python/fedml/fa/simulation/utils.py | 13 +- python/fedml/fa/utils/trie.py | 212 +++++++++++++++++- python/fedml/ml/aggregator/agg_operator.py | 73 ++++++ .../fedml/ml/aggregator/aggregator_creator.py | 10 + .../fedml/ml/aggregator/default_aggregator.py | 41 ++++ .../ml/aggregator/my_server_aggregator.py | 53 +++++ .../my_server_aggregator_classification.py | 34 +++ .../ml/aggregator/my_server_aggregator_nwp.py | 34 +++ .../my_server_aggregator_prediction.py | 34 +++ python/fedml/ml/trainer/feddyn_trainer.py | 80 +++++++ python/fedml/ml/trainer/fednova_trainer.py | 90 ++++++++ python/fedml/ml/trainer/fedprox_trainer.py | 69 ++++++ python/fedml/ml/trainer/trainer_creator.py | 12 +- 55 files changed, 2831 insertions(+), 145 deletions(-) diff --git a/python/fedml/device/device.py b/python/fedml/device/device.py index 1085a19412..7892dbde35 100644 --- a/python/fedml/device/device.py +++ b/python/fedml/device/device.py @@ -10,6 +10,18 @@ def get_device_type(args): + """ + Determine the type of device (CPU, GPU, or MPS) based on the provided arguments. + + Args: + args (object): An object containing arguments, including 'device_type', 'using_gpu', 'gpu_id', and 'training_type'. + + Returns: + str: The type of device to use (e.g., 'cpu', 'gpu', or 'mps'). + + Raises: + Exception: If the provided 'device_type' is not supported. + """ if hasattr(args, "device_type"): if args.device_type == "cpu": device_type = "cpu" @@ -40,6 +52,18 @@ def get_device_type(args): def get_device(args): + """ + Get the device for training based on the provided arguments. + + Args: + args (object): An object containing arguments, including 'training_type', 'backend', 'gpu_id', 'using_gpu', 'process_id', and others. + + Returns: + str: The device (CPU or GPU) assigned to the current process. + + Raises: + Exception: If the 'training_type' is not defined. + """ if args.training_type == "simulation" and args.backend == "sp": if not hasattr(args, "gpu_id"): args.gpu_id = 0 @@ -104,14 +128,14 @@ def get_device(args): gpu_mapping_key = ( args.gpu_mapping_key if hasattr(args, "gpu_mapping_key") else None ) - gpu_id = args.gpu_id if hasattr(args, "gpu_id") else None # no no need to set gpu_id + gpu_id = args.gpu_id if hasattr(args, "gpu_id") else None # no need to set gpu_id else: gpu_mapping_file = None gpu_mapping_key = None gpu_id = None logging.info( - "devide_type = {}, gpu_mapping_file = {}, " + "device_type = {}, gpu_mapping_file = {}, " "gpu_mapping_key = {}, gpu_id = {}".format( device_type, gpu_mapping_file, gpu_mapping_key, gpu_id ) @@ -138,7 +162,7 @@ def get_device(args): if args.enable_cuda_rpc and is_master_process: assert ( device.index == args.cuda_rpc_gpu_mapping[args.rank] - ), f"GPU assignemnt inconsistent with cuda_rpc_gpu_mapping. Assigned to GPU {device.index} while expecting {args.cuda_rpc_gpu_mapping[args.rank]}" + ), f"GPU assignment inconsistent with cuda_rpc_gpu_mapping. Assigned to GPU {device.index} while expecting {args.cuda_rpc_gpu_mapping[args.rank]}" return device elif args.training_type == "cross_device": diff --git a/python/fedml/device/gpu_mapping_cross_silo.py b/python/fedml/device/gpu_mapping_cross_silo.py index f1fd8f9948..4c9c46c6ad 100644 --- a/python/fedml/device/gpu_mapping_cross_silo.py +++ b/python/fedml/device/gpu_mapping_cross_silo.py @@ -10,6 +10,27 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( process_id, worker_number, gpu_util_file, gpu_util_key, device_type, scenario, gpu_id=None, args=None ): + """ + Map processes to GPU devices based on GPU utilization information from a YAML file in a cross-silo setting. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of worker processes. + gpu_util_file (str): The path to the GPU utilization YAML file. + gpu_util_key (str): The key to retrieve GPU utilization information from the YAML file. + device_type (str): The type of device to use (e.g., "gpu" or "cpu"). + scenario (str): The cross-silo training scenario (e.g., hierarchical or non-hierarchical). + gpu_id (int, optional): The GPU ID to use for the current process. Defaults to None. + args (object, optional): An object containing additional arguments (e.g., device settings). + + Returns: + str: The GPU or CPU device assigned to the current process. + + Raises: + Exception: If there is an issue with GPU device mapping, such as exceeding PyTorch DDP limits. + AssertionError: If the number of mapped processes does not match the worker number. + + """ if device_type != "gpu": args.using_gpu = False device = ml_engine_adapter.get_device(args, device_id=gpu_id, device_type=device_type) @@ -27,8 +48,7 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( with open(gpu_util_file, "r") as f: gpu_util_yaml = yaml.load(f, Loader=yaml.FullLoader) - # gpu_util_num_process = 'gpu_util_' + str(worker_number) - # gpu_util = gpu_util_yaml[gpu_util_num_process] + gpu_util = gpu_util_yaml[gpu_util_key] logging.info("gpu_util = {}".format(gpu_util)) gpu_util_map = {} @@ -38,7 +58,7 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( # validate DDP gpu mapping if unique_gpu and num_process_on_gpu > 1: raise Exception( - "Cannot put {num_process_on_gpu} processes on GPU {gpu_j} of {host}." + f"Cannot put {num_process_on_gpu} processes on GPU {gpu_j} of {host}. " "PyTorch DDP supports up to one process on each GPU." ) for _ in range(num_process_on_gpu): @@ -57,3 +77,4 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( logging.info("process_id = {}, GPU device = {}".format(process_id, device)) return device + \ No newline at end of file diff --git a/python/fedml/device/gpu_mapping_mpi.py b/python/fedml/device/gpu_mapping_mpi.py index 5bff4c3e26..568f790666 100644 --- a/python/fedml/device/gpu_mapping_mpi.py +++ b/python/fedml/device/gpu_mapping_mpi.py @@ -9,6 +9,23 @@ def mapping_processes_to_gpu_device_from_yaml_file_mpi( process_id, worker_number, gpu_util_file, gpu_util_key, args=None ): + """ + Map processes to GPU devices based on GPU utilization information from a YAML file. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of worker processes. + gpu_util_file (str): The path to the GPU utilization YAML file. + gpu_util_key (str): The key to retrieve GPU utilization information from the YAML file. + args (object, optional): An object containing additional arguments (e.g., device settings). + + Returns: + str: The GPU device assigned to the current process. + + Raises: + AssertionError: If the number of mapped processes does not match the worker number. + + """ if gpu_util_file is None: logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") logging.info(" ################## You do not indicate gpu_util_file, will use CPU training #################") @@ -16,10 +33,10 @@ def mapping_processes_to_gpu_device_from_yaml_file_mpi( logging.info(device) return device else: + # Load GPU utilization information from the YAML file with open(gpu_util_file, "r") as f: gpu_util_yaml = yaml.load(f, Loader=yaml.FullLoader) - # gpu_util_num_process = 'gpu_util_' + str(worker_number) - # gpu_util = gpu_util_yaml[gpu_util_num_process] + gpu_util = gpu_util_yaml[gpu_util_key] logging.info("gpu_util = {}".format(gpu_util)) gpu_util_map = {} @@ -43,15 +60,31 @@ def mapping_processes_to_gpu_device_from_yaml_file_mpi( def mapping_processes_to_gpu_device_from_gpu_util_parse(process_id, worker_number, gpu_util_parse, args=None): - if gpu_util_parse == None: + """ + Map processes to GPU devices based on parsed GPU utilization information. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of worker processes. + gpu_util_parse (str): The parsed GPU utilization information in string format. + args (object, optional): An object containing additional arguments (e.g., device settings). + + Returns: + str: The GPU device assigned to the current process. + + Raises: + AssertionError: If the number of mapped processes does not match the worker number. + + """ + if gpu_util_parse is None: device = ml_engine_adapter.get_device(args, device_type="cpu") logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") logging.info(" ################## Not Indicate gpu_util_file, using cpu #################") logging.info(device) - # return gpu_util_map[process_id][1] + return device else: - # example parse str `gpu_util_parse`: + # Example parse str `gpu_util_parse`: # "gpu1:0,1,1,2;gpu2:3,3,3;gpu3:0,0,0,1,2,4,4,0" gpu_util_parse_temp = gpu_util_parse.split(";") gpu_util_parse_temp = [(item.split(":")[0], item.split(":")[1]) for item in gpu_util_parse_temp] @@ -68,7 +101,7 @@ def mapping_processes_to_gpu_device_from_gpu_util_parse(process_id, worker_numbe gpu_util_map[i] = (host, gpu_j) i += 1 logging.info( - "Process %d running on host: %s,gethostname: %s, gpu: %d ..." + "Process %d running on host: %s, gethostname: %s, gpu: %d ..." % (process_id, gpu_util_map[process_id][0], socket.gethostname(), gpu_util_map[process_id][1]) ) assert i == worker_number diff --git a/python/fedml/device/ip_config_utils.py b/python/fedml/device/ip_config_utils.py index 1ebedfd73a..fad2cf126a 100644 --- a/python/fedml/device/ip_config_utils.py +++ b/python/fedml/device/ip_config_utils.py @@ -2,6 +2,15 @@ def build_ip_table(path): + """ + Build an IP table from a CSV file containing receiver IDs and their corresponding IP addresses. + + Args: + path (str): The path to the CSV file. + + Returns: + dict: A dictionary mapping receiver IDs to their respective IP addresses. + """ ip_config = dict() with open(path, newline="") as csv_file: csv_reader = csv.reader(csv_file) diff --git a/python/fedml/fa/__init__.py b/python/fedml/fa/__init__.py index e216fe2d20..56fd61b7cc 100644 --- a/python/fedml/fa/__init__.py +++ b/python/fedml/fa/__init__.py @@ -4,8 +4,25 @@ from .. import load_arguments, run_simulation, FEDML_TRAINING_PLATFORM_SIMULATION, FEDML_TRAINING_PLATFORM_CROSS_SILO, \ collect_env, mlops +from .runner import FARunner + +__all__ = [ + "FARunner", + "run_simulation", + "init" +] + def init(args=None): + """ + Initialize FedML Engine. + + Args: + args (object, optional): Arguments for initialization. If None, load default arguments. + + Returns: + object: Initialized arguments. + """ print(f"args={args}") if args is None: args = load_arguments(training_type=None, comm_backend=None) @@ -31,6 +48,12 @@ def init(args=None): return args def manage_mpi_args(args): + """ + Manage MPI-related arguments. + + Args: + args (object): Initialized arguments. + """ if hasattr(args, "backend") and args.backend == "MPI": from mpi4py import MPI @@ -48,6 +71,15 @@ def manage_mpi_args(args): args.comm = None def init_cross_silo(args): + """ + Initialize arguments for cross-silo training. + + Args: + args (object): Initialized arguments. + + Returns: + object: Updated arguments. + """ manage_mpi_args(args) # Set intra-silo arguments @@ -82,13 +114,13 @@ def init_cross_silo(args): def init_simulation_sp(args): - return args + """ + Initialize arguments for simulation with SP backend. + Args: + args (object): Initialized arguments. -from .runner import FARunner - -__all__ = [ - "FARunner", - "run_simulation", - "init" -] + Returns: + object: Updated arguments. + """ + return args diff --git a/python/fedml/fa/aggregator/avg_aggregator.py b/python/fedml/fa/aggregator/avg_aggregator.py index a7493a42d1..db829d0ac0 100644 --- a/python/fedml/fa/aggregator/avg_aggregator.py +++ b/python/fedml/fa/aggregator/avg_aggregator.py @@ -3,12 +3,45 @@ class AVGAggregatorFA(FAServerAggregator): + """ + Aggregator for Federated Learning with Averaging. + + Args: + args (object): An object containing aggregator configuration parameters. + + Attributes: + total_sample_num (int): The total number of training samples aggregated. + server_data (float): The aggregated server data. + + Methods: + aggregate(local_submission_list): + Aggregate local submissions from clients and compute the weighted average. + + """ def __init__(self, args): + """ + Initialize the AVGAggregatorFA. + + Args: + args (object): An object containing aggregator configuration parameters. + + Returns: + None + """ super().__init__(args) self.total_sample_num = 0 self.set_server_data(server_data=0) def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients and compute the weighted average. + + Args: + local_submission_list (list): A list of tuples containing local sample number and local submissions. + + Returns: + float: The computed weighted average. + """ print(f"local_submission_list={local_submission_list}") training_num = 0 for idx in range(len(local_submission_list)): @@ -30,6 +63,7 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): return avg + """ todo: Mode 1: (online mode) each client stores its AVG result and the total number of data being sampled so far; later computation will use this result. diff --git a/python/fedml/fa/aggregator/frequency_estimation_aggregator.py b/python/fedml/fa/aggregator/frequency_estimation_aggregator.py index 1e7c1bbeed..1325c87c52 100644 --- a/python/fedml/fa/aggregator/frequency_estimation_aggregator.py +++ b/python/fedml/fa/aggregator/frequency_estimation_aggregator.py @@ -4,14 +4,51 @@ class FrequencyEstimationAggregatorFA(FAServerAggregator): + """ + Aggregator for Federated Learning with Frequency Estimation. + + Args: + args (object): An object containing aggregator configuration parameters. + + Attributes: + total_sample_num (int): The total number of training samples aggregated. + server_data (dict): Dictionary to store aggregated data. + round_idx (int): The current training round index. + total_round (int): The total number of training rounds. + + Methods: + aggregate(local_submission_list): + Aggregate local submissions from clients. + print_frequency_estimation_results(): + Print and display frequency estimation results as a histogram. + + """ def __init__(self, args): + """ + Initialize the FrequencyEstimationAggregatorFA. + + Args: + args (object): An object containing aggregator configuration parameters. + + Returns: + None + """ super().__init__(args) self.total_sample_num = 0 - self.set_server_data(server_data=[]) + self.set_server_data(server_data={}) self.round_idx = 0 self.total_round = args.comm_round def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submission_list (list): A list of tuples containing local sample number and local submissions. + + Returns: + dict: The aggregated server data. + """ training_num = 0 (sample_num, averaged_params) = local_submission_list[0] for i in range(0, len(local_submission_list)): @@ -33,6 +70,12 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): return self.server_data def print_frequency_estimation_results(self): + """ + Print and display frequency estimation results as a histogram. + + Returns: + None + """ print("frequency estimation: ") for key in self.server_data: print(f"key = {key}, freq = {self.server_data[key] / self.total_sample_num}") @@ -41,3 +84,4 @@ def print_frequency_estimation_results(self): plt.ylabel('Occurrence # ') plt.title('Histogram') plt.show() + \ No newline at end of file diff --git a/python/fedml/fa/aggregator/global_analyzer_creator.py b/python/fedml/fa/aggregator/global_analyzer_creator.py index 142e8b14b6..55a6cc9898 100644 --- a/python/fedml/fa/aggregator/global_analyzer_creator.py +++ b/python/fedml/fa/aggregator/global_analyzer_creator.py @@ -9,6 +9,16 @@ def create_global_analyzer(args, train_data_num): + """ + Create a global analyzer based on the specified federated aggregation task. + + Args: + args: Additional arguments for creating the global analyzer. + train_data_num (int): The number of training data samples. + + Returns: + FAServerAggregator: An instance of a global analyzer based on the specified task. + """ task_type = args.fa_task if task_type == FA_TASK_AVG: return AVGAggregatorFA(args) diff --git a/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py b/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py index 6cfbdc1764..d228075390 100644 --- a/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py +++ b/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py @@ -12,6 +12,16 @@ class HeavyHitterTriehhAggregatorFA(FAServerAggregator): def __init__(self, args, train_data_num): + """ + Initialize the HeavyHitterTriehhAggregatorFA. + + Args: + args: Additional arguments for initialization. + train_data_num (int): The number of training data samples. + + Returns: + None + """ super().__init__(args) if hasattr(args, "max_word_len"): self.MAX_L = args.max_word_len @@ -43,7 +53,7 @@ def __init__(self, args, train_data_num): self.batch_size = int(train_data_num * (np.e ** (self.epsilon / self.MAX_L) - 1) / ( self.theta * np.e ** (self.epsilon / self.MAX_L))) self.init_msg = int(math.ceil(self.batch_size * 1.0 / args.client_num_per_round)) - self.w_global = {} # self.trie = {} + self.w_global = {} def get_init_msg(self): return self.init_msg @@ -52,6 +62,15 @@ def set_init_msg(self, init_msg): self.init_msg = init_msg def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of local submissions. + + Returns: + Dict: The aggregated data. + """ votes = {} for (num, local_vote_dict) in local_submission_list: for key in local_vote_dict.keys(): @@ -70,6 +89,12 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): return self.w_global def _set_theta(self): + """ + Calculate and set the value of theta. + + Returns: + int: The calculated theta value. + """ theta = 5 # initial guess delta_inverse = 1 / self.delta while ((theta - 3) / (theta - 2)) * math.factorial(theta) < delta_inverse: @@ -80,11 +105,15 @@ def _set_theta(self): return theta def server_update(self, votes): - # It might make more sense to define a small class called server_state - # server_state can track 2 things: 1) updated trie, and 2) quit_sign - # server_state can be initialized in the constructor of SimulateTrieHH - # and server_update would just update server_state - # (i.e, it would update self.server_state.trie & self.server_state.quit_sign) + """ + Update the server based on received votes. + + Args: + votes (Dict): A dictionary of votes. + + Returns: + None + """ self.quit_sign = True for prefix in votes: if votes[prefix] >= self.theta: @@ -92,11 +121,17 @@ def server_update(self, votes): self.quit_sign = False def print_heavy_hitters(self): + """ + Print the discovered heavy hitters. + + Returns: + None + """ heavy_hitters = [] print(f"self.w_global = {self.w_global}") raw_result = self.w_global.keys() for word in raw_result: if word[-1:] == '$': heavy_hitters.append(word.rstrip('$')) - # print(f'Discovered {len(heavy_hitters)} heavy hitters in run #{self.round_counter + 1}: {heavy_hitters}') + print(f'Discovered {len(heavy_hitters)} heavy hitters: {heavy_hitters}') diff --git a/python/fedml/fa/aggregator/intersection_aggregator.py b/python/fedml/fa/aggregator/intersection_aggregator.py index c7f0b1e559..c85a24b191 100644 --- a/python/fedml/fa/aggregator/intersection_aggregator.py +++ b/python/fedml/fa/aggregator/intersection_aggregator.py @@ -3,48 +3,77 @@ from fedml.fa.base_frame.server_aggregator import FAServerAggregator -def get_intersection_of_two_lists_keep_duplicates(list1, list2): +def get_intersection_of_two_lists_keep_duplicates(list1: List[Any], list2: List[Any]) -> List[Any]: """ - Keep duplicates in the intersection, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [2,3,2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Return the intersection of two lists while keeping duplicates. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The intersection of the two lists, keeping duplicates. """ intersection = [] - for i in range(len(list1)): - for j in range(len(list2) - 1, -1, -1): - if list1[i] == list2[j]: - intersection.append(list2[j]) - list2.remove(j) + for item in list1: + if item in list2: + intersection.append(item) + list2.remove(item) return intersection -def get_intersection_of_two_lists_remove_duplicates(list1, list2): +def get_intersection_of_two_lists_remove_duplicates(list1: List[Any], list2: List[Any]) -> List[Any]: """ - Remove duplicates in the intersection, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Return the intersection of two lists and remove duplicate values. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The intersection of the two lists with duplicates removed. """ return list(set(list1) & set(list2)) class IntersectionAggregatorFA(FAServerAggregator): def __init__(self, args): + """ + Initialize the IntersectionAggregatorFA. + + Args: + args: Additional arguments for initialization. + + Returns: + None + """ super().__init__(args) self.set_server_data(server_data=[]) - def aggregate(self, local_submission_list: List[Tuple[float, Any]]): - for i in range(0, len(local_submission_list)): - _, local_submission = local_submission_list[i] + def aggregate(self, local_submission_list: List[Tuple[float, Any]]) -> List[Any]: + """ + Aggregate local submissions while maintaining intersection. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of local submissions. + + Returns: + List: The intersection of local submissions. + """ + for _, local_submission in local_submission_list: if len(self.server_data) == 0: - # no need to remove duplicates even in ``remove duplicate'' mode, - # as the duplicates will be removed in later computation + self.server_data = local_submission else: self.server_data = get_intersection_of_two_lists_remove_duplicates(self.server_data, local_submission) print(f"cardinality = {self.get_cardinality()}") return self.server_data - def get_cardinality(self): + def get_cardinality(self) -> int: + """ + Get the cardinality (number of elements) of the aggregated data. + + Returns: + int: The cardinality of the aggregated data. + """ return len(self.server_data) diff --git a/python/fedml/fa/aggregator/k_percentile_element_aggregator.py b/python/fedml/fa/aggregator/k_percentile_element_aggregator.py index 5b0eae456f..d0415716a6 100644 --- a/python/fedml/fa/aggregator/k_percentile_element_aggregator.py +++ b/python/fedml/fa/aggregator/k_percentile_element_aggregator.py @@ -17,6 +17,16 @@ class KPercentileElementAggregatorFA(FAServerAggregator): def __init__(self, args, train_data_num): + """ + Initialize the KPercentileElementAggregatorFA. + + Args: + args: Configuration arguments. + train_data_num (int): The total number of training data samples. + + Returns: + None + """ super().__init__(args) self.total_sample_num = 0 self.set_server_data(server_data=[]) @@ -24,46 +34,63 @@ def __init__(self, args, train_data_num): self.total_sample_num = 0 self.train_data_num_in_total = train_data_num self.percentage = args.k / 100 + + # Initialize server_data and previous_server_data if hasattr(args, "flag"): self.server_data = args.flag self.previous_server_data = args.flag else: self.server_data = 100 self.previous_server_data = 100 + + # Check if use_all_data attribute is specified in args if hasattr(args, "use_all_data") and args.use_all_data in [False]: - self.use_all_data = False # in each iteration, each client randomly sample some data to compute + self.use_all_data = False # In each iteration, each client randomly samples some data to compute else: - self.use_all_data = True # in each iteration, each client uses its all local data to compute + self.use_all_data = True # In each iteration, each client uses all its local data to compute + + # Initialize max_val and min_val self.max_val = self.previous_server_data self.min_val = self.previous_server_data def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of tuples containing local submissions and weights. + + Returns: + float: The aggregated result. + """ if self.quit: return self.server_data + total_sample_num_this_round = 0 local_satisfied_data_num_current_round = 0 - logging.info(f"flag={self.server_data}, local_submission_list={local_submission_list}") + for (sample_num, satisfied_counter) in local_submission_list: total_sample_num_this_round += sample_num local_satisfied_data_num_current_round += satisfied_counter + if total_sample_num_this_round * self.percentage == local_satisfied_data_num_current_round: self.quit = True self.previous_server_data = self.server_data elif total_sample_num_this_round * self.percentage > local_satisfied_data_num_current_round: - # decrease server_data + # Decrease server_data self.max_val = self.server_data if self.previous_server_data >= self.server_data: self.previous_server_data = self.server_data if self.server_data / 2 < self.min_val < self.max_val: - self.server_data = (self.server_data + self.min_val)/2 + self.server_data = (self.server_data + self.min_val) / 2 else: self.server_data = self.server_data / 2 - self.min_val = self.server_data # set lower bound for flag + self.min_val = self.server_data # Set lower bound for flag else: new_server_data = (self.previous_server_data + self.server_data) / 2 self.previous_server_data = self.server_data self.server_data = new_server_data - else: # increase server_data + else: # Increase server_data self.min_val = self.server_data if self.previous_server_data <= self.server_data: self.previous_server_data = self.server_data @@ -76,4 +103,6 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): new_server_data = (self.previous_server_data + self.server_data) / 2 self.previous_server_data = self.server_data self.server_data = new_server_data + return self.server_data + \ No newline at end of file diff --git a/python/fedml/fa/aggregator/union_aggregator.py b/python/fedml/fa/aggregator/union_aggregator.py index f3e3baf730..4a15b2f86a 100644 --- a/python/fedml/fa/aggregator/union_aggregator.py +++ b/python/fedml/fa/aggregator/union_aggregator.py @@ -4,39 +4,65 @@ def get_union_of_two_lists_keep_duplicates(list1, list2): """ - Keep duplicates in the union, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [1,2,3,2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Compute the union of two lists while keeping duplicates. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The union of the two lists with duplicates. """ union = [] for item in list1: union.append(item) if item in list2: - list2.remove(list2.index(item)) + list2.remove(item) union.extend(list2) return union def get_union_of_two_lists_remove_duplicates(list1, list2): """ - Remove duplicates in the union, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [1,2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Compute the union of two lists and remove duplicates. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The union of the two lists without duplicates. """ return list(set(list1 + list2)) class UnionAggregatorFA(FAServerAggregator): def __init__(self, args): + """ + Initialize the UnionAggregatorFA. + + Args: + args: Configuration arguments. + + Returns: + None + """ super().__init__(args) self.set_server_data(server_data=[]) - self.union_function = get_union_of_two_lists_remove_duplicates # select the way to compute union + self.union_function = get_union_of_two_lists_remove_duplicates # Select the way to compute union def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of tuples containing local submissions and weights. + + Returns: + List: The aggregated result. + """ for i in range(0, len(local_submission_list)): _, local_submission = local_submission_list[i] - # when server_data is [], i.e., the first round, will only process local_submission + # When server_data is [], i.e., the first round, will only process local_submission self.server_data = self.union_function(self.server_data, local_submission) return self.server_data diff --git a/python/fedml/fa/base_frame/client_analyzer.py b/python/fedml/fa/base_frame/client_analyzer.py index 624f7c290a..9fec99924e 100644 --- a/python/fedml/fa/base_frame/client_analyzer.py +++ b/python/fedml/fa/base_frame/client_analyzer.py @@ -4,6 +4,15 @@ class FAClientAnalyzer(ABC): def __init__(self, args): + """ + Initialize the client analyzer. + + Args: + args: Configuration arguments. + + Returns: + None + """ self.client_submission = 0 self.id = 0 self.args = args @@ -12,30 +21,104 @@ def __init__(self, args): self.init_msg = None def set_init_msg(self, init_msg): + """ + Set the initialization message. + + Args: + init_msg: The initialization message. + + Returns: + None + """ pass def get_init_msg(self): + """ + Get the initialization message. + + Returns: + Any: The initialization message. + """ pass def set_id(self, analyzer_id): + """ + Set the ID of the client analyzer. + + Args: + analyzer_id: The ID of the analyzer. + + Returns: + None + """ self.id = analyzer_id def get_client_submission(self): + """ + Get the client submission. + + Returns: + Any: The client submission. + """ return self.client_submission def set_client_submission(self, client_submission): + """ + Set the client submission. + + Args: + client_submission: The client submission. + + Returns: + None + """ self.client_submission = client_submission def get_server_data(self): + """ + Get the server data. + + Returns: + Any: The server data. + """ return self.server_data def set_server_data(self, server_data): + """ + Set the server data. + + Args: + server_data: The server data. + + Returns: + None + """ self.server_data = server_data @abstractmethod def local_analyze(self, train_data, args): + """ + Perform local analysis. + + Args: + train_data: The local training data. + args: Configuration arguments. + + Returns: + None + """ pass def update_dataset(self, local_train_dataset, local_sample_number): + """ + Update the local dataset. + + Args: + local_train_dataset: The local training dataset. + local_sample_number: The number of local samples. + + Returns: + None + """ self.local_train_dataset = local_train_dataset self.local_sample_number = local_sample_number diff --git a/python/fedml/fa/base_frame/server_aggregator.py b/python/fedml/fa/base_frame/server_aggregator.py index 76fc1a73bc..5ad46dd6a8 100644 --- a/python/fedml/fa/base_frame/server_aggregator.py +++ b/python/fedml/fa/base_frame/server_aggregator.py @@ -1,9 +1,17 @@ from abc import ABC from typing import List, Tuple, Any - class FAServerAggregator(ABC): def __init__(self, args): + """ + Initialize the server aggregator. + + Args: + args: Configuration arguments. + + Returns: + None + """ self.id = 0 self.args = args self.eval_data = None @@ -11,21 +19,67 @@ def __init__(self, args): self.init_msg = None def get_init_msg(self): - # return self.init_msg + """ + Get the initialization message. + + Returns: + Any: The initialization message. + """ pass def set_init_msg(self, init_msg): - # self.init_msg = init_msg + """ + Set the initialization message. + + Args: + init_msg: The initialization message. + + Returns: + None + """ pass def set_id(self, aggregator_id): + """ + Set the ID of the server aggregator. + + Args: + aggregator_id: The ID of the aggregator. + + Returns: + None + """ self.id = aggregator_id def get_server_data(self): + """ + Get the server data. + + Returns: + Any: The server data. + """ return self.server_data def set_server_data(self, server_data): + """ + Set the server data. + + Args: + server_data: The server data. + + Returns: + None + """ self.server_data = server_data def aggregate(self, local_submissions: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submissions (List[Tuple[float, Any]]): A list of tuples containing local submissions and weights. + + Returns: + None + """ pass diff --git a/python/fedml/fa/cross_silo/client/client_initializer.py b/python/fedml/fa/cross_silo/client/client_initializer.py index 167d5e0a5a..9d4b21fd8f 100644 --- a/python/fedml/fa/cross_silo/client/client_initializer.py +++ b/python/fedml/fa/cross_silo/client/client_initializer.py @@ -13,6 +13,22 @@ def init_client( train_data_local_dict, local_analyzer=None, ): + """ + Initialize the federated learning client. + + Args: + args: Configuration arguments. + comm: Communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indices to their local training data. + local_analyzer: Local analyzer for the client (optional). + + Returns: + None + """ backend = args.backend trainer_dist_adapter = get_trainer_dist_adapter( @@ -38,6 +54,20 @@ def get_trainer_dist_adapter( train_data_local_dict, local_analyzer, ): + """ + Get the trainer distribution adapter. + + Args: + args: Configuration arguments. + client_rank (int): The rank of the client. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indices to their local training data. + local_analyzer: Local analyzer for the client. + + Returns: + TrainerDistAdapter: The trainer distribution adapter. + """ return TrainerDistAdapter( args, client_rank, @@ -49,10 +79,34 @@ def get_trainer_dist_adapter( def get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend): + """ + Get the client master manager. + + Args: + args: Configuration arguments. + trainer_dist_adapter: Trainer distribution adapter. + comm: Communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend: Backend for distributed training. + + Returns: + ClientMasterManager: The client master manager. + """ return ClientMasterManager(args, trainer_dist_adapter, comm, client_rank, client_num, backend) def get_client_manager_salve(args, trainer_dist_adapter): + """ + Get the client slave manager. + + Args: + args: Configuration arguments. + trainer_dist_adapter: Trainer distribution adapter. + + Returns: + ClientSlaveManager: The client slave manager. + """ from .fedml_client_slave_manager import ClientSlaveManager return ClientSlaveManager(args, trainer_dist_adapter) diff --git a/python/fedml/fa/cross_silo/client/client_launcher.py b/python/fedml/fa/cross_silo/client/client_launcher.py index 0f3ea1f278..66c1045124 100644 --- a/python/fedml/fa/cross_silo/client/client_launcher.py +++ b/python/fedml/fa/cross_silo/client/client_launcher.py @@ -2,16 +2,47 @@ from fedml.arguments import load_arguments from fedml.constants import FEDML_TRAINING_PLATFORM_CROSS_SILO - class CrossSiloLauncher: + """ + A class for launching distributed trainers in a cross-silo federated learning setup. + + Attributes: + None + + Methods: + launch_dist_trainers(torch_client_filename, inputs): + Launch distributed trainers using the provided arguments. + + """ @staticmethod def launch_dist_trainers(torch_client_filename, inputs): - # this is only used by the client (DDP or single process), so there is no need to specify the backend. + """ + Launch distributed trainers using the provided arguments. + + Args: + torch_client_filename (str): The filename of the PyTorch client script. + inputs (list): A list of input arguments to be passed to the client script. + + Returns: + None + """ + # This is only used by the client (DDP or single process), so there is no need to specify the backend. args = load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO) CrossSiloLauncher._run_cross_silo_horizontal(args, torch_client_filename, inputs) @staticmethod def _run_cross_silo_horizontal(args, torch_client_filename, inputs): + """ + Run the cross-silo horizontal federated learning process. + + Args: + args: Configuration arguments. + torch_client_filename (str): The filename of the PyTorch client script. + inputs (list): A list of input arguments to be passed to the client script. + + Returns: + None + """ python_path = subprocess.run(["which", "python"], capture_output=True, text=True).stdout.strip() process_arguments = [python_path, torch_client_filename] + inputs - subprocess.run(process_arguments) \ No newline at end of file + subprocess.run(process_arguments) diff --git a/python/fedml/fa/cross_silo/client/fa_local_analyzer.py b/python/fedml/fa/cross_silo/client/fa_local_analyzer.py index 0ec0c89dc9..9aeddc8365 100755 --- a/python/fedml/fa/cross_silo/client/fa_local_analyzer.py +++ b/python/fedml/fa/cross_silo/client/fa_local_analyzer.py @@ -3,6 +3,49 @@ class FALocalAnalyzer(object): + """ + A class representing a local analyzer for federated learning. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples. + train_data_num (int): The total number of training samples. + args: Configuration arguments. + local_analyzer: An instance of the local analyzer. + + Attributes: + local_analyzer: An instance of the local analyzer. + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples. + all_train_data_num (int): The total number of training samples. + train_local: Local training data for the client. + local_sample_number: The number of local training samples for the client. + test_local: Local testing data for the client. + args: Configuration arguments. + init_msg: Initialization message for the client. + + Methods: + set_init_msg(init_msg): + Set the initialization message for the client. + + get_init_msg(): + Get the initialization message for the client. + + set_server_data(server_data): + Set the server data for the client. + + set_client_submission(client_submission): + Set the client's submission. + + update_dataset(client_index): + Update the client's dataset based on the provided client index. + + local_analyze(round_idx=None): + Perform local analysis for federated learning. + + """ def __init__( self, client_index, @@ -12,6 +55,20 @@ def __init__( args, local_analyzer, ): + """ + Initialize the FALocalAnalyzer. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples. + train_data_num (int): The total number of training samples. + args: Configuration arguments. + local_analyzer: An instance of the local analyzer. + + Returns: + None + """ self.local_analyzer = local_analyzer self.client_index = client_index self.train_data_local_dict = train_data_local_dict @@ -24,18 +81,60 @@ def __init__( self.init_msg = None def set_init_msg(self, init_msg): + """ + Set the initialization message for the client. + + Args: + init_msg: Initialization message for the client. + + Returns: + None + """ self.local_analyzer.set_init_msg(init_msg) def get_init_msg(self): + """ + Get the initialization message for the client. + + Returns: + Initialization message for the client. + """ return self.local_analyzer.get_init_msg() def set_server_data(self, server_data): + """ + Set the server data for the client. + + Args: + server_data: Server data for the client. + + Returns: + None + """ self.local_analyzer.set_server_data(server_data) def set_client_submission(self, client_submission): + """ + Set the client's submission. + + Args: + client_submission: Client's submission data. + + Returns: + None + """ self.local_analyzer.set_client_submission(client_submission) def update_dataset(self, client_index): + """ + Update the client's dataset based on the provided client index. + + Args: + client_index (int): The index of the client. + + Returns: + None + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -51,10 +150,19 @@ def update_dataset(self, client_index): self.local_analyzer.update_dataset(self.train_local, self.local_sample_number) def local_analyze(self, round_idx=None): + """ + Perform local analysis for federated learning. + + Args: + round_idx (int): The current round index (default is None). + + Returns: + Tuple containing client submission data and the number of local samples. + """ self.args.round_idx = round_idx tick = time.time() self.local_analyzer.local_analyze(self.train_local, self.args) MLOpsProfilerEvent.log_to_wandb({"Train/Time": time.time() - tick, "round": round_idx}) client_submission = self.local_analyzer.get_client_submission() - return client_submission, self.local_sample_number \ No newline at end of file + return client_submission, self.local_sample_number diff --git a/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py b/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py index 1623114f36..0f56970226 100644 --- a/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py +++ b/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py @@ -11,7 +11,71 @@ class ClientMasterManager(FedMLCommManager): + """ + Manages the communication and training process for a federated learning client master. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + comm: A communication backend (default is None). + rank (int): The rank of the client (default is 0). + size (int): The size of the communication group (default is 0). + backend (str): The communication backend (default is "MPI"). + + Attributes: + trainer_dist_adapter: An instance of the trainer distribution adapter. + args (object): An object containing client configuration parameters. + num_rounds (int): The total number of communication rounds. + round_idx (int): The current communication round index. + rank (int): The rank of the client. + client_real_ids (list): A list of client real IDs. + client_real_id (str): The client's real ID. + has_sent_online_msg (bool): A flag indicating if the online message has been sent. + + Methods: + register_message_receive_handlers(): + Register message receive handlers for various message types. + handle_message_connection_ready(msg_params): + Handle the connection-ready message. + handle_message_check_status(msg_params): + Handle the check-client-status message. + handle_message_init(msg_params): + Handle the initialization message. + handle_message_receive_model_from_server(msg_params): + Handle the message to receive a model from the server. + handle_message_finish(msg_params): + Handle the message indicating the completion of training. + cleanup(): + Perform cleanup after training finishes. + send_model_to_server(receive_id, weights, local_sample_num): + Send the model and related information to the server. + send_client_status(receive_id, status="ONLINE"): + Send the client's status to the server. + report_training_status(status): + Report the training status to MLOps. + sync_process_group(round_idx, model_params=None, client_index=None, src=0): + Synchronize the process group with round information. + __train(): + Perform the training for the current round. + run(): + Start the client master manager's communication and training process. + + """ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the ClientMasterManager. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + comm: A communication backend (default is None). + rank (int): The rank of the client (default is 0). + size (int): The size of the communication group (default is 0). + backend (str): The communication backend (default is "MPI"). + + Returns: + None + """ super().__init__(args, comm, rank, size, backend) self.trainer_dist_adapter = trainer_dist_adapter self.args = args @@ -20,11 +84,17 @@ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backen self.rank = rank self.client_real_ids = json.loads(args.client_id_list) logging.info("self.client_real_ids = {}".format(self.client_real_ids)) - # for the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others. + # For the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others. self.client_real_id = self.client_real_ids[0] self.has_sent_online_msg = False def register_message_receive_handlers(self): + """ + Register message receive handlers for various message types. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -43,6 +113,15 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection-ready message. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -50,9 +129,27 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the check-client-status message. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) init_msg = msg_params.get(MyMessage.MSG_INIT_MSG_TO_CLIENTS) @@ -70,6 +167,15 @@ def handle_message_init(self, msg_params): self.round_idx += 1 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the message to receive a model from the server. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -82,14 +188,40 @@ def handle_message_receive_model_from_server(self, msg_params): self.round_idx += 1 def handle_message_finish(self, msg_params): + """ + Handle the message indicating the completion of training. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ logging.info(" ====================cleanup ====================") self.cleanup() def cleanup(self): + """ + Perform cleanup after training finishes. + + Returns: + None + """ self.finish() mlops.log_training_finished_status() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model and related information to the server. + + Args: + receive_id (int): The ID of the receiver. + weights (object): Model weights or parameters. + local_sample_num (int): The number of local samples. + + Returns: + None + """ tick = time.time() mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id,) @@ -103,6 +235,16 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def send_client_status(self, receive_id, status="ONLINE"): + """ + Send the client's status to the server. + + Args: + receive_id (int): The ID of the receiver. + status (str): The client's status (default is "ONLINE"). + + Returns: + None + """ logging.info("send_client_status") logging.info("self.client_real_id = {}".format(self.client_real_id)) message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) @@ -117,9 +259,30 @@ def send_client_status(self, receive_id, status="ONLINE"): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status (str): The training status to report. + + Returns: + None + """ mlops.log_training_status(status) def sync_process_group(self, round_idx, model_params=None, client_index=None, src=0): + """ + Synchronize the process group with round information. + + Args: + round_idx (int): The current round index. + model_params (object): Model weights or parameters (default is None). + client_index (int): The index of the client (default is None). + src (int): The source process rank (default is 0). + + Returns: + None + """ logging.info("sending round number to pg") round_number = [round_idx, model_params, client_index] dist.broadcast_object_list( @@ -128,6 +291,12 @@ def sync_process_group(self, round_idx, model_params=None, client_index=None, sr logging.info("round number %d broadcast to process group" % round_number[0]) def __train(self): + """ + Perform the training for the current round. + + Returns: + None + """ logging.info("#######training########### round_id = %d" % self.round_idx) mlops.event("train", event_started=True, event_value=str(self.round_idx)) @@ -139,4 +308,10 @@ def __train(self): self.send_model_to_server(0, client_submission, local_sample_num) def run(self): + """ + Start the client master manager's communication and training process. + + Returns: + None + """ super().run() diff --git a/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py b/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py index 48f30d8263..f6cc3af6e2 100644 --- a/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py +++ b/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py @@ -4,7 +4,42 @@ class ClientSlaveManager: + """ + Manages the training process for a federated learning client slave. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + + Attributes: + trainer_dist_adapter: An instance of the trainer distribution adapter. + args (object): An object containing client configuration parameters. + round_idx (int): The current training round index. + num_rounds (int): The total number of training rounds. + finished (bool): A flag indicating if training has finished. + + Methods: + train(): + Perform training for the current round. + finish(): + Finish the client slave's training. + await_sync_process_group(src=0): + Await synchronization with the process group and receive round information. + run(): + Start the client slave's training process. + + """ def __init__(self, args, trainer_dist_adapter): + """ + Initialize the ClientSlaveManager. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + + Returns: + None + """ self.trainer_dist_adapter = trainer_dist_adapter self.args = args self.round_idx = 0 @@ -12,6 +47,12 @@ def __init__(self, args, trainer_dist_adapter): self.finished = False def train(self): + """ + Perform training for the current round. + + Returns: + None + """ [round_idx, model_params, client_index] = self.await_sync_process_group() if round_idx: self.round_idx = round_idx @@ -28,7 +69,12 @@ def train(self): self.trainer_dist_adapter.train(self.round_idx) def finish(self): - # pass + """ + Finish the client slave's training. + + Returns: + None + """ self.trainer_dist_adapter.cleanup_pg() logging.info( "Training finished for slave client rank %s in silo %s" @@ -37,6 +83,15 @@ def finish(self): self.finished = True def await_sync_process_group(self, src=0): + """ + Await synchronization with the process group and receive round information. + + Args: + src (int): The source process rank to receive data from (default is 0). + + Returns: + list: A list containing round index, model parameters, and client index. + """ logging.info("process %d waiting for round number" % dist.get_rank()) objects = [None, None, None] dist.broadcast_object_list( @@ -46,5 +101,11 @@ def await_sync_process_group(self, src=0): return objects def run(self): + """ + Start the client slave's training process. + + Returns: + None + """ while not self.finished: self.train() diff --git a/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py b/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py index d758b11fec..00d8f9692f 100644 --- a/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py +++ b/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py @@ -4,6 +4,36 @@ class TrainerDistAdapter: + """ + Adapter for a Federated Learning Trainer with Distributed Training. + + Args: + args (object): An object containing trainer configuration parameters. + client_rank (int): The rank of the client. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + train_data_local_dict (dict): A dictionary of client-specific training data. + local_analyzer: An instance of the local analyzer (optional). + + Attributes: + client_index (int): The index of the client. + client_rank (int): The rank of the client. + local_analyzer: An instance of the local analyzer. + args (object): An object containing trainer configuration parameters. + + Methods: + local_analyze(round_idx): + Perform local analysis for a given training round. + set_server_data(server_data): + Set server data for the local analyzer. + set_init_msg(init_msg): + Set initialization message for the local analyzer. + set_client_submission(client_submission): + Set client submission for the local analyzer. + update_dataset(client_index=None): + Update the dataset for the local analyzer. + + """ def __init__( self, args, @@ -13,6 +43,23 @@ def __init__( train_data_local_dict, local_analyzer, ): + """ + Initialize the TrainerDistAdapter. + + Args: + args (object): An object containing trainer configuration parameters. + client_rank (int): The rank of the client. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + train_data_local_dict (dict): A dictionary of client-specific training data. + local_analyzer: An instance of the local analyzer (optional). + + Note: + This constructor sets up the adapter and initializes it with the provided dataset and configuration. + + Returns: + None + """ if local_analyzer is None: local_analyzer = create_local_analyzer(args=args) @@ -42,6 +89,20 @@ def get_local_analyzer( args, local_analyzer, ): + """ + Get an instance of the local analyzer. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary of client-specific training data. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + train_data_num (int): The total number of training data samples. + args (object): An object containing trainer configuration parameters. + local_analyzer: An instance of the local analyzer (optional). + + Returns: + FALocalAnalyzer: An instance of the local analyzer. + """ return FALocalAnalyzer( client_index, train_data_local_dict, @@ -52,18 +113,63 @@ def get_local_analyzer( ) def local_analyze(self, round_idx): + """ + Perform local analysis for a given training round. + + Args: + round_idx (int): The index of the training round. + + Returns: + tuple: A tuple containing client submission and local sample count. + """ client_submission, local_sample_num = self.local_analyzer.local_analyze(round_idx) return client_submission, local_sample_num def set_server_data(self, server_data): + """ + Set server data for the local analyzer. + + Args: + server_data: Data received from the server. + + Returns: + None + """ self.local_analyzer.set_server_data(server_data) def set_init_msg(self, init_msg): + """ + Set initialization message for the local analyzer. + + Args: + init_msg: Initialization message received from the server. + + Returns: + None + """ self.local_analyzer.set_init_msg(init_msg) def set_client_submission(self, client_submission): + """ + Set client submission for the local analyzer. + + Args: + client_submission: Client's training submission. + + Returns: + None + """ self.local_analyzer.set_client_submission(client_submission) def update_dataset(self, client_index=None): + """ + Update the dataset for the local analyzer. + + Args: + client_index (int): The index of the client (optional). + + Returns: + None + """ _client_index = client_index or self.client_index - self.local_analyzer.update_dataset(int(_client_index)) \ No newline at end of file + self.local_analyzer.update_dataset(int(_client_index)) diff --git a/python/fedml/fa/cross_silo/client/process_group_manager.py b/python/fedml/fa/cross_silo/client/process_group_manager.py index 92519c6cc4..ff5970b89f 100644 --- a/python/fedml/fa/cross_silo/client/process_group_manager.py +++ b/python/fedml/fa/cross_silo/client/process_group_manager.py @@ -1,12 +1,46 @@ import logging import os - import torch import torch.distributed as dist - class ProcessGroupManager: + """ + Manages the initialization and cleanup of process groups for distributed training. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master process. + master_port (int): The port for communication with the master process. + only_gpu (bool): Whether to use NCCL backend if GPUs are available, otherwise use GLOO. + + Attributes: + messaging_pg (dist.ProcessGroup): The initialized process group for messaging. + + Methods: + cleanup(): + Cleanup and destroy the process group. + get_process_group(): + Get the initialized process group. + + """ def __init__(self, rank, world_size, master_address, master_port, only_gpu): + """ + Initialize the ProcessGroupManager. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master process. + master_port (int): The port for communication with the master process. + only_gpu (bool): Whether to use NCCL backend if GPUs are available, otherwise use GLOO. + + Note: + This constructor sets up the process group and environment variables. + + Returns: + None + """ logging.info("Start process group") logging.info( "rank: %d, world_size: %d, master_address: %s, master_port: %s" @@ -31,7 +65,19 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Cleanup and destroy the process group. + + Returns: + None + """ dist.destroy_process_group() def get_process_group(self): + """ + Get the initialized process group. + + Returns: + dist.ProcessGroup: The initialized process group. + """ return self.messaging_pg diff --git a/python/fedml/fa/cross_silo/fa_client.py b/python/fedml/fa/cross_silo/fa_client.py index 2971c08193..2f43acb0c3 100644 --- a/python/fedml/fa/cross_silo/fa_client.py +++ b/python/fedml/fa/cross_silo/fa_client.py @@ -3,7 +3,39 @@ class FACrossSiloClient: + """ + Federated Learning Client for Cross-Silo Federated Learning. + + Args: + args (object): An object containing client configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + client_analyzer (FAClientAnalyzer): An instance of the client analyzer (optional). + + Attributes: + args (object): An object containing client configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + client_analyzer (FAClientAnalyzer): An instance of the client analyzer. + + Methods: + run(): + Start the Cross-Silo Federated Learning client. + + """ def __init__(self, args, dataset, client_analyzer: FAClientAnalyzer = None): + """ + Initialize the Cross-Silo Federated Learning client. + + Args: + args (object): An object containing client configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + client_analyzer (FAClientAnalyzer): An instance of the client analyzer (optional). + + Note: + This constructor sets up the client and initializes it with the provided dataset and configuration. + + Returns: + None + """ [ train_data_num, train_data_local_num_dict, @@ -21,4 +53,10 @@ def __init__(self, args, dataset, client_analyzer: FAClientAnalyzer = None): ) def run(self): + """ + Start the Cross-Silo Federated Learning client. + + Returns: + None + """ pass diff --git a/python/fedml/fa/cross_silo/fa_server.py b/python/fedml/fa/cross_silo/fa_server.py index a3242dde26..ab1e8f119b 100644 --- a/python/fedml/fa/cross_silo/fa_server.py +++ b/python/fedml/fa/cross_silo/fa_server.py @@ -3,7 +3,39 @@ class FACrossSiloServer: + """ + Federated Learning Server for Cross-Silo Federated Learning. + + Args: + args (object): An object containing server configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + server_aggregator (FAServerAggregator): An instance of the server aggregator (optional). + + Attributes: + args (object): An object containing server configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + server_aggregator (FAServerAggregator): An instance of the server aggregator. + + Methods: + run(): + Start the Cross-Silo Federated Learning server. + + """ def __init__(self, args, dataset, server_aggregator: FAServerAggregator = None): + """ + Initialize the Cross-Silo Federated Learning server. + + Args: + args (object): An object containing server configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + server_aggregator (FAServerAggregator): An instance of the server aggregator (optional). + + Note: + This constructor sets up the server and initializes it with the provided dataset and configuration. + + Returns: + None + """ [ train_data_num, train_data_local_num_dict, @@ -21,4 +53,10 @@ def __init__(self, args, dataset, server_aggregator: FAServerAggregator = None): ) def run(self): + """ + Start the Cross-Silo Federated Learning server. + + Returns: + None + """ pass diff --git a/python/fedml/fa/cross_silo/server/fedml_aggregator.py b/python/fedml/fa/cross_silo/server/fedml_aggregator.py index 63d7331214..169f40608e 100644 --- a/python/fedml/fa/cross_silo/server/fedml_aggregator.py +++ b/python/fedml/fa/cross_silo/server/fedml_aggregator.py @@ -5,6 +5,41 @@ class FAAggregator(object): + """ + The FAAggregator class handles the aggregation of local models and sample numbers from clients. + + Args: + all_train_data_num (int): The total number of training data samples. + train_data_local_dict (dict): A dictionary containing the local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for each client. + client_num (int): The number of clients. + args: Additional arguments. + server_aggregator: The server aggregator responsible for aggregation. + + Attributes: + aggregator: The server aggregator responsible for aggregation. + args: Additional arguments. + all_train_data_num (int): The total number of training data samples. + train_data_local_dict (dict): A dictionary containing the local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for each client. + client_num (int): The number of clients. + model_dict (dict): A dictionary containing the model parameters from each client. + sample_num_dict (dict): A dictionary containing the number of samples from each client. + flag_client_model_uploaded_dict (dict): A dictionary tracking whether each client has uploaded its model. + + Methods: + get_init_msg(): Get the initialization message from the server aggregator. + set_init_msg(init_msg): Set the initialization message in the server aggregator. + get_server_data(): Get the server data from the server aggregator. + set_server_data(server_data): Set the server data in the server aggregator. + add_local_trained_result(index, model_params, sample_num): Add local model parameters and sample numbers from a client. + check_whether_all_receive(): Check if all clients have uploaded their models. + aggregate(): Aggregate local models and calculate the global result. + data_silo_selection(round_idx, client_num_in_total, client_num_per_round): Select data silos for clients in a round. + client_selection(round_idx, client_id_list_in_total, client_num_per_round): Select clients for a round. + client_sampling(round_idx, client_num_in_total, client_num_per_round): Sample clients for a round. + """ + def __init__( self, all_train_data_num, @@ -27,24 +62,71 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_init_msg(self): + """ + Get the initialization message from the server aggregator. + + Returns: + Any: The initialization message. + """ return self.aggregator.get_init_msg() def set_init_msg(self, init_msg): + """ + Set the initialization message in the server aggregator. + + Args: + init_msg: The initialization message to set. + + Returns: + None + """ self.aggregator.set_init_msg(init_msg) def get_server_data(self): + """ + Get the server data from the server aggregator. + + Returns: + Any: The server data. + """ return self.aggregator.get_server_data() def set_server_data(self, server_data): + """ + Set the server data in the server aggregator. + + Args: + server_data: The server data to set. + + Returns: + None + """ self.aggregator.set_server_data(server_data) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add local model parameters and sample numbers from a client. + + Args: + index (int): The index of the client. + model_params: The local model parameters. + sample_num (int): The number of samples used for training. + + Returns: + None + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.debug("client_num = {}".format(self.client_num)) for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -54,6 +136,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models and calculate the global result. + + Returns: + tuple: A tuple containing the global result and a list of local results. + """ start_time = time.time() local_result_list = [] @@ -70,16 +158,15 @@ def aggregate(self): def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_round): """ + Select data silos for clients in a round. + Args: - round_idx: round index, starting from 0 - client_num_in_total: this is equal to the users in a synthetic data, - e.g., in synthetic_1_1, this value is 30 - client_num_per_round: the number of edge devices that can train + round_idx (int): The round index, starting from 0. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients that can train in a round. Returns: - data_silo_index_list: e.g., when client_num_in_total = 30, client_num_in_total = 3, - this value is the form of [0, 11, 20] - + list: A list of data silo indexes. """ logging.info( "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) @@ -89,37 +176,46 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): """ + Select clients for a round. + Args: - round_idx: round index, starting from 0 - client_id_list_in_total: this is the real edge IDs. - In MLOps, its element is real edge ID, e.g., [64, 65, 66, 67]; - in simulated mode, its element is client index starting from 1, e.g., [1, 2, 3, 4] - client_num_per_round: + round_idx (int): The round index, starting from 0. + client_id_list_in_total (list): A list of real edge IDs or client indices. + client_num_per_round (int): The number of clients to select. Returns: - client_id_list_in_this_round: sampled real edge ID list, e.g., [64, 66] + list: A list of selected client IDs. """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample clients for a round. + + Args: + round_idx (int): The round index, starting from 0. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to sample. + + Returns: + list: A list of sampled client indices. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes - - - + \ No newline at end of file diff --git a/python/fedml/fa/cross_silo/server/fedml_server_manager.py b/python/fedml/fa/cross_silo/server/fedml_server_manager.py index e0182bdab9..89baf3ed35 100644 --- a/python/fedml/fa/cross_silo/server/fedml_server_manager.py +++ b/python/fedml/fa/cross_silo/server/fedml_server_manager.py @@ -9,9 +9,84 @@ class FedMLServerManager(FedMLCommManager): + """ + Federated Learning Server Manager for Cross-Silo Federated Learning. + + Args: + args (object): An object containing server configuration parameters. + aggregator (FAAggregator): An instance of the server aggregator. + comm: The communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend (str): The backend for communication (e.g., "MQTT_S3"). + + Attributes: + args (object): An object containing server configuration parameters. + aggregator (FAAggregator): An instance of the server aggregator. + round_num (int): The number of communication rounds. + client_online_mapping (dict): A dictionary mapping client IDs to their online status. + client_real_ids (list): A list of real client IDs. + is_initialized (bool): A flag indicating whether the server is initialized. + client_id_list_in_this_round (list): A list of client IDs for the current round. + data_silo_index_list (list): A list of data silo indices for clients in the current round. + + Methods: + run(): + Start the Federated Learning server. + + send_init_msg(): + Send initialization messages to clients. + + register_message_receive_handlers(): + Register handlers for receiving messages. + + handle_message_connection_ready(msg_params): + Handle the connection ready message from clients. + + handle_message_client_status_update(msg_params): + Handle client status updates. + + handle_message_receive_model_from_client(msg_params): + Handle received models from clients. + + cleanup(): + Perform cleanup operations after completing a round of communication. + + send_message_init_config(receive_id, global_model_params, datasilo_index, + global_model_url=None, global_model_key=None): + Send initialization configuration messages to clients. + + send_message_check_client_status(receive_id, datasilo_index): + Send client status check messages to clients. + + send_message_finish(receive_id, datasilo_index): + Send finish messages to clients. + + send_message_sync_model_to_client(receive_id, global_model_params, client_index, + global_model_url=None, global_model_key=None): + Send synchronized model messages to clients. + + """ def __init__( self, args, aggregator, comm=None, client_rank=0, client_num=0, backend="MQTT_S3", ): + """ + Initialize the Federated Learning Server Manager. + + Args: + args (object): An object containing server configuration parameters. + aggregator (FAAggregator): An instance of the server aggregator. + comm: The communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend (str): The backend for communication (e.g., "MQTT_S3"). + + Note: + This constructor sets up the server manager with the provided configuration and aggregator. + + Returns: + None + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.aggregator = aggregator @@ -24,9 +99,21 @@ def __init__( self.data_silo_index_list = None def run(self): + """ + Start the Federated Learning server. + + Returns: + None + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + Returns: + None + """ global_result = self.aggregator.get_server_data() global_result_url = None @@ -43,6 +130,12 @@ def send_init_msg(self): mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) def register_message_receive_handlers(self): + """ + Register handlers for receiving messages. + + Returns: + None + """ logging.info("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready @@ -57,6 +150,15 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection ready message from clients. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( self.args.round_idx, self.client_real_ids, self.args.client_num_per_round @@ -78,6 +180,15 @@ def handle_message_connection_ready(self, msg_params): client_idx_this_round += 1 def handle_message_client_status_update(self, msg_params): + """ + Handle client status updates. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == "ONLINE": self.client_online_mapping[str(msg_params.get_sender_id())] = True @@ -100,6 +211,15 @@ def handle_message_client_status_update(self, msg_params): self.is_initialized = True def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received models from clients. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) local_results = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) @@ -157,6 +277,12 @@ def handle_message_receive_model_from_client(self, msg_params): self.aggregator.set_init_msg(init_msg=None) def cleanup(self): + """ + Perform cleanup operations after completing a round of communication. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -169,6 +295,19 @@ def cleanup(self): def send_message_init_config(self, receive_id, global_model_params, datasilo_index, global_model_url=None, global_model_key=None): + """ + Send initialization configuration messages to clients. + + Args: + receive_id: The ID of the receiving client. + global_model_params: The global model parameters. + datasilo_index: The data silo index. + global_model_url (str, optional): The URL of global model parameters. Defaults to None. + global_model_key (str, optional): The key of global model parameters. Defaults to None. + + Returns: + tuple: A tuple containing the updated global_model_url and global_model_key. + """ tick = time.time() message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) if global_model_url is not None: @@ -187,11 +326,31 @@ def send_message_init_config(self, receive_id, global_model_params, datasilo_ind return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): + """ + Send client status check messages to clients. + + Args: + receive_id: The ID of the receiving client. + datasilo_index: The data silo index. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): + """ + Send finish messages to clients. + + Args: + receive_id: The ID of the receiving client. + datasilo_index: The data silo index. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) @@ -201,6 +360,19 @@ def send_message_finish(self, receive_id, datasilo_index): def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index, global_model_url=None, global_model_key=None): + """ + Send synchronized model messages to clients. + + Args: + receive_id: The ID of the receiving client. + global_model_params: The global model parameters. + client_index: The client index. + global_model_url (str, optional): The URL of global model parameters. Defaults to None. + global_model_key (str, optional): The key of global model parameters. Defaults to None. + + Returns: + tuple: A tuple containing the updated global_model_url and global_model_key. + """ tick = time.time() logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) diff --git a/python/fedml/fa/cross_silo/server/server_initializer.py b/python/fedml/fa/cross_silo/server/server_initializer.py index cff55ecc26..c296bf1744 100644 --- a/python/fedml/fa/cross_silo/server/server_initializer.py +++ b/python/fedml/fa/cross_silo/server/server_initializer.py @@ -13,6 +13,22 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the Federated Learning server for Cross-Silo Federated Learning. + + Args: + args (object): An object containing server configuration parameters. + comm: The communication object. + rank (int): The rank of the server. + worker_num (int): The total number of workers. + train_data_num (int): The total number of training data samples. + train_data_local_dict (dict): A dictionary of client-specific training data. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + server_aggregator: An instance of the server aggregator (optional). + + Returns: + None + """ if server_aggregator is None: server_aggregator = create_global_analyzer(args, train_data_num=train_data_num) server_aggregator.set_id(0) diff --git a/python/fedml/fa/data/data_loader.py b/python/fedml/fa/data/data_loader.py index a29877cc98..2a6e820da0 100644 --- a/python/fedml/fa/data/data_loader.py +++ b/python/fedml/fa/data/data_loader.py @@ -11,12 +11,31 @@ def fa_load_data(args): + """ + Load synthetic data based on the specified dataset. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + list: A list containing dataset information. + """ return load_synthetic_data(args) def load_synthetic_data(args): + """ + Load synthetic data based on the specified dataset name. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + list: A list containing dataset information. + """ dataset_name = args.dataset if dataset_name == "fake": + # Load fake numeric data data_cache_dir = os.path.join(args.data_cache_dir, "fake_numeric_data") if not os.path.exists(data_cache_dir): os.makedirs(data_cache_dir, exist_ok=True) @@ -33,7 +52,7 @@ def load_synthetic_data(args): train_data_local_num_dict, local_data_dict, ] - # print(f"datasize, train_data_local_num_dict, local_data_dict,{dataset}") + elif dataset_name == "twitter": path = os.path.join(args.data_cache_dir, "twitter_Sentiment140") download_twitter_Sentiment140(data_cache_dir=path) @@ -70,7 +89,7 @@ def load_synthetic_data(args): if hasattr(args, "seperator"): separator = args.seperator else: - separator = "," # default seperator = "," + separator = "," # default separator = "," ( datasize, train_data_local_num_dict, @@ -110,6 +129,7 @@ def load_synthetic_data_test(): load_synthetic_data(args=args) + if __name__ == '__main__': # read_data(train_data_dir="fake_data") # download_twitter_Sentiment140("data") diff --git a/python/fedml/fa/data/fake_numeric_data/data_loader.py b/python/fedml/fa/data/fake_numeric_data/data_loader.py index a2e461b097..f615948204 100644 --- a/python/fedml/fa/data/fake_numeric_data/data_loader.py +++ b/python/fedml/fa/data/fake_numeric_data/data_loader.py @@ -5,16 +5,36 @@ def generate_fake_data(data_cache_dir): - file_path = data_cache_dir + "/fake_numeric_data.txt" + """ + Generate fake numeric data and save it to a text file in the specified directory. + + Args: + data_cache_dir (str): The directory where the fake numeric data file should be saved. + + Note: + This function generates random integer data and writes it to a text file. + + Returns: + None + """ + file_path = os.path.join(data_cache_dir, "fake_numeric_data.txt") if not os.path.exists(file_path): - f = open(file_path, "a") - for i in range(10000): - f.write(f"{random.randint(1, 100)}\n") - f.close() + with open(file_path, "a") as f: + for i in range(10000): + f.write(f"{random.randint(1, 100)}\n") def load_partition_data_fake(data_dir, client_num): + """ + Load and partition fake data from a specified directory into client-specific partitions. + + Args: + data_dir (str): The directory path where the fake data is located. + client_num (int): The total number of clients to partition the data for. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ dataset = read_data(data_dir=data_dir) return equally_partition_a_dataset(client_num, dataset) - diff --git a/python/fedml/fa/data/self_defined_data/data_loader.py b/python/fedml/fa/data/self_defined_data/data_loader.py index 9a249ca0b3..333a53ad40 100644 --- a/python/fedml/fa/data/self_defined_data/data_loader.py +++ b/python/fedml/fa/data/self_defined_data/data_loader.py @@ -6,6 +6,18 @@ def generate_fake_data(data_cache_dir): + """ + Generate fake numeric data and save it to a text file in the specified directory. + + Args: + data_cache_dir (str): The directory where the fake numeric data file should be saved. + + Note: + This function generates random integer data and writes it to a text file. + + Returns: + None + """ file_path = data_cache_dir + "/fake_numeric_data.txt" if not os.path.exists(file_path): @@ -16,9 +28,23 @@ def generate_fake_data(data_cache_dir): def load_partition_self_defined_data(file_folder_path, client_num, data_col_idx, separator=","): + """ + Load and partition self-defined data from a text file into client-specific partitions. + + Args: + file_folder_path (str): The path to the text file containing the data. + client_num (int): The total number of clients to partition the data for. + data_col_idx (int): The column index of the data to be used. + separator (str): The separator used in the data file (default is comma ','). + + Raises: + Exception: If the specified data file does not exist. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ if not os.path.exists(file_folder_path): raise Exception(f"No data file: {file_folder_path}") logging.info(f"file_folder_path = {file_folder_path}") - dataset = read_data_with_column_idx(file_folder_path=file_folder_path, column_idx=data_col_idx, seperator=separator) + dataset = read_data_with_column_idx(file_folder_path=file_folder_path, column_idx=data_col_idx, separator=separator) return equally_partition_a_dataset(client_num, dataset) - diff --git a/python/fedml/fa/data/twitter_Sentiment140/data_loader.py b/python/fedml/fa/data/twitter_Sentiment140/data_loader.py index c99d9b8d77..007c79b524 100644 --- a/python/fedml/fa/data/twitter_Sentiment140/data_loader.py +++ b/python/fedml/fa/data/twitter_Sentiment140/data_loader.py @@ -9,6 +9,18 @@ def download_twitter_Sentiment140(data_cache_dir): + """ + Download the Sentiment140 Twitter dataset if it doesn't exist in the specified directory. + + Args: + data_cache_dir (str): The directory where the dataset should be downloaded. + + Note: + This function downloads the dataset from a URL and extracts it to the specified directory. + + Returns: + None + """ if not os.path.exists(data_cache_dir): os.makedirs(data_cache_dir, exist_ok=True) file_path = os.path.join(data_cache_dir, "trainingandtestdata.zip") @@ -21,10 +33,30 @@ def download_twitter_Sentiment140(data_cache_dir): def load_partition_data_twitter_sentiment140(dataset, client_num_in_total): + """ + Load and partition the Sentiment140 Twitter dataset into client-specific partitions. + + Args: + dataset (dict): A dictionary containing client usernames as keys and their data as values. + client_num_in_total (int): The total number of clients to partition the data for. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ return equally_partition_a_dataset_according_to_users(client_num_in_total, dataset) def load_partition_data_twitter_sentiment140_heavy_hitter(dataset, client_num_in_total): + """ + Load and partition the Sentiment140 Twitter dataset for heavy hitters into client-specific partitions. + + Args: + dataset (dict): A dictionary containing client usernames as keys and their data as values. + client_num_in_total (int): The total number of clients to partition the data for. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ local_data_dict = dict() train_data_local_num_dict = dict() heavy_hitters = list(dataset.values()) @@ -41,4 +73,4 @@ def load_partition_data_twitter_sentiment140_heavy_hitter(dataset, client_num_in datasize, train_data_local_num_dict, local_data_dict, - ) \ No newline at end of file + ) diff --git a/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py b/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py index 8c048b555d..4fc1bb6566 100644 --- a/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py +++ b/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py @@ -10,9 +10,16 @@ def is_valid(word): - if len(word) < 3 or (word[-1] in [ - '?', '!', '.', ';', ',' - ]) or word.startswith('http') or word.startswith('www'): + """ + Check if a word is valid for processing. + + Args: + word (str): The word to check. + + Returns: + bool: True if the word is valid, False otherwise. + """ + if len(word) < 3 or (word[-1] in ['?', '!', '.', ';', ',']) or word.startswith('http') or word.startswith('www'): return False if re.match(r'^[a-z_\@\#\-\;\(\)\*\:\.\'\/]+$', word): return True @@ -20,6 +27,16 @@ def is_valid(word): def truncate_or_extend(word, max_word_len): + """ + Truncate or extend a word to a specified length. + + Args: + word (str): The word to modify. + max_word_len (int): The desired maximum length of the word. + + Returns: + str: The modified word. + """ if len(word) > max_word_len: word = word[:max_word_len] else: @@ -28,10 +45,26 @@ def truncate_or_extend(word, max_word_len): def add_end_symbol(word): + """ + Add an end symbol ('$') to a word. + + Args: + word (str): The word to modify. + + Returns: + str: The modified word with an end symbol. + """ return word + '$' def generate_triehh_clients(clients, path): + """ + Generate TrieHH clients from a list of clients and save them to a file. + + Args: + clients (list): List of client names. + path (str): The directory path to save the file. + """ clients_num = len(clients) triehh_clients = [add_end_symbol(clients[i]) for i in range(clients_num)] word_freq = collections.defaultdict(lambda: 0) @@ -44,6 +77,15 @@ def generate_triehh_clients(clients, path): def preprocess_twitter_data(path): + """ + Preprocess Twitter data from a CSV file. + + Args: + path (str): The directory path where the CSV file is located. + + Returns: + dict: A dictionary containing client usernames as keys and lists of preprocessed words as values. + """ filename = os.path.join(path, 'training.1600000.processed.noemoticon.csv') dataset = {} with open(filename, encoding='ISO-8859-1') as csv_file: @@ -66,6 +108,15 @@ def preprocess_twitter_data(path): def preprocess_twitter_data_heavy_hitter(path): + """ + Preprocess Twitter data and identify heavy hitters (most frequent words) for each client. + + Args: + path (str): The directory path where the CSV file is located. + + Returns: + dict: A dictionary containing client usernames as keys and their identified heavy hitter words as values. + """ # load dataset from csv file filename = os.path.join(path, 'training.1600000.processed.noemoticon.csv') clients = {} diff --git a/python/fedml/fa/data/utils.py b/python/fedml/fa/data/utils.py index ada768dca4..a41b2b6867 100644 --- a/python/fedml/fa/data/utils.py +++ b/python/fedml/fa/data/utils.py @@ -3,6 +3,17 @@ def equally_partition_a_dataset(client_num_in_total, dataset): + """ + Equally partition a dataset among clients. + + Args: + client_num_in_total (int): The total number of clients. + dataset (list): The dataset to partition. + + Returns: + tuple: A tuple containing the total dataset size, a dictionary of local data counts per client, + and a dictionary of local data for each client. + """ client_data_num = int(len(dataset) / client_num_in_total) local_data_dict = dict() train_data_local_num_dict = dict() @@ -20,6 +31,17 @@ def equally_partition_a_dataset(client_num_in_total, dataset): def equally_partition_a_dataset_according_to_users(client_num_in_total, dataset): + """ + Equally partition a dataset among clients based on the number of users. + + Args: + client_num_in_total (int): The total number of clients. + dataset (dict): The dataset organized by user IDs. + + Returns: + tuple: A tuple containing the total dataset size, a dictionary of local data counts per client, + and a dictionary of local data for each client. + """ user_num_for_one_client = int(math.ceil(len(dataset) / client_num_in_total)) local_data_dict = dict() train_data_local_num_dict = dict() @@ -45,6 +67,15 @@ def equally_partition_a_dataset_according_to_users(client_num_in_total, dataset) def read_data(data_dir): + """ + Read data from text files in a directory. + + Args: + data_dir (str): The path to the directory containing text data files. + + Returns: + list: A list of integers representing the dataset. + """ train_files = os.listdir(data_dir) train_files = [f for f in train_files if f.endswith(".txt")] dataset = [] @@ -56,7 +87,18 @@ def read_data(data_dir): return dataset -def read_data_with_column_idx(file_folder_path, column_idx, seperator=","): +def read_data_with_column_idx(file_folder_path, column_idx, separator=","): + """ + Read data from text files in a directory, selecting a specific column. + + Args: + file_folder_path (str): The path to the directory containing text data files. + column_idx (int): The index of the column to extract. + separator (str, optional): The separator used in the text files (default is comma). + + Returns: + list: A list of values from the selected column. + """ train_files = os.listdir(file_folder_path) train_files = [f for f in train_files if not f.startswith(".")] dataset = [] @@ -64,6 +106,6 @@ def read_data_with_column_idx(file_folder_path, column_idx, seperator=","): file_path = os.path.join(file_folder_path, f) f2 = open(file_path, "r") for line in f2: - if len(line.split(seperator)[column_idx].strip()) > 0: - dataset.append(line.split(seperator)[column_idx].strip()) - return dataset \ No newline at end of file + if len(line.split(separator)[column_idx].strip()) > 0: + dataset.append(line.split(separator)[column_idx].strip()) + return dataset diff --git a/python/fedml/fa/local_analyzer/avg.py b/python/fedml/fa/local_analyzer/avg.py index 0e4761c66a..0866988d2e 100644 --- a/python/fedml/fa/local_analyzer/avg.py +++ b/python/fedml/fa/local_analyzer/avg.py @@ -1,10 +1,31 @@ from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class AverageClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for calculating the average of values in the training data. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to calculate the average of values and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to calculate the average of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ sample_num = len(train_data) average = 0.0 for value in train_data: - average = average + float(value) / float(sample_num) + average += float(value) / float(sample_num) self.set_client_submission(average) diff --git a/python/fedml/fa/local_analyzer/client_analyzer_creator.py b/python/fedml/fa/local_analyzer/client_analyzer_creator.py index 64c5a154f9..b5694ef361 100644 --- a/python/fedml/fa/local_analyzer/client_analyzer_creator.py +++ b/python/fedml/fa/local_analyzer/client_analyzer_creator.py @@ -7,8 +7,16 @@ from fedml.fa.local_analyzer.k_percentage_element import KPercentileElementClientAnalyzer from fedml.fa.local_analyzer.union import UnionClientAnalyzer - def create_local_analyzer(args): + """ + Create a specific type of local analyzer based on the task type. + + Args: + args (object): Arguments for the local analyzer creation. + + Returns: + object: A local analyzer instance based on the specified task type. + """ task_type = args.fa_task if task_type == FA_TASK_AVG: return AverageClientAnalyzer(args) @@ -24,4 +32,3 @@ def create_local_analyzer(args): return FrequencyEstimationClientAnalyzer(args) if task_type == FA_TASK_HEAVY_HITTER_TRIEHH: return TrieHHClientAnalyzer(args) - diff --git a/python/fedml/fa/local_analyzer/frequency_estimation.py b/python/fedml/fa/local_analyzer/frequency_estimation.py index 4477f068dd..98ca1154ed 100644 --- a/python/fedml/fa/local_analyzer/frequency_estimation.py +++ b/python/fedml/fa/local_analyzer/frequency_estimation.py @@ -2,12 +2,40 @@ class FrequencyEstimationClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for estimating the frequency of values in the training data. + + Args: + client_id: The unique identifier of the client. + server: The federated learning server. + + Attributes: + client_id: The unique identifier of the client. + server: The federated learning server. + + Methods: + local_analyze(train_data, args): + Analyze the training data to estimate the frequency of values and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to estimate the frequency of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ counter_dict = dict() for value in train_data: if counter_dict.get(value) is None: counter_dict[value] = 1 else: - counter_dict[value] = counter_dict[value] + 1 - self.set_client_submission(counter_dict) \ No newline at end of file + counter_dict[value] += 1 + + self.set_client_submission(counter_dict) diff --git a/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py b/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py index 3933d6e9bb..fac7197baa 100644 --- a/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py +++ b/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py @@ -3,8 +3,39 @@ from collections import defaultdict from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class TrieHHClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for Trie-HH federated learning. + + Args: + args: Additional arguments for configuration. + + Attributes: + round_counter (int): Counter to keep track of rounds. + batch_size (int): Size of the sample batch for analysis. + client_num_per_round (int): Number of clients per round. + + Methods: + __init__(self, args): + Initialize the TrieHHClientAnalyzer with provided arguments. + + set_init_msg(self, init_msg): + Set the initial message containing batch size. + + get_init_msg(self): + Get the initial message. + + local_analyze(self, train_data, args): + Analyze the local training data and set the client submission. + + client_vote(self, sample_local_dataset): + Perform voting based on local data and return the votes. + + one_word_vote(self, word): + Perform voting for a single word in the dataset. + + """ + def __init__(self, args): super().__init__(args=args) self.round_counter = 0 @@ -12,19 +43,53 @@ def __init__(self, args): self.client_num_per_round = args.client_num_per_round def set_init_msg(self, init_msg): + """ + Set the initial message containing batch size. + + Args: + init_msg: The initial message containing batch size. + + Returns: + None + """ self.init_msg = init_msg self.batch_size = self.init_msg def get_init_msg(self): + """ + Get the initial message. + + Returns: + int: The initial message containing batch size. + """ return self.init_msg def local_analyze(self, train_data, args): + """ + Analyze the training data and set the client submission. + + Args: + train_data (list): The training data for analysis. + args: Additional arguments (not used in this method). + + Returns: + None + """ idxs = np.random.choice(range(len(train_data)), self.batch_size, replace=False) sample_local_dataset = [train_data[i] for i in idxs] votes = self.client_vote(sample_local_dataset) self.set_client_submission(votes) def client_vote(self, sample_local_dataset): + """ + Perform voting based on local data and return the votes. + + Args: + sample_local_dataset (list): Sampled local dataset for voting. + + Returns: + dict: Dictionary containing votes. + """ votes = defaultdict(int) self.round_counter += 1 self.w_global = self.get_server_data() @@ -35,13 +100,21 @@ def client_vote(self, sample_local_dataset): return votes def one_word_vote(self, word): + """ + Perform voting for a single word in the dataset. + + Args: + word (str): A word from the dataset. + + Returns: + int: Voting result (1 if valid, 0 otherwise). + """ if len(word) < self.round_counter: return 0 pre = word[0:self.round_counter - 1] - # print(f"self.w_global={self.w_global}") - # print(f"pre = {pre}, type={type(self.w_global)}") + if self.w_global is None: return 1 if pre and (pre not in self.w_global): return 0 - return 1 \ No newline at end of file + return 1 diff --git a/python/fedml/fa/local_analyzer/intersection.py b/python/fedml/fa/local_analyzer/intersection.py index 76f5fe1b92..9102f31bd4 100644 --- a/python/fedml/fa/local_analyzer/intersection.py +++ b/python/fedml/fa/local_analyzer/intersection.py @@ -2,5 +2,27 @@ class IntersectionClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for finding the intersection of values in the training data. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to find the intersection of values and set the client submission. + + """ + def local_analyze(self, train_data, args): - self.set_client_submission(list(set(train_data))) \ No newline at end of file + """ + Analyze the training data to find the intersection of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ + self.set_client_submission(list(set(train_data))) diff --git a/python/fedml/fa/local_analyzer/k_percentage_element.py b/python/fedml/fa/local_analyzer/k_percentage_element.py index 1c075842a6..4ea7819580 100644 --- a/python/fedml/fa/local_analyzer/k_percentage_element.py +++ b/python/fedml/fa/local_analyzer/k_percentage_element.py @@ -1,10 +1,31 @@ from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class KPercentileElementClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for counting values larger than a given percentile. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to count values larger than a given percentile and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to count values larger than a given percentile and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ counter = 0 for data in train_data: if data >= self.server_data: # flag counter += 1 - self.set_client_submission(counter) # number of values that are larger than flag + self.set_client_submission(counter) # number of values that are larger than the flag diff --git a/python/fedml/fa/local_analyzer/union.py b/python/fedml/fa/local_analyzer/union.py index 4ce99a39ad..2b0ad16b63 100644 --- a/python/fedml/fa/local_analyzer/union.py +++ b/python/fedml/fa/local_analyzer/union.py @@ -1,6 +1,27 @@ from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class UnionClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for finding the union of values in the training data. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to find the union of values and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to find the union of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ self.set_client_submission(list(set(train_data))) diff --git a/python/fedml/fa/runner.py b/python/fedml/fa/runner.py index f69dfc38ee..c446b79077 100644 --- a/python/fedml/fa/runner.py +++ b/python/fedml/fa/runner.py @@ -1,8 +1,21 @@ from fedml import FEDML_SIMULATION_TYPE_SP, FEDML_TRAINING_PLATFORM_SIMULATION, FEDML_TRAINING_PLATFORM_CROSS_SILO from fedml.fa.simulation.sp.simulator import FASimulatorSingleProcess - class FARunner: + """ + A class for running Federated Learning simulations. + + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_trainer: The client trainer for training clients (optional). + server_aggregator: The server aggregator for aggregating client updates (optional). + + Methods: + run(): + Run the Federated Learning simulation. + + """ def __init__( self, args, @@ -10,7 +23,19 @@ def __init__( client_trainer=None, server_aggregator=None, ): + """ + Initialize the FARunner with the provided arguments and components. + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_trainer: The client trainer for training clients (optional). + server_aggregator: The server aggregator for aggregating client updates (optional). + + Raises: + Exception: If an invalid training type is specified in the arguments. + + """ if args.training_type == FEDML_TRAINING_PLATFORM_SIMULATION: init_runner_func = self._init_simulation_runner elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_SILO: @@ -25,6 +50,22 @@ def __init__( def _init_simulation_runner( self, args, dataset, client_analyzer=None, server_analyzer=None ): + """ + Initialize a simulation runner based on the provided arguments. + + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_analyzer: The client analyzer for analyzing client behavior (optional). + server_analyzer: The server analyzer for analyzing server behavior (optional). + + Returns: + FASimulatorSingleProcess: A simulation runner for single-process simulation. + + Raises: + Exception: If an unsupported simulation backend is specified in the arguments. + + """ if hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_SP: runner = FASimulatorSingleProcess(args, dataset) else: @@ -33,6 +74,22 @@ def _init_simulation_runner( return runner def _init_cross_silo_runner(self, args, dataset, client_analyzer=None, server_analyzer=None): + """ + Initialize a cross-silo runner based on the provided arguments. + + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_analyzer: The client analyzer for analyzing client behavior (optional). + server_analyzer: The server analyzer for analyzing server behavior (optional). + + Returns: + FACrossSiloClient or FACrossSiloServer: A cross-silo client or server runner. + + Raises: + Exception: If an invalid role is specified in the arguments. + + """ if args.role == "client": from fedml.fa.cross_silo.fa_client import FACrossSiloClient as Client runner = Client(args, dataset, client_analyzer) @@ -45,4 +102,8 @@ def _init_cross_silo_runner(self, args, dataset, client_analyzer=None, server_an return runner def run(self): + """ + Run the Federated Learning simulation. + + """ self.runner.run() diff --git a/python/fedml/fa/simulation/sp/client.py b/python/fedml/fa/simulation/sp/client.py index 1902d165b3..7de16c8ed9 100644 --- a/python/fedml/fa/simulation/sp/client.py +++ b/python/fedml/fa/simulation/sp/client.py @@ -1,10 +1,48 @@ import numpy as np - class Client: + """ + Client class for Federated Analytics simulation. + + Args: + client_idx (int): Index of the client. + local_training_data (list): Local training data for the client. + local_datasize (int): Size of the local training data. + args (object): Arguments for the simulation. + local_analyzer (object): Local analyzer instance. + + Attributes: + client_idx (int): Index of the client. + local_training_data (list): Local training data for the client. + local_datasize (int): Size of the local training data. + local_sample_number (int): Number of local samples. + args (object): Arguments for the simulation. + local_analyzer (object): Local analyzer instance. + + Methods: + update_local_dataset(client_idx, local_training_data, local_sample_number): + Update the client's local dataset and sample number. + + get_sample_number(): + Get the number of local samples. + + local_analyze(w_global): + Perform local analysis and return client's submission. + """ + def __init__( self, client_idx, local_training_data, local_datasize, args, local_analyzer, ): + """ + Initialize the Client class. + + Args: + client_idx (int): Index of the client. + local_training_data (list): Local training data for the client. + local_datasize (int): Size of the local training data. + args (object): Arguments for the simulation. + local_analyzer (object): Local analyzer instance. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_datasize = local_datasize @@ -13,18 +51,41 @@ def __init__( self.local_analyzer = local_analyzer def update_local_dataset(self, client_idx, local_training_data, local_sample_number): + """ + Update the client's local dataset and sample number. + + Args: + client_idx (int): Index of the client. + local_training_data (list): Updated local training data. + local_sample_number (int): Updated number of local samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_sample_number = local_sample_number self.local_analyzer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local samples. + + Returns: + int: Number of local samples. + """ return self.local_sample_number def local_analyze(self, w_global): + """ + Perform local analysis and return client's submission. + + Args: + w_global (object): Global data from the server. + + Returns: + object: Client's submission after local analysis. + """ self.local_analyzer.set_server_data(w_global) idxs = np.random.choice(range(len(self.local_training_data)), self.local_sample_number, replace=False) train_data = [self.local_training_data[i] for i in idxs] - # print(f"train data = {train_data}") + self.local_analyzer.local_analyze(train_data, self.args) return self.local_analyzer.get_client_submission() diff --git a/python/fedml/fa/simulation/sp/simulator.py b/python/fedml/fa/simulation/sp/simulator.py index 265d257dc9..627e5d9f94 100644 --- a/python/fedml/fa/simulation/sp/simulator.py +++ b/python/fedml/fa/simulation/sp/simulator.py @@ -7,7 +7,38 @@ class FASimulatorSingleProcess: + """ + Simulator for Federated Analytics with a Single Process. + + Args: + args (object): Arguments for the simulation. + dataset (list): Dataset information including train data count, local datasize, and train data for each client. + + Attributes: + args (object): Arguments for the simulation. + train_data_num_in_total (int): Total number of training data points. + client_list (list): List of client instances. + local_datasize_dict (dict): Dictionary of local datasizes for each client. + train_data_local_dict (dict): Dictionary of local training data for each client. + local_analyzer (object): Local analyzer instance. + aggregator (object): Global aggregator instance. + + Methods: + analyze(): + Run the Federated Analytics simulation. + + run(): + Run the simulation. + """ + def __init__(self, args, dataset): + """ + Initialize the FASimulatorSingleProcess class. + + Args: + args (object): Arguments for the simulation. + dataset (list): Dataset information including train data count, local datasize, and train data for each client. + """ self.args = args [ train_data_num, @@ -30,6 +61,14 @@ def __init__(self, args, dataset): def _setup_clients( self, local_datasize_dict, train_data_local_dict, local_analyzer, ): + """ + Set up client instances for the simulation. + + Args: + local_datasize_dict (dict): Dictionary of local datasizes for each client. + train_data_local_dict (dict): Dictionary of local training data for each client. + local_analyzer (object): Local analyzer instance. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -43,6 +82,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def analyze(self): + """ + Run the Federated Analytics simulation. + """ logging.info("self.local_analyzer = {}".format(self.local_analyzer)) local_sample_num = dict() for round_idx in range(self.args.comm_round): @@ -76,4 +118,7 @@ def analyze(self): print(f"round_idx={round_idx}, aggregation result = {result}") def run(self): + """ + Run the simulation. + """ self.analyze() diff --git a/python/fedml/fa/simulation/utils.py b/python/fedml/fa/simulation/utils.py index 3ae180397d..0bd40de181 100644 --- a/python/fedml/fa/simulation/utils.py +++ b/python/fedml/fa/simulation/utils.py @@ -2,11 +2,22 @@ def client_sampling(round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for federated learning. + + Args: + round_idx (int): The index of the current federated learning round. + client_num_in_total (int): The total number of available clients. + client_num_per_round (int): The number of clients to select for the current round. + + Returns: + list: A list of selected client indexes for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we select the same clients each round. client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) print("client_indexes = %s" % str(client_indexes)) return client_indexes diff --git a/python/fedml/fa/utils/trie.py b/python/fedml/fa/utils/trie.py index f7b8166692..261f6df2cf 100644 --- a/python/fedml/fa/utils/trie.py +++ b/python/fedml/fa/utils/trie.py @@ -195,10 +195,70 @@ def _levenshtein(path, node, word, distance, cigar): class Trie(object): + """ + A Trie data structure for efficiently storing and searching words. + + Args: + words (list): List of words to initialize the Trie. + + Attributes: + root (dict): The root of the Trie. + + Methods: + __contains__(word): + Check if a word is present in the Trie. + + __iter__(): + Get an iterator for the words in the Trie. + + list(unique=True): + Get a list of words in the Trie. + + add(word, count=1): + Add a word to the Trie. + + get(word): + Get the count of a word in the Trie. + + remove(word, count=1): + Remove a word from the Trie. + + has_prefix(word): + Check if any word in the Trie has a given prefix. + + fill(alphabet, length): + Fill the Trie with words of a given length using characters from the alphabet. + + all_hamming_(word, distance): + Find all words in the Trie within a given Hamming distance. + + all_hamming(word, distance): + Find all words in the Trie within a given Hamming distance (returns words only). + + hamming(word, distance): + Find the first word in the Trie within a given Hamming distance. + + best_hamming(word, distance): + Find the best match for a word in the Trie within a given Hamming distance. + + all_levenshtein_(word, distance): + Find all words in the Trie within a given Levenshtein distance. + + all_levenshtein(word, distance): + Find all words in the Trie within a given Levenshtein distance (returns words only). + + levenshtein(word, distance): + Find the first word in the Trie within a given Levenshtein distance. + + best_levenshtein(word, distance): + Find the best match for a word in the Trie within a given Levenshtein distance. + """ def __init__(self, words=None): - """Initialise the class. + """ + Initialize the Trie class. - :arg list words: List of words. + Args: + words (list): List of words to initialize the Trie. """ self.root = {} @@ -207,54 +267,153 @@ def __init__(self, words=None): self.add(word) def __contains__(self, word): + """ + Check if a word is present in the Trie. + + Args: + word (str): The word to check. + + Returns: + bool: True if the word is in the Trie, False otherwise. + """ return '' in _find(self.root, word) def __iter__(self): + """ + Get an iterator for the words in the Trie. + + Returns: + Iterator: An iterator object for iterating through words in the Trie. + """ return _iterate('', self.root, True) def list(self, unique=True): + """ + Get a list of words in the Trie. + + Args: + unique (bool): Whether to return unique words only (default is True). + + Returns: + list: A list of words in the Trie. + """ return _iterate('', self.root, unique) def add(self, word, count=1): + """ + Add a word to the Trie. + + Args: + word (str): The word to add. + count (int): The count to associate with the word (default is 1). + """ _add(self.root, word, count) def get(self, word): + """ + Get the count of a word in the Trie. + + Args: + word (str): The word to get the count for. + + Returns: + int: The count of the word in the Trie or None if not found. + """ node = _find(self.root, word) if '' in node: return node[''] return None def remove(self, word, count=1): + """ + Remove a word from the Trie. + + Args: + word (str): The word to remove. + count (int): The count to decrement (default is 1). + + Returns: + int: The remaining count of the word in the Trie or None if not found. + """ return _remove(self.root, word, count) def has_prefix(self, word): + """ + Check if any word in the Trie has a given prefix. + + Args: + word (str): The prefix to check. + + Returns: + bool: True if any word has the given prefix, False otherwise. + """ return _find(self.root, word) != {} def fill(self, alphabet, length): + """ + Fill the Trie with words of a given length using characters from the alphabet. + + Args: + alphabet (str): The characters to use for filling. + length (int): The length of words to generate and add to the Trie. + """ _fill(self.root, alphabet, length) def all_hamming_(self, word, distance): + """ + Find all words in the Trie within a given Hamming distance and return detailed results. + + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. + + Returns: + map: A map containing tuples with (word, remaining distance, count). + """ return map( lambda x: (x[0], distance - x[1], x[2]), _hamming('', self.root, word, distance, '')) def all_hamming(self, word, distance): + """ + Find all words in the Trie within a given Hamming distance and return words only. + + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. + + Returns: + map: A map containing words within the specified Hamming distance. + """ return map( lambda x: x[0], _hamming('', self.root, word, distance, '')) def hamming(self, word, distance): + """ + Find the first word in the Trie within a given Hamming distance. + + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. + + Returns: + str: The first word within the specified Hamming distance or None if not found. + """ try: return next(self.all_hamming(word, distance)) except StopIteration: return None def best_hamming(self, word, distance): - """Find the best match with {word} in a trie. + """ + Find the best match with {word} in a trie using Hamming distance. - :arg str word: Query word. - :arg int distance: Maximum allowed distance. + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. - :returns str: Best match with {word}. + Returns: + str: Best match with {word}. """ if self.get(word): return word @@ -267,27 +426,60 @@ def best_hamming(self, word, distance): return None def all_levenshtein_(self, word, distance): + """ + Find all words in the Trie within a given Levenshtein distance and return detailed results. + + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. + + Returns: + map: A map containing tuples with (word, remaining distance, count). + """ return map( lambda x: (x[0], distance - x[1], x[2]), _levenshtein('', self.root, word, distance, '')) def all_levenshtein(self, word, distance): + """ + Find all words in the Trie within a given Levenshtein distance and return words only. + + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. + + Returns: + map: A map containing words within the specified Levenshtein distance. + """ return map( lambda x: x[0], _levenshtein('', self.root, word, distance, '')) def levenshtein(self, word, distance): + """ + Find the first word in the Trie within a given Levenshtein distance. + + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. + + Returns: + str: The first word within the specified Levenshtein distance or None if not found. + """ try: return next(self.all_levenshtein(word, distance)) except StopIteration: return None def best_levenshtein(self, word, distance): - """Find the best match with {word} in a trie. + """ + Find the best match with {word} in a trie using Levenshtein distance. - :arg str word: Query word. - :arg int distance: Maximum allowed distance. + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. - :returns str: Best match with {word}. + Returns: + str: Best match with {word}. """ if self.get(word): return word diff --git a/python/fedml/ml/aggregator/agg_operator.py b/python/fedml/ml/aggregator/agg_operator.py index ebcc939541..4f2a123e0e 100644 --- a/python/fedml/ml/aggregator/agg_operator.py +++ b/python/fedml/ml/aggregator/agg_operator.py @@ -8,6 +8,17 @@ class FedMLAggOperator: @staticmethod def agg(args, raw_grad_list: List[Tuple[float, OrderedDict]]) -> OrderedDict: + """ + Aggregate gradients from multiple clients using a federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing + local sample counts and gradient updates from client models. + + Returns: + OrderedDict: The aggregated model parameters. + """ training_num = 0 if args.federated_optimizer == "SCAFFOLD": for i in range(len(raw_grad_list)): @@ -31,6 +42,20 @@ def agg(args, raw_grad_list: List[Tuple[float, OrderedDict]]) -> OrderedDict: def torch_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Union[Tuple[float, OrderedDict], Tuple[float, OrderedDict, OrderedDict]]]): + A list of tuples containing local sample counts and gradient updates from client models. + For some optimizers, it also includes an additional tuple element with local gradients. + training_num (int): The total number of training samples used for aggregation. + + Returns: + Union[OrderedDict, Tuple[OrderedDict, OrderedDict]]: The aggregated model parameters or a tuple + containing aggregated model parameters and aggregated local gradients, depending on the optimizer. + """ if args.federated_optimizer == "FedAvg": (num0, avg_params) = raw_grad_list[0] @@ -135,6 +160,18 @@ def torch_aggregator(args, raw_grad_list, training_num): def tf_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a TensorFlow-based federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, List[float]]]): A list of tuples containing local sample counts and + gradient updates from client models. + training_num (int): The total number of training samples used for aggregation. + + Returns: + List[float]: The aggregated model parameters. + """ (num0, avg_params) = raw_grad_list[0] if args.federated_optimizer == "FedAvg": @@ -161,6 +198,17 @@ def tf_aggregator(args, raw_grad_list, training_num): def jax_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a JAX-based federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, Dict[str, Dict[str, float]]]]): A list of tuples containing local sample counts + and gradient updates from client models. Each update is a dictionary containing 'w' and 'b' keys. + + Returns: + Dict[str, Dict[str, float]]: The aggregated model parameters containing 'w' and 'b' keys. + """ (num0, avg_params) = raw_grad_list[0] if args.federated_optimizer == "FedAvg": @@ -191,6 +239,17 @@ def jax_aggregator(args, raw_grad_list, training_num): def mxnet_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a MXNet-based federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, Dict[str, List[float]]]]): A list of tuples containing local sample counts + and gradient updates from client models. Each update is a dictionary containing lists of parameters. + + Returns: + Dict[str, List[float]]: The aggregated model parameters. + """ (num0, avg_params) = raw_grad_list[0] if args.federated_optimizer == "FedAvg": @@ -221,6 +280,20 @@ def mxnet_aggregator(args, raw_grad_list, training_num): def model_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a federated learning aggregator based on the + specified machine learning engine. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Union[Tuple[float, Dict[str, Dict[str, float]]], Tuple[float, Dict[str, List[float]]]]]): + A list of tuples containing local sample counts and gradient updates from client models. The format of + updates varies based on the machine learning engine. + + Returns: + Union[Dict[str, Dict[str, float]], Dict[str, List[float]]]: The aggregated model parameters or gradients + based on the selected machine learning engine. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: return tf_aggregator(args, raw_grad_list, training_num) diff --git a/python/fedml/ml/aggregator/aggregator_creator.py b/python/fedml/ml/aggregator/aggregator_creator.py index 0ea2f506ee..e00475fe1e 100644 --- a/python/fedml/ml/aggregator/aggregator_creator.py +++ b/python/fedml/ml/aggregator/aggregator_creator.py @@ -4,6 +4,16 @@ def create_server_aggregator(model, args): + """ + Create a server aggregator instance based on the selected dataset and configuration parameters. + + Args: + model: The machine learning model to be used for aggregation. + args: A dictionary containing training configuration parameters, including the dataset. + + Returns: + ServerAggregator: An instance of a server aggregator class suitable for the specified dataset. + """ if args.dataset == "stackoverflow_lr": aggregator = MyServerAggregatorTAGPred(model, args) elif args.dataset in ["fed_shakespeare", "stackoverflow_nwp"]: diff --git a/python/fedml/ml/aggregator/default_aggregator.py b/python/fedml/ml/aggregator/default_aggregator.py index d81507d09a..a1a0f44162 100644 --- a/python/fedml/ml/aggregator/default_aggregator.py +++ b/python/fedml/ml/aggregator/default_aggregator.py @@ -11,18 +11,48 @@ class DefaultServerAggregator(ServerAggregator): def __init__(self, model, args): + """ + Initialize the DefaultServerAggregator. + + Args: + model: The machine learning model. + args: A dictionary containing configuration parameters. + """ super().__init__(model, args) self.cpu_transfer = False if not hasattr(self.args, "cpu_transfer") else self.args.cpu_transfer def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ if self.cpu_transfer: return self.model.cpu().state_dict() return self.model.state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -75,6 +105,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset and log the results. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator.py b/python/fedml/ml/aggregator/my_server_aggregator.py index 6f4125e6fa..4e7ca9b33d 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator.py +++ b/python/fedml/ml/aggregator/my_server_aggregator.py @@ -11,18 +11,48 @@ class MyServerAggregator(ServerAggregator): def __init__(self, model, args): + """ + Initialize the MyServerAggregator. + + Args: + model: The model used for aggregation. + args: A dictionary containing configuration parameters. + """ super().__init__(model, args) self.cpu_transfer = False if not hasattr(self.args, "cpu_transfer") else self.args.cpu_transfer def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ if self.cpu_transfer: return self.model.cpu().state_dict() return self.model.state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -75,6 +105,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -107,6 +148,18 @@ def test(self, test_data, device, args): return (test_acc, test_loss, None, None) def test_all(self, train_data_local_dict, test_data_local_dict, device, args) -> bool: + """ + Test the model on all client datasets, log the results, and return True. + + Args: + train_data_local_dict: A dictionary of training datasets for each client. + test_data_local_dict: A dictionary of test datasets for each client. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + bool: Always returns True. + """ train_num_samples = [] train_tot_corrects = [] train_losses = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator_classification.py b/python/fedml/ml/aggregator/my_server_aggregator_classification.py index 7f93417641..e265beb01d 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator_classification.py +++ b/python/fedml/ml/aggregator/my_server_aggregator_classification.py @@ -11,12 +11,35 @@ class MyServerAggregatorCLS(ServerAggregator): def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -42,6 +65,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator_nwp.py b/python/fedml/ml/aggregator/my_server_aggregator_nwp.py index 9306d42c8e..ae4ec10b7f 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator_nwp.py +++ b/python/fedml/ml/aggregator/my_server_aggregator_nwp.py @@ -11,12 +11,35 @@ class MyServerAggregatorNWP(ServerAggregator): def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -42,6 +65,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator_prediction.py b/python/fedml/ml/aggregator/my_server_aggregator_prediction.py index def0647e80..6d913bd864 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator_prediction.py +++ b/python/fedml/ml/aggregator/my_server_aggregator_prediction.py @@ -11,12 +11,35 @@ class MyServerAggregatorTAGPred(ServerAggregator): def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -59,6 +82,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/trainer/feddyn_trainer.py b/python/fedml/ml/trainer/feddyn_trainer.py index 1f24f27e6a..2fe646c3e0 100644 --- a/python/fedml/ml/trainer/feddyn_trainer.py +++ b/python/fedml/ml/trainer/feddyn_trainer.py @@ -8,23 +8,92 @@ def model_parameter_vector(model): + """ + Flatten and concatenate the parameters of a PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model whose parameters need to be flattened. + + Returns: + torch.Tensor: A 1D tensor containing the concatenated flattened parameters. + """ param = [p.view(-1) for p in model.parameters()] return torch.concat(param, dim=0) def parameter_vector(parameters): + """ + Flatten and concatenate a dictionary of PyTorch parameters. + + Args: + parameters (dict): A dictionary of PyTorch parameters. + + Returns: + torch.Tensor: A 1D tensor containing the concatenated flattened parameters. + """ param = [p.view(-1) for p in parameters.values()] return torch.concat(param, dim=0) class FedDynModelTrainer(ClientTrainer): + """ + A class for training and testing federated dynamic models. + + Args: + model: The neural network model to train. + id: The client's unique identifier. + args: A dictionary containing training configuration parameters. + + Attributes: + model: The neural network model for training. + id: The unique identifier of the client. + args: A dictionary containing training configuration parameters. + + Methods: + get_model_params(): + Get the current state dictionary of the model. + + set_model_params(model_parameters): + Set the model's parameters using the provided state dictionary. + + train(train_data, device, args, old_grad): + Train the model on the given training data. + + test(test_data, device, args): + Test the model's performance on the provided test data. + + """ def get_model_params(self): + """ + Get the current state dictionary of the model. + + Returns: + dict: The state dictionary of the model. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model's parameters using the provided state dictionary. + + Args: + model_parameters (dict): The state dictionary containing model parameters. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, old_grad): + """ + Train the model on the given training data. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + old_grad (dict): Dictionary of old gradients for dynamic regularization. + + Returns: + dict: Updated old gradients after training. + """ model = self.model for params in model.parameters(): params.requires_grad = True @@ -117,6 +186,17 @@ def train(self, train_data, device, args, old_grad): def test(self, test_data, device, args): + """ + Test the model's performance on the provided test data. + + Args: + test_data (torch.utils.data.DataLoader): The DataLoader containing test data. + device (str): The device to perform testing (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing testing configuration parameters. + + Returns: + dict: Metrics including test accuracy and test loss. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/fednova_trainer.py b/python/fedml/ml/trainer/fednova_trainer.py index c0fd4cc5c9..84122ac9c5 100644 --- a/python/fedml/ml/trainer/fednova_trainer.py +++ b/python/fedml/ml/trainer/fednova_trainer.py @@ -175,13 +175,71 @@ def step(self, closure=None): class FedNovaModelTrainer(ClientTrainer): + """ + A class for training and testing federated Nova (FedNova) models. + + Args: + model: The neural network model to train. + id: The client's unique identifier. + args: A dictionary containing training configuration parameters. + + Attributes: + model: The neural network model for training. + id: The unique identifier of the client. + args: A dictionary containing training configuration parameters. + + Methods: + get_model_params(): + Get the current state dictionary of the model. + + set_model_params(model_parameters): + Set the model's parameters using the provided state dictionary. + + get_local_norm_grad(opt, cur_params, init_params, weight=0): + Calculate the local normalized gradients. + + get_local_tau_eff(opt): + Calculate the effective tau for FedNova. + + train(train_data, device, args, **kwargs): + Train the model on the given training data using FedNova optimizer. + + test(test_data, device, args): + Test the model's performance on the provided test data. + + """ + def get_model_params(self): + """ + Get the current state dictionary of the model. + + Returns: + dict: The state dictionary of the model. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model's parameters using the provided state dictionary. + + Args: + model_parameters (dict): The state dictionary containing model parameters. + """ self.model.load_state_dict(model_parameters) def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): + """ + Calculate the local normalized gradients. + + Args: + opt: The FedNova optimizer instance. + cur_params (dict): The current model's parameters. + init_params (dict): The initial model's parameters. + weight (float): The weight for gradient scaling (default is 0). + + Returns: + dict: Dictionary of local normalized gradients. + """ if weight == 0: weight = opt.ratio grad_dict = {} @@ -193,12 +251,33 @@ def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): return grad_dict def get_local_tau_eff(self, opt): + """ + Calculate the effective tau for FedNova. + + Args: + opt: The FedNova optimizer instance. + + Returns: + float: The effective tau for FedNova. + """ if opt.mu != 0: return opt.local_steps * opt.ratio else: return opt.local_normalizing_vec * opt.ratio def train(self, train_data, device, args, **kwargs): + """ + Train the model on the given training data using the FedNova optimizer. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + **kwargs: Additional keyword arguments. + + Returns: + Tuple[float, dict, float]: Tuple containing the average loss, local normalized gradients, and effective tau. + """ model = self.model model.to(device) @@ -248,6 +327,17 @@ def train(self, train_data, device, args, **kwargs): def test(self, test_data, device, args): + """ + Test the model's performance on the provided test data. + + Args: + test_data (torch.utils.data.DataLoader): The DataLoader containing test data. + device (str): The device to perform testing (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing testing configuration parameters. + + Returns: + dict: Metrics including test accuracy and test loss. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/fedprox_trainer.py b/python/fedml/ml/trainer/fedprox_trainer.py index 06ebb4feab..e4741a1fa6 100644 --- a/python/fedml/ml/trainer/fedprox_trainer.py +++ b/python/fedml/ml/trainer/fedprox_trainer.py @@ -7,13 +7,63 @@ class FedProxModelTrainer(ClientTrainer): + """ + A class for training and testing federated Proximal (FedProx) models. + + Args: + model: The neural network model to train. + id: The client's unique identifier. + args: A dictionary containing training configuration parameters. + + Attributes: + model: The neural network model for training. + id: The unique identifier of the client. + args: A dictionary containing training configuration parameters. + + Methods: + get_model_params(): + Get the current state dictionary of the model. + + set_model_params(model_parameters): + Set the model's parameters using the provided state dictionary. + + train(train_data, device, args): + Train the model on the given training data with optional FedProx regularization. + + train_iterations(train_data, device, args): + Train the model for a specified number of local iterations. + + test(test_data, device, args): + Test the model's performance on the provided test data. + + """ def get_model_params(self): + """ + Get the current state dictionary of the model. + + Returns: + dict: The state dictionary of the model. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model's parameters using the provided state dictionary. + + Args: + model_parameters (dict): The state dictionary containing model parameters. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the given training data with optional FedProx regularization. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + """ model = self.model model.to(device) @@ -79,6 +129,14 @@ def train(self, train_data, device, args): def train_iterations(self, train_data, device, args): + """ + Train the model for a specified number of local iterations. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + """ model = self.model model.to(device) @@ -145,6 +203,17 @@ def train_iterations(self, train_data, device, args): def test(self, test_data, device, args): + """ + Test the model's performance on the provided test data. + + Args: + test_data (torch.utils.data.DataLoader): The DataLoader containing test data. + device (str): The device to perform testing (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing testing configuration parameters. + + Returns: + dict: Metrics including test accuracy and test loss. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/trainer_creator.py b/python/fedml/ml/trainer/trainer_creator.py index 67f629d33e..0441159dd3 100644 --- a/python/fedml/ml/trainer/trainer_creator.py +++ b/python/fedml/ml/trainer/trainer_creator.py @@ -4,10 +4,20 @@ def create_model_trainer(model, args): + """ + Create and return an appropriate model trainer based on the dataset type. + + Args: + model: The neural network model to be trained. + args: A dictionary containing training configuration parameters, including the dataset type. + + Returns: + ModelTrainer: An instance of a model trainer tailored to the dataset type. + """ if args.dataset == "stackoverflow_lr": model_trainer = ModelTrainerTAGPred(model, args) elif args.dataset in ["fed_shakespeare", "stackoverflow_nwp"]: model_trainer = ModelTrainerNWP(model, args) - else: # default model trainer is for classification problem + else: # Default model trainer is for classification problem model_trainer = ModelTrainerCLS(model, args) return model_trainer From d6686bad48bcb8f4d4c6b659004c7e11bfc3742a Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 16 Sep 2023 12:17:34 +0530 Subject: [PATCH 25/70] ed --- python/fedml/data/ImageNet/data_loader.py | 190 +++++++++++++--- python/fedml/data/ImageNet/datasets.py | 109 +++++++-- python/fedml/data/ImageNet/datasets_hdf5.py | 50 +++- python/fedml/data/Landmarks/data_loader.py | 110 +++++++++ python/fedml/data/Landmarks/datasets.py | 38 ++++ .../data/Landmarks/download_without_tf.py | 55 +++-- .../data/Landmarks/download_without_tff.py | 21 ++ python/fedml/data/Landmarks/utils.py | 5 + python/fedml/data/MNIST/data_loader.py | 74 ++++-- .../data/MNIST/mnist_mobile_preprocessor.py | 42 +++- python/fedml/data/MNIST/stats.py | 15 ++ .../fedml/data/NUS_WIDE/nus_wide_dataset.py | 94 ++++++++ .../data/UCI/data_loader_for_susy_and_ro.py | 104 +++++++++ .../lending_club_loan/lending_club_dataset.py | 126 ++++++++++- python/fedml/data/reddit/data_loader.py | 16 ++ python/fedml/data/reddit/datasets.py | 103 ++++++++- python/fedml/data/reddit/divide_data.py | 213 ++++++++++++++++-- python/fedml/data/reddit/nlp.py | 99 +++++++- .../data/stackoverflow_lr/data_loader.py | 36 +++ python/fedml/data/stackoverflow_lr/dataset.py | 16 +- python/fedml/data/stackoverflow_lr/utils.py | 86 +++++++ .../data/stackoverflow_nwp/data_loader.py | 26 +++ .../fedml/data/stackoverflow_nwp/dataset.py | 46 ++-- .../synthetic_0.5_0.5/generate_synthetic.py | 23 ++ .../data/synthetic_0_0/generate_synthetic.py | 23 ++ .../fedml/data/synthetic_1_1/data_loader.py | 37 ++- .../data/synthetic_1_1/generate_synthetic.py | 30 ++- python/fedml/data/synthetic_1_1/stats.py | 19 ++ 28 files changed, 1647 insertions(+), 159 deletions(-) diff --git a/python/fedml/data/ImageNet/data_loader.py b/python/fedml/data/ImageNet/data_loader.py index 22ab9a54f6..84150d1dd9 100644 --- a/python/fedml/data/ImageNet/data_loader.py +++ b/python/fedml/data/ImageNet/data_loader.py @@ -13,11 +13,51 @@ from .datasets_hdf5 import ImageNet_truncated_hdf5 +import numpy as np +import torch + class Cutout(object): + """ + Apply the Cutout data augmentation technique to an image. + + Cutout is a technique used for regularization during training deep neural networks. + It randomly masks out a rectangular region of the input image. + + Args: + length (int): The length of the square mask to apply. + + Usage: + transform = Cutout(length=16) # Create an instance of the Cutout transform. + transformed_image = transform(input_image) # Apply the Cutout transform to an image. + + Note: + The Cutout transform is typically applied as part of a data augmentation pipeline. + + References: + - Original paper: https://arxiv.org/abs/1708.04552 + + """ + def __init__(self, length): + """ + Initialize the Cutout transform with the specified length. + + Args: + length (int): The length of the square mask to apply. + """ self.length = length def __call__(self, img): + """ + Apply Cutout transformation to an input image. + + Args: + img (torch.Tensor): The input image tensor to which Cutout will be applied. + + Returns: + torch.Tensor: The input image with a randomly masked region. + + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -36,6 +76,13 @@ def __call__(self, img): def _data_transforms_ImageNet(): + """ + Define data transforms for the ImageNet dataset. + + Returns: + transforms.Compose: A composition of data augmentation transforms for training + and validation data. + """ # IMAGENET_MEAN = [0.5071, 0.4865, 0.4409] # IMAGENET_STD = [0.2673, 0.2564, 0.2762] @@ -43,41 +90,55 @@ def _data_transforms_ImageNet(): IMAGENET_STD = [0.229, 0.224, 0.225] image_size = 224 - train_transform = transforms.Compose( - [ - # transforms.ToPILImage(), - transforms.RandomResizedCrop(image_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), - ] - ) + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) train_transform.transforms.append(Cutout(16)) - valid_transform = transforms.Compose( - [ - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), - ] - ) + valid_transform = transforms.Compose([ + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) return train_transform, valid_transform - -# for centralized training def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): - return get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs) + """ + Get data loaders for centralized training. + Args: + dataset (str): The dataset name. + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training (default: None). -# for local devices -def get_dataloader_test( - dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test -): - return get_dataloader_test_ImageNet( - datadir, train_bs, test_bs, dataidxs_train, dataidxs_test - ) + Returns: + DataLoader: Training and testing data loaders. + """ + return get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs) + +def get_dataloader_test(dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test): + """ + Get data loaders for local devices. + + Args: + dataset (str): The dataset name. + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to use for training. + dataidxs_test (list): List of data indices to use for testing. + + Returns: + DataLoader: Training and testing data loaders. + """ + return get_dataloader_test_ImageNet(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test) def get_dataloader_ImageNet_truncated( @@ -89,7 +150,25 @@ def get_dataloader_ImageNet_truncated( net_dataidx_map=None, ): """ - imagenet_dataset_train, imagenet_dataset_test should be ImageNet or ImageNet_hdf5 + Get data loaders for a truncated version of the ImageNet dataset. + + Args: + imagenet_dataset_train: The training dataset (ImageNet or ImageNet_hdf5). + imagenet_dataset_test: The testing dataset (ImageNet or ImageNet_hdf5). + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training (default: None). + net_dataidx_map (dict, optional): Mapping of data indices to network indices (default: None). + + Returns: + tuple: A tuple containing training and testing data loaders. + + Raises: + NotImplementedError: If the dataset type is not supported. + + Note: + - The `imagenet_dataset_train` and `imagenet_dataset_test` should be instances of `ImageNet` or `ImageNet_hdf5`. + """ if type(imagenet_dataset_train) == ImageNet: dl_obj = ImageNet_truncated @@ -138,6 +217,19 @@ def get_dataloader_ImageNet_truncated( def get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for the ImageNet dataset. + + Args: + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training (default: None). + + Returns: + tuple: A tuple containing training and testing data loaders. + + """ dl_obj = ImageNet transform_train, transform_test = _data_transforms_ImageNet() @@ -176,6 +268,20 @@ def get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_ImageNet( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loaders for the ImageNet dataset for testing. + + Args: + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list, optional): List of data indices to use for training (default: None). + dataidxs_test (list, optional): List of data indices to use for testing (default: None). + + Returns: + tuple: A tuple containing training and testing data loaders. + + """ dl_obj = ImageNet transform_train, transform_test = _data_transforms_ImageNet() @@ -219,14 +325,25 @@ def distributed_centralized_ImageNet_loader( dataset, data_dir, world_size, rank, batch_size ): """ - Used for generating distributed dataloader for - accelerating centralized training + Generate a distributed dataloader for accelerating centralized training. + + Args: + dataset (str): The dataset name ("ILSVRC2012" or "ILSVRC2012_hdf5"). + data_dir (str): The path to the dataset directory. + world_size (int): The total number of processes in the distributed training. + rank (int): The rank of the current process in the distributed training. + batch_size (int): Batch size for training and testing data. + + Returns: + tuple: A tuple containing various training and testing data related information. + """ train_bs = batch_size test_bs = batch_size transform_train, transform_test = _data_transforms_ImageNet() + if dataset == "ILSVRC2012": train_dataset = ImageNet( data_dir=data_dir, dataidxs=None, train=True, transform=transform_train @@ -278,6 +395,21 @@ def load_partition_data_ImageNet( client_number=100, batch_size=10, ): + """ + Load and partition data for the ImageNet dataset. + + Args: + dataset (str): The dataset name ("ILSVRC2012" or "ILSVRC2012_hdf5"). + data_dir (str): The path to the dataset directory. + partition_method (str, optional): The partitioning method (default: None). + partition_alpha (float, optional): The partitioning alpha value (default: None). + client_number (int, optional): The number of clients (default: 100). + batch_size (int, optional): Batch size for training and testing data (default: 10). + + Returns: + tuple: A tuple containing various data-related information. + + """ if dataset == "ILSVRC2012": train_dataset = ImageNet(data_dir=data_dir, dataidxs=None, train=True) diff --git a/python/fedml/data/ImageNet/datasets.py b/python/fedml/data/ImageNet/datasets.py index 5b47b65184..f4103a18a2 100644 --- a/python/fedml/data/ImageNet/datasets.py +++ b/python/fedml/data/ImageNet/datasets.py @@ -19,6 +19,15 @@ def has_file_allowed_extension(filename, extensions): def find_classes(dir): + """Find class names from subdirectories in a given directory. + + Args: + dir (str): The root directory containing subdirectories, each representing a class. + + Returns: + list: A sorted list of class names. + dict: A dictionary mapping class names to their respective indices. + """ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} @@ -26,6 +35,18 @@ def find_classes(dir): def make_dataset(dir, class_to_idx, extensions): + """Create a dataset of image file paths and their corresponding class indices. + + Args: + dir (str): The root directory containing subdirectories, each representing a class. + class_to_idx (dict): A dictionary mapping class names to their respective indices. + extensions (tuple): A tuple of allowed file extensions. + + Returns: + list: A list of tuples, each containing the file path and class index. + dict: A dictionary mapping class indices to the number of samples per class. + dict: A dictionary mapping class indices to data index ranges. + """ images = [] data_local_num_dict = dict() @@ -55,14 +76,29 @@ def make_dataset(dir, class_to_idx, extensions): def pil_loader(path): - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + """Load an image using PIL (Python Imaging Library). + + Args: + path (str): The path to the image file. + + Returns: + PIL.Image.Image: The loaded image in RGB format. + """ with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") def accimage_loader(path): - import accimage # pylint: disable=E0401 + """Load an image using AccImage (optimized for CUDA). + + Args: + path (str): The path to the image file. + + Returns: + accimage.Image: The loaded image using AccImage. + """ + import accimage try: return accimage.Image(path) @@ -72,6 +108,14 @@ def accimage_loader(path): def default_loader(path): + """Load an image using the default loader (PIL or AccImage). + + Args: + path (str): The path to the image file. + + Returns: + PIL.Image.Image or accimage.Image: The loaded image. + """ from torchvision import get_image_backend if get_image_backend() == "accimage": @@ -91,8 +135,20 @@ def __init__( download=False, ): """ - Generating this class too many times will be time-consuming. - So it will be better calling this once and put it into ImageNet_truncated. + Initialize the ImageNet dataset. + + Args: + data_dir (str): Root directory of the dataset. + dataidxs (int or list, optional): List of indices to select specific data subsets. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + Note: + Generating this class too many times will be time-consuming. + It's better to call this once and use ImageNet_truncated. + """ self.dataidxs = dataidxs self.train = train @@ -110,9 +166,10 @@ def __init__( self.data_local_num_dict, self.net_dataidx_map, ) = self.__getdatasets__() - if dataidxs == None: + + if dataidxs is None: self.local_data = self.all_data - elif type(dataidxs) == int: + elif isinstance(dataidxs, int): (begin, end) = self.net_dataidx_map[dataidxs] self.local_data = self.all_data[begin:end] else: @@ -130,20 +187,18 @@ def get_net_dataidx_map(self): def get_data_local_num_dict(self): return self.data_local_num_dict + def __getdatasets__(self): - # all_data = datasets.ImageFolder(data_dir, self.transform, self.target_transform) classes, class_to_idx = find_classes(self.data_dir) - IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif"] + all_data, data_local_num_dict, net_dataidx_map = make_dataset( self.data_dir, class_to_idx, IMG_EXTENSIONS ) if len(all_data) == 0: - raise ( - RuntimeError( - "Found 0 files in subfolders of: " + self.data_dir + "\n" - "Supported extensions are: " + ",".join(IMG_EXTENSIONS) - ) + raise RuntimeError( + f"Found 0 files in subfolders of: {self.data_dir}\n" + f"Supported extensions are: {','.join(IMG_EXTENSIONS)}" ) return all_data, data_local_num_dict, net_dataidx_map @@ -153,9 +208,8 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ - # img, target = self.data[index], self.target[index] path, target = self.local_data[index] img = self.loader(path) @@ -174,7 +228,7 @@ def __len__(self): class ImageNet_truncated(data.Dataset): def __init__( self, - imagenet_dataset: ImageNet, + imagenet_dataset, dataidxs, net_dataidx_map, train=True, @@ -182,7 +236,19 @@ def __init__( target_transform=None, download=False, ): + """ + Initialize a truncated version of the ImageNet dataset. + Args: + imagenet_dataset (ImageNet): The original ImageNet dataset. + dataidxs (int or list): List of indices to select specific data subsets. + net_dataidx_map (dict): Mapping of data indices in the original dataset. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + """ self.dataidxs = dataidxs self.train = train self.transform = transform @@ -191,9 +257,10 @@ def __init__( self.net_dataidx_map = net_dataidx_map self.loader = default_loader self.all_data = imagenet_dataset.get_local_data() - if dataidxs == None: + + if dataidxs is None: self.local_data = self.all_data - elif type(dataidxs) == int: + elif isinstance(dataidxs, int): (begin, end) = self.net_dataidx_map[dataidxs] self.local_data = self.all_data[begin:end] else: @@ -208,10 +275,9 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ - # img, target = self.data[index], self.target[index] - + path, target = self.local_data[index] img = self.loader(path) if self.transform is not None: @@ -224,3 +290,4 @@ def __getitem__(self, index): def __len__(self): return len(self.local_data) + \ No newline at end of file diff --git a/python/fedml/data/ImageNet/datasets_hdf5.py b/python/fedml/data/ImageNet/datasets_hdf5.py index 042016fee8..c20f29da83 100644 --- a/python/fedml/data/ImageNet/datasets_hdf5.py +++ b/python/fedml/data/ImageNet/datasets_hdf5.py @@ -13,7 +13,14 @@ class DatasetHDF5(data.Dataset): def __init__(self, hdf5fn, t, transform=None, target_transform=None): """ - t: 'train' or 'val' + Initialize a custom dataset from an HDF5 file. + + Args: + hdf5fn (str): Filepath to the HDF5 file. + t (str): 'train' or 'val' to specify the dataset split. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + """ super(DatasetHDF5, self).__init__() self.hf = h5py.File(hdf5fn, "r", libver="latest", swmr=True) @@ -21,8 +28,7 @@ def __init__(self, hdf5fn, t, transform=None, target_transform=None): self.n_images = self.hf["%s_img" % self.t].shape[0] self.dlabel = self.hf["%s_labels" % self.t][...] self.d = self.hf["%s_img" % self.t] - # self.transform = transform - # self.target_transform = target_transform + def _get_dataset_x_and_target(self, index): img = self.d[index, ...] @@ -30,6 +36,13 @@ def _get_dataset_x_and_target(self, index): return img, np.int64(target) def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is the label of the image. + """ img, target = self._get_dataset_x_and_target(index) # if self.transform is not None: # img = self.transform(img) @@ -52,8 +65,20 @@ def __init__( download=False, ): """ - Generating this class too many times will be time-consuming. - So it will be better calling this once and put it into ImageNet_truncated. + Initialize the ImageNet dataset using HDF5 files. + + Args: + data_dir (str): Directory containing the HDF5 file. + dataidxs (int or list, optional): List of indices to select specific data subsets. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + Note: + Generating this class too many times will be time-consuming. + It's better to call this once and use ImageNet_truncated. + """ self.dataidxs = dataidxs self.train = train @@ -117,7 +142,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the label of the image. """ img, target = self.all_data_hdf5[self.local_data_idx[index]] @@ -146,6 +171,19 @@ def __init__( target_transform=None, download=False, ): + """ + Initialize a truncated version of the ImageNet dataset using HDF5 files. + + Args: + imagenet_dataset (ImageNet_hdf5): The original ImageNet HDF5 dataset. + dataidxs (int or list): List of indices to select specific data subsets. + net_dataidx_map (dict): Mapping of data indices in the original dataset. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + """ self.dataidxs = dataidxs self.train = train diff --git a/python/fedml/data/Landmarks/data_loader.py b/python/fedml/data/Landmarks/data_loader.py index 63514bb088..923df16293 100644 --- a/python/fedml/data/Landmarks/data_loader.py +++ b/python/fedml/data/Landmarks/data_loader.py @@ -67,9 +67,24 @@ def _read_csv(path: str): class Cutout(object): def __init__(self, length): + """ + Initialize the Cutout transformation. + + Args: + length (int): The size of the square patch to cut out from the image. + """ self.length = length def __call__(self, img): + """ + Apply the Cutout transformation to the input image. + + Args: + img (PIL.Image): The input image. + + Returns: + PIL.Image: The transformed image with a square patch cut out. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -126,6 +141,17 @@ def get_mapping_per_user(fn): [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {'user_id': xxx, 'image_id': xxx, 'class': xxx} ... ] } + + Load mapping information per user from a CSV file. + + Args: + fn (str): The filename of the CSV file containing user-image mapping. + + Returns: + tuple: A tuple containing: + - data_files (list): A list of dictionaries containing mapping information. + - data_local_num_dict (dict): A dictionary mapping user IDs to the number of data entries they have. + - net_dataidx_map (dict): A dictionary mapping user IDs to data index ranges. """ mapping_table = _read_csv(fn) expected_cols = ["user_id", "image_id", "class"] @@ -163,6 +189,21 @@ def get_mapping_per_user(fn): def get_dataloader( dataset, datadir, train_files, test_files, train_bs, test_bs, dataidxs=None ): + """ + Get data loaders for centralized training. + + Args: + dataset (str): The name of the dataset. + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + + Returns: + DataLoader: Data loaders for training and testing. + """ return get_dataloader_Landmarks( datadir, train_files, test_files, train_bs, test_bs, dataidxs ) @@ -179,6 +220,22 @@ def get_dataloader_test( dataidxs_train, dataidxs_test, ): + """ + Get data loaders for testing with specified data indices. + + Args: + dataset (str): The name of the dataset. + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs_train (list): List of data indices to select specific training data entries. + dataidxs_test (list): List of data indices to select specific testing data entries. + + Returns: + DataLoader: Data loaders for training and testing. + """ return get_dataloader_test_Landmarks( datadir, train_files, @@ -193,6 +250,20 @@ def get_dataloader_test( def get_dataloader_Landmarks( datadir, train_files, test_files, train_bs, test_bs, dataidxs=None ): + """ + Get data loaders for Landmarks dataset. + + Args: + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + + Returns: + DataLoader: Data loaders for training and testing. + """ dl_obj = Landmarks transform_train, transform_test = _data_transforms_landmarks() @@ -233,6 +304,21 @@ def get_dataloader_test_Landmarks( dataidxs_train=None, dataidxs_test=None, ): + """ + Get data loaders for testing Landmarks dataset with specified data indices. + + Args: + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs_train (list, optional): List of data indices to select specific training data entries. Defaults to None. + dataidxs_test (list, optional): List of data indices to select specific testing data entries. Defaults to None. + + Returns: + DataLoader: Data loaders for training and testing. + """ dl_obj = Landmarks transform_train, transform_test = _data_transforms_landmarks() @@ -274,6 +360,30 @@ def load_partition_data_landmarks( client_number=233, batch_size=10, ): + """ + Load partitioned data for the Landmarks dataset. + + Args: + dataset (str): The name of the dataset. + data_dir (str): The directory containing the data files. + fed_train_map_file (str): The path to the federated train data mapping file. + fed_test_map_file (str): The path to the federated test data mapping file. + partition_method (str, optional): The partitioning method for data. Defaults to None. + partition_alpha (float, optional): The alpha value for partitioning. Defaults to None. + client_number (int): The number of clients/participants. Defaults to 233. + batch_size (int): The batch size for data loaders. Defaults to 10. + + Returns: + Tuple: A tuple containing the following elements: + - train_data_num (int): The number of training data samples. + - test_data_num (int): The number of testing data samples. + - train_data_global (DataLoader): Global training data loader. + - test_data_global (DataLoader): Global testing data loader. + - data_local_num_dict (dict): Dictionary mapping client IDs to the number of local data samples. + - train_data_local_dict (dict): Dictionary mapping client IDs to their local training data loaders. + - test_data_local_dict (dict): Dictionary mapping client IDs to their local testing data loaders. + - class_num (int): The number of unique classes in the dataset. + """ train_files, data_local_num_dict, net_dataidx_map = get_mapping_per_user( fed_train_map_file diff --git a/python/fedml/data/Landmarks/datasets.py b/python/fedml/data/Landmarks/datasets.py index ac8364c75a..020e855001 100644 --- a/python/fedml/data/Landmarks/datasets.py +++ b/python/fedml/data/Landmarks/datasets.py @@ -6,6 +6,18 @@ class Landmarks(data.Dataset): + """ + Custom dataset class for the Landmarks dataset. + + Args: + data_dir (str): The directory containing the data files. + allfiles (list): A list of data entries in the form of dictionaries with 'user_id', 'image_id', and 'class'. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + train (bool, optional): Indicates whether the dataset is for training. Defaults to True. + transform (callable, optional): A function/transform to apply to the data. Defaults to None. + target_transform (callable, optional): A function/transform to apply to the target. Defaults to None. + download (bool, optional): Whether to download the data. Defaults to False. + """ def __init__( self, data_dir, @@ -19,6 +31,16 @@ def __init__( """ allfiles is [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {'user_id': xxx, 'image_id': xxx, 'class': xxx} ... ] + Initialize the Landmarks dataset. + + Args: + data_dir (str): The directory containing the data files. + allfiles (list): A list of data entries in the form of dictionaries with 'user_id', 'image_id', and 'class'. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + train (bool, optional): Indicates whether the dataset is for training. Defaults to True. + transform (callable, optional): A function/transform to apply to the data. Defaults to None. + target_transform (callable, optional): A function/transform to apply to the target. Defaults to None. + download (bool, optional): Whether to download the data. Defaults to False. """ self.allfiles = allfiles if dataidxs == None: @@ -32,6 +54,13 @@ def __init__( self.target_transform = target_transform def __len__(self): + """ + Get the number of data entries in the dataset. + + Returns: + int: The number of data entries. + """ + # if self.user_id != None: # return sum([len(local_data) for local_data in self.mapping_per_user.values()]) # else: @@ -39,6 +68,15 @@ def __len__(self): return len(self.local_files) def __getitem__(self, idx): + """ + Get a data sample and its corresponding label by index. + + Args: + idx (int): Index of the data sample to retrieve. + + Returns: + tuple: A tuple containing the data sample and its corresponding label. + """ # if self.user_id != None: # img_name = self.mapping_per_user[self.user_id][idx]['image_id'] # label = self.mapping_per_user[self.user_id][idx]['class'] diff --git a/python/fedml/data/Landmarks/download_without_tf.py b/python/fedml/data/Landmarks/download_without_tf.py index f38f3205b2..66c7bad004 100644 --- a/python/fedml/data/Landmarks/download_without_tf.py +++ b/python/fedml/data/Landmarks/download_without_tf.py @@ -49,12 +49,15 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): - """Sets up a separate process for handling logging messages. + """ + Sets up a separate process for handling logging messages. + This setup is required because without it, the logging messages will be duplicated when multiple processes are created for downloading GLD dataset. + Args: - queue: The queue to receive logging messages. - log_file: The file which the messages will be written to. + queue (multiprocessing.Queue): The queue to receive logging messages. + log_file (str): The file to which the messages will be written. """ root = logging.getLogger() h = logging.FileHandler(log_file) @@ -77,27 +80,32 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): def _read_csv(path: str) -> List[Dict[str, str]]: - """Reads a csv file, and returns the content inside a list of dictionaries. + """ + Reads a CSV file and returns the content inside a list of dictionaries. + Args: - path: The path to the csv file. + path (str): The path to the CSV file. + Returns: - A list of dictionaries. Each row in the csv file will be a list entry. The - dictionary is keyed by the column names. + List[Dict[str, str]]: A list of dictionaries. Each row in the CSV file will be a list entry. + The dictionary is keyed by the column names. """ with open(path, "r") as f: return list(csv.DictReader(f)) def _filter_images(shard: int, all_images: Set[str], image_dir: str, base_url: str): - """Download full GLDv2 dataset, only keep images that are included in the federated gld v2 dataset. + """ + Download full GLDv2 dataset, only keep images that are included in the federated GLD v2 dataset. + Args: - shard: The shard of the GLDv2 dataset. - all_images: A set which contains all images included in the federated GLD - dataset. - image_dir: The directory to keep all filtered images. - base_url: The base url for downloading GLD v2 dataset images. + shard (int): The shard of the GLDv2 dataset. + all_images (Set[str]): A set that contains all images included in the federated GLD dataset. + image_dir (str): The directory to keep all filtered images. + base_url (str): The base URL for downloading GLD v2 dataset images. + Raises: - IOError: when failed to download checksum. + IOError: When failed to download checksum. """ shard_str = "%03d" % shard images_tar_url = "%s/train/images_%s.tar" % (base_url, shard_str) @@ -135,10 +143,14 @@ def _download_data(num_worker: int, cache_dir: str, base_url: str): Download the entire GLD v2 dataset, subset the dataset to only include the images in the federated GLD v2 dataset, and create both gld23k and gld160k datasets. + Args: - num_worker: The number of threads for downloading the GLD v2 dataset. - cache_dir: The directory for caching temporary results. - base_url: The base url for downloading GLD images. + num_worker (int): The number of threads for downloading the GLD v2 dataset. + cache_dir (str): The directory for caching temporary results. + base_url (str): The base URL for downloading GLD images. + + Raises: + IOError: When failed to download checksum. """ logger = logging.getLogger(LOGGER) logging.info("Start to download fed gldv2 mapping files") @@ -194,6 +206,15 @@ def load_data( gld23k: bool = False, base_url: str = GLD_SHARD_BASE_URL, ): + """ + Load the GLD v2 dataset. + + Args: + num_worker (int): The number of threads for downloading the GLD v2 dataset. + cache_dir (str): The directory for caching temporary results. + gld23k (bool): Whether to load the gld23k dataset. + base_url (str): The base URL for downloading GLD images. + """ if not os.path.exists(cache_dir): os.mkdir(cache_dir) diff --git a/python/fedml/data/Landmarks/download_without_tff.py b/python/fedml/data/Landmarks/download_without_tff.py index eb351d433f..7e0d028fd7 100644 --- a/python/fedml/data/Landmarks/download_without_tff.py +++ b/python/fedml/data/Landmarks/download_without_tff.py @@ -48,6 +48,7 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): """Sets up a separate process for handling logging messages. This setup is required because without it, the logging messages will be duplicated when multiple processes are created for downloading GLD dataset. + Args: queue: The queue to receive logging messages. log_file: The file which the messages will be written to. @@ -74,8 +75,10 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): def _read_csv(path: str) -> List[Dict[str, str]]: """Reads a csv file, and returns the content inside a list of dictionaries. + Args: path: The path to the csv file. + Returns: A list of dictionaries. Each row in the csv file will be a list entry. The dictionary is keyed by the column names. @@ -88,10 +91,12 @@ def _create_dataset_with_mapping( image_dir: str, mapping: List[Dict[str, str]] ) -> List[tf.train.Example]: """Builds a dataset based on the mapping file and the images in the image dir. + Args: image_dir: The directory contains the image files. mapping: A list of dictionaries. Each dictionary contains 'image_id' and 'class' columns. + Returns: A list of `tf.train.Example`. """ @@ -126,6 +131,7 @@ def _create_dataset_with_mapping( def _create_train_data_files(cache_dir: str, image_dir: str, mapping_file: str): """Create the train data and persist it into a separate file per user. + Args: cache_dir: The directory caching the intermediate results. image_dir: The directory containing all the downloaded images. @@ -165,6 +171,7 @@ def _create_train_data_files(cache_dir: str, image_dir: str, mapping_file: str): def _create_test_data_file(cache_dir: str, image_dir: str, mapping_file: str): """Create the test data and persist it into a file. + Args: cache_dir: The directory caching the intermediate results. image_dir: The directory containing all the downloaded images. @@ -195,6 +202,7 @@ def _create_federated_gld_dataset( cache_dir: str, image_dir: str, train_mapping_file: str, test_mapping_file: str ): """Generate fedreated GLDv2 dataset with the downloaded images. + Args: cache_dir: The directory for caching the intermediate results. image_dir: The directory that contains the filtered images. @@ -217,6 +225,7 @@ def _create_federated_gld_dataset( def _create_mini_gld_dataset(cache_dir: str, image_dir: str): """Generate mini federated GLDv2 dataset with the downloaded images. + Args: cache_dir: The directory for caching the intermediate results. image_dir: The directory that contains the filtered images. @@ -249,12 +258,14 @@ def _create_mini_gld_dataset(cache_dir: str, image_dir: str): def _filter_images(shard: int, all_images: Set[str], image_dir: str, base_url: str): """Download full GLDv2 dataset, only keep images that are included in the federated gld v2 dataset. + Args: shard: The shard of the GLDv2 dataset. all_images: A set which contains all images included in the federated GLD dataset. image_dir: The directory to keep all filtered images. base_url: The base url for downloading GLD v2 dataset images. + Raises: IOError: when failed to download checksum. """ @@ -301,6 +312,7 @@ def _download_data(num_worker: int, cache_dir: str, base_url: str): Download the entire GLD v2 dataset, subset the dataset to only include the images in the federated GLD v2 dataset, and create both gld23k and gld160k datasets. + Args: num_worker: The number of threads for downloading the GLD v2 dataset. cache_dir: The directory for caching temporary results. @@ -362,6 +374,15 @@ def load_data( gld23k: bool = False, base_url: str = GLD_SHARD_BASE_URL, ): + """ + Load the GLD v2 dataset. + + Args: + num_worker (int, optional): The number of threads for downloading the GLD v2 dataset. + cache_dir (str, optional): The directory for caching temporary results. + gld23k (bool, optional): Whether to load the gld23k dataset. + base_url (str, optional): The base URL for downloading GLD images. + """ if not os.path.exists(cache_dir): os.mkdir(cache_dir) diff --git a/python/fedml/data/Landmarks/utils.py b/python/fedml/data/Landmarks/utils.py index aa75034d08..68e1b0e7d5 100644 --- a/python/fedml/data/Landmarks/utils.py +++ b/python/fedml/data/Landmarks/utils.py @@ -17,6 +17,7 @@ class Progbar(object): """Displays a progress bar. + Arguments: target: Total number of steps expected, None if unknown. width: Progress bar width on screen. @@ -242,6 +243,7 @@ def chunk_read(response, chunk_size=8192, reporthook=None): def _extract_archive(file_path, path=".", archive_format="auto"): """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. + Arguments: file_path: path to the archive file path: path to extract the archive file @@ -251,6 +253,7 @@ def _extract_archive(file_path, path=".", archive_format="auto"): The default 'auto' is ['tar', 'zip']. None or an empty list will return no matches found. Returns: + True if a match was found and an archive extraction was completed, False otherwise. """ @@ -301,6 +304,7 @@ def get_file( Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. Passing a hash will verify the file after download. The command line programs `shasum` and `sha256sum` can compute the hash. + Arguments: fname: Name of the file. If an absolute path `/path/to/file.txt` is specified the file will be saved at that location. @@ -320,6 +324,7 @@ def get_file( defaults to the [Keras Directory](/faq/#where-is-the-keras-configuration-filed-stored). Returns: + Path to the downloaded file """ if cache_dir is None: diff --git a/python/fedml/data/MNIST/data_loader.py b/python/fedml/data/MNIST/data_loader.py index 18e0c29bcf..cbf587c708 100755 --- a/python/fedml/data/MNIST/data_loader.py +++ b/python/fedml/data/MNIST/data_loader.py @@ -14,6 +14,15 @@ def download_mnist(data_cache_dir): + """ + Download the MNIST dataset if it's not already downloaded. + + Args: + data_cache_dir (str): Directory where the dataset should be stored. + + Returns: + None + """ if not os.path.exists(data_cache_dir): os.makedirs(data_cache_dir, exist_ok=True) @@ -30,18 +39,18 @@ def download_mnist(data_cache_dir): zip_ref.extractall(data_cache_dir) def read_data(train_data_dir, test_data_dir): - """parses data in given train and test data directories - - assumes: - - the data in the input directories are .json files with - keys 'users' and 'user_data' - - the set of train set users is the same as the set of test set users - - Return: - clients: list of non-unique client ids - groups: list of group ids; empty list if none found - train_data: dictionary of train data - test_data: dictionary of test data + """ + Parses data in the given train and test data directories. + + Args: + train_data_dir (str): Path to the directory containing train data. + test_data_dir (str): Path to the directory containing test data. + + Returns: + clients (list): List of non-unique client ids. + groups (list): List of group ids; empty list if none found. + train_data (dict): Dictionary of train data. + test_data (dict): Dictionary of test data. """ clients = [] groups = [] @@ -71,24 +80,29 @@ def read_data(train_data_dir, test_data_dir): return clients, groups, train_data, test_data - def batch_data(args, data, batch_size): - """ - data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client) - returns x, y, which are both numpy array of length: batch_size + Prepare data batches. + + Args: + args: Additional arguments (not specified). + data (dict): Data dictionary containing 'x' and 'y'. + batch_size (int): Size of each batch. + + Returns: + batch_data (list): List of data batches. """ data_x = data["x"] data_y = data["y"] - # randomly shuffle data + # Randomly shuffle data np.random.seed(100) rng_state = np.random.get_state() np.random.shuffle(data_x) np.random.set_state(rng_state) np.random.shuffle(data_y) - # loop through mini-batches + # Loop through mini-batches batch_data = list() for i in range(0, len(data_x), batch_size): batched_x = data_x[i : i + batch_size] @@ -99,6 +113,18 @@ def batch_data(args, data, batch_size): def load_partition_data_mnist_by_device_id(batch_size, device_id, train_path="MNIST_mobile", test_path="MNIST_mobile"): + """ + Load partitioned MNIST data by device ID. + + Args: + batch_size (int): Size of each batch. + device_id (str): ID of the device. + train_path (str): Path to the train data directory. + test_path (str): Path to the test data directory. + + Returns: + Tuple containing data information. + """ train_path += os.path.join("/", device_id, "train") test_path += os.path.join("/", device_id, "test") return load_partition_data_mnist(batch_size, train_path, test_path) @@ -108,6 +134,18 @@ def load_partition_data_mnist( args, batch_size, train_path=os.path.join(os.getcwd(), "MNIST", "train"), test_path=os.path.join(os.getcwd(), "MNIST", "test") ): + """ + Load partitioned MNIST data. + + Args: + args: Additional arguments (not specified). + batch_size (int): Size of each batch. + train_path (str): Path to the train data directory. + test_path (str): Path to the test data directory. + + Returns: + Tuple containing data information. + """ users, groups, train_data, test_data = read_data(train_path, test_path) if len(groups) == 0: diff --git a/python/fedml/data/MNIST/mnist_mobile_preprocessor.py b/python/fedml/data/MNIST/mnist_mobile_preprocessor.py index 0d65c0e95b..b058e3c38a 100644 --- a/python/fedml/data/MNIST/mnist_mobile_preprocessor.py +++ b/python/fedml/data/MNIST/mnist_mobile_preprocessor.py @@ -28,18 +28,24 @@ def add_args(parser): def read_data(train_data_dir, test_data_dir): - """parses data in given train and test data directories - - assumes: - - the data in the input directories are .json files with - keys 'users' and 'user_data' - - the set of train set users is the same as the set of test set users - - Return: - clients: list of client ids - groups: list of group ids; empty list if none found - train_data: dictionary of train data - test_data: dictionary of test data + """ + Parse data from train and test data directories. + + Assumes: + - Data in the input directories are .json files with keys 'users' and 'user_data'. + - The set of train set users is the same as the set of test set users. + + Args: + train_data_dir (str): Path to the directory containing train data. + test_data_dir (str): Path to the directory containing test data. + + Returns: + clients (list): List of client ids. + train_num_samples (list): List of the number of samples for each client in the training data. + test_num_samples (list): List of the number of samples for each client in the test data. + train_data (dict): Dictionary of training data. + test_data (dict): Dictionary of test data. + client_list (list): List of client arguments. """ clients = [] train_num_samples = [] @@ -94,6 +100,18 @@ def __init__(self, client_id, client_num_per_round, comm_round): def client_sampling(round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select clients for federated learning. + + Args: + round_idx (int): Index of the current federated learning round. + client_num_in_total (int): Total number of clients available. + client_num_per_round (int): Number of clients to select for the current round. + + Returns: + client_indexes (list): List of selected client indexes for the current round. + """ + if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: diff --git a/python/fedml/data/MNIST/stats.py b/python/fedml/data/MNIST/stats.py index 761e1cd563..cf499a16bf 100755 --- a/python/fedml/data/MNIST/stats.py +++ b/python/fedml/data/MNIST/stats.py @@ -24,6 +24,15 @@ def load_data(name): + """ + Load user and sample data from JSON files in a specified directory. + + Args: + name (str): The name of the dataset. + + Returns: + tuple: A tuple containing lists of users and their corresponding number of samples. + """ users = [] num_samples = [] @@ -47,6 +56,12 @@ def load_data(name): def print_dataset_stats(name): + """ + Print statistics about the dataset, including user count, total samples, mean, std, skewness, and histogram. + + Args: + name (str): The name of the dataset. + """ users, num_samples = load_data(name) num_users = len(users) diff --git a/python/fedml/data/NUS_WIDE/nus_wide_dataset.py b/python/fedml/data/NUS_WIDE/nus_wide_dataset.py index e6931d60b2..15e00fa14b 100644 --- a/python/fedml/data/NUS_WIDE/nus_wide_dataset.py +++ b/python/fedml/data/NUS_WIDE/nus_wide_dataset.py @@ -6,6 +6,16 @@ def get_top_k_labels(data_dir, top_k=5): + """ + Get the top k labels based on their frequency in the dataset. + + Args: + data_dir (str): The directory containing the dataset. + top_k (int): The number of top labels to retrieve. + + Returns: + list: A list of the top k labels. + """ data_path = "Groundtruth/AllLabels" label_counts = {} for filename in os.listdir(os.path.join(data_dir, data_path)): @@ -21,6 +31,18 @@ def get_top_k_labels(data_dir, top_k=5): def get_labeled_data_with_2_party(data_dir, selected_labels, n_samples, dtype="Train"): + """ + Load labeled data for a two-party scenario. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + n_samples (int): The number of samples to load. + dtype (str): The data type (e.g., 'Train' or 'Test'). + + Returns: + tuple: A tuple containing XA (image features), XB (tags), and Y (labels). + """ # get labels data_path = "Groundtruth/TrainTestLabels/" dfs = [] @@ -71,6 +93,18 @@ def get_labeled_data_with_2_party(data_dir, selected_labels, n_samples, dtype="T def get_labeled_data_with_3_party(data_dir, selected_labels, n_samples, dtype="Train"): + """ + Load labeled data for a three-party scenario. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + n_samples (int): The number of samples to load. + dtype (str): The data type (e.g., 'Train' or 'Test'). + + Returns: + tuple: A tuple containing XA (image features), XB1 (tags for party 1), XB2 (tags for party 2), and Y (labels). + """ Xa, Xb, Y = get_labeled_data_with_2_party( data_dir=data_dir, selected_labels=selected_labels, @@ -83,6 +117,18 @@ def get_labeled_data_with_3_party(data_dir, selected_labels, n_samples, dtype="T def NUS_WIDE_load_two_party_data(data_dir, selected_labels, neg_label=-1, n_samples=-1): + """ + Load two-party data for NUS-WIDE dataset. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + neg_label (int): The negative label value. + n_samples (int): The number of samples to load. + + Returns: + tuple: A tuple containing training data and testing data for two parties. + """ print("# load_two_party_data") Xa, Xb, y = get_labeled_data_with_2_party( @@ -134,6 +180,19 @@ def NUS_WIDE_load_two_party_data(data_dir, selected_labels, neg_label=-1, n_samp def NUS_WIDE_load_three_party_data( data_dir, selected_labels, neg_label=-1, n_samples=-1 ): + """ + Load three-party data for NUS-WIDE dataset. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + neg_label (int): The negative label value. + n_samples (int): The number of samples to load. + + Returns: + tuple: A tuple containing training data and testing data for three parties. + """ + print("# load_three_party_data") Xa, Xb, Xc, y = get_labeled_data_with_3_party( data_dir=data_dir, selected_labels=selected_labels, n_samples=n_samples @@ -185,6 +244,20 @@ def prepare_party_data( n_samples, is_three_party=False, ): + """ + Prepare data for a federated learning scenario. + + Args: + src_data_folder (str): The source data folder. + des_data_folder (str): The destination data folder. + selected_labels (list): The selected labels for the data. + neg_label (int): The negative label value. + n_samples (int): The number of samples to load. + is_three_party (bool): Whether it's a three-party scenario. + + Returns: + None + """ print("# preparing data ...") train_data_list, test_data_list = ( @@ -235,6 +308,16 @@ def prepare_party_data( def get_data_folder_name(sel_lbls, is_three_party): + """ + Generate a folder name based on selected labels and party type. + + Args: + sel_lbls (list): List of selected labels. + is_three_party (bool): Indicates whether it's a three-party scenario. + + Returns: + str: Generated folder name. + """ folder_name = sel_lbls[0] for idx, lbl in enumerate(sel_lbls): if idx == 0: @@ -246,6 +329,17 @@ def get_data_folder_name(sel_lbls, is_three_party): def load_prepared_parties_data(data_dir, sel_lbls, load_three_party): + """ + Load prepared party data from a specific directory. + + Args: + data_dir (str): The directory containing the prepared data. + sel_lbls (list): List of selected labels. + load_three_party (bool): Indicates whether to load three-party data. + + Returns: + tuple: A tuple containing training and testing data lists. + """ print( "# load prepared {0} party data".format("three" if load_three_party else "two") ) diff --git a/python/fedml/data/UCI/data_loader_for_susy_and_ro.py b/python/fedml/data/UCI/data_loader_for_susy_and_ro.py index 5936c49c07..79916526a8 100644 --- a/python/fedml/data/UCI/data_loader_for_susy_and_ro.py +++ b/python/fedml/data/UCI/data_loader_for_susy_and_ro.py @@ -5,7 +5,60 @@ class DataLoader(object): + """ + DataLoader class for managing data loading and preprocessing. + + Args: + data_name (str): The name of the dataset. + data_path (str): The path to the dataset CSV file. + client_list (list): A list of client IDs. + sample_num_in_total (int): The total number of data samples. + beta (float): A parameter for data loading. + + Attributes: + data_name (str): The name of the dataset. + data_path (str): The path to the dataset CSV file. + client_list (list): A list of client IDs. + sample_num_in_total (int): The total number of data samples. + beta (float): A parameter for data loading. + streaming_full_dataset_X (list): A list to store data samples. + streaming_full_dataset_Y (list): A list to store data labels. + StreamingDataDict (dict): A dictionary to store streaming data for clients. + + Methods: + load_datastream(): + Load and preprocess the data for streaming and return it as a dictionary. + load_adversarial_data(): + Load adversarial data based on the beta parameter. + load_stochastic_data(): + Load stochastic data based on the beta parameter. + read_csv_file(percent): + Read and return data samples and labels from a CSV file. + read_csv_file_for_cluster(percent): + Read and cluster data samples based on the beta parameter. + kMeans(X): + Perform K-means clustering on the data. + preprocessing(): + Perform preprocessing on the data. + + """ def __init__(self, data_name, data_path, client_list, sample_num_in_total, beta): + """ + Initialize the DataLoader with dataset information and parameters. + + Args: + data_name (str): The name of the dataset. + data_path (str): The path to the dataset CSV file. + client_list (list): A list of client IDs. + sample_num_in_total (int): The total number of data samples. + beta (float): A parameter for data loading. + + Note: + This constructor initializes the DataLoader with dataset details and parameters. + + Returns: + None + """ # SUSY, Room Occupancy; self.data_name = data_name self.data_path = data_path @@ -24,6 +77,12 @@ def __init__(self, data_name, data_path, client_list, sample_num_in_total, beta) """ def load_datastream(self): + """ + Load and preprocess the data for streaming and return it as a dictionary. + + Returns: + dict: A dictionary containing streaming data for clients. + """ self.preprocessing() self.load_adversarial_data() self.load_stochastic_data() @@ -37,14 +96,35 @@ def load_datastream(self): # beta (clustering, GMM) def load_adversarial_data(self): + """ + Load adversarial data based on the beta parameter. + + Returns: + dict: A dictionary containing adversarial streaming data for clients. + """ streaming_data = self.read_csv_file_for_cluster(self.beta) return streaming_data def load_stochastic_data(self): + """ + Load stochastic data based on the beta parameter. + + Returns: + dict: A dictionary containing stochastic streaming data for clients. + """ streaming_data = self.read_csv_file(self.beta) return streaming_data def read_csv_file(self, percent): + """ + Read and return data samples and labels from a CSV file. + + Args: + percent (float): The percentage of data to read. + + Returns: + dict: A dictionary containing streaming data for clients. + """ # print("start from:") iteration_number = int(self.sample_num_in_total / len(self.client_list)) @@ -105,6 +185,15 @@ def read_csv_file(self, percent): return self.StreamingDataDict def read_csv_file_for_cluster(self, percent): + """ + Read and cluster data samples based on the beta parameter. + + Args: + percent (float): The percentage of data to read and cluster. + + Returns: + dict: A dictionary containing clustered streaming data for clients. + """ data = [] label = [] for client_id in self.client_list: @@ -134,11 +223,26 @@ def read_csv_file_for_cluster(self, percent): return self.StreamingDataDict def kMeans(self, X): + """ + Perform K-means clustering on the data. + + Args: + X (list): List of data samples. + + Returns: + array: Cluster labels for data samples. + """ kmeans = KMeans(n_clusters=len(self.client_list)) kmeans.fit(X) return kmeans.labels_ def preprocessing(self): + """ + Perform preprocessing on the data. + + Returns: + None + """ # print("sample_num_in_total = " + str(self.sample_num_in_total)) data = [] with open(self.data_path) as csvfile: diff --git a/python/fedml/data/lending_club_loan/lending_club_dataset.py b/python/fedml/data/lending_club_loan/lending_club_dataset.py index 15812e28d4..f93ac6f132 100644 --- a/python/fedml/data/lending_club_loan/lending_club_dataset.py +++ b/python/fedml/data/lending_club_loan/lending_club_dataset.py @@ -105,20 +105,45 @@ def normalize(x): + """ + Normalize a numerical array using StandardScaler. + + Args: + x (array-like): The data to normalize. + + Returns: + array-like: Normalized data. + """ scaler = StandardScaler() x_scaled = scaler.fit_transform(x) return x_scaled - def normalize_df(df): + """ + Normalize a DataFrame using StandardScaler. + + Args: + df (pd.DataFrame): The DataFrame to normalize. + + Returns: + pd.DataFrame: Normalized DataFrame. + """ column_names = df.columns x = df.values x_scaled = normalize(x) scaled_df = pd.DataFrame(data=x_scaled, columns=column_names) return scaled_df - def loan_condition(status): + """ + Determine if a loan is a good or bad loan based on its status. + + Args: + status (str): Loan status. + + Returns: + str: "Good Loan" or "Bad Loan". + """ bad_loan = [ "Charged Off", "Default", @@ -132,22 +157,45 @@ def loan_condition(status): else: return "Good Loan" - def compute_annual_income(row): + """ + Compute the annual income for a loan applicant. + + Args: + row (pd.Series): A row of loan data. + + Returns: + float: Annual income. + """ if row["verification_status"] == row["verification_status_joint"]: return row["annual_inc_joint"] return row["annual_inc"] - def determine_good_bad_loan(df_loan): - print("[INFO] determine good or bad loan") + """ + Determine if a loan is a good or bad loan based on its status. + Args: + df_loan (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with "target" column indicating loan condition. + """ + print("[INFO] determine good or bad loan") df_loan["target"] = np.nan df_loan["target"] = df_loan["loan_status"].apply(loan_condition) return df_loan - def determine_annual_income(df_loan): + """ + Determine annual income for loan applicants. + + Args: + df_loan (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with "annual_inc_comp" column for annual income. + """ print("[INFO] determine annual income") df_loan["annual_inc_comp"] = np.nan @@ -156,15 +204,32 @@ def determine_annual_income(df_loan): def determine_issue_year(df_loan): - print("[INFO] determine issue year") + """ + Determine the issue year of loans. - # transform the issue dates by year + Args: + df_loan (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with "issue_year" column for issue years. + """ + print("[INFO] determine issue year") + # Transform the issue dates by year dt_series = pd.to_datetime(df_loan["issue_d"]) df_loan["issue_year"] = dt_series.dt.year return df_loan def digitize_columns(data_frame): + """ + Digitize categorical columns in the DataFrame. + + Args: + data_frame (pd.DataFrame): The DataFrame to digitize. + + Returns: + pd.DataFrame: DataFrame with categorical columns converted to numerical values. + """ print("[INFO] digitize columns") data_frame = data_frame.replace( @@ -185,6 +250,15 @@ def digitize_columns(data_frame): def prepare_data(file_path): + """ + Prepare loan data from a CSV file. + + Args: + file_path (str): Path to the CSV file containing loan data. + + Returns: + pd.DataFrame: DataFrame with processed loan data. + """ print("[INFO] prepare loan data.") df_loan = pd.read_csv(file_path, low_memory=False) @@ -200,6 +274,15 @@ def prepare_data(file_path): def process_data(loan_df): + """ + Process loan data. + + Args: + loan_df (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with processed loan features and target. + """ loan_feat_df = loan_df[all_feature_list] loan_feat_df = loan_feat_df.fillna(-99) assert loan_feat_df.isnull().sum().sum() == 0 @@ -211,6 +294,15 @@ def process_data(loan_df): def load_processed_data(data_dir): + """ + Load processed loan data from a CSV file, or preprocess and save it if not available. + + Args: + data_dir (str): Directory path for data files. + + Returns: + pd.DataFrame: DataFrame with processed loan data. + """ file_path = data_dir + "processed_loan.csv" if os.path.exists(file_path): print(f"[INFO] load processed loan data from {file_path}") @@ -226,6 +318,15 @@ def load_processed_data(data_dir): def loan_load_two_party_data(data_dir): + """ + Load two-party loan data. + + Args: + data_dir (str): Directory path for data files. + + Returns: + tuple: Training and testing data for two parties. + """ print("[INFO] load two party data") processed_loan_df = load_processed_data(data_dir) party_a_feat_list = qualification_feat + loan_feat @@ -253,6 +354,15 @@ def loan_load_two_party_data(data_dir): def loan_load_three_party_data(data_dir): + """ + Load three-party loan data. + + Args: + data_dir (str): Directory path for data files. + + Returns: + tuple: Training and testing data for three parties (Party A, Party B, Party C). + """ print("[INFO] load three party data") processed_loan_df = load_processed_data(data_dir) party_a_feat_list = qualification_feat + loan_feat diff --git a/python/fedml/data/reddit/data_loader.py b/python/fedml/data/reddit/data_loader.py index 65dff93415..939a3fee0d 100644 --- a/python/fedml/data/reddit/data_loader.py +++ b/python/fedml/data/reddit/data_loader.py @@ -35,6 +35,22 @@ def load_partition_data_reddit( batch_size, n_proc_in_silo=0, ): + """ + Load and partition Reddit dataset for Federated Learning. + + Args: + args: An object containing configuration parameters. + dataset: The Reddit dataset. + data_dir: The directory containing the dataset. + partition_method: The method used for data partitioning. + partition_alpha: A parameter for data partitioning. + client_number: The number of clients/partitions. + batch_size: The batch size for data loading. + n_proc_in_silo: The number of processes in the silo (default: 0). + + Returns: + tuple: A tuple containing various data components for Federated Learning. + """ from .nlp import load_and_cache_examples, mask_tokens from transformers import (AdamW, AlbertTokenizer, AutoConfig, diff --git a/python/fedml/data/reddit/datasets.py b/python/fedml/data/reddit/datasets.py index 24bf05c475..2199f67540 100644 --- a/python/fedml/data/reddit/datasets.py +++ b/python/fedml/data/reddit/datasets.py @@ -13,10 +13,53 @@ class Reddit_dataset(): + """ + Dataset class for Reddit data. + + Args: + root (str): The root directory where the data is stored. + train (bool): Whether to load the training or testing dataset. + + Attributes: + train_file (str): The file name for the training dataset. + test_file (str): The file name for the testing dataset. + vocab_tokens_size (int): The size of the token vocabulary. + vocab_tags_size (int): The size of the tag vocabulary. + raw_data (list): A list of tokenized text data. + dict (dict): A mapping dictionary from sample id to target tag. + + Methods: + __getitem__(self, index): + Get an item from the dataset by index. + __mapping_dict__(self): + Get the mapping dictionary. + __len__(self): + Get the length of the dataset. + raw_folder(self): + Get the raw data folder path. + processed_folder(self): + Get the processed data folder path. + class_to_idx(self): + Get a mapping from class names to class indices. + _check_exists(self): + Check if the dataset exists. + load_token_vocab(self, vocab_size, path): + Load token vocabulary from a file. + load_file(self, path, is_train): + Load the dataset from files. + + """ classes = [] MAX_SEQ_LEN = 20000 def __init__(self, root, train=True): + """ + Initialize the Reddit_dataset. + + Args: + root (str): The root directory where the data is stored. + train (bool): Whether to load the training or testing dataset. + """ self.train = train # training set or test set self.root = root @@ -61,34 +104,90 @@ def __getitem__(self, index): return tokens def __mapping_dict__(self): + """ + Get the mapping dictionary. + + Returns: + dict: A dictionary mapping sample IDs to target tags. + """ + return self.dict def __len__(self): + """ + Get the length of the dataset. + + Returns: + int: The number of samples in the dataset. + """ return len(self.raw_data) @property def raw_folder(self): + """ + Get the raw data folder path. + + Returns: + str: The path to the raw data folder. + """ return self.root @property def processed_folder(self): + """ + Get the processed data folder path. + + Returns: + str: The path to the processed data folder. + """ return self.root @property def class_to_idx(self): + """ + Get a mapping from class names to class indices. + + Returns: + dict: A dictionary mapping class names to class indices. + """ return {_class: i for i, _class in enumerate(self.classes)} def _check_exists(self): - return (os.path.exists(os.path.join(self.processed_folder, - self.data_file))) + """ + Check if the dataset exists. + + Returns: + bool: True if the dataset exists, False otherwise. + """ + return (os.path.exists(os.path.join(self.processed_folder, self.data_file))) def load_token_vocab(self, vocab_size, path): + """ + Load token vocabulary from a file. + + Args: + vocab_size (int): The size of the token vocabulary. + path (str): The path to the vocabulary file. + + Returns: + list: A list of tokens from the vocabulary. + """ tokens_file = "reddit_vocab.pkl" with open(os.path.join(path, tokens_file), 'rb') as f: tokens = pickle.load(f) return tokens[:vocab_size] def load_file(self, path, is_train): + """ + Load the dataset from files. + + Args: + path (str): The path to the dataset files. + is_train (bool): Whether to load the training or testing dataset. + + Returns: + tuple: A tuple containing text data and a mapping dictionary. + """ file_name = os.path.join( path, 'train') if self.train else os.path.join(path, 'test') diff --git a/python/fedml/data/reddit/divide_data.py b/python/fedml/data/reddit/divide_data.py index 96f562c422..e4be2ed988 100644 --- a/python/fedml/data/reddit/divide_data.py +++ b/python/fedml/data/reddit/divide_data.py @@ -12,24 +12,129 @@ class Partition(object): - """ Dataset partitioning helper """ + """ + Helper class for dataset partitioning. + + Args: + data (list): The dataset to be partitioned. + index (list): A list of indices specifying the partition. + + Attributes: + data (list): The dataset to be partitioned. + index (list): A list of indices specifying the partition. + + Methods: + __len__(): + Get the length of the partition. + __getitem__(index): + Get an item from the partition by index. + + """ def __init__(self, data, index): + """ + Initialize a dataset partition. + + Args: + data (list): The dataset to be partitioned. + index (list): A list of indices specifying the partition. + + Returns: + None + """ self.data = data self.index = index def __len__(self): + """ + Get the length of the partition. + + Returns: + int: The length of the partition. + """ return len(self.index) def __getitem__(self, index): + """ + Get an item from the partition by index. + + Args: + index (int): The index of the item to retrieve. + + Returns: + object: The item from the partition. + """ data_idx = self.index[index] return self.data[data_idx] -class DataPartitioner(object): - """Partition data by trace or random""" +import csv +import logging +import numpy as np +from collections import defaultdict +from random import Random +class DataPartitioner(object): + """ + Partition data by trace or random for federated learning. + + Args: + data: The dataset to be partitioned. + args: An object containing configuration parameters. + numOfClass (int): The number of classes in the dataset (default: 0). + seed (int): The seed for randomization (default: 10). + isTest (bool): Whether the partitioning is for a test dataset (default: False). + + Attributes: + partitions (list): A list of partitions, where each partition is a list of sample indices. + rng (Random): A random number generator. + data: The dataset to be partitioned. + labels: The labels of the dataset. + args: An object containing configuration parameters. + isTest (bool): Whether the partitioning is for a test dataset. + data_len (int): The length of the dataset. + task: The task type. + numOfLabels (int): The number of labels in the dataset. + client_label_cnt (defaultdict): A dictionary to count labels for each client. + + Methods: + getNumOfLabels(): + Get the number of unique labels in the dataset. + getDataLen(): + Get the length of the dataset. + getClientLen(): + Get the number of clients/partitions. + getClientLabel(): + Get the number of unique labels for each client. + trace_partition(data_map_file): + Partition data based on a trace file. + partition_data_helper(num_clients, data_map_file=None): + Helper function for partitioning data. + uniform_partition(num_clients): + Uniformly partition data randomly. + use(partition, istest): + Get a partition of the dataset for a specific client. + getSize(): + Get the size of each partition (number of samples). + + """ def __init__(self, data, args, numOfClass=0, seed=10, isTest=False): + """ + Initialize the DataPartitioner. + + Args: + data: The dataset to be partitioned. + args: An object containing configuration parameters. + numOfClass (int): The number of classes in the dataset (default: 0). + seed (int): The seed for randomization (default: 10). + isTest (bool): Whether the partitioning is for a test dataset (default: False). + + Note: + This constructor sets up the DataPartitioner with the provided dataset and configuration. + + Returns: + None + """ self.partitions = [] self.rng = Random() self.rng.seed(seed) @@ -46,24 +151,57 @@ def __init__(self, data, args, numOfClass=0, seed=10, isTest=False): self.client_label_cnt = defaultdict(set) def getNumOfLabels(self): + """ + Get the number of unique labels in the dataset. + + Returns: + int: The number of unique labels. + """ return self.numOfLabels def getDataLen(self): + """ + Get the length of the dataset. + + Returns: + int: The length of the dataset. + """ return self.data_len def getClientLen(self): + """ + Get the number of clients/partitions. + + Returns: + int: The number of clients/partitions. + """ return len(self.partitions) def getClientLabel(self): + """ + Get the number of unique labels for each client. + + Returns: + list: A list of the number of unique labels for each client. + """ return [len(self.client_label_cnt[i]) for i in range(self.getClientLen())] def trace_partition(self, data_map_file): - """Read data mapping from data_map_file. Format: """ - logging.info(f"Partitioning data by profile {data_map_file}...") + """ + Partition data based on a trace file. + Args: + data_map_file (str): The path to the data mapping file. + + Returns: + None + """ + logging.info(f"Partitioning data by profile {data_map_file}...") + clientId_maps = {} unique_clientIds = {} - # load meta data from the data_map_file + + # Load meta data from the data_map_file with open(data_map_file) as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') read_first = True @@ -80,8 +218,7 @@ def trace_partition(self, data_map_file): unique_clientIds[client_id] = len(unique_clientIds) clientId_maps[sample_id] = unique_clientIds[client_id] - self.client_label_cnt[unique_clientIds[client_id]].add( - row[-1]) + self.client_label_cnt[unique_clientIds[client_id]].add(row[-1]) sample_id += 1 # Partition data given mapping @@ -91,15 +228,33 @@ def trace_partition(self, data_map_file): self.partitions[clientId_maps[idx]].append(idx) def partition_data_helper(self, num_clients, data_map_file=None): + """ + Helper function for partitioning data. + + Args: + num_clients (int): The number of clients/partitions. + data_map_file (str): The path to the data mapping file (default: None). - # read mapping file to partition trace + Returns: + None + """ + # Read mapping file to partition trace if data_map_file is not None: self.trace_partition(data_map_file) else: self.uniform_partition(num_clients=num_clients) def uniform_partition(self, num_clients): - # random partition + """ + Uniformly partition data randomly. + + Args: + num_clients (int): The number of clients/partitions. + + Returns: + None + """ + # Random partition numOfLabels = self.getNumOfLabels() data_len = self.getDataLen() logging.info(f"Randomly partitioning data, {data_len} samples...") @@ -108,11 +263,21 @@ def uniform_partition(self, num_clients): self.rng.shuffle(indexes) for _ in range(num_clients): - part_len = int(1./num_clients * data_len) + part_len = int(1. / num_clients * data_len) self.partitions.append(indexes[0:part_len]) indexes = indexes[part_len:] def use(self, partition, istest): + """ + Get a partition of the dataset for a specific client. + + Args: + partition (int): The index of the client/partition. + istest (bool): Whether the partition is for a test dataset. + + Returns: + Partition: A partition of the dataset for the specified client. + """ resultIndex = self.partitions[partition] exeuteLength = len(resultIndex) if not istest else int( @@ -123,12 +288,31 @@ def use(self, partition, istest): return Partition(self.data, resultIndex) def getSize(self): - # return the size of samples + """ + Get the size of each partition (number of samples). + + Returns: + dict: A dictionary containing the size of each partition. + """ + # Return the size of samples return {'size': [len(partition) for partition in self.partitions]} def select_dataset(rank, partition, batch_size, args, isTest=False, collate_fn=None): - """Load data given client Id""" + """ + Load data for a specific client based on client ID. + + Args: + rank (int): The client's rank or ID. + partition (Partition): A partition of the dataset for the client. + batch_size (int): The batch size for data loading. + args: An object containing configuration parameters. + isTest (bool): Whether the data loading is for a test dataset (default: False). + collate_fn (callable, optional): A function used to collate data samples into batches (default: None). + + Returns: + DataLoader: A DataLoader object for loading the client's data. + """ partition = partition.use(rank - 1, isTest) dropLast = False if isTest else True num_loaders = min(int(len(partition)/args.batch_size/2), args.num_loaders) @@ -145,6 +329,3 @@ def select_dataset(rank, partition, batch_size, args, isTest=False, collate_fn=N return DataLoader(partition, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_loaders, drop_last=dropLast, collate_fn=collate_fn) return DataLoader(partition, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_loaders, drop_last=dropLast) - - - diff --git a/python/fedml/data/reddit/nlp.py b/python/fedml/data/reddit/nlp.py index 4711a4846f..a62a681289 100644 --- a/python/fedml/data/reddit/nlp.py +++ b/python/fedml/data/reddit/nlp.py @@ -44,13 +44,34 @@ def chunks_idx(l, n): + """ + Split a list into 'n' roughly equal-sized chunks and yield the start and end indices of each chunk. + + Args: + l (list): The list to be split. + n (int): The number of chunks to split the list into. + + Yields: + tuple: A tuple containing the start and end indices of each chunk. + """ d, r = divmod(len(l), n) for i in range(n): si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r) yield si, si+(d+1 if i < r else d) - def feature_creation_worker(files, tokenizer, block_size, worker_idx): + """ + Worker function for creating features from a list of text files. + + Args: + files (list): A list of file paths containing text data. + tokenizer: The tokenizer to convert text to tokens. + block_size (int): The maximum block size for tokenized text. + worker_idx (int): The index of the worker. + + Returns: + tuple: A tuple containing examples (tokenized text), client mapping, and sample client IDs. + """ examples = [] sample_client = [] client_mapping = collections.defaultdict(list) @@ -83,8 +104,43 @@ def feature_creation_worker(files, tokenizer, block_size, worker_idx): class TextDataset(Dataset): + """ + Dataset for text data used in language modeling tasks. + + Args: + tokenizer: The tokenizer to convert text to tokens. + args: An object containing dataset configuration parameters. + file_path (str): The directory containing the dataset files. + block_size (int): The maximum block size for tokenized text (default: 512). + + Attributes: + examples (list): A list of tokenized text examples. + sample_client (list): A list of sample client IDs. + client_mapping (dict): A dictionary mapping client IDs to tokenized text examples. + + Methods: + __len__(): + Get the number of examples in the dataset. + __getitem__(item): + Get an example from the dataset. + + """ def __init__(self, tokenizer, args, file_path, block_size=512): + """ + Initialize the TextDataset. + + Args: + tokenizer: The tokenizer to convert text to tokens. + args: An object containing dataset configuration parameters. + file_path (str): The directory containing the dataset files. + block_size (int): The maximum block size for tokenized text (default: 512). + + Note: + This constructor processes and loads the dataset from files or creates features if not cached. + Returns: + None + """ block_size = block_size - \ (tokenizer.model_max_length - tokenizer.max_len_single_sentence) @@ -135,7 +191,7 @@ def __init__(self, tokenizer, args, file_path, block_size=512): self.client_mapping[true_user_id] = client_mapping[user_id] user_id_base = true_sample_client[-1] + 1 - # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) + # Note that we are losing the last truncated example here for the sake of simplicity (no padding) # If your dataset is small, first you should look for a bigger one :-) and second you # can change this behavior by adding (model specific) padding. logger.info("Saving features into cached file %s", @@ -159,13 +215,39 @@ def __init__(self, tokenizer, args, file_path, block_size=512): self.targets = [0 for i in range(len(self.data))] def __len__(self): + """ + Get the number of examples in the dataset. + + Returns: + int: The number of examples in the dataset. + """ return len(self.examples) def __getitem__(self, item): + """ + Get an example from the dataset. + + Args: + item: The index of the example to retrieve. + + Returns: + torch.Tensor: The tokenized text example as a PyTorch tensor. + """ return torch.tensor(self.examples[item], dtype=torch.long) def load_and_cache_examples(args, tokenizer, evaluate=False): + """ + Load and cache examples from the dataset for training or evaluation. + + Args: + args: An object containing dataset configuration parameters. + tokenizer: The tokenizer to convert text to tokens. + evaluate (bool): Whether to load examples for evaluation (default: False). + + Returns: + TextDataset: A dataset containing tokenized text examples. + """ file_path = os.path.join(args.data_cache_dir, 'test') if evaluate else os.path.join( args.data_cache_dir, 'train') @@ -173,7 +255,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False): def mask_tokens(inputs, tokenizer, args, device='cpu') -> Tuple[torch.Tensor, torch.Tensor]: - """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + + Args: + inputs (torch.Tensor): The input token IDs. + tokenizer: The tokenizer to convert text to tokens. + args: An object containing configuration parameters. + device (str): The device to use for computations (default: 'cpu'). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing masked input tokens and labels for masked language modeling. + """ labels = inputs.clone().to(device=device) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full( diff --git a/python/fedml/data/stackoverflow_lr/data_loader.py b/python/fedml/data/stackoverflow_lr/data_loader.py index 0aa5087017..6edd993bec 100644 --- a/python/fedml/data/stackoverflow_lr/data_loader.py +++ b/python/fedml/data/stackoverflow_lr/data_loader.py @@ -21,6 +21,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): + """ + Get DataLoader objects for training and testing data. + + Args: + dataset: The dataset to use. + data_dir (str): The directory containing the data. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + client_idx (int, optional): The client index (None for global data). + + Returns: + tuple: A tuple containing training and testing DataLoader objects. + """ if client_idx is None: train_dl = data.DataLoader( @@ -94,6 +107,18 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): def load_partition_data_distributed_federated_stackoverflow_lr( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for distributed federated stackoverflow_lr. + + Args: + process_id (int): The process ID. + dataset: The dataset to use. + data_dir (str): The directory containing the data. + batch_size (int, optional): The batch size (default is 64). + + Returns: + tuple: A tuple containing data for distributed federated stackoverflow_lr. + """ # get global dataset if process_id == 0: train_data_global, test_data_global = get_dataloader( @@ -131,6 +156,17 @@ def load_partition_data_distributed_federated_stackoverflow_lr( def load_partition_data_federated_stackoverflow_lr( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for federated stackoverflow_lr. + + Args: + dataset: The dataset to use. + data_dir (str): The directory containing the data. + batch_size (int, optional): The batch size (default is 64). + + Returns: + tuple: A tuple containing data for federated stackoverflow_lr. + """ logging.info("load_partition_data_federated_stackoverflow_lr START") global cache_data diff --git a/python/fedml/data/stackoverflow_lr/dataset.py b/python/fedml/data/stackoverflow_lr/dataset.py index 7d7dab6ecb..3b312d90c5 100644 --- a/python/fedml/data/stackoverflow_lr/dataset.py +++ b/python/fedml/data/stackoverflow_lr/dataset.py @@ -4,7 +4,15 @@ class StackOverflowDataset(data.Dataset): - """StackOverflow dataset""" + """ + StackOverflow dataset. + + Args: + h5_path (str): Path to the h5 file. + client_idx (int): Index of the train file. + datast (str): "train" or "test" denoting the train set or test set. + preprocess (dict of callable, optional): Optional preprocessing functions with keys "input" and "target". + """ __train_client_id_list = None __test_client_id_list = None @@ -33,6 +41,12 @@ def __init__(self, h5_path, client_idx, datast, preprocess=None): self.target_fn = preprocess["target"] def get_client_id_list(self): + """ + Get a list of client IDs based on the dataset type. + + Returns: + list: List of client IDs. + """ if self.datast == "train": if StackOverflowDataset.__train_client_id_list is None: with h5py.File(self.h5_path, "r") as h5_file: diff --git a/python/fedml/data/stackoverflow_lr/utils.py b/python/fedml/data/stackoverflow_lr/utils.py index 7bcb011106..5fbf007e60 100644 --- a/python/fedml/data/stackoverflow_lr/utils.py +++ b/python/fedml/data/stackoverflow_lr/utils.py @@ -17,6 +17,15 @@ def get_word_count_file(data_dir): + """ + Get the path to the word count file. + + Args: + data_dir (str): The directory where the file is located. + + Returns: + str: The full path to the word count file. + """ # word_count_file_path global word_count_file_path if word_count_file_path is None: @@ -25,6 +34,15 @@ def get_word_count_file(data_dir): def get_tag_count_file(data_dir): + """ + Get the path to the tag count file. + + Args: + data_dir (str): The directory where the file is located. + + Returns: + str: The full path to the tag count file. + """ # tag_count_file_path global tag_count_file_path if tag_count_file_path is None: @@ -33,6 +51,16 @@ def get_tag_count_file(data_dir): def get_most_frequent_words(data_dir=None, vocab_size=10000): + """ + Get a list of the most frequent words. + + Args: + data_dir (str, optional): The directory where the word count file is located. + vocab_size (int, optional): The number of most frequent words to retrieve. + + Returns: + list: A list of the most frequent words. + """ frequent_words = [] with open(get_word_count_file(data_dir), "r") as f: frequent_words = [next(f).split()[0] for i in range(vocab_size)] @@ -40,12 +68,31 @@ def get_most_frequent_words(data_dir=None, vocab_size=10000): def get_tags(data_dir=None, tag_size=500): + """ + Get a list of tags. + + Args: + data_dir (str, optional): The directory where the tag count file is located. + tag_size (int, optional): The number of tags to retrieve. + + Returns: + list: A list of tags. + """ f = open(get_tag_count_file(data_dir), "r") frequent_tags = json.load(f) return list(frequent_tags.keys())[:tag_size] def get_word_dict(data_dir): + """ + Get a dictionary that maps words to their IDs. + + Args: + data_dir (str): The directory where the word count file is located. + + Returns: + collections.OrderedDict: A dictionary mapping words to their IDs. + """ global word_dict if word_dict == None: words = get_most_frequent_words(data_dir) @@ -56,6 +103,15 @@ def get_word_dict(data_dir): def get_tag_dict(data_dir): + """ + Get a dictionary that maps tags to their IDs. + + Args: + data_dir (str): The directory where the tag count file is located. + + Returns: + collections.OrderedDict: A dictionary mapping tags to their IDs. + """ global tag_dict if tag_dict == None: tags = get_tags(data_dir) @@ -66,6 +122,16 @@ def get_tag_dict(data_dir): def preprocess_inputs(sentences, data_dir): + """ + Preprocess a list of sentences into a bag-of-words representation. + + Args: + sentences (list): List of sentences to preprocess. + data_dir (str): The directory where the word count file is located. + + Returns: + list: List of preprocessed bag-of-words representations. + """ sentences = [sentence.split(" ") for sentence in sentences] vocab_size = len(get_word_dict(data_dir)) @@ -87,6 +153,16 @@ def to_bag_of_words(sentence): def preprocess_targets(tags, data_dir): + """ + Preprocess a list of tags into a bag-of-words representation. + + Args: + tags (list): List of tags to preprocess. + data_dir (str): The directory where the tag count file is located. + + Returns: + list: List of preprocessed bag-of-words representations. + """ tags = [tag.split("|") for tag in tags] tag_size = len(get_tag_dict(data_dir)) @@ -129,6 +205,16 @@ def to_bag_of_words(sentence): def preprocess_target(tag, data_dir): + """ + Preprocess a single sentence into a bag-of-words representation. + + Args: + sentence (str): The sentence to preprocess. + data_dir (str): The directory where the word count file is located. + + Returns: + numpy.ndarray: Preprocessed bag-of-words representation. + """ tag = tag.split("|") tag_size = len(get_tag_dict(data_dir)) diff --git a/python/fedml/data/stackoverflow_nwp/data_loader.py b/python/fedml/data/stackoverflow_nwp/data_loader.py index c1a1b2d008..ccc8bf8250 100644 --- a/python/fedml/data/stackoverflow_nwp/data_loader.py +++ b/python/fedml/data/stackoverflow_nwp/data_loader.py @@ -21,6 +21,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): + """ + Get data loaders for training and testing. + + Args: + dataset: The dataset object. + data_dir (str): The directory containing the data. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + client_idx (int or None): Index of the client (None for global dataset). + + Returns: + tuple: A tuple containing train and test data loaders (train_dl, test_dl). + """ def _tokenizer(x): return utils.tokenizer(x, data_dir) @@ -79,6 +92,19 @@ def _tokenizer(x): def load_partition_data_distributed_federated_stackoverflow_nwp( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for distributed federated StackOverflow NWP. + + Args: + process_id (int): The process ID or rank. + dataset: The dataset object. + data_dir (str): The directory containing the data. + batch_size (int): Batch size. + + Returns: + tuple: A tuple containing client number, train data number, global train data, + global test data, local data number, local train data, local test data, and vocabulary length. + """ # get global dataset if process_id == 0: diff --git a/python/fedml/data/stackoverflow_nwp/dataset.py b/python/fedml/data/stackoverflow_nwp/dataset.py index 8ebf6f07a7..c9ae22fdbc 100644 --- a/python/fedml/data/stackoverflow_nwp/dataset.py +++ b/python/fedml/data/stackoverflow_nwp/dataset.py @@ -4,29 +4,44 @@ class StackOverflowDataset(data.Dataset): - """StackOverflow dataset""" + """ + StackOverflow dataset. + + Args: + h5_path (str): Path to the h5 file. + client_idx (int): Index of the train file. + datast (str): "train" or "test" denoting the train set or test set. + preprocess (callable, optional): Optional preprocessing function. + + Attributes: + _EXAMPLE (str): Name of the "examples" attribute in the h5 file. + _TOKENS (str): Name of the "tokens" attribute in the h5 file. + + """ __train_client_id_list = None __test_client_id_list = None - def __init__(self, h5_path, client_idx, datast, preprocess): - """ - Args: - h5_path (string) : path to the h5 file - client_idx (idx) : index of train file - datast (string) : "train" or "test" denoting on train set or test set - preprocess (callable, optional) : Optional preprocessing - """ - + def __init__(self, h5_path, client_idx, datast, preprocess=None): self._EXAMPLE = "examples" self._TOKENS = "tokens" self.h5_path = h5_path self.datast = datast - self.client_id = self.get_client_id_list()[client_idx] # pylint: disable=E1136 + self.client_id = self.get_client_id_list()[client_idx] + self.preprocess = preprocess def get_client_id_list(self): + """ + Get the list of client IDs for the specified dataset. + + Returns: + list: List of client IDs. + + Raises: + Exception: If an invalid dataset is specified. + """ if self.datast == "train": if StackOverflowDataset.__train_client_id_list is None: with h5py.File(self.h5_path, "r") as h5_file: @@ -42,7 +57,7 @@ def get_client_id_list(self): ) return StackOverflowDataset.__test_client_id_list else: - raise Exception("Please specify either train or test set!") + raise Exception("Please specify either 'train' or 'test' set!") def __len__(self): with h5py.File(self.h5_path, "r") as h5_file: @@ -50,8 +65,7 @@ def __len__(self): def __getitem__(self, idx): with h5py.File(self.h5_path, "r") as h5_file: - sample = h5_file[self._EXAMPLE][self.client_id][self._TOKENS][()][ - idx - ].decode("utf8") - sample = self.preprocess(sample) + sample = h5_file[self._EXAMPLE][self.client_id][self._TOKENS][()][idx].decode("utf8") + if self.preprocess is not None: + sample = self.preprocess(sample) return np.asarray(sample[:-1]), np.asarray(sample[1:]) diff --git a/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py b/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py index 014587f6d6..bde7ab7304 100644 --- a/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py +++ b/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py @@ -8,12 +8,35 @@ def softmax(x): + """ + Compute the softmax function for an array of values. + + Args: + x (numpy.ndarray): Input array. + + Returns: + numpy.ndarray: Softmax probabilities for the input array. + """ ex = np.exp(x) sum_ex = np.sum(np.exp(x)) return ex / sum_ex def generate_synthetic(alpha, beta, iid): + """ + Generate synthetic data for federated learning. + + Args: + NUM_USER (int): Number of users/clients. + alpha (float): Mean of the normal distribution for generating model weights. + beta (float): Mean of the normal distribution for generating model bias. + iid (int): Indicator for generating independent (1) or non-independent (0) data. + + Returns: + tuple: A tuple containing synthetic data for X (features) and y (labels). + - X_split (list): List of lists containing feature data for each user. + - y_split (list): List of lists containing label data for each user. + """ dimension = 60 NUM_CLASS = 10 np.random.seed(0) diff --git a/python/fedml/data/synthetic_0_0/generate_synthetic.py b/python/fedml/data/synthetic_0_0/generate_synthetic.py index 53c544944d..be1f67efc8 100644 --- a/python/fedml/data/synthetic_0_0/generate_synthetic.py +++ b/python/fedml/data/synthetic_0_0/generate_synthetic.py @@ -8,12 +8,35 @@ def softmax(x): + """ + Compute the softmax function for an array of values. + + Args: + x (numpy.ndarray): Input array. + + Returns: + numpy.ndarray: Softmax probabilities for the input array. + """ ex = np.exp(x) sum_ex = np.sum(np.exp(x)) return ex / sum_ex def generate_synthetic(alpha, beta, iid): + """ + Generate synthetic data for federated learning. + + Args: + NUM_USER (int): Number of users/clients. + alpha (float): Mean of the normal distribution for generating model weights. + beta (float): Mean of the normal distribution for generating model bias. + iid (int): Indicator for generating independent (1) or non-independent (0) data. + + Returns: + tuple: A tuple containing synthetic data for X (features) and y (labels). + - X_split (list): List of lists containing feature data for each user. + - y_split (list): List of lists containing label data for each user. + """ dimension = 60 NUM_CLASS = 10 np.random.seed(0) diff --git a/python/fedml/data/synthetic_1_1/data_loader.py b/python/fedml/data/synthetic_1_1/data_loader.py index 021963596c..5ad41fed63 100644 --- a/python/fedml/data/synthetic_1_1/data_loader.py +++ b/python/fedml/data/synthetic_1_1/data_loader.py @@ -14,8 +14,28 @@ def load_partition_data_federated_synthetic_1_1( - data_dir=None, batch_size=DEFAULT_BATCH_SIZE + train_file_path, test_file_path, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load federated synthetic data for training and testing. + + Args: + train_file_path (str): Path to the training data JSON file. + test_file_path (str): Path to the testing data JSON file. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing client and data-related information: + - client_num (int): Number of clients. + - train_data_num (int): Number of samples in the global training dataset. + - test_data_num (int): Number of samples in the global testing dataset. + - train_data_global (torch.utils.data.DataLoader): DataLoader for the global training dataset. + - test_data_global (torch.utils.data.DataLoader): DataLoader for the global testing dataset. + - data_local_num_dict (dict): Dictionary containing the number of samples for each client. + - train_data_local_dict (dict): Dictionary of DataLoader objects for local training data. + - test_data_local_dict (dict): Dictionary of DataLoader objects for local testing data. + - output_dim (int): Dimension of the output (e.g., number of classes). + """ logging.info("load_partition_data_federated_synthetic_1_1 START") with open(train_file_path, "r") as train_f, open(test_file_path, "r") as test_f: @@ -118,7 +138,13 @@ def load_partition_data_federated_synthetic_1_1( ) -def test_data_loader(): +def test_data_loader(train_file_path): + """ + Test the data loader function by comparing the number of samples with the original data. + + Args: + train_file_path (str): Path to the training data JSON file. + """ ( client_num, train_data_num, @@ -129,9 +155,10 @@ def test_data_loader(): train_data_local_dict, test_data_local_dict, output_dim, - ) = load_partition_data_federated_synthetic_1_1() - f = open(train_file_path, "r") - train_data = json.load(f) + ) = load_partition_data_federated_synthetic_1_1(train_file_path, train_file_path) + + with open(train_file_path, "r") as f: + train_data = json.load(f) assert train_data["num_samples"] == list(data_local_num_dict.values()) diff --git a/python/fedml/data/synthetic_1_1/generate_synthetic.py b/python/fedml/data/synthetic_1_1/generate_synthetic.py index 9a46ad43ca..edc2b15b65 100644 --- a/python/fedml/data/synthetic_1_1/generate_synthetic.py +++ b/python/fedml/data/synthetic_1_1/generate_synthetic.py @@ -8,23 +8,43 @@ def softmax(x): + """ + Compute the softmax values for a given array. + + Args: + x (numpy.ndarray): Input array. + + Returns: + numpy.ndarray: Softmax values. + """ ex = np.exp(x) sum_ex = np.sum(np.exp(x)) return ex / sum_ex def generate_synthetic(alpha, beta, iid): + """ + Generate synthetic data for federated learning. + + Args: + alpha (float): Mean of user weights. + beta (float): Mean of user biases. + iid (int): Unused parameter. + + Returns: + list: List of user data samples. + list: List of labels for user data samples. + """ dimension = 60 NUM_CLASS = 10 np.random.seed(0) samples_per_user = np.random.lognormal(4, 2, (NUM_USER)).astype(int) + 50 - print(samples_per_user) - # num_samples = np.sum(samples_per_user) + X_split = [[] for _ in range(NUM_USER)] y_split = [[] for _ in range(NUM_USER)] - #### define some eprior #### + mean_W = np.random.normal(0, alpha, NUM_USER) mean_b = mean_W B = np.random.normal(0, beta, NUM_USER) @@ -37,7 +57,7 @@ def generate_synthetic(alpha, beta, iid): for i in range(NUM_USER): mean_x[i] = np.random.normal(B[i], 1, dimension) - # print(mean_x[i]) + for i in range(NUM_USER): @@ -54,7 +74,7 @@ def generate_synthetic(alpha, beta, iid): X_split[i] = xx.tolist() y_split[i] = yy.tolist() - print("{}-th users has {} exampls".format(i, len(y_split[i]))) + return X_split, y_split diff --git a/python/fedml/data/synthetic_1_1/stats.py b/python/fedml/data/synthetic_1_1/stats.py index 5328562587..43f57bccd0 100755 --- a/python/fedml/data/synthetic_1_1/stats.py +++ b/python/fedml/data/synthetic_1_1/stats.py @@ -23,6 +23,16 @@ def load_data(name): + """ + Load user and sample data from JSON files in a dataset directory. + + Args: + name (str): The name of the dataset. + + Returns: + list: List of user names. + list: List of the number of samples per user. + """ users = [] num_samples = [] @@ -57,6 +67,15 @@ def load_data(name): def print_dataset_stats(name): + """ + Print statistics of a dataset, including the number of users and samples. + + Args: + name (str): The name of the dataset. + + Returns: + None + """ users, num_samples = load_data(name) num_users = len(users) From 4057c0f8e840ea22aae3b7d3b86cff5d75de0ba0 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Mon, 18 Sep 2023 11:41:09 +0530 Subject: [PATCH 26/70] add --- .../fedml/data/FederatedEMNIST/data_loader.py | 39 ++++ python/fedml/data/cinic10/data_loader.py | 157 +++++++++++++- python/fedml/data/cinic10/datasets.py | 20 +- .../data/edge_case_examples/data_loader.py | 100 ++++++++- .../fedml/data/edge_case_examples/datasets.py | 157 ++++++++++++-- python/fedml/data/fed_cifar100/data_loader.py | 48 ++++- python/fedml/data/fed_cifar100/utils.py | 23 ++- .../fedml/data/fed_shakespeare/data_loader.py | 65 ++++-- python/fedml/data/fed_shakespeare/utils.py | 65 +++++- .../base/data_manager/base_data_manager.py | 111 +++++++++- .../base/preprocess/base_preprocessor.py | 30 +++ .../base/raw_data/base_raw_data_loader.py | 192 ++++++++++++++++++ .../data/fednlp/base/raw_data/partition.py | 21 ++ python/fedml/data/fednlp/base/utils.py | 92 ++++++++- 14 files changed, 1057 insertions(+), 63 deletions(-) diff --git a/python/fedml/data/FederatedEMNIST/data_loader.py b/python/fedml/data/FederatedEMNIST/data_loader.py index c8160b2a7f..fb4c1414b5 100644 --- a/python/fedml/data/FederatedEMNIST/data_loader.py +++ b/python/fedml/data/FederatedEMNIST/data_loader.py @@ -21,6 +21,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): + """ + Create data loaders for training and testing data. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory where the dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + client_idx (int or None): Index of the client to load data for. If None, load data for all clients. + + Returns: + tuple: A tuple containing the training and testing data loaders. + """ train_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TEST_FILE), "r") @@ -76,6 +89,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): def load_partition_data_distributed_federated_emnist( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for a federated EMNIST dataset. + + Args: + process_id (int): The ID of the current process (0 for server, >0 for clients). + dataset (str): The dataset name. + data_dir (str): The directory where the dataset is located. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing information about the dataset, including the number of clients, + the number of samples in the global training data, global and local data loaders, and class number. + """ if process_id == 0: # get global dataset @@ -133,6 +159,19 @@ def load_partition_data_distributed_federated_emnist( def load_partition_data_federated_emnist( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for federated EMNIST dataset. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory where the dataset is located. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing information about the dataset, including the number of clients, + the number of samples in the global training and testing data, global and local data loaders, + the number of samples per client, and the class number. + """ # client ids train_file_path = os.path.join(data_dir, DEFAULT_TRAIN_FILE) diff --git a/python/fedml/data/cinic10/data_loader.py b/python/fedml/data/cinic10/data_loader.py index 10515a1569..13a28552c4 100644 --- a/python/fedml/data/cinic10/data_loader.py +++ b/python/fedml/data/cinic10/data_loader.py @@ -14,6 +14,15 @@ def read_data_distribution( filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt", ): + """ + Reads the data distribution from a text file. + + Args: + filename (str): The path to the distribution file. + + Returns: + dict: A dictionary representing the data distribution. + """ distribution = {} with open(filename, "r") as data: for x in data.readlines(): @@ -33,6 +42,15 @@ def read_data_distribution( def read_net_dataidx_map( filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt", ): + """ + Reads the network data index map from a text file. + + Args: + filename (str): The path to the network data index map file. + + Returns: + dict: A dictionary mapping network IDs to data indices. + """ net_dataidx_map = {} with open(filename, "r") as data: for x in data.readlines(): @@ -48,6 +66,16 @@ def read_net_dataidx_map( def record_net_data_stats(y_train, net_dataidx_map): + """ + Records network-specific data statistics. + + Args: + y_train (numpy.ndarray): Array of ground truth labels for the entire dataset. + net_dataidx_map (dict): A dictionary mapping network IDs to data indices. + + Returns: + dict: A dictionary containing network-specific class counts. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -63,6 +91,15 @@ def __init__(self, length): self.length = length def __call__(self, img): + """ + Applies the Cutout augmentation to an image. + + Args: + img (torch.Tensor): The input image. + + Returns: + torch.Tensor: The image with the Cutout augmentation applied. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -81,8 +118,15 @@ def __call__(self, img): def _data_transforms_cinic10(): + """ + Define data transformations for the CIFAR-10 dataset. + + Returns: + tuple: A tuple containing two transformation functions, one for training and one for validation/test. + """ cinic_mean = [0.47889522, 0.47227842, 0.43047404] cinic_std = [0.24205776, 0.23828046, 0.25874835] + # Transformer for train set: random crops and horizontal flip train_transform = transforms.Compose( [ @@ -120,6 +164,15 @@ def _data_transforms_cinic10(): def load_cinic10_data(datadir): + """ + Load CIFAR-10 data from the specified directory. + + Args: + datadir (str): The directory containing CIFAR-10 data. + + Returns: + tuple: A tuple containing training and testing data. + """ _train_dir = datadir + str("/train") logging.info("_train_dir = " + str(_train_dir)) _test_dir = datadir + str("/test") @@ -168,6 +221,19 @@ def load_cinic10_data(datadir): def partition_data(dataset, datadir, partition, n_nets, alpha): + """ + Partition the dataset into subsets for federated learning. + + Args: + dataset: The dataset to be partitioned. + datadir (str): The directory containing the dataset. + partition (str): The type of partitioning to be applied ("homo", "hetero", "hetero-fix"). + n_nets (int): The number of clients (networks) to partition the data for. + alpha (float): A hyperparameter controlling the heterogeneity of the data partition. + + Returns: + tuple: A tuple containing partitioned data and related information. + """ logging.info("*********partition data***************") X_train, y_train, X_test, y_test = load_cinic10_data(datadir) @@ -176,7 +242,7 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): y_train = np.array(y_train) y_test = np.array(y_test) n_train = len(X_train) - # n_test = len(X_test) + if partition == "homo": total_num = n_train @@ -193,12 +259,12 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): while min_size < 10: idx_batch = [[] for _ in range(n_nets)] - # for each class in the dataset + for k in range(K): idx_k = np.where(y_train == k)[0] np.random.shuffle(idx_k) proportions = np.random.dirichlet(np.repeat(alpha, n_nets)) - ## Balance + proportions = np.array( [ p * (len(idx_j) < N / n_nets) @@ -234,21 +300,60 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts -# for centralized training + def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for centralized training using the CIFAR-10 dataset. + + Args: + dataset (str): The dataset name. + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training. Default is None. + + Returns: + tuple: A tuple containing the training and testing data loaders. + """ return get_dataloader_cinic10(datadir, train_bs, test_bs, dataidxs) -# for local devices + def get_dataloader_test( dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ): + """ + Get data loaders for decentralized (local devices) testing using the CIFAR-10 dataset. + + Args: + dataset (str): The dataset name. + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to use for training. + dataidxs_test (list): List of data indices to use for testing. + + Returns: + tuple: A tuple containing the training and testing data loaders for local devices. + """ return get_dataloader_test_cinic10( datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ) def get_dataloader_cinic10(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for centralized training using the CIFAR-10 dataset. + + Args: + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training. Default is None. + + Returns: + tuple: A tuple containing the training and testing data loaders. + """ dl_obj = ImageFolderTruncated transform_train, transform_test = _data_transforms_cinic10() @@ -272,6 +377,19 @@ def get_dataloader_cinic10(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_cinic10( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loaders for decentralized (local devices) testing using the CIFAR-10 dataset. + + Args: + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to use for training. + dataidxs_test (list): List of data indices to use for testing. + + Returns: + tuple: A tuple containing the training and testing data loaders for local devices. + """ dl_obj = ImageFolderTruncated transform_train, transform_test = _data_transforms_cinic10() @@ -301,6 +419,21 @@ def load_partition_data_distributed_cinic10( client_number, batch_size, ): + """ + Load partitioned data for distributed training using the CIFAR-10 dataset. + + Args: + process_id (int): The ID of the current process. + dataset (str): The dataset name. + data_dir (str): The directory containing the dataset. + partition_method (str): The data partitioning method (e.g., 'homo' or 'hetero'). + partition_alpha (float): The alpha parameter for data partitioning. + client_number (int): The number of clients. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing training and testing data information for distributed training. + """ ( X_train, y_train, @@ -360,6 +493,20 @@ def load_partition_data_distributed_cinic10( def load_partition_data_cinic10( dataset, data_dir, partition_method, partition_alpha, client_number, batch_size ): + """ + Load partitioned data for centralized training using the CIFAR-10 dataset. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory containing the dataset. + partition_method (str): The data partitioning method (e.g., 'homo' or 'hetero'). + partition_alpha (float): The alpha parameter for data partitioning. + client_number (int): The number of clients. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing training and testing data information for centralized training. + """ ( X_train, y_train, diff --git a/python/fedml/data/cinic10/datasets.py b/python/fedml/data/cinic10/datasets.py index a82d553ac0..515cf63de3 100644 --- a/python/fedml/data/cinic10/datasets.py +++ b/python/fedml/data/cinic10/datasets.py @@ -16,11 +16,29 @@ def default_loader(path): + """ + Default image loader function. + + Args: + path (str): The file path to the image. + + Returns: + PIL.Image.Image: An RGB image loaded from the specified path. + """ return pil_loader(path) def pil_loader(path): - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + """ + Image loader function using the PIL library. + + Args: + path (str): The file path to the image. + + Returns: + PIL.Image.Image: An RGB image loaded from the specified path. + """ + # Open the path as a file to avoid ResourceWarning with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") diff --git a/python/fedml/data/edge_case_examples/data_loader.py b/python/fedml/data/edge_case_examples/data_loader.py index 726d262919..8ecbd44d08 100644 --- a/python/fedml/data/edge_case_examples/data_loader.py +++ b/python/fedml/data/edge_case_examples/data_loader.py @@ -28,6 +28,12 @@ def download_edgecase_data(data_cache_dir): + """ + Download edge case attack data and extract it to the specified directory. + + Args: + data_cache_dir (str): The directory where the data should be downloaded and extracted. + """ file_path = data_cache_dir + "/edge_case_examples.zip" logging.info(file_path) URL = "http://pages.cs.wisc.edu/~hongyiwang/edge_case_attack/edge_case_examples.zip" @@ -38,8 +44,17 @@ def download_edgecase_data(data_cache_dir): with zipfile.ZipFile(file_path, "r") as zip_ref: zip_ref.extractall(data_cache_dir) - def record_net_data_stats(y_train, net_dataidx_map): + """ + Record data statistics for each network based on the provided data index mapping. + + Args: + y_train (numpy.ndarray): The labels of the training data. + net_dataidx_map (dict): A dictionary mapping network indices to data indices. + + Returns: + dict: A dictionary containing class counts for each network. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -51,6 +66,15 @@ def record_net_data_stats(y_train, net_dataidx_map): def load_mnist_data(datadir): + """ + Load the MNIST dataset from the specified directory. + + Args: + datadir (str): The directory where the dataset is stored. + + Returns: + tuple: A tuple containing training and testing data and labels for MNIST. + """ transform = transforms.Compose([transforms.ToTensor()]) mnist_train_ds = MNIST_truncated( @@ -72,6 +96,15 @@ def load_mnist_data(datadir): def load_emnist_data(datadir): + """ + Load the EMNIST dataset from the specified directory. + + Args: + datadir (str): The directory where the dataset is stored. + + Returns: + tuple: A tuple containing training and testing data and labels for EMNIST. + """ transform = transforms.Compose([transforms.ToTensor()]) emnist_train_ds = EMNIST_truncated( @@ -93,6 +126,15 @@ def load_emnist_data(datadir): def load_cifar10_data(datadir): + """ + Load the CIFAR-10 dataset from the specified directory. + + Args: + datadir (str): The directory where the dataset is stored. + + Returns: + tuple: A tuple containing training and testing data and labels for CIFAR-10. + """ transform = transforms.Compose([transforms.ToTensor()]) cifar10_train_ds = CIFAR10_truncated( @@ -109,6 +151,20 @@ def load_cifar10_data(datadir): def partition_data(dataset, datadir, partition, n_nets, alpha, args): + """ + Partition the dataset based on the specified method and parameters. + + Args: + dataset (str): The name of the dataset (e.g., "mnist", "emnist", "cifar10"). + datadir (str): The directory where the dataset is stored. + partition (str): The partitioning method ("homo" or "hetero-dir"). + n_nets (int): The number of clients/networks. + alpha (float): A parameter for data partitioning. + args: Additional arguments. + + Returns: + dict: A dictionary mapping network indices to data indices. + """ if dataset == "mnist": X_train, y_train, X_test, y_test = load_mnist_data(datadir) n_train = X_train.shape[0] @@ -244,6 +300,19 @@ def partition_data(dataset, datadir, partition, n_nets, alpha, args): def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for the specified dataset. + + Args: + dataset (str): The name of the dataset (e.g., "mnist", "emnist", "cifar10"). + datadir (str): The directory where the dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (dict): A dictionary mapping network indices to data indices. + + Returns: + tuple: A tuple containing training and testing data loaders. + """ if dataset in ("mnist", "emnist", "cifar10"): if dataset == "mnist": dl_obj = MNIST_truncated @@ -320,6 +389,24 @@ def get_dataloader_normal_case( ardis_dataset=None, attack_case="normal-case", ): + """ + Get data loaders for the specified dataset with support for poison attacks. + + Args: + dataset (str): The name of the dataset (e.g., "mnist", "emnist", "cifar10"). + datadir (str): The directory where the dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (dict): A dictionary mapping network indices to data indices. + user_id (int): The user ID for poison attack. + num_total_users (int): The total number of users. + poison_type (str): The type of poison attack (e.g., "southwest"). + ardis_dataset: ARDIS dataset for poison attack (if applicable). + attack_case (str): The type of attack case (e.g., "normal-case"). + + Returns: + tuple: A tuple containing training and testing data loaders. + """ if dataset in ("mnist", "emnist", "cifar10"): if dataset == "mnist": dl_obj = MNIST_truncated @@ -391,6 +478,17 @@ def get_dataloader_normal_case( def load_poisoned_dataset(args): + """ + Load a poisoned dataset based on the provided arguments. + + Args: + args (Namespace): Command-line arguments containing dataset details. + + Returns: + DataLoader: DataLoader for the poisoned dataset. + DataLoader: DataLoader for the clean test dataset. + DataLoader: DataLoader for the targetted task test dataset. + """ use_cuda = not args.using_gpu and torch.cuda.is_available() kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} if args.dataset in ("mnist", "emnist"): diff --git a/python/fedml/data/edge_case_examples/datasets.py b/python/fedml/data/edge_case_examples/datasets.py index e493a16750..2cf26a649f 100644 --- a/python/fedml/data/edge_case_examples/datasets.py +++ b/python/fedml/data/edge_case_examples/datasets.py @@ -37,6 +37,18 @@ class MNIST_truncated(data.Dataset): def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): + """ + Initialize the MNIST_truncated dataset. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of data indices to include in the dataset. + train (bool, optional): Whether to load the training or testing data. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the target. + download (bool, optional): Whether to download the dataset if it's not found. + """ + self.root = root self.dataidxs = dataidxs @@ -48,6 +60,13 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated dataset based on the provided data indices. + + Returns: + torch.Tensor: The truncated data. + torch.Tensor: The corresponding labels/targets. + """ mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download) @@ -94,6 +113,17 @@ class EMNIST_truncated(data.Dataset): def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): + """ + Initialize the EMNIST_truncated dataset. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of data indices to include in the dataset. + train (bool, optional): Whether to load the training or testing data. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the target. + download (bool, optional): Whether to download the dataset if it's not found. + """ self.root = root self.dataidxs = dataidxs @@ -105,6 +135,13 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated dataset based on the provided data indices. + + Returns: + torch.Tensor: The truncated data. + torch.Tensor: The corresponding labels/targets. + """ emnist_dataobj = EMNIST( self.root, split="digits", @@ -154,20 +191,28 @@ def __len__(self): def get_ardis_dataset(): - # load the data from csv's + """Load the ARDIS dataset and prepare it for training. + + This function loads the ARDIS dataset from CSV files, reshapes the images, + and prepares the dataset for training. + + Returns: + torch.utils.data.Dataset: The ARDIS dataset prepared for training. + """ + # Load the data from CSV files ardis_images = np.loadtxt("./../../../data/edge_case_examples/ARDIS/ARDIS_train_2828.csv", dtype="float") ardis_labels = np.loadtxt("./../../../data/edge_case_examples/ARDIS/ARDIS_train_labels.csv", dtype="float") - #### reshape to be [samples][width][height] + # Reshape the images to [samples][width][height] ardis_images = ardis_images.reshape(ardis_images.shape[0], 28, 28).astype("float32") - # labels are one-hot encoded + # Labels are one-hot encoded; extract images and labels for digit 7 indices_seven = np.where(ardis_labels[:, 7] == 1)[0] images_seven = ardis_images[indices_seven, :] images_seven = torch.tensor(images_seven).type(torch.uint8) + labels_seven = torch.tensor([7 for _ in ardis_labels]) - labels_seven = torch.tensor([7 for y in ardis_labels]) - + # Create an EMNIST dataset for digit 7 ardis_dataset = EMNIST( "./../../../data", split="digits", @@ -176,13 +221,23 @@ def get_ardis_dataset(): transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), ) + # Set the data and targets to the extracted images and labels ardis_dataset.data = images_seven ardis_dataset.targets = labels_seven return ardis_dataset - def get_southwest_dataset(attack_case="normal-case"): + """Load the Southwest dataset for a specified attack case. + + This function loads the Southwest dataset for a given attack case. + + Args: + attack_case (str): The attack case to load. Options are "normal-case" and "almost-edge-case". + + Returns: + pickle.Unpickler: The loaded Southwest dataset for the specified attack case. + """ if attack_case == "normal-case": with open( "./../../../data/edge_case_examples/southwest_cifar10/southwest_images_honest_full_normal.pkl", "rb", @@ -200,8 +255,8 @@ def get_southwest_dataset(attack_case="normal-case"): class EMNIST_NormalCase_truncated(data.Dataset): """ - we use this class for normal case attack where normal - users also hold the poisoned data point with true label + Dataset class for normal case attack where normal + users also hold the poisoned data point with true label. """ def __init__( @@ -218,7 +273,22 @@ def __init__( ardis_dataset_train=None, attack_case="normal-case", ): + """ + Initializes the EMNIST_NormalCase_truncated dataset. + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of indices to select specific data points. Default is None. + train (bool): True for training dataset, False for testing dataset. + transform (callable, optional): A function/transform to apply to the data. Default is None. + target_transform (callable, optional): A function/transform to apply to the target. Default is None. + download (bool): Whether to download the dataset if it's not found in the root directory. Default is False. + user_id (int): ID of the user accessing the dataset. + num_total_users (int): Total number of users in the scenario. + poison_type (str): Type of poisoning data. Default is "ardis". + ardis_dataset_train (torch.utils.data.Dataset): ARDIS dataset used for poisoning. Default is None. + attack_case (str): The type of attack case. Options are "normal-case" and "almost-edge-case". Default is "normal-case". + """ self.root = root self.dataidxs = dataidxs self.train = train @@ -229,9 +299,9 @@ def __init__( if attack_case == "normal-case": self._num_users_hold_edge_data = int( 3383 / 20 - ) # we allow 1/20 of the users (other than the attacker) to hold the edge data. + ) # We allow 1/20 of the users (other than the attacker) to hold the edge data. else: - # almost edge case + # Almost edge case self._num_users_hold_edge_data = 66 # ~2% of users hold data if poison_type == "ardis": @@ -249,17 +319,18 @@ def __init__( self.saved_ardis_dataset_train = self.ardis_dataset_train.data[user_partition] self.saved_ardis_label_train = self.ardis_dataset_train.targets[user_partition] else: - NotImplementedError("Unsupported poison type for normal case attack ...") + raise NotImplementedError("Unsupported poison type for normal case attack ...") - # logging.info("USER: {} got {} points".format(user_id, len(self.saved_ardis_dataset_train.data))) self.data, self.target = self.__build_truncated_dataset__() - # if self.dataidxs is not None: - # print("$$$$$$$$ Inside data loader: user ID: {}, Combined data: {}, Ori data shape: {}".format( - # user_id, self.data.shape, len(dataidxs))) - def __build_truncated_dataset__(self): + """ + Builds the truncated dataset by combining the EMNIST dataset with the ARDIS dataset. + Returns: + np.ndarray: Combined data. + np.ndarray: Combined target labels. + """ emnist_dataobj = EMNIST( self.root, split="digits", @@ -290,7 +361,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] @@ -307,6 +378,21 @@ def __len__(self): class CIFAR10_truncated(data.Dataset): + """ + Dataset class for a truncated version of the CIFAR-10 dataset. + + This class allows you to create a truncated version of the CIFAR-10 dataset + by selecting specific data indices. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of indices to select specific data points. Default is None. + train (bool): True for training dataset, False for testing dataset. + transform (callable, optional): A function/transform to apply to the data. Default is None. + target_transform (callable, optional): A function/transform to apply to the target. Default is None. + download (bool): Whether to download the dataset if it's not found in the root directory. Default is False. + """ + def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): @@ -321,12 +407,16 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Builds the truncated dataset by selecting specific data indices. + Returns: + np.ndarray: Combined data. + np.ndarray: Combined target labels. + """ cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) if self.train: - # print("train member of the class: {}".format(self.train)) - # data = cifar_dataobj.train_data data = cifar_dataobj.data target = np.array(cifar_dataobj.targets) else: @@ -345,7 +435,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] @@ -363,8 +453,8 @@ def __len__(self): class CIFAR10NormalCase_truncated(data.Dataset): """ - we use this class for normal case attack where normal - users also hold the poisoned data point with true label + Dataset class for normal case attack where normal + users also hold the poisoned data point with true label. """ def __init__( @@ -381,6 +471,22 @@ def __init__( ardis_dataset_train=None, attack_case="normal-case", ): + """ + Initializes the CIFAR10NormalCase_truncated dataset. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of indices to select specific data points. Default is None. + train (bool): True for training dataset, False for testing dataset. + transform (callable, optional): A function/transform to apply to the data. Default is None. + target_transform (callable, optional): A function/transform to apply to the target. Default is None. + download (bool): Whether to download the dataset if it's not found in the root directory. Default is False. + user_id (int): ID of the user accessing the dataset. + num_total_users (int): Total number of users in the scenario. + poison_type (str): Type of poisoning data. Default is "southwest". + ardis_dataset_train (np.ndarray): ARDIS dataset used for poisoning. Default is None. + attack_case (str): The type of attack case. Options are "normal-case" and "almost-edge-case". Default is "normal-case". + """ self.root = root self.dataidxs = dataidxs @@ -447,6 +553,13 @@ def __init__( # user_id, self.data.shape, len(dataidxs))) def __build_truncated_dataset__(self): + """ + Builds the truncated dataset by combining the CIFAR-10 dataset with the poisoned ARDIS dataset. + + Returns: + np.ndarray: Combined data. + np.ndarray: Combined target labels. + """ cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) diff --git a/python/fedml/data/fed_cifar100/data_loader.py b/python/fedml/data/fed_cifar100/data_loader.py index e54111e575..c05d909ca1 100644 --- a/python/fedml/data/fed_cifar100/data_loader.py +++ b/python/fedml/data/fed_cifar100/data_loader.py @@ -23,7 +23,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): - + """ + Get data loaders for training and testing. + + Args: + dataset (str): Dataset name. + data_dir (str): Directory containing the data. + train_bs (int): Batch size for training data loader. + test_bs (int): Batch size for testing data loader. + client_idx (int, optional): Index of the client to load data for. + + Returns: + tuple: A tuple containing the training data loader and testing data loader. + """ train_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TEST_FILE), "r") train_x = [] @@ -31,7 +43,7 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): test_x = [] test_y = [] - # load data in numpy format from h5 file + # Load data in numpy format from h5 file if client_idx is None: train_x = np.vstack( [ @@ -62,14 +74,14 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): [test_h5[_EXAMPLE][client_id_test][_LABEL][()]] ).squeeze() - # preprocess + # Preprocess train_x = utils.preprocess_cifar_img(torch.tensor(train_x), train=True) train_y = torch.tensor(train_y) if len(test_x) != 0: test_x = utils.preprocess_cifar_img(torch.tensor(test_x), train=False) test_y = torch.tensor(test_y) - # generate dataloader + # Generate data loader train_ds = data.TensorDataset(train_x, train_y) train_dl = data.DataLoader( dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=False @@ -91,11 +103,22 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): def load_partition_data_distributed_federated_cifar100( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): - + """ + Load distributed federated CIFAR-100 dataset for a specific client. + + Args: + process_id (int): Identifier of the client process. + dataset (str): Dataset name. + data_dir (str): Directory containing the data. + batch_size (int, optional): Batch size for data loader. + + Returns: + tuple: A tuple containing information about the dataset, including the number of classes. + """ class_num = 100 if process_id == 0: - # get global dataset + # Get global dataset train_data_global, test_data_global = get_dataloader( dataset, data_dir, batch_size, batch_size ) @@ -107,7 +130,7 @@ def load_partition_data_distributed_federated_cifar100( test_data_local = None local_data_num = 0 else: - # get local dataset + # Get local dataset train_data_local, test_data_local = get_dataloader( dataset, data_dir, batch_size, batch_size, process_id - 1 ) @@ -132,6 +155,17 @@ def load_partition_data_distributed_federated_cifar100( def load_partition_data_federated_cifar100( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load federated CIFAR-100 dataset for multiple clients. + + Args: + dataset (str): Dataset name. + data_dir (str): Directory containing the data. + batch_size (int, optional): Batch size for data loader. + + Returns: + tuple: A tuple containing information about the dataset, including the number of classes. + """ class_num = 100 diff --git a/python/fedml/data/fed_cifar100/utils.py b/python/fedml/data/fed_cifar100/utils.py index 323654200c..d2108ce4dc 100644 --- a/python/fedml/data/fed_cifar100/utils.py +++ b/python/fedml/data/fed_cifar100/utils.py @@ -9,7 +9,18 @@ # def cifar100_transform(img_mean, img_std, train=True, crop_size=(24, 24)): def cifar100_transform(img_mean, img_std, train=True, crop_size=32): - """cropping, flipping, and normalizing.""" + """ + Define data transformations for CIFAR-100 dataset. + + Args: + img_mean (tuple): Mean values for image normalization. + img_std (tuple): Standard deviation values for image normalization. + train (bool): Whether the transformations are for training or testing data. + crop_size (int): Size of the crop (default is 32). + + Returns: + torchvision.transforms.Compose: A composition of data transformations. + """ if train: return transforms.Compose( [ @@ -40,6 +51,16 @@ def cifar100_transform(img_mean, img_std, train=True, crop_size=32): def preprocess_cifar_img(img, train): + """ + Preprocess CIFAR-100 images for use in a PyTorch model. + + Args: + img (torch.Tensor): Input images. + train (bool): Whether the data is for training or testing. + + Returns: + torch.Tensor: Preprocessed images as a PyTorch tensor. + """ # scale img to range [0,1] to fit ToTensor api img = torch.div(img, 255.0) transoformed_img = torch.stack( diff --git a/python/fedml/data/fed_shakespeare/data_loader.py b/python/fedml/data/fed_shakespeare/data_loader.py index 5f21334dd6..aad409a6b3 100644 --- a/python/fedml/data/fed_shakespeare/data_loader.py +++ b/python/fedml/data/fed_shakespeare/data_loader.py @@ -21,19 +21,31 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): - + """ + Get data loaders for the specified dataset. + + Args: + dataset (str): The name of the dataset. + data_dir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + client_idx (int): Index of the client (None for all clients). + + Returns: + tuple: A tuple of DataLoader objects for training and testing data. + """ train_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TEST_FILE), "r") train_ds = [] test_ds = [] - # load data + # Load data if client_idx is None: - # get ids of all clients + # Get IDs of all clients train_ids = client_ids_train test_ids = client_ids_test else: - # get ids of single client + # Get IDs of a single client train_ids = [client_ids_train[client_idx]] test_ids = [client_ids_test[client_idx]] @@ -46,7 +58,7 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): raw_test = [x.decode("utf8") for x in raw_test] test_ds.extend(utils.preprocess(raw_test)) - # split data + # Split data train_x, train_y = utils.split(train_ds) test_x, test_y = utils.split(test_ds) train_ds = data.TensorDataset(torch.tensor(train_x[:, :]), torch.tensor(train_y[:])) @@ -62,26 +74,36 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): test_h5.close() return train_dl, test_dl - def load_partition_data_distributed_federated_shakespeare( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): - + """ + Load partitioned data for distributed federated learning with Shakespearean text data. + + Args: + process_id (int): The process ID of the current worker (0 for the server). + dataset (str): The name of the dataset. + data_dir (str): The directory containing the dataset. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing information about the data partitions and vocabulary size. + """ if process_id == 0: - # get global dataset + # Get global dataset train_data_global, test_data_global = get_dataloader( dataset, data_dir, batch_size, batch_size, process_id - 1 ) - train_data_num = len(train_data_global) - test_data_num = len(test_data_global) + train_data_num = len(train_data_global.dataset) + test_data_num = len(test_data_global.dataset) logging.info("train_dl_global number = " + str(train_data_num)) logging.info("test_dl_global number = " + str(test_data_num)) train_data_local = None test_data_local = None local_data_num = 0 else: - # get local dataset - # client id list + # Get local dataset + # Client ID list train_file_path = os.path.join(data_dir, DEFAULT_TRAIN_FILE) test_file_path = os.path.join(data_dir, DEFAULT_TEST_FILE) with h5py.File(train_file_path, "r") as train_h5, h5py.File( @@ -117,8 +139,18 @@ def load_partition_data_distributed_federated_shakespeare( def load_partition_data_federated_shakespeare( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): - - # client id list + """ + Load partitioned data for federated learning with Shakespearean text data. + + Args: + dataset (str): The name of the dataset. + data_dir (str): The directory containing the dataset. + batch_size (int): Batch size for data loaders (default is DEFAULT_BATCH_SIZE). + + Returns: + tuple: A tuple containing information about the data partitions and vocabulary size. + """ + # Client ID list train_file_path = os.path.join(data_dir, DEFAULT_TRAIN_FILE) test_file_path = os.path.join(data_dir, DEFAULT_TEST_FILE) with h5py.File(train_file_path, "r") as train_h5, h5py.File( @@ -128,7 +160,7 @@ def load_partition_data_federated_shakespeare( client_ids_train = list(train_h5[_EXAMPLE].keys()) client_ids_test = list(test_h5[_EXAMPLE].keys()) - # get local dataset + # Get local dataset data_local_num_dict = dict() train_data_local_dict = dict() test_data_local_dict = dict() @@ -149,7 +181,7 @@ def load_partition_data_federated_shakespeare( train_data_local_dict[client_idx] = train_data_local test_data_local_dict[client_idx] = test_data_local - # global dataset + # Global dataset train_data_global = data.DataLoader( data.ConcatDataset( list(dl.dataset for dl in list(train_data_local_dict.values())) @@ -185,3 +217,4 @@ def load_partition_data_federated_shakespeare( VOCAB_LEN, ) + diff --git a/python/fedml/data/fed_shakespeare/utils.py b/python/fedml/data/fed_shakespeare/utils.py index 8393710249..13db267b2e 100644 --- a/python/fedml/data/fed_shakespeare/utils.py +++ b/python/fedml/data/fed_shakespeare/utils.py @@ -21,8 +21,14 @@ def get_word_dict(): + """ + Get a dictionary mapping words to their corresponding IDs. + + Returns: + collections.OrderedDict: A dictionary with words as keys and their IDs as values. + """ global word_dict - if word_dict == None: + if word_dict is None: words = [_pad] + CHAR_VOCAB + [_bos] + [_eos] word_dict = collections.OrderedDict() for i, w in enumerate(words): @@ -31,18 +37,42 @@ def get_word_dict(): def get_word_list(): + """ + Get a list of words in the vocabulary. + + Returns: + list: A list of words in the vocabulary. + """ global word_list - if word_list == None: + if word_list is None: word_dict = get_word_dict() word_list = list(word_dict.keys()) return word_list def id_to_word(idx): + """ + Convert a word ID to the corresponding word. + + Args: + idx (int): The word ID. + + Returns: + str: The corresponding word. + """ return get_word_list()[idx] def char_to_id(char): + """ + Convert a character to its corresponding ID using the word_dict. + + Args: + char (str): The character to convert. + + Returns: + int: The corresponding ID for the character. + """ word_dict = get_word_dict() if char in word_dict: return word_dict[char] @@ -51,15 +81,29 @@ def char_to_id(char): def preprocess(sentences, max_seq_len=SEQUENCE_LENGTH): + """ + Preprocess a list of sentences by converting characters to IDs and padding. + Args: + sentences (list): A list of sentences, where each sentence is a string. + max_seq_len (int): Maximum sequence length (including start and end tokens). + + Returns: + list: A list of sequences, where each sequence is a list of token IDs. + """ sequences = [] def to_ids(sentence, num_oov_buckets=1): """ - map list of sentence to list of [idx..] and pad to max_seq_len + 1 + Map a sentence to a list of token IDs and pad it to the specified length. + Args: - num_oov_buckets : The number of out of vocabulary buckets. - max_seq_len: Integer determining shape of padded batches. + sentence (str): The input sentence. + num_oov_buckets (int): The number of out-of-vocabulary (OOV) buckets. + max_seq_len (int): Maximum sequence length (including start and end tokens). + + Returns: + list: A list of token IDs, padded to max_seq_len. """ tokens = [char_to_id(c) for c in sentence] tokens = [char_to_id(_bos)] + tokens + [char_to_id(_eos)] @@ -77,12 +121,23 @@ def to_ids(sentence, num_oov_buckets=1): def split(dataset): + """ + Split a dataset into input sequences (x) and target sequences (y). + + Args: + dataset (list): A list of sequences, where each sequence is a list of token IDs. + + Returns: + tuple: A tuple containing two arrays, x and y, where x represents input sequences + and y represents target sequences. + """ ds = np.asarray(dataset) x = ds[:, :-1] y = ds[:, 1:] return x, y + if __name__ == "__main__": print( split( diff --git a/python/fedml/data/fednlp/base/data_manager/base_data_manager.py b/python/fedml/data/fednlp/base/data_manager/base_data_manager.py index 4afc77e375..b0addb59f7 100644 --- a/python/fedml/data/fednlp/base/data_manager/base_data_manager.py +++ b/python/fedml/data/fednlp/base/data_manager/base_data_manager.py @@ -12,8 +12,29 @@ class BaseDataManager(ABC): + """Abstract base class for managing data in federated learning scenarios. + + This class defines the common interface and functionality for managing data in federated learning, + including loading, partitioning, and distributing datasets to clients. + + Attributes: + args: The command-line arguments passed to the manager. + model_args: The model-specific arguments. + train_batch_size: The batch size for training data. + eval_batch_size: The batch size for evaluation data. + process_id: The identifier of the current process. + num_workers: The total number of workers (including the server). + """ @abstractmethod def __init__(self, args, model_args, process_id, num_workers): + """Initialize the BaseDataManager. + + Args: + args: Command-line arguments. + model_args: Model-specific arguments. + process_id: Identifier of the current process. + num_workers: Total number of workers (including the server). + """ self.model_args = model_args self.args = args self.train_batch_size = model_args.train_batch_size @@ -44,6 +65,14 @@ def __init__(self, args, model_args, process_id, num_workers): @staticmethod def load_attributes(data_path): + """Load data attributes from an HDF5 data file. + + Args: + data_path: Path to the HDF5 data file. + + Returns: + Dictionary containing data attributes. + """ data_file = h5py.File(data_path, "r", swmr=True) attributes = json.loads(data_file["attributes"][()]) data_file.close() @@ -51,6 +80,15 @@ def load_attributes(data_path): @staticmethod def load_num_clients(partition_file_path, partition_name): + """Load the number of clients from a partition file. + + Args: + partition_file_path: Path to the partition file. + partition_name: Name of the partition. + + Returns: + The number of clients. + """ data_file = h5py.File(partition_file_path, "r", swmr=True) num_clients = int(data_file[partition_name]["n_clients"][()]) data_file.close() @@ -58,11 +96,27 @@ def load_num_clients(partition_file_path, partition_name): @abstractmethod def read_instance_from_h5(self, data_file, index_list, desc): + """Read instances from an HDF5 data file. + + Args: + data_file: HDF5 data file object. + index_list: List of indices to read. + desc: Description of the read operation. + + Returns: + Data instances. + """ pass def sample_client_index(self, process_id, num_workers): - """ - Sample client indices according to process_id + """Sample client indices according to the process_id. + + Args: + process_id (int): The identifier of the current process. + num_workers (int): The total number of workers. + + Returns: + list or None: A list of client indices if process_id is not 0, else None. """ # process_id = 0 means this process is the server process if process_id == 0: @@ -71,6 +125,14 @@ def sample_client_index(self, process_id, num_workers): return self._simulated_sampling(process_id) def _simulated_sampling(self, process_id): + """Simulated client sampling for federated learning. + + Args: + process_id (int): The identifier of the current process. + + Returns: + list: A list of sampled client indices. + """ res_client_indexes = list() for round_idx in range(self.args.comm_round): if self.num_clients == self.num_workers: @@ -92,6 +154,14 @@ def get_all_clients(self): return list(range(0, self.num_clients)) def load_centralized_data(self, cut_off=None): + """Load centralized training and testing data. + + Args: + cut_off (int, optional): The maximum number of data points to load. + + Returns: + tuple: A tuple containing centralized training and testing data loaders. + """ state, res = self._load_data_loader_from_cache(-1) if state: ( @@ -169,6 +239,14 @@ def load_centralized_data(self, cut_off=None): return train_dl, test_dl def load_federated_data(self, test_cut_off=None): + """Load federated training and testing data. + + Args: + test_cut_off (int, optional): The maximum number of testing data points to load. + + Returns: + tuple: A tuple containing federated training and testing data and related information. + """ ( train_data_num, test_data_num, @@ -193,6 +271,16 @@ def load_federated_data(self, test_cut_off=None): ) def _load_federated_data_server(self, test_only=False, test_cut_off=None): + """Load federated training and testing data from the server. + + Args: + test_only (bool, optional): Whether to load only testing data. Defaults to False. + test_cut_off (int, optional): The maximum number of testing data points to load. + + Returns: + tuple: A tuple containing the number of training data points, the number of testing data points, + federated training data loader, and federated testing data loader. + """ # state, res = self._load_data_loader_from_cache(-1) state = False train_data_local_dict = None @@ -288,6 +376,12 @@ def _load_federated_data_server(self, test_only=False, test_cut_off=None): return (train_data_num, test_data_num, train_data_global, test_data_global) def _load_federated_data_local(self): + """Load federated training and testing data for local clients. + + Returns: + tuple: A tuple containing dictionaries with local client data loaders, the number of clients, + and the number of training data points and testing data points. + """ data_file = h5py.File(self.args.data_file_path, "r", swmr=True) partition_file = h5py.File(self.args.partition_file_path, "r", swmr=True) @@ -397,8 +491,17 @@ def _load_federated_data_local(self): ) def _load_data_loader_from_cache(self, client_id): - """ - Different clients has different cache file. client_id = -1 means loading the cached file on server end. + """Load cached data loader from cache file for a specific client. + + Different clients has different cache file. client_id = -1 means + loading the cached file on server end. + + Args: + client_id (int): The ID of the client for which to load the cached data loader. + + Returns: + tuple: A tuple containing a boolean indicating whether the data loader was loaded from cache, + and the cached data loader if available. """ args = self.args model_args = self.model_args diff --git a/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py b/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py index 352b4d38fc..db1b6bd8b3 100644 --- a/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py +++ b/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py @@ -1,11 +1,41 @@ from abc import ABC, abstractmethod +from abc import ABC, abstractmethod + class BasePreprocessor(ABC): + """Abstract base class for data preprocessors. + + This class defines the common interface for data preprocessors, which are responsible for transforming + and preparing data for further processing or analysis. + + Attributes: + **kwargs: Additional keyword arguments specific to the preprocessor implementation. + + Methods: + transform(*args): Abstract method to transform data. + + """ + @abstractmethod def __init__(self, **kwargs): + """Initialize the BasePreprocessor with optional keyword arguments. + + Args: + **kwargs: Additional keyword arguments specific to the preprocessor implementation. + """ self.__dict__.update(kwargs) @abstractmethod def transform(self, *args): + """Transform data using the preprocessor. + + This method should be implemented by subclasses to apply data transformation operations. + + Args: + *args: Variable-length arguments representing the input data to be transformed. + + Returns: + Transformed data or processed result. + """ pass diff --git a/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py b/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py index 5f50a80a95..9665cb276c 100644 --- a/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py +++ b/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py @@ -7,27 +7,96 @@ class BaseRawDataLoader(ABC): + """Abstract base class for raw data loaders. + + This class defines the common interface for raw data loaders, which are responsible for loading + and processing raw data from various sources. + + Attributes: + data_path (str): The path to the raw data. + attributes (dict): A dictionary to store attributes related to the loaded data. + + Methods: + load_data(): Abstract method to load the raw data. + process_data_file(file_path): Abstract method to process a data file. + generate_h5_file(file_path): Abstract method to generate an HDF5 file from the loaded data. + + """ + @abstractmethod def __init__(self, data_path): + """Initialize the BaseRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ self.data_path = data_path self.attributes = dict() self.attributes["index_list"] = None @abstractmethod def load_data(self): + """Load the raw data. + + This method should be implemented by subclasses to load raw data from the specified data_path. + + Returns: + None + """ pass @abstractmethod def process_data_file(self, file_path): + """Process a data file. + + This method should be implemented by subclasses to process a specific data file. + + Args: + file_path (str): The path to the data file to be processed. + + Returns: + None + """ pass @abstractmethod def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + This method should be implemented by subclasses to generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ pass class TextClassificationRawDataLoader(BaseRawDataLoader): + """Raw data loader for text classification tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for text classification tasks. + + Attributes: + X (dict): A dictionary to store input data. + Y (dict): A dictionary to store target labels. + attributes (dict): Additional attributes related to the loaded data, including 'num_labels', + 'label_vocab', and 'task_type' which is set to "text_classification". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the TextClassificationRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(TextClassificationRawDataLoader, self).__init__(data_path) self.X = dict() self.Y = dict() @@ -36,6 +105,14 @@ def __init__(self, data_path): self.attributes["task_type"] = "text_classification" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in self.X.keys(): @@ -45,7 +122,30 @@ def generate_h5_file(self, file_path): class SpanExtractionRawDataLoader(BaseRawDataLoader): + """Raw data loader for span extraction tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for span extraction tasks. + + Attributes: + context_X (dict): A dictionary to store context input data. + question_X (dict): A dictionary to store question input data. + Y (dict): A dictionary to store target spans. + Y_answer (dict): A dictionary to store target answers. + attributes (dict): Additional attributes related to the loaded data, including 'task_type' which is + set to "span_extraction". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the SpanExtractionRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(SpanExtractionRawDataLoader, self).__init__(data_path) self.context_X = dict() self.question_X = dict() @@ -54,6 +154,14 @@ def __init__(self, data_path): self.Y_answer = dict() def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in self.context_X.keys(): @@ -65,7 +173,28 @@ def generate_h5_file(self, file_path): class SeqTaggingRawDataLoader(BaseRawDataLoader): + """Raw data loader for sequence tagging tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for sequence tagging tasks. + + Attributes: + X (dict): A dictionary to store input sequences. + Y (dict): A dictionary to store target labels. + attributes (dict): Additional attributes related to the loaded data, including 'num_labels', + 'label_vocab', and 'task_type' which is set to "seq_tagging". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the SeqTaggingRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(SeqTaggingRawDataLoader, self).__init__(data_path) self.X = dict() self.Y = dict() @@ -74,6 +203,14 @@ def __init__(self, data_path): self.attributes["task_type"] = "seq_tagging" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) utf8_type = h5py.string_dtype("utf-8", None) @@ -84,13 +221,41 @@ def generate_h5_file(self, file_path): class Seq2SeqRawDataLoader(BaseRawDataLoader): + """Raw data loader for sequence-to-sequence (seq2seq) tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for sequence-to-sequence tasks. + + Attributes: + X (dict): A dictionary to store source sequences. + Y (dict): A dictionary to store target sequences. + task_type (str): The type of the task, which is set to "seq2seq". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the Seq2SeqRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(Seq2SeqRawDataLoader, self).__init__(data_path) self.X = dict() self.Y = dict() self.task_type = "seq2seq" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in self.X.keys(): @@ -100,12 +265,39 @@ def generate_h5_file(self, file_path): class LanguageModelRawDataLoader(BaseRawDataLoader): + """Raw data loader for language modeling tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for language modeling tasks. + + Attributes: + X (dict): A dictionary to store language model input data. + task_type (str): The type of the task, which is set to "lm". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the LanguageModelRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(LanguageModelRawDataLoader, self).__init__(data_path) self.X = dict() self.task_type = "lm" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in tqdm(self.X.keys(), desc="generate data h5 file"): diff --git a/python/fedml/data/fednlp/base/raw_data/partition.py b/python/fedml/data/fednlp/base/raw_data/partition.py index e56e77e8bd..83d7f83634 100644 --- a/python/fedml/data/fednlp/base/raw_data/partition.py +++ b/python/fedml/data/fednlp/base/raw_data/partition.py @@ -5,6 +5,27 @@ def uniform_partition(train_index_list, test_index_list=None, n_clients=N_CLIENTS): + """Uniformly partition data indices into multiple clients. + + This function partitions a list of training data indices into 'n_clients' subsets, + ensuring a roughly equal distribution of data among clients. Optionally, it can also + partition a list of test data indices in a similar manner. + + Args: + train_index_list (list): List of training data indices. + test_index_list (list, optional): List of test data indices. Default is None. + n_clients (int): Number of clients to partition the data for. + + Returns: + dict: A dictionary containing the data partition information. + - 'n_clients': Number of clients. + - 'partition_data': A dictionary where each key represents a client ID (0 to n_clients-1), + and the value is another dictionary containing the partitioned data for that client. + For each client: + - 'train': List of training data indices. + - 'test': List of test data indices (if 'test_index_list' is provided). + + """ partition_dict = dict() partition_dict["n_clients"] = n_clients partition_dict["partition_data"] = dict() diff --git a/python/fedml/data/fednlp/base/utils.py b/python/fedml/data/fednlp/base/utils.py index 6048cfef05..2bd53a6c0e 100644 --- a/python/fedml/data/fednlp/base/utils.py +++ b/python/fedml/data/fednlp/base/utils.py @@ -18,6 +18,18 @@ class SpacyTokenizer: + """Tokenizer class for different languages using spaCy models. + + Attributes: + __zh_tokenizer: Chinese tokenizer instance. + __en_tokenizer: English tokenizer instance. + __cs_tokenizer: Czech tokenizer instance. + __de_tokenizer: German tokenizer instance. + __ru_tokenizer: Russian tokenizer instance. + + Methods: + get_tokenizer(lang): Get a spaCy tokenizer for the specified language. + """ def __init__(self): self.__zh_tokenizer = None self.__en_tokenizer = None @@ -27,6 +39,17 @@ def __init__(self): @staticmethod def get_tokenizer(lang): + """Get a spaCy tokenizer for the specified language. + + Args: + lang (str): The language code (e.g., "zh" for Chinese, "en" for English). + + Returns: + spacy.language.Language: A spaCy tokenizer instance. + + Raises: + Exception: If an unacceptable language code is provided. + """ if lang == "zh": # nlp = spacy.load("zh_core_web_sm") nlp = Chinese() @@ -46,37 +69,49 @@ def get_tokenizer(lang): @property def zh_tokenizer(self): + """Chinese tokenizer property.""" if self.__zh_tokenizer is None: self.__zh_tokenizer = self.get_tokenizer("zh") return self.__zh_tokenizer @property def en_tokenizer(self): + """English tokenizer property.""" if self.__en_tokenizer is None: self.__en_tokenizer = self.get_tokenizer("en") return self.__en_tokenizer @property def cs_tokenizer(self): + """Czech tokenizer property.""" if self.__cs_tokenizer is None: self.__cs_tokenizer = self.get_tokenizer("cs") return self.__cs_tokenizer @property def de_tokenizer(self): + """German tokenizer property.""" if self.__de_tokenizer is None: self.__de_tokenizer = self.get_tokenizer("de") return self.__de_tokenizer @property def ru_tokenizer(self): + """Russian tokenizer property.""" if self.__ru_tokenizer is None: self.__ru_tokenizer = self.get_tokenizer("ru") return self.__ru_tokenizer def build_vocab(x): - # x -> [num_seqs, num_tokens] + """Build a vocabulary from a list of tokenized sequences. + + Args: + x (list): List of tokenized sequences, where each sequence is a list of tokens. + + Returns: + dict: A vocabulary where tokens are keys and their corresponding indices are values. + """ vocab = dict() for single_x in x: for token in single_x: @@ -88,6 +123,14 @@ def build_vocab(x): def build_freq_vocab(x): + """Build a frequency-based vocabulary from a list of tokenized sequences. + + Args: + x (list): List of tokenized sequences, where each sequence is a list of tokens. + + Returns: + dict: A vocabulary where tokens are keys and their frequencies are values. + """ freq_vocab = dict() for single_x in x: for token in single_x: @@ -99,6 +142,16 @@ def build_freq_vocab(x): def padding_data(x, max_sequence_length): + """Pad sequences in a list to a specified maximum sequence length. + + Args: + x (list): List of sequences, where each sequence is a list of tokens. + max_sequence_length (int): The desired maximum sequence length for padding. + + Returns: + list: Padded sequences with a length of max_sequence_length. + list: Sequence lengths before padding. + """ padding_x = [] seq_lens = [] for single_x in x: @@ -115,6 +168,17 @@ def padding_data(x, max_sequence_length): def padding_char_data(x, max_sequence_length, max_word_length): + """Pad character-level sequences in a list to specified maximum lengths. + + Args: + x (list): List of sequences, where each sequence is a list of character tokens. + max_sequence_length (int): The desired maximum sequence length for padding. + max_word_length (int): The desired maximum word length for character tokens. + + Returns: + list: Padded character sequences with specified word and sequence lengths. + list: Word lengths before padding. + """ padding_x = [] word_lens = [] for sent in x: @@ -142,6 +206,15 @@ def padding_char_data(x, max_sequence_length, max_word_length): def token_to_idx(x, vocab): + """Convert tokenized sequences to indices using a vocabulary. + + Args: + x (list): List of tokenized sequences, where each sequence is a list of tokens. + vocab (dict): A vocabulary where tokens are keys and their corresponding indices are values. + + Returns: + list: Sequences with tokens replaced by their corresponding indices. + """ idx_x = [] for single_x in x: new_single_x = [] @@ -247,6 +320,15 @@ def NER_data_formatter(ner_data): def generate_h5_from_dict(file_name, data_dict): + """Generate an HDF5 file from a nested dictionary. + + Args: + file_name (str): The name of the HDF5 file to be created. + data_dict (dict): The nested dictionary containing data to be stored in the HDF5 file. + + Returns: + None + """ def dict_to_h5_recursive(h5_file, path, dic): for key, value in dic.items(): if isinstance(value, dict): @@ -270,6 +352,14 @@ def dict_to_h5_recursive(h5_file, path, dic): def decode_data_from_h5(data): + """Decode data from bytes to UTF-8 string if necessary. + + Args: + data (bytes or any): The input data, which may be in bytes. + + Returns: + str or any: The decoded data as a UTF-8 string, or the input data if it's not in bytes. + """ if isinstance(data, bytes): return data.decode("utf8") return data From 1fcfc561b45bb9a78b61a3012cbb9978c37da273 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Mon, 18 Sep 2023 12:58:19 +0530 Subject: [PATCH 27/70] fg --- python/fedml/data/cifar10/data_loader.py | 159 +++++++++++++- python/fedml/data/cifar10/datasets.py | 56 ++++- python/fedml/data/cifar10/efficient_loader.py | 200 +++++++++++++++++- python/fedml/data/cifar10/without_reload.py | 12 ++ python/fedml/data/cifar100/data_loader.py | 126 ++++++++++- python/fedml/data/cifar100/datasets.py | 44 +++- 6 files changed, 564 insertions(+), 33 deletions(-) diff --git a/python/fedml/data/cifar10/data_loader.py b/python/fedml/data/cifar10/data_loader.py index 459c1bbc53..cb85eb4a27 100644 --- a/python/fedml/data/cifar10/data_loader.py +++ b/python/fedml/data/cifar10/data_loader.py @@ -11,6 +11,15 @@ def read_data_distribution( filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt", ): + """ + Read data distribution from a file. + + Args: + filename (str, optional): Path to the distribution file (default: "./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt"). + + Returns: + dict: A dictionary representing the data distribution. + """ distribution = {} with open(filename, "r") as data: for x in data.readlines(): @@ -26,10 +35,18 @@ def read_data_distribution( ) return distribution - def read_net_dataidx_map( filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt", ): + """ + Read network data index map from a file. + + Args: + filename (str, optional): Path to the network data index map file (default: "./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt"). + + Returns: + dict: A dictionary representing the network data index map. + """ net_dataidx_map = {} with open(filename, "r") as data: for x in data.readlines(): @@ -43,8 +60,17 @@ def read_net_dataidx_map( net_dataidx_map[key] = [int(i.strip()) for i in tmp_array] return net_dataidx_map - def record_net_data_stats(y_train, net_dataidx_map): + """ + Record network data statistics. + + Args: + y_train (numpy.ndarray): Labels for the training data. + net_dataidx_map (dict): Network data index map. + + Returns: + dict: A dictionary containing network data statistics. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -54,12 +80,27 @@ def record_net_data_stats(y_train, net_dataidx_map): logging.debug("Data statistics: %s" % str(net_cls_counts)) return net_cls_counts - class Cutout(object): + """ + Apply cutout augmentation to an image. + + Args: + length (int): Length of the cutout square. + """ + def __init__(self, length): self.length = length def __call__(self, img): + """ + Apply cutout to the image. + + Args: + img (PIL.Image.Image): Input image. + + Returns: + PIL.Image.Image: Image with cutout applied. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -78,6 +119,12 @@ def __call__(self, img): def _data_transforms_cifar10(): + """ + Define data transformations for CIFAR-10 dataset. + + Returns: + tuple: A tuple of two transformations, one for training and one for validation. + """ CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] @@ -104,6 +151,15 @@ def _data_transforms_cifar10(): def load_cifar10_data(datadir): + """ + Load CIFAR-10 dataset. + + Args: + datadir (str): Directory where the CIFAR-10 dataset is located. + + Returns: + tuple: A tuple containing training and testing data and labels. + """ train_transform, test_transform = _data_transforms_cifar10() cifar10_train_ds = CIFAR10_truncated( @@ -120,11 +176,24 @@ def load_cifar10_data(datadir): def partition_data(dataset, datadir, partition, n_nets, alpha): + """ + Partition the CIFAR-10 dataset for federated learning. + + Args: + dataset: Not used, included for compatibility with your code. + datadir (str): Directory where the CIFAR-10 dataset is located. + partition (str): Partitioning method, can be "homo," "hetero," or "hetero-fix." + n_nets (int): Number of clients (networks). + alpha (float): Dirichlet distribution parameter for data partitioning. + + Returns: + tuple: A tuple containing data, labels, data index map, and data statistics. + """ np.random.seed(10) logging.info("*********partition data***************") X_train, y_train, X_test, y_test = load_cifar10_data(datadir) n_train = X_train.shape[0] - # n_test = X_test.shape[0] + if partition == "homo": total_num = n_train @@ -184,6 +253,19 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): # for centralized training def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for centralized training. + + Args: + dataset: Not used, included for compatibility with your code. + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to include (default: None). + + Returns: + DataLoader: Training and testing data loaders. + """ return get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs) @@ -191,12 +273,38 @@ def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test( dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ): + """ + Get data loaders for testing in CIFAR-10 dataset. + + Args: + dataset: Not used, included for compatibility with your code. + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to include in the training set. + dataidxs_test (list): List of data indices to include in the testing set. + + Returns: + DataLoader: Training and testing data loaders for CIFAR-10. + """ return get_dataloader_test_CIFAR10( datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ) def get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for CIFAR-10 dataset. + + Args: + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to include (default: None). + + Returns: + DataLoader: Training and testing data loaders for CIFAR-10. + """ dl_obj = CIFAR10_truncated transform_train, transform_test = _data_transforms_cifar10() @@ -219,6 +327,19 @@ def get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_CIFAR10( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loaders for testing CIFAR-10 dataset. + + Args: + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list, optional): List of data indices to include in the training set. + dataidxs_test (list, optional): List of data indices to include in the testing set. + + Returns: + DataLoader: Training and testing data loaders for CIFAR-10. + """ dl_obj = CIFAR10_truncated transform_train, transform_test = _data_transforms_cifar10() @@ -257,6 +378,21 @@ def load_partition_data_distributed_cifar10( client_number, batch_size, ): + """ + Load partitioned CIFAR-10 dataset for distributed training. + + Args: + process_id (int): ID of the current process. + dataset: Not used, included for compatibility with your code. + data_dir (str): Directory where the CIFAR-10 dataset is located. + partition_method (str): Partitioning method, can be "homo," "hetero," or "hetero-fix." + partition_alpha (float): Dirichlet distribution parameter for data partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + + Returns: + tuple: A tuple containing training and testing data loaders, data statistics, and class number. + """ ( X_train, y_train, @@ -318,6 +454,21 @@ def load_partition_data_cifar10( batch_size, n_proc_in_silo=0, ): + """ + Load partitioned CIFAR-10 dataset for federated learning. + + Args: + dataset: Not used, included for compatibility with your code. + data_dir (str): Directory where the CIFAR-10 dataset is located. + partition_method (str): Partitioning method, can be "homo," "hetero," or "hetero-fix." + partition_alpha (float): Dirichlet distribution parameter for data partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + n_proc_in_silo (int, optional): Number of processes in a silo (default: 0). + + Returns: + tuple: A tuple containing training and testing data loaders, data statistics, and class number. + """ ( X_train, y_train, diff --git a/python/fedml/data/cifar10/datasets.py b/python/fedml/data/cifar10/datasets.py index 8f64d1d76e..3df54283ca 100644 --- a/python/fedml/data/cifar10/datasets.py +++ b/python/fedml/data/cifar10/datasets.py @@ -18,17 +18,45 @@ def default_loader(path): - return pil_loader(path) + """ + Default loader function for loading images. + + Args: + path (str): Path to the image file. + Returns: + PIL.Image.Image: Loaded image in RGB format. + """ + return pil_loader(path) def pil_loader(path): + """ + Custom PIL image loader function. + + Args: + path (str): Path to the image file. + + Returns: + PIL.Image.Image: Loaded image in RGB format. + """ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") - class CIFAR10_truncated(data.Dataset): + """ + Custom dataset class for truncated CIFAR-10 data. + + Args: + root (str): Root directory where CIFAR-10 dataset is located. + dataidxs (list, optional): List of data indices to include (default: None). + train (bool, optional): Whether the dataset is for training (default: True). + transform (callable, optional): Optional transform to be applied to the image (default: None). + target_transform (callable, optional): Optional transform to be applied to the target (default: None). + download (bool, optional): Whether to download the dataset if not found (default: False). + """ + def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): @@ -43,12 +71,17 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated CIFAR-10 dataset. + + Returns: + tuple: Tuple containing data and target arrays. + """ print("download = " + str(self.download)) cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) if self.train: - # print("train member of the class: {}".format(self.train)) - # data = cifar_dataobj.train_data + data = cifar_dataobj.data target = np.array(cifar_dataobj.targets) else: @@ -62,6 +95,12 @@ def __build_truncated_dataset__(self): return data, target def truncate_channel(self, index): + """ + Truncate channels (G and B) in the images specified by the given index. + + Args: + index (numpy.ndarray): Array of indices specifying which images to truncate. + """ for i in range(index.shape[0]): gs_index = index[i] self.data[gs_index, :, :, 1] = 0.0 @@ -73,7 +112,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] @@ -86,4 +125,11 @@ def __getitem__(self, index): return img, target def __len__(self): + """ + Get the number of samples in the dataset. + + Returns: + int: Number of samples in the dataset. + """ return len(self.data) + \ No newline at end of file diff --git a/python/fedml/data/cifar10/efficient_loader.py b/python/fedml/data/cifar10/efficient_loader.py index d86edbb753..8751c7293b 100644 --- a/python/fedml/data/cifar10/efficient_loader.py +++ b/python/fedml/data/cifar10/efficient_loader.py @@ -10,7 +10,16 @@ # generate the non-IID distribution for all methods -def read_data_distribution(filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt",): +def read_data_distribution(filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt"): + """ + Read data distribution from a file and return it as a dictionary. + + Args: + filename (str): The path to the file containing data distribution information. + + Returns: + dict: A dictionary representing the data distribution. + """ distribution = {} with open(filename, "r") as data: for x in data.readlines(): @@ -24,8 +33,16 @@ def read_data_distribution(filename="./data_preprocessing/non-iid-distribution/C distribution[first_level_key][second_level_key] = int(tmp[1].strip().replace(",", "")) return distribution +def read_net_dataidx_map(filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt"): + """ + Read network data index mapping from a file and return it as a dictionary. + + Args: + filename (str): The path to the file containing network data index mapping information. -def read_net_dataidx_map(filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt",): + Returns: + dict: A dictionary representing the network data index mapping. + """ net_dataidx_map = {} with open(filename, "r") as data: for x in data.readlines(): @@ -41,6 +58,16 @@ def read_net_dataidx_map(filename="./data_preprocessing/non-iid-distribution/CIF def record_net_data_stats(y_train, net_dataidx_map): + """ + Record data statistics for each network based on network data index mapping. + + Args: + y_train (numpy.ndarray): The labels of the training data. + net_dataidx_map (dict): The network data index mapping. + + Returns: + dict: A dictionary containing data statistics for each network. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -56,6 +83,15 @@ def __init__(self, length): self.length = length def __call__(self, img): + """ + Apply Cutout augmentation to the input image. + + Args: + img (PIL.Image): The input image. + + Returns: + PIL.Image: The image after applying Cutout augmentation. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -74,6 +110,12 @@ def __call__(self, img): def _data_transforms_cifar10(): + """ + Define data transforms for CIFAR-10 dataset. + + Returns: + transforms.Compose: Training and validation data transforms. + """ CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] @@ -89,15 +131,32 @@ def _data_transforms_cifar10(): train_transform.transforms.append(Cutout(16)) - valid_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD),]) + valid_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD),] + ) return train_transform, valid_transform def load_cifar10_data(datadir, process_id, synthetic_data_url, private_local_data, resize=32, augmentation=True, data_efficient_load=False): + """ + Load CIFAR-10 dataset with specified configurations. + + Args: + datadir (str): Directory where CIFAR-10 dataset is stored. + process_id (int): ID of the current process. + synthetic_data_url (str): URL for synthetic data (not used in the provided code). + private_local_data (bool): Whether to use private local data (not used in the provided code). + resize (int): Resize images to this size (not used in the provided code). + augmentation (bool): Perform data augmentation (not used in the provided code). + data_efficient_load (bool): Load data efficiently (not used in the provided code). + + Returns: + tuple: Tuple containing X_train, y_train, X_test, y_test, cifar10_train_ds, and cifar10_test_ds. + """ train_transform, test_transform = _data_transforms_cifar10() - is_download = True; + is_download = True if data_efficient_load: cifar10_train_ds = CIFAR10(datadir, train=True, download=True, transform=train_transform) @@ -113,11 +172,27 @@ def load_cifar10_data(datadir, process_id, synthetic_data_url, private_local_dat def partition_data(dataset, datadir, partition, n_nets, alpha, process_id, synthetic_data_url, private_local_data): + """ + Partition the CIFAR-10 dataset into subsets for federated learning. + + Args: + dataset (str): Name of the dataset (not used in the provided code). + datadir (str): Directory where CIFAR-10 dataset is stored. + partition (str): Partitioning method (homo, hetero, hetero-fix). + n_nets (int): Number of clients (networks). + alpha (float): Alpha value for partitioning (not used in the provided code). + process_id (int): ID of the current process. + synthetic_data_url (str): URL for synthetic data (not used in the provided code). + private_local_data (bool): Whether to use private local data (not used in the provided code). + + Returns: + tuple: Tuple containing X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts, cifar10_train_ds, and cifar10_test_ds. + """ np.random.seed(10) logging.info("*********partition data***************") X_train, y_train, X_test, y_test, cifar10_train_ds, cifar10_test_ds = load_cifar10_data(datadir, process_id, synthetic_data_url, private_local_data) n_train = X_train.shape[0] - # n_test = X_test.shape[0] + if partition == "homo": total_num = n_train @@ -174,6 +249,22 @@ def get_dataloader( full_train_dataset=None, full_test_dataset=None, ): + """ + Get data loaders for CIFAR-10 dataset. + + Args: + dataset (str): Name of the dataset. + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list): List of data indices for custom data loading (default: None). + data_efficient_load (bool): Use data-efficient loading (default: False). + full_train_dataset: Full training dataset (default: None). + full_test_dataset: Full testing dataset (default: None). + + Returns: + tuple: Tuple containing training and testing data loaders. + """ return get_dataloader_CIFAR10( datadir, train_bs, @@ -184,12 +275,24 @@ def get_dataloader( full_test_dataset=full_test_dataset, ) - # for local devices def get_dataloader_test(dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test): + """ + Get data loaders for testing CIFAR-10 dataset on local devices. + + Args: + dataset (str): Name of the dataset. + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of training data indices. + dataidxs_test (list): List of testing data indices. + + Returns: + tuple: Tuple containing training and testing data loaders. + """ return get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test) - def get_dataloader_CIFAR10( datadir, train_bs, @@ -199,6 +302,21 @@ def get_dataloader_CIFAR10( full_train_dataset=None, full_test_dataset=None, ): + """ + Get data loaders for CIFAR-10 dataset. + + Args: + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list): List of data indices for custom data loading (default: None). + data_efficient_load (bool): Use data-efficient loading (default: False). + full_train_dataset: Full training dataset (default: None). + full_test_dataset: Full testing dataset (default: None). + + Returns: + tuple: Tuple containing training and testing data loaders. + """ transform_train, transform_test = _data_transforms_cifar10() if data_efficient_load: @@ -217,8 +335,20 @@ def get_dataloader_CIFAR10( return train_dl, test_dl - def get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None): + """ + Get data loaders for testing CIFAR-10 dataset on local devices. + + Args: + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of training data indices. + dataidxs_test (list): List of testing data indices. + + Returns: + tuple: Tuple containing training and testing data loaders. + """ dl_obj = CIFAR10_truncated transform_train, transform_test = _data_transforms_cifar10() @@ -231,7 +361,6 @@ def get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train=None, return train_dl, test_dl - def load_partition_data_distributed_cifar10( process_id, dataset, @@ -242,6 +371,24 @@ def load_partition_data_distributed_cifar10( batch_size, data_efficient_load=True, ): + """ + Load partitioned CIFAR-10 data for distributed learning. + + Args: + process_id (int): ID of the current process. + dataset (str): Name of the dataset. + data_dir (str): Directory where CIFAR-10 dataset is stored. + partition_method (str): Partitioning method (homo, hetero, hetero-fix). + partition_alpha (float): Alpha value for partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + data_efficient_load (bool): Use data-efficient loading (default: True). + + Returns: + tuple: Tuple containing training data size, global training data loader, + global testing data loader, local data size, local training data loader, + local testing data loader, and class count. + """ ( X_train, y_train, @@ -318,6 +465,28 @@ def efficient_load_partition_data_cifar10( n_proc_in_silo=0, data_efficient_load=True, ): + """ + Efficiently load partitioned CIFAR-10 data for distributed learning. + + Args: + dataset (str): Name of the dataset. + data_dir (str): Directory where CIFAR-10 dataset is stored. + partition_method (str): Partitioning method (homo, hetero, hetero-fix). + partition_alpha (float): Alpha value for partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + process_id (int): ID of the current process (default: 0). + synthetic_data_url (str): URL for synthetic data (default: ""). + private_local_data (str): Path to private local data (default: ""). + n_proc_in_silo (int): Number of processes in the silo (default: 0). + data_efficient_load (bool): Use data-efficient loading (default: True). + + Returns: + tuple: Tuple containing training data size, global testing data size, + global training data loader, global testing data loader, dictionary of + local data sample numbers, dictionary of local training data loaders, + dictionary of local testing data loaders, and class count. + """ ( X_train, y_train, @@ -327,7 +496,16 @@ def efficient_load_partition_data_cifar10( traindata_cls_counts, cifar10_train_ds, cifar10_test_ds, - ) = partition_data(dataset, data_dir, partition_method, client_number, partition_alpha, process_id, synthetic_data_url, private_local_data) + ) = partition_data( + dataset, + data_dir, + partition_method, + client_number, + partition_alpha, + process_id, + synthetic_data_url, + private_local_data, + ) class_num = len(np.unique(y_train)) logging.info("traindata_cls_counts = " + str(traindata_cls_counts)) train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)]) @@ -382,4 +560,4 @@ def efficient_load_partition_data_cifar10( train_data_local_dict, test_data_local_dict, class_num, - ) + ) \ No newline at end of file diff --git a/python/fedml/data/cifar10/without_reload.py b/python/fedml/data/cifar10/without_reload.py index c483f9f2a1..6f60f27e22 100644 --- a/python/fedml/data/cifar10/without_reload.py +++ b/python/fedml/data/cifar10/without_reload.py @@ -28,6 +28,12 @@ def __init__(self, root, dataidxs=None, train=True, transform=None, target_trans self.data, self.targets = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated CIFAR-10 dataset by loading data based on data indices. + + Returns: + tuple: A tuple containing the data and targets (class labels). + """ print("download = " + str(self.download)) cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) @@ -47,6 +53,12 @@ def __build_truncated_dataset__(self): return data, targets def truncate_channel(self, index): + """ + Truncate the green and blue channels of specified images in the dataset. + + Args: + index (numpy.ndarray): An array of indices indicating which images to truncate. + """ for i in range(index.shape[0]): gs_index = index[i] self.data[gs_index, :, :, 1] = 0.0 diff --git a/python/fedml/data/cifar100/data_loader.py b/python/fedml/data/cifar100/data_loader.py index bee691bd13..96b9f8ab72 100644 --- a/python/fedml/data/cifar100/data_loader.py +++ b/python/fedml/data/cifar100/data_loader.py @@ -78,6 +78,12 @@ def __call__(self, img): def _data_transforms_cifar100(): + """ + Get data transforms for CIFAR-100 dataset. + + Returns: + tuple: A tuple containing train and validation data transforms. + """ CIFAR_MEAN = [0.5071, 0.4865, 0.4409] CIFAR_STD = [0.2673, 0.2564, 0.2762] @@ -103,6 +109,15 @@ def _data_transforms_cifar100(): return train_transform, valid_transform def load_cifar100_data(datadir): + """ + Load CIFAR-100 dataset. + + Args: + datadir (str): The directory where CIFAR-100 dataset is stored. + + Returns: + tuple: A tuple containing training data, training labels, testing data, and testing labels. + """ train_transform, test_transform = _data_transforms_cifar100() cifar100_train_ds = CIFAR100_truncated( @@ -115,13 +130,26 @@ def load_cifar100_data(datadir): X_train, y_train = cifar100_train_ds.data, cifar100_train_ds.target X_test, y_test = cifar100_test_ds.data, cifar100_test_ds.target - return (X_train, y_train, X_test, y_test) + return X_train, y_train, X_test, y_test def partition_data(dataset, datadir, partition, n_nets, alpha): + """ + Partition CIFAR-100 data for federated learning. + + Args: + dataset (str): The dataset name. + datadir (str): The directory where CIFAR-100 dataset is stored. + partition (str): The data partitioning method ("homo", "hetero", or "hetero-fix"). + n_nets (int): The number of clients (networks). + alpha (float): Alpha parameter for data partitioning. + + Returns: + tuple: A tuple containing training data, training labels, testing data, testing labels, network data index map, and class counts. + """ logging.info("*********partition data***************") X_train, y_train, X_test, y_test = load_cifar100_data(datadir) n_train = X_train.shape[0] - # n_test = X_test.shape[0] + if partition == "homo": total_num = n_train @@ -179,21 +207,58 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts -# for centralized training + def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loader for centralized training. + + Args: + dataset (str): The dataset name. + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use. Defaults to None. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ return get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs) - -# for local devices def get_dataloader_test( dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ): + """ + Get data loader for local devices. + + Args: + dataset (str): The dataset name. + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of training data indices. + dataidxs_test (list): List of testing data indices. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ return get_dataloader_test_CIFAR100( datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ) def get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loader for CIFAR-100 dataset. + + Args: + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use. Defaults to None. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ dl_obj = CIFAR100_truncated transform_train, transform_test = _data_transforms_cifar100() @@ -216,6 +281,19 @@ def get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_CIFAR100( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loader for testing CIFAR-100 dataset. + + Args: + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list, optional): List of training data indices. Defaults to None. + dataidxs_test (list, optional): List of testing data indices. Defaults to None. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ dl_obj = CIFAR100_truncated transform_train, transform_test = _data_transforms_cifar100() @@ -254,6 +332,21 @@ def load_partition_data_distributed_cifar100( client_number, batch_size, ): + """ + Load partitioned CIFAR-100 data for distributed training. + + Args: + process_id (int): The process ID. + dataset (str): The dataset name. + data_dir (str): The directory where CIFAR-100 dataset is stored. + partition_method (str): The data partitioning method ("homo", "hetero", or "hetero-fix"). + partition_alpha (float): Alpha parameter for data partitioning. + client_number (int): The number of clients (networks). + batch_size (int): Batch size for training and testing data. + + Returns: + tuple: A tuple containing various data loaders and class count information. + """ ( X_train, y_train, @@ -268,7 +361,7 @@ def load_partition_data_distributed_cifar100( logging.info("traindata_cls_counts = " + str(traindata_cls_counts)) train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)]) - # get global test data + if process_id == 0: train_data_global, test_data_global = get_dataloader( dataset, data_dir, batch_size, batch_size @@ -279,13 +372,13 @@ def load_partition_data_distributed_cifar100( test_data_local = None local_data_num = 0 else: - # get local dataset + dataidxs = net_dataidx_map[process_id - 1] local_data_num = len(dataidxs) logging.info( "rank = %d, local_sample_number = %d" % (process_id, local_data_num) ) - # training batch size = 64; algorithms batch size = 32 + train_data_local, test_data_local = get_dataloader( dataset, data_dir, batch_size, batch_size, dataidxs ) @@ -310,6 +403,21 @@ def load_partition_data_distributed_cifar100( def load_partition_data_cifar100( dataset, data_dir, partition_method, partition_alpha, client_number, batch_size ): + """ + Load and partition CIFAR-100 data for federated learning. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory where CIFAR-100 dataset is stored. + partition_method (str): The data partitioning method ("homo", "hetero", or "hetero-fix"). + partition_alpha (float): Alpha parameter for data partitioning. + client_number (int): The number of clients (networks). + batch_size (int): Batch size for training and testing data. + + Returns: + tuple: A tuple containing various data loaders and class count information. + + """ ( X_train, y_train, @@ -363,4 +471,4 @@ def load_partition_data_cifar100( train_data_local_dict, test_data_local_dict, class_num, - ) + ) \ No newline at end of file diff --git a/python/fedml/data/cifar100/datasets.py b/python/fedml/data/cifar100/datasets.py index ee0b332bdc..c7a2cec84a 100644 --- a/python/fedml/data/cifar100/datasets.py +++ b/python/fedml/data/cifar100/datasets.py @@ -17,21 +17,46 @@ def default_loader(path): - return pil_loader(path) + """ + Default image loader function using PIL to open and convert an image to RGB format. + + Args: + path (str): The path to the image file. + Returns: + PIL.Image: The loaded image in RGB format. + """ + return pil_loader(path) def pil_loader(path): - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + """ + Image loader function using PIL to open and convert an image to RGB format. + + Args: + path (str): The path to the image file. + + Returns: + PIL.Image: The loaded image in RGB format. + """ with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") - class CIFAR100_truncated(data.Dataset): def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): + """ + Custom dataset class for truncated CIFAR-100 dataset. + Args: + root (str): The root directory where the dataset is stored. + dataidxs (list or None): List of data indices to include in the dataset. If None, includes all data. + train (bool): Indicates whether the dataset is for training (True) or testing (False). + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the target (class label). + download (bool, optional): Whether to download the dataset if not found locally. + """ self.root = root self.dataidxs = dataidxs self.train = train @@ -42,7 +67,12 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated CIFAR-100 dataset by loading data based on data indices. + Returns: + tuple: A tuple containing the data and target (class labels). + """ cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download) data = cifar_dataobj.data @@ -55,6 +85,12 @@ def __build_truncated_dataset__(self): return data, target def truncate_channel(self, index): + """ + Truncate the green and blue channels of specified images in the dataset. + + Args: + index (numpy.ndarray): An array of indices indicating which images to truncate. + """ for i in range(index.shape[0]): gs_index = index[i] self.data[gs_index, :, :, 1] = 0.0 @@ -66,7 +102,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] From fc72f1ee0adc4119d3c27f82bc3849d56b20b1e6 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Tue, 19 Sep 2023 13:13:50 +0530 Subject: [PATCH 28/70] update --- python/fedml/cross_silo/fedml_client.py | 28 ++ python/fedml/cross_silo/fedml_server.py | 28 ++ .../lightsecagg/lsa_fedml_aggregator.py | 213 +++++++++-- .../cross_silo/lightsecagg/lsa_fedml_api.py | 76 +++- .../lightsecagg/lsa_fedml_client_manager.py | 117 ++++++ .../lightsecagg/lsa_fedml_server_manager.py | 198 ++++++++-- .../cross_silo/secagg/sa_fedml_aggregator.py | 196 +++++++--- .../fedml/cross_silo/secagg/sa_fedml_api.py | 72 +++- .../secagg/sa_fedml_client_manager.py | 354 ++++++++++++++++-- .../secagg/sa_fedml_server_manager.py | 351 +++++++++++++++-- .../cross_silo/server/fedml_aggregator.py | 243 ++++++++++-- .../cross_silo/server/fedml_server_manager.py | 350 ++++++++++++++--- .../cross_silo/server/server_initializer.py | 27 +- 13 files changed, 1983 insertions(+), 270 deletions(-) diff --git a/python/fedml/cross_silo/fedml_client.py b/python/fedml/cross_silo/fedml_client.py index 5e009977d0..b1198997ca 100644 --- a/python/fedml/cross_silo/fedml_client.py +++ b/python/fedml/cross_silo/fedml_client.py @@ -3,6 +3,25 @@ class FedMLCrossSiloClient: + """ + Represents a client for a cross-silo federated learning setup. + + Args: + args (object): An object containing various configuration parameters. + device (torch.device): The device (e.g., 'cpu' or 'cuda') for computation. + dataset (tuple): A tuple containing dataset-related information. + model (torch.nn.Module): The PyTorch model used in federated learning. + model_trainer (ClientTrainer, optional): An optional client trainer. + + Raises: + Exception: If an unsupported federated optimizer is specified in args. + + Attributes: + None + + Methods: + run(): Placeholder method for client execution. + """ def __init__(self, args, device, dataset, model, model_trainer: ClientTrainer = None): if args.federated_optimizer == "FedAvg": [ @@ -61,4 +80,13 @@ def __init__(self, args, device, dataset, model, model_trainer: ClientTrainer = raise Exception("Exception") def run(self): + """ + Placeholder method for client execution. + + Args: + None + + Returns: + None + """ pass diff --git a/python/fedml/cross_silo/fedml_server.py b/python/fedml/cross_silo/fedml_server.py index 6778469b52..97d9890c66 100644 --- a/python/fedml/cross_silo/fedml_server.py +++ b/python/fedml/cross_silo/fedml_server.py @@ -2,6 +2,25 @@ class FedMLCrossSiloServer: + """ + Represents a server for a cross-silo federated learning setup. + + Args: + args (object): An object containing various configuration parameters. + device (torch.device): The device (e.g., 'cpu' or 'cuda') for computation. + dataset (tuple): A tuple containing dataset-related information. + model (torch.nn.Module): The PyTorch model used in federated learning. + server_aggregator (ServerAggregator, optional): An optional server aggregator. + + Raises: + Exception: If an unsupported federated optimizer is specified in args. + + Attributes: + None + + Methods: + run(): Placeholder method for server execution. + """ def __init__(self, args, device, dataset, model, server_aggregator: ServerAggregator = None): if args.federated_optimizer == "FedAvg": from fedml.cross_silo.server import server_initializer @@ -65,4 +84,13 @@ def __init__(self, args, device, dataset, model, server_aggregator: ServerAggreg raise Exception("Exception") def run(self): + """ + Placeholder method for server execution. + + Args: + None + + Returns: + None + """ pass diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py index 68cbd66f85..8c64fb1217 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py @@ -16,6 +16,49 @@ class LightSecAggAggregator(object): + """ + Initialize a LightSecAggAggregator for federated learning. + + Args: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global test dataset. + all_train_data_num (int): The total number of training data points globally. + train_data_local_dict (dict): A dictionary of local training datasets for each client. + test_data_local_dict (dict): A dictionary of local test datasets for each client. + train_data_local_num_dict (dict): A dictionary of the number of local training data points for each client. + client_num (int): The number of client nodes participating in federated learning. + device (torch.device): The device on which the server runs. + args (argparse.Namespace): Command-line arguments and configurations. + model_trainer: An instance of the model trainer for federated learning. + + Attributes: + trainer: The model trainer for federated learning. + args (argparse.Namespace): Command-line arguments and configurations. + train_global (Dataset): The global training dataset. + test_global (Dataset): The global test dataset. + val_global: The validation dataset generated from the global test dataset. + all_train_data_num (int): The total number of training data points globally. + train_data_local_dict (dict): A dictionary of local training datasets for each client. + test_data_local_dict (dict): A dictionary of local test datasets for each client. + train_data_local_num_dict (dict): A dictionary of the number of local training data points for each client. + client_num (int): The number of client nodes participating in federated learning. + device (torch.device): The device on which the server runs. + model_dict (dict): A dictionary to store the local models submitted by clients. + sample_num_dict (dict): A dictionary to store the number of samples each client used for training. + aggregate_encoded_mask_dict (dict): A dictionary to store encoded aggregate masks from clients. + flag_client_model_uploaded_dict (dict): A dictionary to track whether a client has uploaded its model. + flag_client_aggregate_encoded_mask_uploaded_dict (dict): A dictionary to track whether a client has uploaded its encoded aggregate mask. + total_dimension: The total dimension of the model's parameters. + dimensions (list): A list of dimensions for each parameter of the model. + targeted_number_active_clients (int): The targeted number of active clients for aggregation. + privacy_guarantee (int): The privacy guarantee parameter. + prime_number: The prime number used in aggregation. + precision_parameter: The precision parameter used in aggregation. + + Returns: + None + """ + def __init__( self, train_global, @@ -62,14 +105,35 @@ def __init__( self.precision_parameter = args.precision_parameter def get_global_model_params(self): + """ + Get the global model parameters from the model trainer. + + Returns: + dict: The global model parameters. + """ global_model_params = self.trainer.get_model_params() - self.dimensions, self.total_dimension = model_dimension(global_model_params) + self.dimensions, self.total_dimension = model_dimension( + global_model_params) return global_model_params def set_global_model_params(self, model_parameters): + """ + Set the global model parameters in the model trainer. + + Args: + model_parameters (dict): The global model parameters to be set. + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the locally trained model results for a client. + + Args: + index (int): The index of the client. + model_params (dict): The locally trained model parameters. + sample_num (int): The number of samples used for training. + """ logging.info("add_model. index = %d" % index) # for key in model_params.keys(): # model_params[key] = model_params[key].to(self.device) @@ -78,11 +142,24 @@ def add_local_trained_result(self, index, model_params, sample_num): self.flag_client_model_uploaded_dict[index] = True def add_local_aggregate_encoded_mask(self, index, aggregate_encoded_mask): + """ + Add the locally generated aggregate encoded mask for a client. + + Args: + index (int): The index of the client. + aggregate_encoded_mask (array): The encoded aggregate mask. + """ logging.info("add_aggregate_encoded_mask index = %d" % index) self.aggregate_encoded_mask_dict[index] = aggregate_encoded_mask self.flag_client_aggregate_encoded_mask_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all clients have uploaded their local models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -91,6 +168,12 @@ def check_whether_all_receive(self): return True def check_whether_all_aggregate_encoded_mask_receive(self): + """ + Check whether all clients have uploaded their aggregate encoded masks. + + Returns: + bool: True if all clients have uploaded their masks, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_aggregate_encoded_mask_uploaded_dict[idx]: return False @@ -100,38 +183,60 @@ def check_whether_all_aggregate_encoded_mask_receive(self): def aggregate_mask_reconstruction(self, active_clients): """ - Recover the aggregate-mask via decoding + Recover the aggregate-mask via decoding. + + Args: + active_clients (list): List of active client indices for aggregation. + + Returns: + array: The reconstructed aggregate mask. """ d = self.total_dimension N = self.client_num U = self.targeted_number_active_clients T = self.privacy_guarantee p = self.prime_number - logging.debug("d = {}, N = {}, U = {}, T = {}, p = {}".format(d, N, U, T, p)) + logging.debug( + "d = {}, N = {}, U = {}, T = {}, p = {}".format(d, N, U, T, p)) d = int(np.ceil(float(d) / (U - T))) * (U - T) alpha_s = np.array(range(N)) + 1 beta_s = np.array(range(U)) + (N + 1) logging.info("Server starts the reconstruction of aggregate_mask") - aggregate_encoded_mask_buffer = np.zeros((U, d // (U - T)), dtype="int64") + aggregate_encoded_mask_buffer = np.zeros( + (U, d // (U - T)), dtype="int64") # logging.info( # "active_clients = {}, aggregate_encoded_mask_dict = {}".format( # active_clients, self.aggregate_encoded_mask_dict # ) # ) for i, client_idx in enumerate(active_clients): - aggregate_encoded_mask_buffer[i, :] = self.aggregate_encoded_mask_dict[client_idx] + aggregate_encoded_mask_buffer[i, + :] = self.aggregate_encoded_mask_dict[client_idx] eval_points = alpha_s[active_clients] - aggregate_mask = LCC_decoding_with_points(aggregate_encoded_mask_buffer, eval_points, beta_s, p) - logging.info("Server finish the reconstruction of aggregate_mask via LCC decoding") + aggregate_mask = LCC_decoding_with_points( + aggregate_encoded_mask_buffer, eval_points, beta_s, p) + logging.info( + "Server finish the reconstruction of aggregate_mask via LCC decoding") aggregate_mask = np.reshape(aggregate_mask, (U * (d // (U - T)), 1)) aggregate_mask = aggregate_mask[0:d] # logging.info("aggregated mask = {}".format(aggregate_mask)) return aggregate_mask def aggregate_model_reconstruction(self, active_clients_first_round, active_clients_second_round): + """ + Perform aggregate model reconstruction using encoded masks. + + Args: + active_clients_first_round (list): List of active client indices in the first round. + active_clients_second_round (list): List of active client indices in the second round. + + Returns: + dict: The averaged global model parameters after reconstruction. + """ start_time = time.time() - aggregate_mask = self.aggregate_mask_reconstruction(active_clients_second_round) + aggregate_mask = self.aggregate_mask_reconstruction( + active_clients_second_round) p = self.prime_number q_bits = self.precision_parameter logging.info("Server starts the reconstruction of aggregate_model") @@ -146,7 +251,7 @@ def aggregate_model_reconstruction(self, active_clients_first_round, active_clie averaged_params[k] += local_model_params[k] cur_shape = np.shape(averaged_params[k]) d = self.dimensions[j] - cur_mask = np.array(aggregate_mask[pos : pos + d, :]) + cur_mask = np.array(aggregate_mask[pos: pos + d, :]) cur_mask = np.reshape(cur_mask, cur_shape) # Cancel out the aggregate-mask to recover the aggregate-model @@ -157,7 +262,8 @@ def aggregate_model_reconstruction(self, active_clients_first_round, active_clie # Convert the model from finite to real # logging.info("Server converts the aggregate_model from finite to tensor") # logging.info("aggregate model before transform = {}".format(averaged_params)) - averaged_params = transform_finite_to_tensor(averaged_params, p, q_bits) + averaged_params = transform_finite_to_tensor( + averaged_params, p, q_bits) # do the avg after transform for j, k in enumerate(averaged_params): @@ -188,15 +294,18 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou """ logging.info( - "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) + "client_num_in_total = %d, client_num_per_round = %d" % ( + client_num_in_total, client_num_per_round) ) assert client_num_in_total >= client_num_per_round if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): @@ -213,31 +322,78 @@ def client_selection(self, round_idx, client_id_list_in_total, client_num_per_ro """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a list of clients for the current training round. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients in the dataset. + client_num_per_round (int): The number of clients to sample for the current round. + + Returns: + list: List of sampled client indices for the current round. + """ if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset subset. + + Args: + num_samples (int): The number of samples to include in the validation set. + + Returns: + torch.utils.data.DataLoader: DataLoader for the validation dataset subset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) - sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) - subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) - sample_testset = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num)) + subset = torch.utils.data.Subset( + self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size) return sample_testset else: return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients and log the results. + + Args: + round_idx (int): Round index, starting from 0. + + This method tests the performance of the global model on both the training and testing datasets for all clients + and logs the results. It calculates and logs the training accuracy, training loss, test accuracy, and test loss. + + If the `round_idx` is a multiple of the specified `frequency_of_the_test` or it's the final round (`comm_round - 1`), + testing is performed; otherwise, it is skipped. + + The results are logged using the `wandb` library if the `enable_wandb` flag is set. + + Note: The method assumes that the `trainer` attribute has appropriate testing methods defined. + + Returns: + None + """ # if self.trainer.test_on_the_server( # self.train_data_local_dict, # self.test_data_local_dict, @@ -247,13 +403,15 @@ def test_on_server_for_all_clients(self, round_idx): # return if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) train_num_samples = [] train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): # train data - metrics = self.trainer.test(self.train_data_local_dict[client_idx], self.device, self.args) + metrics = self.trainer.test( + self.train_data_local_dict[client_idx], self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( metrics["test_correct"], metrics["test_total"], @@ -272,7 +430,8 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - mlops.log({"accuracy": round(train_acc, 4), "loss": round(train_loss, 4)}) + mlops.log({"accuracy": round(train_acc, 4), + "loss": round(train_loss, 4)}) # test data test_num_samples = [] @@ -280,9 +439,11 @@ def test_on_server_for_all_clients(self, round_idx): test_losses = [] if round_idx == self.args.comm_round - 1: - metrics = self.trainer.test(self.test_global, self.device, self.args) + metrics = self.trainer.test( + self.test_global, self.device, self.args) else: - metrics = self.trainer.test(self.val_global, self.device, self.args) + metrics = self.trainer.test( + self.val_global, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( metrics["test_correct"], diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py index b50f342c62..9f7b99c9fa 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py @@ -8,6 +8,33 @@ def FedML_LSA_Horizontal( args, client_rank, client_num, comm, device, dataset, model, model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Initialize and run the Federated Learning with LightSecAgg (LSA) in a horizontal setup. + + Args: + args (object): Command-line arguments and configuration. + client_rank (int): Rank or identifier of the current client (0 for the server). + client_num (int): Total number of clients participating in the federated learning. + comm (object): Communication backend for distributed training. + device (object): The device on which the training will be performed (e.g., GPU or CPU). + dataset (list): A list containing dataset-related information: + - train_data_num (int): Number of samples in the global training dataset. + - test_data_num (int): Number of samples in the global test dataset. + - train_data_global (object): Global training dataset. + - test_data_global (object): Global test dataset. + - train_data_local_num_dict (dict): Dictionary mapping client indices to the number of local training samples. + - train_data_local_dict (dict): Dictionary mapping client indices to their local training dataset. + - test_data_local_dict (dict): Dictionary mapping client indices to their local test dataset. + - class_num (int): Number of classes in the dataset. + model (object): The federated learning model to be trained. + model_trainer (object, optional): The model trainer responsible for training and testing. If not provided, + it will be created based on the model and args. + preprocessed_sampling_lists (list, optional): Preprocessed client sampling lists. If provided, the server will + use these preprocessed sampling lists during initialization. + + Returns: + None + """ [ train_data_num, test_data_num, @@ -67,6 +94,29 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for Federated Learning with LightSecAgg (LSA) in a horizontal setup. + + Args: + args (object): Command-line arguments and configuration. + device (object): The device on which the training will be performed (e.g., GPU or CPU). + comm (object): Communication backend for distributed training. + client_rank (int): Rank or identifier of the server (0 for the server). + client_num (int): Total number of clients participating in the federated learning. + model (object): The federated learning model to be trained. + train_data_num (int): Number of samples in the global training dataset. + train_data_global (object): Global training dataset. + test_data_global (object): Global test dataset. + train_data_local_dict (dict): Dictionary mapping client indices to their local training dataset. + test_data_local_dict (dict): Dictionary mapping client indices to their local test dataset. + train_data_local_num_dict (dict): Dictionary mapping client indices to the number of local training samples. + model_trainer (object): The model trainer responsible for training and testing. + preprocessed_sampling_lists (list, optional): Preprocessed client sampling lists. If provided, the server will + use these preprocessed sampling lists during initialization. + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(0) @@ -88,7 +138,8 @@ def init_server( # start the distributed training backend = args.backend if preprocessed_sampling_lists is None: - server_manager = FedMLServerManager(args, aggregator, comm, client_rank, client_num, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, client_rank, client_num, backend) else: server_manager = FedMLServerManager( args, @@ -117,6 +168,26 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Federated Learning with LightSecAgg (LSA) in a horizontal setup. + + Args: + args (object): Command-line arguments and configuration. + device (object): The device on which the training will be performed (e.g., GPU or CPU). + comm (object): Communication backend for distributed training. + client_rank (int): Rank or identifier of the current client. + client_num (int): Total number of clients participating in the federated learning. + model (object): The federated learning model to be trained. + train_data_num (int): Number of samples in the global training dataset. + train_data_local_num_dict (dict): Dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): Dictionary mapping client indices to their local training dataset. + test_data_local_dict (dict): Dictionary mapping client indices to their local test dataset. + model_trainer (object, optional): The model trainer responsible for training and testing. If not provided, + it will be created based on the model and args. + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(client_rank) @@ -131,5 +202,6 @@ def init_client( args, model_trainer, ) - client_manager = FedMLClientManager(args, trainer, comm, client_rank, client_num, backend) + client_manager = FedMLClientManager( + args, trainer, comm, client_rank, client_num, backend) client_manager.run() diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py index f46372e529..dcdb627c84 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py @@ -19,6 +19,17 @@ class FedMLClientManager(FedMLCommManager): def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backend="MPI"): + """ + Initialize the FedMLClientManager. + + Args: + args (argparse.Namespace): The command-line arguments. + trainer: The trainer for the client. + comm: The communication backend. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend (str): The communication backend (default is "MPI"). + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.trainer = trainer @@ -51,6 +62,9 @@ def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backen self.sys_stats_process = None def register_message_receive_handlers(self): + """ + Register message receive handlers for various message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -74,6 +88,12 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection-ready message. + + Args: + msg_params (dict): Parameters of the message. + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -81,9 +101,21 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the check-client-status message. + + Args: + msg_params (dict): Parameters of the message. + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params (dict): Parameters of the message. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -100,6 +132,12 @@ def handle_message_init(self, msg_params): self.__offline() def handle_message_receive_encoded_mask_from_server(self, msg_params): + """ + Handle the received encoded mask from the server. + + Args: + msg_params (dict): Parameters of the message. + """ encoded_mask = msg_params.get(MyMessage.MSG_ARG_KEY_ENCODED_MASK) client_id = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_ID) # logging.info( @@ -114,6 +152,12 @@ def handle_message_receive_encoded_mask_from_server(self, msg_params): self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Parameters of the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -130,6 +174,12 @@ def handle_message_receive_model_from_server(self, msg_params): self.__offline() def handle_message_receive_active_from_server(self, msg_params): + """ + Handle the received active clients message from the server. + + Args: + msg_params (dict): Parameters of the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) # Receive the set of active client id in first round active_clients_first_round = msg_params.get(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS) @@ -146,10 +196,20 @@ def handle_message_receive_active_from_server(self, msg_params): self.send_aggregate_encoded_mask_to_server(0, aggregate_encoded_mask) def start_training(self): + """ + Start the training process. + """ self.round_idx = 0 self.__train() def send_client_status(self, receive_id, status="ONLINE"): + """ + Send the client status to another entity. + + Args: + receive_id: The ID of the entity receiving the status. + status (str): The status to send (default is "ONLINE"). + """ logging.info("send_client_status") message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) sys_name = platform.system() @@ -163,9 +223,23 @@ def send_client_status(self, receive_id, status="ONLINE"): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status: The training status to report. + """ mlops.log_training_status(status) def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model to the server. + + Args: + receive_id: The ID of the entity receiving the model. + weights: The model parameters to send. + local_sample_num: The number of local samples used for training. + """ mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) @@ -178,21 +252,49 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def send_encoded_mask_to_server(self, receive_id, encoded_mask): + """ + Send the encoded mask to the server. + + Args: + receive_id: The ID of the entity receiving the encoded mask. + encoded_mask: The encoded mask to send. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_ENCODED_MASK_TO_SERVER, self.get_sender_id(), 0) message.add_params(MyMessage.MSG_ARG_KEY_ENCODED_MASK, encoded_mask) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_ID, receive_id) self.send_message(message) def send_aggregate_encoded_mask_to_server(self, receive_id, aggregate_encoded_mask): + """ + Send the aggregate encoded mask to the server. + + Args: + receive_id: The ID of the entity receiving the aggregate encoded mask. + aggregate_encoded_mask: The aggregate encoded mask to send. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_MASK_TO_SERVER, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_AGGREGATE_ENCODED_MASK, aggregate_encoded_mask) self.send_message(message) def add_encoded_mask(self, index, encoded_mask): + """ + Add an encoded mask to the internal dictionary. + + Args: + index: The index of the encoded mask. + encoded_mask: The encoded mask to add. + """ + self.encoded_mask_dict[index] = encoded_mask self.flag_encoded_mask_dict[index] = True def check_whether_all_encoded_mask_receive(self): + """ + Check if all encoded masks have been received. + + Returns: + bool: True if all encoded masks have been received, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_encoded_mask_dict[idx]: return False @@ -201,6 +303,12 @@ def check_whether_all_encoded_mask_receive(self): return True def encoded_mask_sharing(self, encoded_mask_set): + """ + Share encoded masks with other clients. + + Args: + encoded_mask_set (list): A list of encoded masks. + """ for receive_id in range(1, self.size + 1): print(receive_id) print("the size is ", self.size) @@ -213,6 +321,9 @@ def encoded_mask_sharing(self, encoded_mask_set): self.flag_encoded_mask_dict[receive_id - 1] = True def __offline(self): + """ + Perform the offline phase, including mask encoding and sharing. + """ # Encoding the local generated mask logging.info("#######Client %d offline encoding round_id = %d######" % (self.get_sender_id(), self.round_idx)) @@ -237,6 +348,9 @@ def __offline(self): logging.info("finish share") def __train(self): + """ + Perform the training for the client. + """ logging.info("#######training########### round_id = %d" % self.round_idx) mlops.event("train", event_started=True, event_value=str(self.round_idx)) @@ -262,4 +376,7 @@ def __train(self): self.send_model_to_server(0, masked_weights, local_sample_num) def run(self): + """ + Start the client's execution. + """ super().run() diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py index 89269c689e..38595164ac 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py @@ -12,6 +12,8 @@ class FedMLServerManager(FedMLCommManager): + """FedML Server Manager class.""" + def __init__( self, args, @@ -23,6 +25,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the FedMLServerManager. + + Args: + args: Arguments for the manager. + aggregator: The aggregator for global model updates. + comm: Communication object. + client_rank: Rank of the client. + client_num: Number of clients. + backend: Communication backend. + is_preprocessed: Whether the data is preprocessed. + preprocessed_client_lists: List of preprocessed client data. + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.aggregator = aggregator @@ -55,6 +70,9 @@ def run(self): super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + """ global_model_params = self.aggregator.get_global_model_params() client_idx_in_this_round = 0 @@ -64,9 +82,13 @@ def send_init_msg(self): ) client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def register_message_receive_handlers(self): + """ + Register message receive handlers. + """ print("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_messag_connection_ready @@ -89,11 +111,19 @@ def register_message_receive_handlers(self): ) def handle_messag_connection_ready(self, msg_params): + """ + Handle the 'connection is ready' message. + + Args: + msg_params: Parameters of the message. + """ + self.client_id_list_in_this_round = self.aggregator.client_selection( self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) if not self.is_initialized: mlops.log_round_info(self.round_num, -1) @@ -107,6 +137,12 @@ def handle_messag_connection_ready(self, msg_params): client_idx_in_this_round += 1 def handle_message_client_status_update(self, msg_params): + """ + Handle client status update message. + + Args: + msg_params: Parameters of the message. + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == "ONLINE": self.client_online_mapping[str(msg_params.get_sender_id())] = True @@ -120,7 +156,8 @@ def handle_message_client_status_update(self, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: @@ -129,12 +166,25 @@ def handle_message_client_status_update(self, msg_params): self.is_initialized = True def handle_message_receive_encoded_mask_from_client(self, msg_params): + """ + Handle received encoded mask from client. + + Args: + msg_params: Parameters of the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) receive_id = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_ID) encoded_mask = msg_params.get(MyMessage.MSG_ARG_KEY_ENCODED_MASK) - self.send_message_encoded_mask_to_client(sender_id, receive_id, encoded_mask) + self.send_message_encoded_mask_to_client( + sender_id, receive_id, encoded_mask) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model from client. + + Args: + msg_params: Parameters of the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event( "comm_c2s", event_started=False, event_value=str(self.round_idx), event_edge_id=sender_id, @@ -144,7 +194,8 @@ def handle_message_receive_model_from_client(self, msg_params): local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) self.active_clients_first_round.append(sender_id - 1) b_all_received = self.aggregator.check_whether_all_receive() @@ -153,13 +204,23 @@ def handle_message_receive_model_from_client(self, msg_params): if b_all_received: # Specify the active clients for the first round and inform them for receiver_id in range(1, self.size + 1): - self.send_message_to_active_client(receiver_id, self.active_clients_first_round) + self.send_message_to_active_client( + receiver_id, self.active_clients_first_round) def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): + """ + Handle received aggregate encoded mask from client. + + Args: + msg_params: Parameters of the message. + """ + # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - aggregate_encoded_mask = msg_params.get(MyMessage.MSG_ARG_KEY_AGGREGATE_ENCODED_MASK) - self.aggregator.add_local_aggregate_encoded_mask(sender_id - 1, aggregate_encoded_mask) + aggregate_encoded_mask = msg_params.get( + MyMessage.MSG_ARG_KEY_AGGREGATE_ENCODED_MASK) + self.aggregator.add_local_aggregate_encoded_mask( + sender_id - 1, aggregate_encoded_mask) logging.info( "Server handle_message_receive_aggregate_mask = %d from_client = %d" % (len(aggregate_encoded_mask), sender_id) @@ -167,12 +228,14 @@ def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): # Active clients for the second round self.active_clients_second_round.append(sender_id - 1) b_all_received = self.aggregator.check_whether_all_aggregate_encoded_mask_receive() - logging.info("Server: mask_all_received = " + str(b_all_received) + " in round_idx %d" % self.round_idx) + logging.info("Server: mask_all_received = " + + str(b_all_received) + " in round_idx %d" % self.round_idx) # TODO: add a timeout step # After receiving enough aggregate of encoded masks, server recovers the aggregate-model if b_all_received: - mlops.event("server.wait", event_started=False, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=False, + event_value=str(self.round_idx)) mlops.event( "server.agg_and_eval", event_started=True, event_value=str(self.round_idx), ) @@ -197,17 +260,20 @@ def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) client_idx_in_this_round = 0 for receiver_id in self.client_id_list_in_this_round: self.send_message_sync_model_to_client( - receiver_id, global_model_params, self.data_silo_index_list[client_idx_in_this_round], + receiver_id, global_model_params, self.data_silo_index_list[ + client_idx_in_this_round], ) client_idx_in_this_round += 1 - mlops.log_aggregated_model_info(self.round_idx + 1, self.aggregated_model_url) + mlops.log_aggregated_model_info( + self.round_idx + 1, self.aggregated_model_url) self.aggregated_model_url = None # start the next round @@ -216,18 +282,24 @@ def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): self.active_clients_second_round = [] if self.round_idx == self.round_num: - logging.info("=================TRAINING IS FINISHED!=============") + logging.info( + "=================TRAINING IS FINISHED!=============") sleep(3) self.finish() if self.is_preprocessed: mlops.log_training_finished_status() - logging.info("=============training is finished. Cleanup...============") + logging.info( + "=============training is finished. Cleanup...============") self.cleanup() else: logging.info("waiting for another round...") - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def cleanup(self): + """ + Cleanup the server after training. + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: @@ -239,34 +311,86 @@ def cleanup(self): self.finish() def send_message_init_config(self, receive_id, global_model_params, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send an initialization configuration message to a client. + + Args: + receive_id: ID of the receiving client. + global_model_params: Global model parameters. + datasilo_index: Index of the data silo. + """ + message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) def send_message_encoded_mask_to_client(self, sender_id, receive_id, encoded_mask): - message = Message(MyMessage.MSG_TYPE_S2C_ENCODED_MASK_TO_CLIENT, self.get_sender_id(), receive_id,) + """ + Send an encoded mask to a client. + + Args: + sender_id: ID of the sender client. + receive_id: ID of the receiving client. + encoded_mask: Encoded mask to be sent. + """ + message = Message( + MyMessage.MSG_TYPE_S2C_ENCODED_MASK_TO_CLIENT, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_ID, sender_id) message.add_params(MyMessage.MSG_ARG_KEY_ENCODED_MASK, encoded_mask) self.send_message(message) def send_message_check_client_status(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a message to check the status of a client. + + Args: + receive_id: ID of the receiving client. + datasilo_index: Index of the data silo. + """ + + message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a finish message to a client. + + Args: + receive_id: ID of the receiving client. + datasilo_index: Index of the data silo. + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index): - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + """ + Send a synchronization message with the global model to a client. + + Args: + receive_id: ID of the receiving client. + global_model_params: Global model parameters to be synchronized. + client_index: Index of the client. + """ + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) @@ -275,7 +399,17 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, cli ) def send_message_to_active_client(self, receive_id, active_clients): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SEND_TO_ACTIVE_CLIENT, self.get_sender_id(), receive_id,) - message.add_params(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) + """ + Send a message to active clients. + + Args: + receive_id: ID of the receiving client. + active_clients: List of active client IDs. + """ + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SEND_TO_ACTIVE_CLIENT, self.get_sender_id(), receive_id,) + message.add_params( + MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) self.send_message(message) diff --git a/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py b/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py index c8d3c57668..46162edc1d 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py @@ -29,6 +29,20 @@ def __init__( args, model_trainer, ): + """ + + Args: + train_global: Global training data. + test_global: Global test data. + all_train_data_num: Total number of training samples. + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + train_data_local_num_dict: Number of local training samples for all clients. + client_num: Total number of clients. + device: Computing device (e.g., 'cuda' or 'cpu'). + args: Command-line arguments. + model_trainer: Model trainer instance. + """ self.trainer = model_trainer self.args = args @@ -54,9 +68,12 @@ def __init__( self.privacy_guarantee = int(np.floor(args.worker_num / 2)) self.prime_number = args.prime_number self.precision_parameter = args.precision_parameter - self.public_key_others = np.empty(self.num_pk_per_user * self.args.worker_num).astype("int64") - self.b_u_SS_others = np.empty((self.args.worker_num, self.args.worker_num), dtype="int64") - self.s_sk_SS_others = np.empty((self.args.worker_num, self.args.worker_num), dtype="int64") + self.public_key_others = np.empty( + self.num_pk_per_user * self.args.worker_num).astype("int64") + self.b_u_SS_others = np.empty( + (self.args.worker_num, self.args.worker_num), dtype="int64") + self.s_sk_SS_others = np.empty( + (self.args.worker_num, self.args.worker_num), dtype="int64") for idx in range(self.client_num): self.flag_client_model_uploaded_dict[idx] = False @@ -66,14 +83,36 @@ def __init__( self.dimensions = [] def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + global_model_params: Global model parameters. + """ global_model_params = self.trainer.get_model_params() - self.dimensions, self.total_dimension = model_dimension(global_model_params) + self.dimensions, self.total_dimension = model_dimension( + global_model_params) return global_model_params def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters: Global model parameters to be set. + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the locally trained model and sample count from a client. + + Args: + index: Index of the client. + model_params: Locally trained model parameters. + sample_num: Number of samples used for training. + """ + logging.info("add_model. index = %d" % index) # for key in model_params.keys(): # model_params[key] = model_params[key].to(self.device) @@ -82,6 +121,12 @@ def add_local_trained_result(self, index, model_params, sample_num): self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their locally trained models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -91,7 +136,15 @@ def check_whether_all_receive(self): def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): """ - Recover the aggregate-mask via decoding + Recover the aggregate-mask via decoding. + + Args: + active_clients (list): List of active client indices. + SS_rx (numpy.ndarray): Received secret shares. + public_key_list (numpy.ndarray): List of public keys. + + Returns: + numpy.ndarray: The reconstructed aggregate mask. """ d = self.total_dimension T = self.privacy_guarantee @@ -102,7 +155,8 @@ def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): for i in range(self.targeted_number_active_clients): if self.flag_client_model_uploaded_dict[i]: - SS_input = np.reshape(SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) + SS_input = np.reshape( + SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) b_u = BGW_decoding(SS_input, active_clients[: T + 1], p) np.random.seed(b_u[0][0]) mask = np.random.randint(0, p, size=d).astype(int) @@ -110,7 +164,8 @@ def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): # z = np.mod(z - temp, p) else: mask = np.zeros(d, dtype="int") - SS_input = np.reshape(SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) + SS_input = np.reshape( + SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) s_sk_dec = BGW_decoding(SS_input, active_clients[: T + 1], p) for j in range(self.targeted_number_active_clients): s_pk_list_ = public_key_list[1, :] @@ -138,8 +193,21 @@ def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): def aggregate_model_reconstruction( self, active_clients_first_round, active_clients_second_round, SS_rx, public_key_list ): + """ + Reconstruct the aggregate model using secret shares and aggregate masks. + + Args: + active_clients_first_round (list): List of active client indices in the first round. + active_clients_second_round (list): List of active client indices in the second round. + SS_rx (numpy.ndarray): Received secret shares. + public_key_list (numpy.ndarray): List of public keys. + + Returns: + dict: The reconstructed aggregate model parameters. + """ start_time = time.time() - aggregate_mask = self.aggregate_mask_reconstruction(active_clients_second_round, SS_rx, public_key_list) + aggregate_mask = self.aggregate_mask_reconstruction( + active_clients_second_round, SS_rx, public_key_list) p = self.prime_number q_bits = self.precision_parameter logging.info("Server starts the reconstruction of aggregate_model") @@ -164,9 +232,9 @@ def aggregate_model_reconstruction( cur_shape = np.shape(averaged_params[k]) d = self.dimensions[j] - #aggregate_mask = aggregate_mask.reshape((aggregate_mask.shape[0], 1)) + # aggregate_mask = aggregate_mask.reshape((aggregate_mask.shape[0], 1)) # logging.info('aggregate_mask shape = {}'.format(np.shape(aggregate_mask))) - cur_mask = np.array(aggregate_mask[pos : pos + d]) + cur_mask = np.array(aggregate_mask[pos: pos + d]) cur_mask = np.reshape(cur_mask, cur_shape) # Cancel out the aggregate-mask to recover the aggregate-model @@ -174,10 +242,11 @@ def aggregate_model_reconstruction( averaged_params[k] = np.mod(averaged_params[k], p) pos += d - # Convert the model from finite to real - logging.info("Server converts the aggregate_model from finite to tensor") - averaged_params = transform_finite_to_tensor(averaged_params, p, q_bits) + logging.info( + "Server converts the aggregate_model from finite to tensor") + averaged_params = transform_finite_to_tensor( + averaged_params, p, q_bits) # do the avg after transform for j, k in enumerate(averaged_params): w = 1 / len(active_clients_first_round) @@ -189,69 +258,107 @@ def aggregate_model_reconstruction( def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_round): """ + Select a subset of clients for data siloing. Args: - round_idx: round index, starting from 0 - client_num_in_total: this is equal to the users in a synthetic data, - e.g., in synthetic_1_1, this value is 30 - client_num_per_round: the number of edge devices that can train + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select. Returns: - data_silo_index_list: e.g., when client_num_in_total = 30, client_num_in_total = 3, - this value is the form of [0, 11, 20] - + list: List of selected client indices. """ logging.info( - "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) + "client_num_in_total = %d, client_num_per_round = %d" % ( + client_num_in_total, client_num_per_round) ) assert client_num_in_total >= client_num_per_round if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): """ + Select a subset of clients for training. + Args: - round_idx: round index, starting from 0 - client_id_list_in_total: this is the real edge IDs. - In MLOps, its element is real edge ID, e.g., [64, 65, 66, 67]; - in simulated mode, its element is client index starting from 1, e.g., [1, 2, 3, 4] - client_num_per_round: + round_idx (int): Round index, starting from 0. + client_id_list_in_total (list): List of real edge IDs or client indices. + client_num_per_round (int): Number of clients to select. Returns: - client_id_list_in_this_round: sampled real edge ID list, e.g., [64, 66] + list: List of selected client IDs or indices. """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for training. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select. + + Returns: + list: List of selected client indices. + """ if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set. + + Args: + num_samples (int): Number of samples in the validation set. + + Returns: + DataLoader: DataLoader for the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) - sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) - subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) - sample_testset = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num)) + subset = torch.utils.data.Subset( + self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size) return sample_testset else: return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients. + + Args: + round_idx (int): Round index. + + Returns: + None + """ # if self.trainer.test_on_the_server( # self.train_data_local_dict, # self.test_data_local_dict, @@ -261,13 +368,15 @@ def test_on_server_for_all_clients(self, round_idx): # return if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) train_num_samples = [] train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): # train data - metrics = self.trainer.test(self.train_data_local_dict[client_idx], self.device, self.args) + metrics = self.trainer.test( + self.train_data_local_dict[client_idx], self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( metrics["test_correct"], metrics["test_total"], @@ -286,7 +395,8 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - mlops.log({"accuracy": round(train_acc, 4), "loss": round(train_loss, 4)}) + mlops.log({"accuracy": round(train_acc, 4), + "loss": round(train_loss, 4)}) # test data test_num_samples = [] @@ -294,9 +404,11 @@ def test_on_server_for_all_clients(self, round_idx): test_losses = [] if round_idx == self.args.comm_round - 1: - metrics = self.trainer.test(self.test_global, self.device, self.args) + metrics = self.trainer.test( + self.test_global, self.device, self.args) else: - metrics = self.trainer.test(self.val_global, self.device, self.args) + metrics = self.trainer.test( + self.val_global, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( metrics["test_correct"], diff --git a/python/fedml/cross_silo/secagg/sa_fedml_api.py b/python/fedml/cross_silo/secagg/sa_fedml_api.py index ba0b6cbb8f..b8f3dbd8df 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_api.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_api.py @@ -8,6 +8,26 @@ def FedML_SA_Horizontal( args, client_rank, client_num, comm, device, dataset, model, model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Initialize and run the Secure Aggregation-based Horizontal Federated Learning. + + This function initializes either the server or client based on the client_rank and runs + the Secure Aggregation-based Horizontal Federated Learning. + + Args: + args: Command-line arguments. + client_rank: Rank of the client. + client_num: Total number of clients. + comm: Communication backend. + device: Computing device (e.g., 'cuda' or 'cpu'). + dataset: Federated dataset containing data and metadata. + model: Federated model. + model_trainer: Model trainer instance (default: None). + preprocessed_sampling_lists: Preprocessed sampling lists (default: None). + + Returns: + None + """ [ train_data_num, test_data_num, @@ -67,6 +87,31 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for Secure Aggregation-based Horizontal Federated Learning. + + This function initializes the server for Secure Aggregation-based Horizontal Federated Learning. + + Args: + args: Command-line arguments. + device: Computing device (e.g., 'cuda' or 'cpu'). + comm: Communication backend. + client_rank: Rank of the client (server rank is 0). + client_num: Total number of clients. + model: Federated model. + train_data_num: Total number of training samples. + train_data_global: Global training data. + test_data_global: Global test data. + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + train_data_local_num_dict: Number of local training samples for all clients. + model_trainer: Model trainer instance. + preprocessed_sampling_lists: Preprocessed sampling lists (default: None). + + Returns: + None + """ + if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(0) @@ -88,7 +133,8 @@ def init_server( # start the distributed training backend = args.backend if preprocessed_sampling_lists is None: - server_manager = FedMLServerManager(args, aggregator, comm, client_rank, client_num, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, client_rank, client_num, backend) else: server_manager = FedMLServerManager( args, @@ -117,6 +163,27 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Secure Aggregation-based Horizontal Federated Learning. + + This function initializes a client for Secure Aggregation-based Horizontal Federated Learning. + + Args: + args: Command-line arguments. + device: Computing device (e.g., 'cuda' or 'cpu'). + comm: Communication backend. + client_rank: Rank of the client. + client_num: Total number of clients. + model: Federated model. + train_data_num: Total number of training samples. + train_data_local_num_dict: Number of local training samples for all clients. + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + model_trainer: Model trainer instance (default: None). + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(client_rank) @@ -131,5 +198,6 @@ def init_client( args, model_trainer, ) - client_manager = FedMLClientManager(args, trainer, comm, client_rank, client_num, backend) + client_manager = FedMLClientManager( + args, trainer, comm, client_rank, client_num, backend) client_manager.run() diff --git a/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py b/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py index 8eff9828ea..652c8f36ee 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py @@ -19,6 +19,20 @@ class FedMLClientManager(FedMLCommManager): def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backend="MPI"): + """ + Initialize the client object. + + Args: + args: Command-line arguments passed to the client. + trainer: The trainer object responsible for training. + comm: Communication handler (optional). + client_rank: Rank of the client (optional). + client_num: Number of clients (optional). + backend: Communication backend (optional). + + Returns: + None + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.trainer = trainer @@ -35,9 +49,12 @@ def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backen self.privacy_guarantee = int(np.floor(args.worker_num / 2)) self.prime_number = args.prime_number self.precision_parameter = args.precision_parameter - self.public_key_others = np.empty(self.num_pk_per_user * self.worker_num).astype("int64") - self.b_u_SS_others = np.empty((self.worker_num, self.worker_num), dtype="int64") - self.s_sk_SS_others = np.empty((self.worker_num, self.worker_num), dtype="int64") + self.public_key_others = np.empty( + self.num_pk_per_user * self.worker_num).astype("int64") + self.b_u_SS_others = np.empty( + (self.worker_num, self.worker_num), dtype="int64") + self.s_sk_SS_others = np.empty( + (self.worker_num, self.worker_num), dtype="int64") self.client_real_ids = json.loads(args.client_id_list) logging.info("self.client_real_ids = {}".format(self.client_real_ids)) @@ -48,6 +65,18 @@ def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backen self.sys_stats_process = None def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + + This method registers handlers for various message types that the client + can receive from the server. + + Args: + self: The client instance. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -56,7 +85,8 @@ def register_message_receive_handlers(self): MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status ) - self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) + self.register_message_receive_handler( + MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.handle_message_receive_model_from_server, @@ -75,6 +105,19 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle a connection-ready message from the server. + + This method handles the initial connection-ready message from the server, + sends a client status message, and logs system performance. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -82,10 +125,37 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle a message to check the client's status. + + This method handles a message from the server to check the client's status + and responds accordingly. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ self.send_client_status(0) def handle_message_init(self, msg_params): - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + """ + Handle an initialization message from the server. + + This method handles an initialization message from the server, updates + the client's dataset and model, and reports the training status to MLOps. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ + global_model_params = msg_params.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) logging.info("client_index = %s" % str(client_index)) @@ -93,7 +163,8 @@ def handle_message_init(self, msg_params): # Notify MLOps with training status. self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING) - self.dimensions, self.total_dimension = model_dimension(global_model_params) + self.dimensions, self.total_dimension = model_dimension( + global_model_params) self.trainer.update_dataset(int(client_index)) self.trainer.update_model(global_model_params) @@ -102,6 +173,19 @@ def handle_message_init(self, msg_params): self.__offline() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the reception of a model from the server. + + This method updates the client's dataset and model based on the received + model parameters and handles the completion of training if it's the last round. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -116,22 +200,67 @@ def handle_message_receive_model_from_server(self, msg_params): return self.round_idx += 1 if (not self.dimensions) or (not self.total_dimension): - self.dimensions, self.total_dimension = model_dimension(model_params) + self.dimensions, self.total_dimension = model_dimension( + model_params) self.__offline() def handle_message_receive_pk_others(self, msg_params): - self.public_key_others = msg_params.get(MyMessage.MSG_ARG_KEY_PK_OTHERS) - logging.info(" self.public_key_others = {}".format( self.public_key_others)) - self.public_key_others = np.reshape(self.public_key_others, (self.num_pk_per_user, self.worker_num)) + """ + Handle the reception of public keys from other clients. + + This method handles the reception of public keys from other clients for secure aggregation. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ + + self.public_key_others = msg_params.get( + MyMessage.MSG_ARG_KEY_PK_OTHERS) + logging.info(" self.public_key_others = {}".format( + self.public_key_others)) + self.public_key_others = np.reshape( + self.public_key_others, (self.num_pk_per_user, self.worker_num)) def handle_message_receive_ss_others(self, msg_params): - self.s_sk_SS_others = msg_params.get(MyMessage.MSG_ARG_KEY_SK_SS_OTHERS).flatten() - self.b_u_SS_others = msg_params.get(MyMessage.MSG_ARG_KEY_B_SS_OTHERS).flatten() + """ + Handle the reception of encoded masks from other clients. + + This method handles the reception of encoded masks (s_sk_SS and b_u_SS) from other clients + for secure aggregation. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ + self.s_sk_SS_others = msg_params.get( + MyMessage.MSG_ARG_KEY_SK_SS_OTHERS).flatten() + self.b_u_SS_others = msg_params.get( + MyMessage.MSG_ARG_KEY_B_SS_OTHERS).flatten() self.s_pk_list = self.public_key_others[1, :] self.s_uv = np.mod(self.s_pk_list * self.my_s_sk, self.prime_number) self.__train() def handle_message_receive_active_from_server(self, msg_params): + """ + Handle the reception of active client IDs from the server. + + This method handles the reception of active client IDs from the server and decides which + encoded masks to send based on active clients. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) # Receive the set of active client id in first round active_clients = msg_params.get(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS) @@ -148,8 +277,22 @@ def handle_message_receive_active_from_server(self, msg_params): self._send_others_ss_to_server(SS_info) def send_client_status(self, receive_id, status="ONLINE"): + """ + Send a client status message to the server. + + This method sends a client status message to the server to indicate the client's status. + + Args: + self: The client instance. + receive_id: The ID of the receiving entity (usually the server). + status: The status message (default is "ONLINE"). + + Returns: + None + """ logging.info("send_client_status") - message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) + message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, + self.client_real_id, receive_id) sys_name = platform.system() if sys_name == "Darwin": sys_name = "Mac" @@ -161,11 +304,40 @@ def send_client_status(self, receive_id, status="ONLINE"): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + This method reports the training status to MLOps for tracking. + + Args: + self: The client instance. + status: The training status message. + + Returns: + None + """ mlops.log_training_status(status) def send_model_to_server(self, receive_id, weights, local_sample_num): - mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) - message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) + """ + Send the trained model to the server. + + This method sends the trained model and relevant information to the server. + + Args: + self: The client instance. + receive_id: The ID of the receiving entity (usually the server). + weights: The model parameters/weights. + local_sample_num: The number of local training samples. + + Returns: + None + """ + + mlops.event("comm_c2s", event_started=True, + event_value=str(self.round_idx)) + message = Message( + MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) @@ -176,6 +348,18 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def _send_public_key_to_sever(self, public_key): + """ + Send the public key to the server. + + This method sends the client's public key to the server for secure aggregation. + + Args: + self: The client instance. + public_key: The public key to send. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_PK_TO_SERVER, self.get_sender_id(), 0 ) @@ -183,6 +367,19 @@ def _send_public_key_to_sever(self, public_key): self.send_message(message) def _send_secret_share_to_sever(self, b_u_SS, s_sk_SS): + """ + Send the secret shares to the server. + + This method sends the secret shares (b_u_SS and s_sk_SS) to the server for secure aggregation. + + Args: + self: The client instance. + b_u_SS: The encoded mask (b values). + s_sk_SS: The encoded mask (s_sk values). + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_SS_TO_SERVER, self.get_sender_id(), 0 ) @@ -191,12 +388,24 @@ def _send_secret_share_to_sever(self, b_u_SS, s_sk_SS): self.send_message(message) def _send_others_ss_to_server(self, ss_info): + """ + Send secret shares to the server. + + This method sends secret shares (ss_info) to the server for secure aggregation. + + Args: + self: The client instance. + ss_info: Secret shares to send. + + Returns: + None + """ # for j, k in enumerate(self.finite_w): - # if j == 0: - # logging.info("Sent from %d" % (self.rank - 1)) - # logging.info(self.finite_w[k][0]) - # break + # if j == 0: + # logging.info("Sent from %d" % (self.rank - 1)) + # logging.info(self.finite_w[k][0]) + # break message = Message( MyMessage.MSG_TYPE_C2S_SEND_SS_OTHERS_TO_SERVER, @@ -210,16 +419,42 @@ def _send_others_ss_to_server(self, ss_info): self.send_message(message) def get_model_dimension(self, weights): + """ + Get the dimensions of the model. + + This method calculates and returns the dimensions of the model based on its weights. + + Args: + self: The client instance. + weights: Model weights. + + Returns: + None + """ self.dimensions, self.total_dimension = model_dimension(weights) def mask(self, weights): + """ + Apply masking to the model weights. + + This method applies masking to the model weights to protect privacy during aggregation. + + Args: + self: The client instance. + weights: Model weights. + + Returns: + Masked model weights. + """ if (not self.dimensions) or (not self.total_dimension): - self.dimensions, self.total_dimension = self.get_model_dimension(weights) + self.dimensions, self.total_dimension = self.get_model_dimension( + weights) q_bits = self.precision_parameter self.infinite_w = copy.deepcopy(weights) - weights_finite = transform_tensor_to_finite(weights, self.prime_number, q_bits) + weights_finite = transform_tensor_to_finite( + weights, self.prime_number, q_bits) self.finite_w = copy.deepcopy(weights_finite) @@ -228,10 +463,12 @@ def mask(self, weights): for i in range(1, self.worker_num + 1): if self.rank == i: np.random.seed(self.b_u) - temp = np.random.randint(0, self.prime_number, size=d).astype(int) + temp = np.random.randint( + 0, self.prime_number, size=d).astype(int) logging.info("b for %d to %d" % (self.rank, i)) logging.info(temp) - self.local_mask = np.mod(self.local_mask + temp, self.prime_number) + self.local_mask = np.mod( + self.local_mask + temp, self.prime_number) # temp = np.zeros(d,dtype='int') elif self.rank > i: np.random.seed(self.s_uv[i - 1]) @@ -242,12 +479,14 @@ def mask(self, weights): logging.info("{},{}".format(self.rank - 1, i - 1)) # Debugging Block End # ################################## - temp = np.random.randint(0, self.prime_number, size=d).astype(int) + temp = np.random.randint( + 0, self.prime_number, size=d).astype(int) logging.info("s for %d to %d" % (self.rank, i)) logging.info(temp) # if self.rank == 1: # print '############ (seed, temp)=', self.s_uv[i-1], temp - self.local_mask = np.mod(self.local_mask + temp, self.prime_number) + self.local_mask = np.mod( + self.local_mask + temp, self.prime_number) else: np.random.seed(self.s_uv[i - 1]) ################################## @@ -257,23 +496,40 @@ def mask(self, weights): logging.info("{},{}".format(self.rank - 1, i - 1)) # Debugging Block End # ################################## - temp = -np.random.randint(0, self.prime_number, size=d).astype(int) + temp = - \ + np.random.randint(0, self.prime_number, size=d).astype(int) logging.info("s for %d to %d" % (self.rank, i)) logging.info(temp) # if self.rank == 1: # print '############ (seed, temp)=', self.s_uv[i-1], temp - self.local_mask = np.mod(self.local_mask + temp, self.prime_number) + self.local_mask = np.mod( + self.local_mask + temp, self.prime_number) logging.info("Client") logging.info(self.rank) - masked_weights = model_masking(weights_finite, self.dimensions, self.local_mask, self.prime_number) + masked_weights = model_masking( + weights_finite, self.dimensions, self.local_mask, self.prime_number) return masked_weights def __offline(self): + """ + Perform offline setup for secure aggregation. + + This method performs the necessary offline setup for secure aggregation, including generating + keys, secret shares, and sending them to the server. + + Args: + self: The client instance. + + Returns: + None + """ np.random.seed(self.rank) - self.sk = np.random.randint(0, self.prime_number, size=(2)).astype("int64") + self.sk = np.random.randint( + 0, self.prime_number, size=(2)).astype("int64") self.pk = my_pk_gen(self.sk, self.prime_number, 0) - self.key = np.concatenate((self.pk, self.sk)) # length=4 : c_pk, s_pk, c_sk, s_sk + # length=4 : c_pk, s_pk, c_sk, s_sk + self.key = np.concatenate((self.pk, self.sk)) self._send_public_key_to_sever(self.key[0:2]) @@ -282,8 +538,10 @@ def __offline(self): self.b_u = self.my_c_sk - self.SS_input = np.reshape(np.array([self.my_c_sk, self.my_s_sk]), (2, 1)) - self.my_SS = BGW_encoding(self.SS_input, self.worker_num, self.privacy_guarantee, self.prime_number) + self.SS_input = np.reshape( + np.array([self.my_c_sk, self.my_s_sk]), (2, 1)) + self.my_SS = BGW_encoding( + self.SS_input, self.worker_num, self.privacy_guarantee, self.prime_number) self.b_u_SS = self.my_SS[:, 0, 0].astype("int64") self.s_sk_SS = self.my_SS[:, 1, 0].astype("int64") @@ -293,14 +551,29 @@ def __offline(self): self._send_secret_share_to_sever(self.b_u_SS, self.s_sk_SS) def __train(self): - logging.info("#######training########### round_id = %d" % self.round_idx) - mlops.event("train", event_started=True, event_value=str(self.round_idx)) + """ + Perform the training for a round. + + This method initiates the training process for the current round and sends the trained model + to the server after applying masking. + + Args: + self: The client instance. + + Returns: + None + """ + logging.info("#######training########### round_id = %d" % + self.round_idx) + mlops.event("train", event_started=True, + event_value=str(self.round_idx)) weights, local_sample_num = self.trainer.train(self.round_idx) # logging.info( # "Client %d original weights = %s" % (self.get_sender_id(), weights) # ) - mlops.event("train", event_started=False, event_value=str(self.round_idx)) + mlops.event("train", event_started=False, + event_value=str(self.round_idx)) # Mask the local model masked_weights = self.mask(weights) @@ -312,4 +585,15 @@ def __train(self): self.send_model_to_server(0, masked_weights, local_sample_num) def run(self): + """ + Run the client. + + This method starts the client and its communication loop. + + Args: + self: The client instance. + + Returns: + None + """ super().run() diff --git a/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py b/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py index 0614b7e966..a24a7aa7fc 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py @@ -23,6 +23,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the Federated Learning Server Manager. + + Args: + args (object): Arguments object containing configuration parameters. + aggregator (object): Federated learning aggregator. + comm (object, optional): Communication manager (default: None). + client_rank (int, optional): Rank of the client (default: 0). + client_num (int, optional): Number of clients (default: 0). + backend (str, optional): Backend for communication (default: "MQTT_S3"). + is_preprocessed (bool, optional): Whether the data is preprocessed (default: False). + preprocessed_client_lists (list, optional): List of preprocessed clients (default: None). + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.aggregator = aggregator @@ -46,15 +59,19 @@ def __init__( self.ss_received = 0 self.num_pk_per_user = 2 self.public_key_list = np.empty( - shape=(self.num_pk_per_user, self.targeted_number_active_clients), dtype="int64" + shape=(self.num_pk_per_user, + self.targeted_number_active_clients), dtype="int64" ) self.b_u_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) self.s_sk_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) - self.SS_rx = np.empty((self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64") + self.SS_rx = np.empty((self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64") self.aggregated_model_url = None @@ -63,9 +80,26 @@ def __init__( self.data_silo_index_list = None def run(self): + """ + Start the Federated Learning Server Manager. + + This method starts the server manager and begins the federated learning process. + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + This method sends initialization messages to all clients, providing them with the + global model parameters to start training. + + Args: + None + + Returns: + None + """ global_model_params = self.aggregator.get_global_model_params() client_idx_in_this_round = 0 @@ -75,9 +109,22 @@ def send_init_msg(self): ) client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def register_message_receive_handlers(self): + """ + Register message receive handlers for server communication. + + This method registers various message receive handlers for different types of + communication messages received by the server. + + Args: + None + + Returns: + None + """ print("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_messag_connection_ready @@ -104,11 +151,25 @@ def register_message_receive_handlers(self): ) def handle_messag_connection_ready(self, msg_params): + """ + Handle a connection-ready message from clients. + + This function processes client connection requests and initializes necessary + parameters for the server's operation. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ self.client_id_list_in_this_round = self.aggregator.client_selection( self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) if not self.is_initialized: mlops.log_round_info(self.round_num, -1) @@ -122,6 +183,19 @@ def handle_messag_connection_ready(self, msg_params): client_idx_in_this_round += 1 def handle_message_client_status_update(self, msg_params): + """ + Handle a message containing client status updates. + + This function updates the server's record of client statuses and takes + appropriate actions when all clients are online. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == "ONLINE": self.client_online_mapping[str(msg_params.get_sender_id())] = True @@ -135,7 +209,8 @@ def handle_message_client_status_update(self, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: @@ -144,18 +219,45 @@ def handle_message_client_status_update(self, msg_params): self.is_initialized = True def _handle_message_receive_public_key(self, msg_params): + """ + Handle the reception of public keys from clients. + + This function receives and processes public keys from active clients, + combining them for further use. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) public_key = msg_params.get(MyMessage.MSG_ARG_KEY_PK) self.public_key_list[:, sender_id - 1] = public_key self.public_keys_received += 1 if self.public_keys_received == self.targeted_number_active_clients: - data = np.reshape(self.public_key_list, self.num_pk_per_user * self.targeted_number_active_clients) + data = np.reshape( + self.public_key_list, self.num_pk_per_user * self.targeted_number_active_clients) for i in range(self.targeted_number_active_clients): logging.info("sending data = {}".format(data)) self._send_public_key_others_to_user(i + 1, data) def _handle_message_receive_ss(self, msg_params): + """ + Handle the reception of encoded masks from clients. + + This function receives and processes encoded masks from active clients, + aggregating them for further use. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) b_u_SS = msg_params.get(MyMessage.MSG_ARG_KEY_B_SS) @@ -165,9 +267,23 @@ def _handle_message_receive_ss(self, msg_params): self.ss_received += 1 if self.ss_received == self.targeted_number_active_clients: for i in range(self.targeted_number_active_clients): - self._send_ss_others_to_user(i + 1, self.b_u_SS_list[:, i], self.s_sk_SS_list[:, i]) + self._send_ss_others_to_user( + i + 1, self.b_u_SS_list[:, i], self.s_sk_SS_list[:, i]) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the reception of a trained model from a client. + + This function receives and processes a trained model from a client, + updating the server's records and taking appropriate actions. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event( "comm_c2s", event_started=False, event_value=str(self.round_idx), event_edge_id=sender_id, @@ -177,7 +293,8 @@ def handle_message_receive_model_from_client(self, msg_params): local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) self.active_clients_first_round.append(sender_id - 1) b_all_received = self.aggregator.check_whether_all_receive() @@ -186,9 +303,23 @@ def handle_message_receive_model_from_client(self, msg_params): if b_all_received: # Specify the active clients for the first round and inform them for receiver_id in range(1, self.size + 1): - self._send_message_to_active_client(receiver_id, self.active_clients_first_round) + self._send_message_to_active_client( + receiver_id, self.active_clients_first_round) def _handle_message_receive_ss_others_from_client(self, msg_params): + """ + Handle the reception of encoded masks from clients in the second round. + + This function receives and processes encoded masks from clients in the + second round, and performs model aggregation and evaluation. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) ss_others = msg_params.get(MyMessage.MSG_ARG_KEY_SS_OTHERS) @@ -213,13 +344,15 @@ def _handle_message_receive_ss_others_from_client(self, msg_params): self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) client_idx_in_this_round = 0 for receiver_id in self.client_id_list_in_this_round: self.send_message_sync_model_to_client( - receiver_id, global_model_params, self.data_silo_index_list[client_idx_in_this_round], + receiver_id, global_model_params, self.data_silo_index_list[ + client_idx_in_this_round], ) client_idx_in_this_round += 1 @@ -232,31 +365,50 @@ def _handle_message_receive_ss_others_from_client(self, msg_params): self.ss_received = 0 self.num_pk_per_user = 2 self.public_key_list = np.empty( - shape=(self.num_pk_per_user, self.targeted_number_active_clients), dtype="int64" + shape=(self.num_pk_per_user, + self.targeted_number_active_clients), dtype="int64" ) self.b_u_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) self.s_sk_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) self.SS_rx = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) if self.round_idx == self.round_num: - logging.info("=================TRAINING IS FINISHED!=============") + logging.info( + "=================TRAINING IS FINISHED!=============") sleep(3) self.finish() if self.is_preprocessed: mlops.log_training_finished_status() - logging.info("=============training is finished. Cleanup...============") + logging.info( + "=============training is finished. Cleanup...============") self.cleanup() else: logging.info("waiting for another round...") - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def cleanup(self): + """ + Cleanup function to finish the training process. + + This function is responsible for cleaning up after the training process, + sending finish messages to clients, and finalizing the server's state. + + Args: + self: The server instance. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: @@ -268,28 +420,98 @@ def cleanup(self): self.finish() def send_message_init_config(self, receive_id, global_model_params, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send an initialization configuration message to a client. + + This function sends an initialization message containing global model + parameters and other configuration details to a specific client. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + datasilo_index (int): The index of the data silo associated with the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) def send_message_check_client_status(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a message to check the status of a client. + + This function sends a message to a client to check its status and readiness. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + datasilo_index (int): The index of the data silo associated with the client. + + Returns: + None + """ + + message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a finish message to a client. + + This function sends a finish message to a client to signal the end of the + training process. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + datasilo_index (int): The index of the data silo associated with the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index): - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + """ + Send a message to synchronize the global model with a client. + + This function sends a synchronization message to a specific client, + containing the global model parameters and client index. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + client_index (int): The index of the client. + + Returns: + None + """ + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) @@ -298,20 +520,71 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, cli ) def _send_public_key_others_to_user(self, receive_id, public_key_other): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_OTHER_PK_TO_CLIENT, self.get_sender_id(), receive_id) + """ + Send public keys to a user/client. + + This function sends public keys to a specific user/client, typically during + a secure communication setup. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving user/client. + public_key_other: The public keys to send. + + Returns: + None + """ + + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_OTHER_PK_TO_CLIENT, + self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_PK_OTHERS, public_key_other) self.send_message(message) def _send_ss_others_to_user(self, receive_id, b_ss_others, sk_ss_others): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_OTHER_SS_TO_CLIENT, self.get_sender_id(), receive_id) + """ + Send encoded masks to a user/client. + + This function sends encoded masks to a specific user/client, typically during + a secure communication setup. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving user/client. + b_ss_others: Encoded masks (b values) to send. + sk_ss_others: Encoded masks (sk values) to send. + + Returns: + None + """ + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_OTHER_SS_TO_CLIENT, + self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_B_SS_OTHERS, b_ss_others) message.add_params(MyMessage.MSG_ARG_KEY_SK_SS_OTHERS, sk_ss_others) self.send_message(message) def _send_message_to_active_client(self, receive_id, active_clients): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_ACTIVE_CLIENT_LIST, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) + """ + Send a message to active clients. + + This function sends a message to a specific user/client containing a list of + active clients, typically during initialization. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving user/client. + active_clients (list): A list of active client IDs. + + Returns: + None + """ + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_ACTIVE_CLIENT_LIST, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) self.send_message(message) diff --git a/python/fedml/cross_silo/server/fedml_aggregator.py b/python/fedml/cross_silo/server/fedml_aggregator.py index 5d56b8c1cf..7a4c13a2af 100644 --- a/python/fedml/cross_silo/server/fedml_aggregator.py +++ b/python/fedml/cross_silo/server/fedml_aggregator.py @@ -11,6 +11,33 @@ class FedMLAggregator(object): + """ + Represents an aggregator for federated learning. + + Args: + train_global (object): The global training dataset. + test_global (object): The global testing dataset. + all_train_data_num (int): The total number of training data points. + train_data_local_dict (dict): A dictionary containing local training datasets. + test_data_local_dict (dict): A dictionary containing local testing datasets. + train_data_local_num_dict (dict): A dictionary containing the number of local training data points. + client_num (int): The number of clients. + device (torch.device): The device (e.g., 'cpu' or 'cuda') for computation. + args (object): An object containing various configuration parameters. + server_aggregator (ServerAggregator, optional): An optional server aggregator. + + Attributes: + None + + Methods: + get_global_model_params(): Get the global model parameters. + set_global_model_params(model_parameters): Set the global model parameters. + add_local_trained_result(index, model_params, sample_num): Add locally trained model results. + check_whether_all_receive(): Check if all clients have uploaded their models. + aggregate(): Aggregate model updates from clients. + assess_contribution(): Assess the contribution of clients. + """ + def __init__( self, train_global, @@ -49,23 +76,62 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Args: + None + + Returns: + object: The global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (object): The global model parameters. + + Returns: + None + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add locally trained model results. + + Args: + index (int): The index of the client. + model_params (object): The locally trained model parameters. + sample_num (int): The number of training samples used. + + Returns: + None + """ logging.info("add_model. index = %d" % index) - # for dictionary model_params, we let the user level code to control the device + # for dictionary model_params, we let the user level code control the device if type(model_params) is not dict: - model_params = ml_engine_adapter.model_params_to_device(self.args, model_params, self.device) + model_params = ml_engine_adapter.model_params_to_device( + self.args, model_params, self.device) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Args: + None + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.debug("client_num = {}".format(self.client_num)) for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -75,27 +141,44 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate model updates from clients. + + Args: + None + + Returns: + object: The aggregated model parameters. + list: The list of models after outlier removal. + list: The list of model indexes. + """ start_time = time.time() model_list = [] for idx in range(self.client_num): - model_list.append((self.sample_num_dict[idx], self.model_dict[idx])) + model_list.append( + (self.sample_num_dict[idx], self.model_dict[idx])) # model_list is the list after outlier removal - model_list, model_list_idxes = self.aggregator.on_before_aggregation(model_list) + model_list, model_list_idxes = self.aggregator.on_before_aggregation( + model_list) Context().add(Context.KEY_CLIENT_MODEL_LIST, model_list) averaged_params = self.aggregator.aggregate(model_list) if type(averaged_params) is dict: - if len(averaged_params) == self.client_num + 1: # aggregator pass extra {-1 : global_parms_dict} as global_params - itr_count = len(averaged_params) - 1 # do not apply on_after_aggregation to client -1 + # aggregator pass extra {-1 : global_parms_dict} as global_params + if len(averaged_params) == self.client_num + 1: + # do not apply on_after_aggregation to client -1 + itr_count = len(averaged_params) - 1 else: itr_count = len(averaged_params) for client_index in range(itr_count): - averaged_params[client_index] = self.aggregator.on_after_aggregation(averaged_params[client_index]) + averaged_params[client_index] = self.aggregator.on_after_aggregation( + averaged_params[client_index]) else: - averaged_params = self.aggregator.on_after_aggregation(averaged_params) + averaged_params = self.aggregator.on_after_aggregation( + averaged_params) self.set_global_model_params(averaged_params) @@ -104,6 +187,17 @@ def aggregate(self): return averaged_params, model_list, model_list_idxes def assess_contribution(self): + """ + Assess the contribution of clients. + + If enabled, this method assesses the contribution of clients in the federated learning process. + + Args: + None + + Returns: + None + """ if hasattr(self.args, "enable_contribution") and \ self.args.enable_contribution is not None and self.args.enable_contribution: self.aggregator.assess_contribution() @@ -123,15 +217,18 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou """ logging.info( - "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) + "client_num_in_total = %d, client_num_per_round = %d" % ( + client_num_in_total, client_num_per_round) ) assert client_num_in_total >= client_num_per_round if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): @@ -148,33 +245,73 @@ def client_selection(self, round_idx, client_id_list_in_total, client_num_per_ro """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for a federated learning round. + + Args: + round_idx (int): The round index, starting from 0. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to sample for the round. + + Returns: + list: A list of sampled client indexes. + + """ if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set. + + Args: + num_samples (int): The number of samples to include in the validation set (default is 10,000). + + Returns: + object: The validation dataset. + + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) - sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) - subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) - sample_testset = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num)) + subset = torch.utils.data.Subset( + self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size) return sample_testset else: return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Test the global model on all clients. + + Args: + round_idx (int): The round index. + + Returns: + None + """ if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -184,25 +321,39 @@ def test_on_server_for_all_clients(self, round_idx): if round_idx == self.args.comm_round - 1: # we allow to return four metrics, such as accuracy, AUC, loss, etc. - metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + metric_result_in_current_round = self.aggregator.test( + self.test_global, self.device, self.args) else: - metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) - logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) - metric_results_in_the_last_round = Context().get(Context.KEY_METRICS_ON_AGGREGATED_MODEL) + metric_result_in_current_round = self.aggregator.test( + self.val_global, self.device, self.args) + logging.info("metric_result_in_current_round = {}".format( + metric_result_in_current_round)) + metric_results_in_the_last_round = Context().get( + Context.KEY_METRICS_ON_AGGREGATED_MODEL) Context().add(Context.KEY_METRICS_ON_AGGREGATED_MODEL, metric_result_in_current_round) if metric_results_in_the_last_round is not None: Context().add(Context.KEY_METRICS_ON_LAST_ROUND, metric_results_in_the_last_round) else: Context().add(Context.KEY_METRICS_ON_LAST_ROUND, metric_result_in_current_round) key_metrics_on_last_round = Context().get(Context.KEY_METRICS_ON_LAST_ROUND) - logging.info("key_metrics_on_last_round = {}".format(key_metrics_on_last_round)) + logging.info("key_metrics_on_last_round = {}".format( + key_metrics_on_last_round)) if round_idx == self.args.comm_round - 1: mlops.log({"round_idx": round_idx}) else: mlops.log({"round_idx": round_idx}) - + def get_dummy_input_tensor(self): + """ + Get a dummy input tensor for testing purposes. + + This method retrieves a dummy input tensor from the test dataset. + + Returns: + list: A list of dummy input tensors. + + """ test_data = None if self.test_global: test_data = self.test_global @@ -210,18 +361,29 @@ def get_dummy_input_tensor(self): for k, v in self.test_data_local_dict.items(): if v: test_data = v - break - + break + with torch.no_grad(): - batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> dataloader obj + batch_idx, features_label_tensors = next( + enumerate(test_data)) # test_data -> dataloader obj dummy_list = [] for tensor in features_label_tensors: - dummy_tensor = tensor[:1] # only take the first element as dummy input + # only take the first element as dummy input + dummy_tensor = tensor[:1] dummy_list.append(dummy_tensor) features = dummy_list[:-1] # Can adapt Process Multi-Label return features def get_input_shape_type(self): + """ + Get the input shape and type information. + + This method retrieves the input shape and type information from the test dataset. + + Returns: + tuple: A tuple containing two lists - input shape and input type. + + """ test_data = None if self.test_global: test_data = self.test_global @@ -230,12 +392,14 @@ def get_input_shape_type(self): if v: test_data = v break - + with torch.no_grad(): - batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> dataloader obj + batch_idx, features_label_tensors = next( + enumerate(test_data)) # test_data -> dataloader obj dummy_list = [] for tensor in features_label_tensors: - dummy_tensor = tensor[:1] # only take the first element as dummy input + # only take the first element as dummy input + dummy_tensor = tensor[:1] dummy_list.append(dummy_tensor) features = dummy_list[:-1] # Can adapt Multi-Label @@ -248,10 +412,19 @@ def get_input_shape_type(self): input_type.append("int") else: input_type.append("float") - + return input_shape, input_type - + def save_dummy_input_tensor(self): + """ + Save the dummy input tensor to a file. + + This method saves the input shape and type information to a file named 'dummy_input_tensor.pkl'. + + Returns: + None + + """ import pickle features = self.get_input_size_type() with open('dummy_input_tensor.pkl', 'wb') as handle: diff --git a/python/fedml/cross_silo/server/fedml_server_manager.py b/python/fedml/cross_silo/server/fedml_server_manager.py index bb6739edf0..8a0974a485 100644 --- a/python/fedml/cross_silo/server/fedml_server_manager.py +++ b/python/fedml/cross_silo/server/fedml_server_manager.py @@ -13,6 +13,28 @@ class FedMLServerManager(FedMLCommManager): + """ + Represents the server manager for federated learning. + + Args: + args: The configuration arguments. + aggregator: The aggregator for federated learning. + comm: The communication backend (default is None). + client_rank: The rank of the client (default is 0). + client_num: The number of clients (default is 0). + backend: The communication backend (default is "MQTT_S3"). + + Attributes: + ONLINE_STATUS_FLAG (str): Flag indicating online status. + RUN_FINISHED_STATUS_FLAG (str): Flag indicating run finished status. + + Methods: + is_main_process(): Check if the current process is the main process. + run(): Run the server manager. + send_init_msg(): Send initialization messages to clients. + register_message_receive_handlers(): Register message receive handlers for communication. + + """ ONLINE_STATUS_FLAG = "ONLINE" RUN_FINISHED_STATUS_FLAG = "FINISHED" @@ -35,12 +57,26 @@ def __init__( self.data_silo_index_list = None def is_main_process(self): + """ + Check if the current process is the main process. + + Returns: + bool: True if the current process is the main process, False otherwise. + """ return getattr(self.aggregator, "aggregator", None) is None or self.aggregator.aggregator.is_main_process() def run(self): super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + This method sends initialization messages to clients, including model parameters and configuration. + + Returns: + None + """ global_model_params = self.aggregator.get_global_model_params() global_model_url = None @@ -54,25 +90,37 @@ def send_init_msg(self): ) client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) try: # get input type and shape for inference dummy_input_tensor = self.aggregator.get_dummy_input_tensor() if not getattr(self.args, "skip_log_model_net", False): - model_net_url = mlops.log_training_model_net_info(self.aggregator.aggregator.model, dummy_input_tensor) + model_net_url = mlops.log_training_model_net_info( + self.aggregator.aggregator.model, dummy_input_tensor) # type and shape for later configuration input_shape, input_type = self.aggregator.get_input_shape_type() # Send output input size and type (saved as json) to s3, # and transfer when click "Create Model Card" - model_input_url = mlops.log_training_model_input_info(list(input_shape), list(input_type)) + model_input_url = mlops.log_training_model_input_info( + list(input_shape), list(input_type)) except Exception as e: - logging.info("Cannot get dummy input size or shape for model serving") + logging.info( + "Cannot get dummy input size or shape for model serving") def register_message_receive_handlers(self): + """ + Register message receive handlers for communication. + + This method registers message receive handlers for handling different types of messages. + + Returns: + None + """ logging.info("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready @@ -87,12 +135,24 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the message indicating that the client connection is ready. + + This method processes the message from a client indicating that its connection is ready for communication. + + Args: + msg_params (dict): The message parameters. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( self.args.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.args.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.args.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) mlops.log_round_info(self.round_num, -1) @@ -104,15 +164,30 @@ def handle_message_connection_ready(self, msg_params): self.send_message_check_client_status( client_id, self.data_silo_index_list[client_idx_in_this_round], ) - logging.info("Connection ready for client" + str(client_id)) + logging.info( + "Connection ready for client" + str(client_id)) except Exception as e: - logging.info("Connection not ready for client" + str(client_id)) + logging.info( + "Connection not ready for client" + str(client_id)) client_idx_in_this_round += 1 def process_online_status(self, client_status, msg_params): + """ + Process the online status message from a client. + + This method processes the online status message from a client and checks if all clients are online. + + Args: + client_status (str): The client status. + msg_params (dict): The message parameters. + + Returns: + None + """ self.client_online_mapping[str(msg_params.get_sender_id())] = True - logging.info("self.client_online_mapping = {}".format(self.client_online_mapping)) + logging.info("self.client_online_mapping = {}".format( + self.client_online_mapping)) all_client_is_online = True for client_id in self.client_id_list_in_this_round: @@ -121,17 +196,31 @@ def process_online_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: - mlops.log_aggregation_status(MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) + mlops.log_aggregation_status( + MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) # send initialization message to all clients to start training self.send_init_msg() self.is_initialized = True def process_finished_status(self, client_status, msg_params): + """ + Process the finished status message from a client. + + This method processes the finished status message from a client and checks if all clients have finished. + + Args: + client_status (str): The client status. + msg_params (dict): The message parameters. + + Returns: + None + """ self.client_finished_mapping[str(msg_params.get_sender_id())] = True all_client_is_finished = True @@ -141,7 +230,8 @@ def process_finished_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_finished = %s" % (msg_params.get_sender_id(), str(all_client_is_finished)) + "sender_id = %d, all_client_is_finished = %s" % ( + msg_params.get_sender_id(), str(all_client_is_finished)) ) if all_client_is_finished: @@ -152,6 +242,17 @@ def process_finished_status(self, client_status, msg_params): self.finish() def handle_message_client_status_update(self, msg_params): + """ + Handle the client status update message. + + This method processes the client status update message and takes appropriate actions based on the status. + + Args: + msg_params (dict): The message parameters. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) logging.info(f"received client status {client_status}") if client_status == FedMLServerManager.ONLINE_STATUS_FLAG: @@ -160,38 +261,59 @@ def handle_message_client_status_update(self, msg_params): self.process_finished_status(client_status, msg_params) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the message receiving the model from a client. + + This method handles the message that receives the model parameters from a client and performs aggregation. + + Args: + msg_params (dict): The message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) + mlops.event("comm_c2s", event_started=False, event_value=str( + self.args.round_idx), event_edge_id=sender_id) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) b_all_received = self.aggregator.check_whether_all_receive() logging.info("b_all_received = " + str(b_all_received)) if b_all_received: - mlops.event("server.wait", event_started=False, event_value=str(self.args.round_idx)) - mlops.event("server.agg_and_eval", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=False, + event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=True, + event_value=str(self.args.round_idx)) tick = time.time() global_model_params, model_list, model_list_idxes = self.aggregator.aggregate() - logging.info("self.client_id_list_in_this_round = {}".format(self.client_id_list_in_this_round)) + logging.info("self.client_id_list_in_this_round = {}".format( + self.client_id_list_in_this_round)) new_client_id_list_in_this_round = [] for client_idx in model_list_idxes: - new_client_id_list_in_this_round.append(self.client_id_list_in_this_round[client_idx]) - logging.info("new_client_id_list_in_this_round = {}".format(new_client_id_list_in_this_round)) - Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, new_client_id_list_in_this_round) + new_client_id_list_in_this_round.append( + self.client_id_list_in_this_round[client_idx]) + logging.info("new_client_id_list_in_this_round = {}".format( + new_client_id_list_in_this_round)) + Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, + new_client_id_list_in_this_round) if self.is_main_process(): - MLOpsProfilerEvent.log_to_wandb({"AggregationTime": time.time() - tick, "round": self.args.round_idx}) + MLOpsProfilerEvent.log_to_wandb( + {"AggregationTime": time.time() - tick, "round": self.args.round_idx}) self.aggregator.test_on_server_for_all_clients(self.args.round_idx) self.aggregator.assess_contribution() - mlops.event("server.agg_and_eval", event_started=False, event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=False, + event_value=str(self.args.round_idx)) # send round info to the MQTT backend mlops.log_round_info(self.round_num, self.args.round_idx) @@ -200,12 +322,15 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.args.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.args.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) - Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, self.client_id_list_in_this_round) + Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, + self.client_id_list_in_this_round) if self.args.round_idx == 0 and self.is_main_process(): - MLOpsProfilerEvent.log_to_wandb({"BenchmarkStart": time.time()}) + MLOpsProfilerEvent.log_to_wandb( + {"BenchmarkStart": time.time()}) client_idx_in_this_round = 0 global_model_url = None @@ -232,13 +357,24 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx += 1 if self.is_main_process(): - mlops.log_aggregated_model_info(self.args.round_idx, model_url=global_model_url) + mlops.log_aggregated_model_info( + self.args.round_idx, model_url=global_model_url) - logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) + logging.info( + "\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) if self.args.round_idx < self.round_num: - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) def cleanup(self): + """ + Send cleanup messages to clients. + + This method sends cleanup messages to all clients to signal the end of communication. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -248,73 +384,175 @@ def cleanup(self): def send_message_init_config(self, receive_id, global_model_params, datasilo_index, global_model_url=None, global_model_key=None): + """ + Send an initialization message with configuration to a client. + + This method sends an initialization message to a client containing configuration information and model parameters. + + Args: + receive_id (int): The receiver's ID. + global_model_params (dict): Global model parameters. + datasilo_index (int): The data silo index of the client. + global_model_url (str): The URL of the global model (optional). + global_model_key (str): The key of the global model (optional). + + Returns: + str: The URL of the global model. + str: The key of the global model. + """ if self.is_main_process(): tick = time.time() - message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, + self.get_sender_id(), receive_id) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) - global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) - global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) - MLOpsProfilerEvent.log_to_wandb({"Communiaction/Send_Total": time.time() - tick}) + global_model_url = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) + global_model_key = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) + MLOpsProfilerEvent.log_to_wandb( + {"Communiaction/Send_Total": time.time() - tick}) return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a message to check the status of a client. + + This method sends a message to a client to check its status. + + Args: + receive_id (int): The receiver's ID. + datasilo_index (int): The data silo index of the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a finish message to a client. + + This method sends a finish message to a client to signal the end of communication. + + Args: + receive_id (int): The receiver's ID. + datasilo_index (int): The data silo index of the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) logging.info( "finish from send id {} to receive id {}.".format(message.get_sender_id(), message.get_receiver_id())) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index, global_model_url=None, global_model_key=None): + """ + Send a message to synchronize the global model to a client. + + This method sends a message to a client to synchronize the global model parameters. + + Args: + receive_id (int): The receiver's ID. + global_model_params (dict): Global model parameters. + client_index (int): The client index. + global_model_url (str): The URL of the global model (optional). + global_model_key (str): The key of the global model (optional). + + Returns: + str: The URL of the global model. + str: The key of the global model. + """ + if self.is_main_process(): tick = time.time() - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) - MLOpsProfilerEvent.log_to_wandb({"Communiaction/Send_Total": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Communiaction/Send_Total": time.time() - tick}) - global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) - global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) + global_model_url = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) + global_model_key = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) return global_model_url, global_model_key def send_message_diff_sync_model_to_client(self, receive_id, client_model_params, client_index): + """ + Send a message to synchronize a different global model to a client. + + This method sends a message to a client to synchronize a different global model parameters. + Unlike `send_message_sync_model_to_client`, this method does not synchronize the global model for all clients, + but rather sends a specific client's model. + + Args: + receive_id (int): The receiver's ID. + client_model_params (dict): The client's model parameters. + client_index (int): The client index. + + Returns: + str: The URL of the global model. + str: The key of the global model. + """ global_model_url = None global_model_key = None if self.is_main_process(): tick = time.time() - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, client_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, client_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) - MLOpsProfilerEvent.log_to_wandb({"Communiaction/Send_Total": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Communiaction/Send_Total": time.time() - tick}) - global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) - global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) + global_model_url = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) + global_model_key = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) return global_model_url, global_model_key diff --git a/python/fedml/cross_silo/server/server_initializer.py b/python/fedml/cross_silo/server/server_initializer.py index 5877d96fea..f402be918a 100644 --- a/python/fedml/cross_silo/server/server_initializer.py +++ b/python/fedml/cross_silo/server/server_initializer.py @@ -18,6 +18,30 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for federated learning. + + This function sets up the server for federated learning, including creating an aggregator, + starting distributed training, and running the server manager. + + Args: + args (argparse.Namespace): Command-line arguments and configurations. + device (torch.device): The device on which the server runs. + comm (Communicator): The communication backend. + rank (int): The rank of the server in the distributed environment. + worker_num (int): The number of worker nodes participating in federated learning. + model (torch.nn.Module): The model used for federated learning. + train_data_num (int): The number of training data points globally. + train_data_global (Dataset): The global training dataset. + test_data_global (Dataset): The global test dataset. + train_data_local_dict (dict): A dictionary of local training datasets for each client. + test_data_local_dict (dict): A dictionary of local test datasets for each client. + train_data_local_num_dict (dict): A dictionary of the number of local training data points for each client. + server_aggregator (ServerAggregator, optional): The server aggregator. If not provided, it will be created. + + Returns: + None + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(0) @@ -38,5 +62,6 @@ def init_server( # start the distributed training backend = args.backend - server_manager = FedMLServerManager(args, aggregator, comm, rank, worker_num, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, rank, worker_num, backend) server_manager.run() From 832356cec13bbdda341b45c536bebf7a8a5d8c90 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 20 Sep 2023 19:13:00 +0530 Subject: [PATCH 29/70] add --- .../core/security/defense/RFA_defense.py | 60 +++++- .../core/security/defense/bulyan_defense.py | 71 +++++++ .../core/security/defense/cclip_defense.py | 73 ++++++- .../defense/coordinate_wise_median_defense.py | 28 ++- .../coordinate_wise_trimmed_mean_defense.py | 34 +++- .../core/security/defense/crfl_defense.py | 63 +++++- .../security/defense/cross_round_defense.py | 62 ++++++ .../core/security/defense/defense_base.py | 34 +++- .../security/defense/foolsgold_defense.py | 42 ++++ .../defense/geometric_median_defense.py | 32 ++- .../core/security/defense/krum_defense.py | 25 +++ .../defense/norm_diff_clipping_defense.py | 50 ++++- .../security/defense/outlier_detection.py | 34 +++- .../residual_based_reweighting_defense.py | 114 ++++++++++- .../defense/robust_learning_rate_defense.py | 70 ++++++- .../core/security/defense/slsgd_defense.py | 63 ++++++ .../core/security/defense/soteria_defense.py | 47 +++++ .../security/defense/three_sigma_defense.py | 86 +++++++- .../defense/three_sigma_geomedian_defense.py | 107 +++++++++- .../defense/three_sigma_krum_defense.py | 133 +++++++++++++ .../core/security/defense/wbc_defense.py | 52 ++++- .../core/security/defense/weak_dp_defense.py | 44 ++++ python/fedml/core/security/fedml_attacker.py | 129 +++++++++++- python/fedml/core/security/fedml_defender.py | 144 ++++++++++++++ python/fedml/cross_device/mnn_server.py | 39 +++- .../cross_silo/client/client_initializer.py | 73 ++++++- .../cross_silo/client/client_launcher.py | 70 ++++++- .../client/fedml_client_master_manager.py | 188 ++++++++++++++++-- .../client/fedml_client_slave_manager.py | 56 ++++++ .../fedml/cross_silo/client/fedml_trainer.py | 97 ++++++++- .../client/fedml_trainer_dist_adapter.py | 89 ++++++++- .../client/process_group_manager.py | 35 +++- python/fedml/cross_silo/client/utils.py | 51 ++++- 33 files changed, 2166 insertions(+), 129 deletions(-) diff --git a/python/fedml/core/security/defense/RFA_defense.py b/python/fedml/core/security/defense/RFA_defense.py index ceedcf6b65..1bba3a6809 100644 --- a/python/fedml/core/security/defense/RFA_defense.py +++ b/python/fedml/core/security/defense/RFA_defense.py @@ -12,15 +12,65 @@ class RFADefense(BaseDefenseMethod): - def __init__(self, config): - pass + """ + Robust Aggregation for Federated Learning (RFA) Defense. - def defend_on_aggregation( - self, + This defense method computes a geometric median in aggregation. + + Args: + config: Configuration parameters (currently unused). + + Attributes: + None + + Methods: + defend_on_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, - ): + ) -> OrderedDict: + Defend against potential adversarial behavior during aggregation. + + References: + - "RFA: Robust Aggregation for Federated Learning." + https://arxiv.org/pdf/1912.13445.pdf + """ + + def __init__(self, config): + """ + Initialize the RFADefense. + + Args: + config: Configuration parameters (currently unused). + """ + pass + + def defend_on_aggregation( + self, + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + """ + Defend against potential adversarial behavior during aggregation. + + This method computes a geometric median aggregation of client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + OrderedDict: + Aggregated parameters after applying the defense. + + Notes: + This defense method computes a geometric median aggregation of client gradients. + """ (num0, avg_params) = raw_client_grad_list[0] weights = {num for (num, params) in raw_client_grad_list} weights = {weight / sum(weights, 0.0) for weight in weights} diff --git a/python/fedml/core/security/defense/bulyan_defense.py b/python/fedml/core/security/defense/bulyan_defense.py index cea55960ae..253488934d 100644 --- a/python/fedml/core/security/defense/bulyan_defense.py +++ b/python/fedml/core/security/defense/bulyan_defense.py @@ -21,6 +21,21 @@ class BulyanDefense(BaseDefenseMethod): + """ + Bulyan Defense for Federated Learning. + + Bulyan Defense is a defense method for federated learning that aims to mitigate the impact of Byzantine clients + by selecting a subset of clients' gradients for aggregation. + + Args: + config: Configuration parameters for the defense. + - byzantine_client_num (int): The number of Byzantine (malicious) clients. + - client_num_per_round (int): The total number of clients participating in each aggregation round. + + Attributes: + byzantine_client_num (int): The number of Byzantine (malicious) clients. + client_num_per_round (int): The total number of clients participating in each aggregation round. + """ def __init__(self, config): self.byzantine_client_num = config.byzantine_client_num self.client_num_per_round = config.client_num_per_round @@ -37,6 +52,18 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> OrderedDict: + """ + Run the Bulyan Defense to aggregate gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + base_aggregation_func (Callable, optional): The base aggregation function to use. Default is None. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + OrderedDict: The aggregated gradients after applying the Bulyan Defense. + """ # note: raw_client_grad_list is a list, each item is (sample_num, gradients). num_clients = len(raw_client_grad_list) (num0, localw0) = raw_client_grad_list[0] @@ -70,6 +97,18 @@ def run( return aggregated_params def _bulyan(self, users_params, users_count, corrupted_count): + """ + Perform the Bulyan aggregation. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + users_count (int): The total number of users. + corrupted_count (int): The number of corrupted (Byzantine) users. + + Returns: + Tuple[List[int], List[numpy.ndarray], numpy.ndarray]: A tuple containing the selected indices, + selected set of gradients, and the aggregated gradients. + """ assert users_count >= 4 * corrupted_count + 3 set_size = users_count - 2 * corrupted_count selection_set = [] @@ -98,6 +137,16 @@ def _bulyan(self, users_params, users_count, corrupted_count): @staticmethod def trimmed_mean(users_params, corrupted_count): + """ + Compute the trimmed mean of users' gradients. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + corrupted_count (int): The number of corrupted (Byzantine) users. + + Returns: + numpy.ndarray: The trimmed mean of gradients. + """ users_params = np.array(users_params) number_to_consider = int(users_params.shape[0] - corrupted_count) - 1 @@ -120,6 +169,19 @@ def _krum( distances=None, return_index=False, ): + """ + Perform the Krum selection. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + users_count (int): The total number of users. + corrupted_count (int): The number of corrupted (Byzantine) users. + distances (dict, optional): Precomputed distances between users. Default is None. + return_index (bool, optional): Whether to return the selected index. Default is False. + + Returns: + numpy.ndarray or int: The selected gradients or index. + """ non_malicious_count = users_count - corrupted_count minimal_error = 1e20 @@ -141,6 +203,15 @@ def _krum( @staticmethod def _krum_create_distances(users_params): + """ + Create pairwise distances between users' gradients. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + + Returns: + dict: A dictionary containing pairwise distances between users' gradients. + """ distances = defaultdict(dict) for i in range(len(users_params)): for j in range(i): diff --git a/python/fedml/core/security/defense/cclip_defense.py b/python/fedml/core/security/defense/cclip_defense.py index eba983cb48..63e232ba8e 100755 --- a/python/fedml/core/security/defense/cclip_defense.py +++ b/python/fedml/core/security/defense/cclip_defense.py @@ -13,10 +13,29 @@ class CClipDefense(BaseDefenseMethod): + """ + CClip Defense for Federated Learning. + + CClip (Coordinate-wise Clipping) Defense is a defense method for federated learning that clips gradients at each + coordinate to mitigate the impact of Byzantine clients. + + Args: + config: Configuration parameters for the defense. + - tau (float, optional): The clipping radius. Default is 10. + - bucket_size (int, optional): The number of elements in each bucket when partitioning gradients. + Default is None. + + Attributes: + tau (float): The clipping radius. + bucket_size (int): The number of elements in each bucket when partitioning gradients. + initial_guess (OrderedDict): The initial guess for the global model. + """ + def __init__(self, config): self.config = config if hasattr(config, "tau") and type(config.tau) in [int, float] and config.tau > 0: - self.tau = config.tau # clipping raduis; tau = 10 / (1-beta), beta is the coefficient of momentum + # clipping raduis; tau = 10 / (1-beta), beta is the coefficient of momentum + self.tau = config.tau else: self.tau = 10 # default: no momentum, beta = 0 # element # in each bucket; a grad_list is partitioned into floor(len(grad_list)/bucket_size) buckets @@ -28,10 +47,23 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply CClip Defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + List[Tuple[float, OrderedDict]]: The modified gradients after applying CClip Defense. + """ + client_grad_buckets = Bucket.bucketization( raw_client_grad_list, self.bucket_size ) - self.initial_guess = self._compute_an_initial_guess(client_grad_buckets) + self.initial_guess = self._compute_an_initial_guess( + client_grad_buckets) bucket_num = len(client_grad_buckets) vec_local_w = [ ( @@ -47,25 +79,58 @@ def defend_before_aggregation( tuple = OrderedDict() sample_num, bucket_params = client_grad_buckets[i] for k in bucket_params.keys(): - tuple[k] = (bucket_params[k] - self.initial_guess[k]) * cclip_score[i] + tuple[k] = (bucket_params[k] - + self.initial_guess[k]) * cclip_score[i] new_grad_list.append((sample_num, tuple)) return new_grad_list def defend_after_aggregation(self, global_model): + """ + Apply CClip Defense after aggregation. + + Args: + global_model (OrderedDict): The global model after aggregation. + + Returns: + OrderedDict: The modified global model after applying CClip Defense. + """ + for k in global_model.keys(): global_model[k] = self.initial_guess[k] + global_model[k] return global_model @staticmethod def _compute_an_initial_guess(client_grad_list): + """ + Compute an initial guess for the global model. + + Args: + client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + + Returns: + OrderedDict: The initial guess for the global model. + """ # randomly select a gradient as the initial guess return client_grad_list[np.random.randint(0, len(client_grad_list))][1] def _compute_cclip_score(self, local_w, refs): + """ + Compute the CClip score for each local gradient. + + Args: + local_w (List[Tuple[float, numpy.ndarray]]): A list of tuples containing the number of samples and + vectorized local gradients. + refs (numpy.ndarray): Vectorized reference gradient. + + Returns: + List[float]: A list of CClip scores for each local gradient. + """ cclip_score = [] num_client = len(local_w) for i in range(0, num_client): - dist = utils.compute_euclidean_distance(local_w[i][1], refs).item() + 1e-8 + dist = utils.compute_euclidean_distance( + local_w[i][1], refs).item() + 1e-8 score = min(1, self.tau / dist) cclip_score.append(score) return cclip_score diff --git a/python/fedml/core/security/defense/coordinate_wise_median_defense.py b/python/fedml/core/security/defense/coordinate_wise_median_defense.py index 30357a1d67..6412b55cdc 100644 --- a/python/fedml/core/security/defense/coordinate_wise_median_defense.py +++ b/python/fedml/core/security/defense/coordinate_wise_median_defense.py @@ -12,6 +12,19 @@ class CoordinateWiseMedianDefense(BaseDefenseMethod): + """ + Coordinate-wise Median Defense for Federated Learning. + + Coordinate-wise Median Defense is a defense method for federated learning that computes the median of the gradients + for each coordinate to mitigate the impact of Byzantine clients. + + Args: + config: Configuration parameters for the defense. (Currently, no specific parameters are required.) + + Attributes: + None + """ + def __init__(self, config): pass @@ -21,6 +34,18 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Apply Coordinate-wise Median Defense on aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + base_aggregation_func (Callable, optional): The base aggregation function. Default is None. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + OrderedDict: The aggregated global model after applying Coordinate-wise Median Defense. + """ vectorized_params = [] for i in range(0, len(raw_client_grad_list)): @@ -35,11 +60,10 @@ def defend_on_aggregation( index = 0 (num0, averaged_params) = raw_client_grad_list[0] for k, params in averaged_params.items(): - median_params = vec_median_params[index : index + params.numel()].view( + median_params = vec_median_params[index: index + params.numel()].view( params.size() ) index += params.numel() averaged_params[k] = median_params return averaged_params - diff --git a/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py b/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py index 1a717946d7..6bbc4f97bf 100644 --- a/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py +++ b/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py @@ -12,15 +12,45 @@ class CoordinateWiseTrimmedMeanDefense(BaseDefenseMethod): + """ + Coordinate-wise Trimmed Mean Defense for Federated Learning. + + Coordinate-wise Trimmed Mean Defense is a defense method for federated learning that computes the trimmed mean of + gradients for each coordinate to mitigate the impact of Byzantine clients. + + Args: + config: Configuration parameters for the defense, including 'beta' which represents the fraction of trimmed + values; total trimmed values: client_num * beta * 2. + + Attributes: + beta (float): The fraction of trimmed values, which determines the number of gradients to be trimmed on each side. + """ + def __init__(self, config): - self.beta = config.beta # fraction of trimmed values; total trimmed values: client_num * beta * 2 + """ + Initialize the CoordinateWiseTrimmedMeanDefense with the specified configuration. + Args: + config: Configuration parameters for the defense. + """ + self.beta = config.beta def defend_before_aggregation( self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply Coordinate-wise Trimmed Mean Defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + OrderedDict: The aggregated global model after applying Coordinate-wise Trimmed Mean Defense. + """ if self.beta > 1 / 2 or self.beta < 0: - raise ValueError("the bound of beta is [0, 1/2)") + raise ValueError("The bound of 'beta' is [0, 1/2)") return trimmed_mean(raw_client_grad_list, int(self.beta * len(raw_client_grad_list))) diff --git a/python/fedml/core/security/defense/crfl_defense.py b/python/fedml/core/security/defense/crfl_defense.py index 13712da1f3..bff5d53691 100644 --- a/python/fedml/core/security/defense/crfl_defense.py +++ b/python/fedml/core/security/defense/crfl_defense.py @@ -9,8 +9,38 @@ """ +from .base_defense_method import BaseDefenseMethod +from .utils import compute_model_norm +from .gaussian import compute_noise_using_sigma +from collections import OrderedDict + + class CRFLDefense(BaseDefenseMethod): + """ + CRFL (Clip and Randomly Flip) Defense for Federated Learning. + + CRFL Defense is a defense method for federated learning that clips the global model's weights if they exceed a + dynamic threshold and adds Gaussian noise to the clipped weights to improve privacy. + + Args: + config: Configuration parameters for the defense, including 'clip_threshold' (optional), 'sigma', 'comm_round', + and 'dataset'. + + Attributes: + epoch (int): The current training epoch. + user_defined_clip_threshold (float, optional): A user-defined clipping threshold for model weights. + sigma (float): The standard deviation of Gaussian noise added to clipped weights. + total_ite_num (int): The total number of communication rounds. + dataset_param_function (function): A function to compute the dynamic clipping threshold based on the dataset. + """ + def __init__(self, config): + """ + Initialize the CRFLDefense with the specified configuration. + + Args: + config: Configuration parameters for the defense. + """ self.config = config self.epoch = 1 if hasattr(config, "clip_threshold"): @@ -20,7 +50,7 @@ def __init__(self, config): if hasattr(config, "sigma") and isinstance(config.sigma, float): self.sigma = config.sigma else: - self.sigma = 0.01 # in the code of CRFL, the author set sigma to 0.01 + self.sigma = 0.01 # Default sigma value as used in CRFL code self.total_ite_num = config.comm_round if config.dataset == "mnist": self.dataset_param_function = self._crfl_compute_param_for_mnist @@ -31,15 +61,18 @@ def __init__(self, config): elif self.user_defined_clip_threshold is not None: self.dataset_param_function = self._crfl_self_defined_dataset_param else: - raise Exception(f"dataset not supported: {config.dataset} and clip_threshold not defined ") + raise Exception( + f"Dataset not supported: {config.dataset} and clip_threshold not defined.") def defend_after_aggregation(self, global_model): """ - clip the global model; dynamic threshold is adjusted according to the dataset; - in the experiment, the authors set the dynamic threshold as follows: - dataset == MNIST: dynamic_thres = epoch * 0.1 + 2 - dataseet == LOAN: dynamic_thres = epoch * 0.025 + 2 - datset == EMNIST: dynamic_thres = epoch * 0.25 + 4 + Apply CRFL Defense after model aggregation. + + Args: + global_model (OrderedDict): The global model to be defended. + + Returns: + OrderedDict: The defended global model after clipping and adding Gaussian noise. """ clip_threshold = self.dataset_param_function() if self.user_defined_clip_threshold is not None and self.user_defined_clip_threshold < clip_threshold: @@ -51,7 +84,8 @@ def defend_after_aggregation(self, global_model): self.epoch += 1 new_global_model = OrderedDict() for k in global_model.keys(): - new_global_model[k] = global_model[k] + Gaussian.compute_noise_using_sigma(self.sigma, global_model[k].shape) + new_global_model[k] = global_model[k] + \ + compute_noise_using_sigma(self.sigma, global_model[k].shape) return new_global_model def _crfl_self_defined_dataset_param(self): @@ -68,8 +102,17 @@ def _crfl_compute_param_for_emnist(self): @staticmethod def clip_weight_norm(model, clip_threshold): - total_norm = utils.compute_model_norm(model) - print(f"total_norm = {total_norm}") + """ + Clip the weight norm of the model. + + Args: + model (OrderedDict): The model whose weights are to be clipped. + clip_threshold (float): The threshold value for clipping. + + Returns: + OrderedDict: The model with clipped weights. + """ + total_norm = compute_model_norm(model) if total_norm > clip_threshold: clip_coef = clip_threshold / (total_norm + 1e-6) new_model = OrderedDict() diff --git a/python/fedml/core/security/defense/cross_round_defense.py b/python/fedml/core/security/defense/cross_round_defense.py index 0d8eb34bd5..ae8174563a 100644 --- a/python/fedml/core/security/defense/cross_round_defense.py +++ b/python/fedml/core/security/defense/cross_round_defense.py @@ -13,6 +13,24 @@ # too much difference: malicious, need further defense # todo: pretraining round? class CrossRoundDefense(BaseDefenseMethod): + """ + CrossRoundDefense for Federated Learning. + + CrossRoundDefense is a defense method for federated learning that detects potentially poisoned workers + based on cosine similarity between client and global model features across training rounds. + + Args: + config: Configuration parameters for the defense, including 'upperbound' and 'lowerbound'. + + Attributes: + potentially_poisoned_worker_list (list): List of potentially poisoned worker indices. + lazy_worker_list (list): List of lazy worker indices. + upperbound (float): Threshold for detecting potential attacks. + lowerbound (float): Threshold for defining "very limited difference." + client_cache (list): Cache of client features for comparison across training rounds. + training_round (int): The current training round. + is_attack_existing (bool): Flag indicating whether an attack exists in the current round. + """ def __init__(self, config): self.potentially_poisoned_worker_list = [] self.lazy_worker_list = None @@ -28,6 +46,16 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply CrossRoundDefense before model aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + extra_auxiliary_info: Global model or auxiliary information. + + Returns: + list: List of potentially poisoned client gradients. + """ self.is_attack_existing = False client_features = self._get_importance_feature(raw_client_grad_list) if self.training_round == 1: @@ -71,9 +99,25 @@ def defend_before_aggregation( return raw_client_grad_list def get_potential_poisoned_clients(self): + """ + Get the list of potentially poisoned client indices. + + Returns: + list: List of potentially poisoned client indices. + """ return self.potentially_poisoned_worker_list def compute_client_cosine_scores(self, client_features, global_model_feature): + """ + Compute cosine similarity scores between client features and global model features. + + Args: + client_features (list): List of client feature vectors. + global_model_feature (list): Feature vector of the global model. + + Returns: + tuple: Two lists of cosine similarity scores for each client (client-wise and global-wise). + """ client_wise_scores = [] global_wise_scores = [] num_client = len(client_features) @@ -85,6 +129,15 @@ def compute_client_cosine_scores(self, client_features, global_model_feature): return client_wise_scores, global_wise_scores def _get_importance_feature(self, raw_client_grad_list): + """ + Extract importance features from client gradients. + + Args: + raw_client_grad_list (list): List of client gradients. + + Returns: + list: List of extracted importance feature vectors. + """ ret_feature_vector_list = [] for idx in range(len(raw_client_grad_list)): raw_grad = raw_client_grad_list[idx] @@ -96,6 +149,15 @@ def _get_importance_feature(self, raw_client_grad_list): @classmethod def _get_importance_feature_of_a_model(self, grad): + """ + Extract importance feature from a client gradient. + + Args: + grad (OrderedDict): Client gradient. + + Returns: + numpy.ndarray: Importance feature vector. + """ # Get last key-value tuple (weight_name, importance_feature) = list(grad.items())[-2] # print(importance_feature) diff --git a/python/fedml/core/security/defense/defense_base.py b/python/fedml/core/security/defense/defense_base.py index 4abc3bbecf..77a1adbaa9 100644 --- a/python/fedml/core/security/defense/defense_base.py +++ b/python/fedml/core/security/defense/defense_base.py @@ -4,8 +4,20 @@ class BaseDefenseMethod(ABC): + """ + Base class for defense methods in Federated Learning. + + Attributes: + config: Configuration parameters for the defense method. + """ @abstractmethod def __init__(self, config): + """ + Initialize the defense method with the specified configuration. + + Args: + config: Configuration parameters for the defense method. + """ pass def defend_before_aggregation( @@ -14,12 +26,14 @@ def defend_before_aggregation( extra_auxiliary_info: Any = None, ) -> List[Tuple[float, OrderedDict]]: """ - args: - client_grad_list: client_grad_list is a list, each item is (sample_num, gradients) - extra_auxiliary_info: for methods which need extra info (e.g., data, previous model/gradient), - please use this variable. - return: - Note: the data type of the return variable should be the same as the input + Apply defense before model aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + extra_auxiliary_info: Additional information required for defense. + + Returns: + list: List of defended client gradients. """ pass @@ -41,4 +55,10 @@ def defend_on_aggregation( pass def get_malicious_client_idxs(self): - return [] \ No newline at end of file + """ + Get the indices of potentially malicious clients. + + Returns: + list: List of indices of potentially malicious clients. + """ + return [] diff --git a/python/fedml/core/security/defense/foolsgold_defense.py b/python/fedml/core/security/defense/foolsgold_defense.py index 5db59eecad..4814637d96 100644 --- a/python/fedml/core/security/defense/foolsgold_defense.py +++ b/python/fedml/core/security/defense/foolsgold_defense.py @@ -12,7 +12,21 @@ class FoolsGoldDefense(BaseDefenseMethod): + """ + Defense method using FoolsGold for federated learning. + + Attributes: + config: Configuration parameters for the defense method. + memory: Memory for storing client importance features. + """ + def __init__(self, config): + """ + Initialize the FoolsGoldDefense. + + Args: + config: Configuration parameters for the defense method. + """ super().__init__(config) self.config = config self.memory = None @@ -22,6 +36,16 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply FoolsGold defense before model aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + extra_auxiliary_info: Additional information required for defense. + + Returns: + list: List of defended client gradients. + """ client_num = len(raw_client_grad_list) importance_feature_list = self._get_importance_feature(raw_client_grad_list) # print(len(importance_feature_list)) @@ -47,6 +71,15 @@ def defend_before_aggregation( # Takes in grad, compute similarity, get weightings @classmethod def fools_gold_score(cls, feature_vec_list): + """ + Compute FoolsGold scores for client importance features. + + Args: + feature_vec_list (list): List of client importance features. + + Returns: + list: List of FoolsGold scores. + """ import sklearn.metrics.pairwise as smp n_clients = len(feature_vec_list) cs = smp.cosine_similarity(feature_vec_list) - np.eye(n_clients) @@ -75,6 +108,15 @@ def fools_gold_score(cls, feature_vec_list): return alpha def _get_importance_feature(self, raw_client_grad_list): + """ + Get the importance feature from client gradients. + + Args: + raw_client_grad_list (list): List of client gradients. + + Returns: + list: List of importance features. + """ # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] for idx in range(len(raw_client_grad_list)): diff --git a/python/fedml/core/security/defense/geometric_median_defense.py b/python/fedml/core/security/defense/geometric_median_defense.py index adf60edaa2..edd2dc733d 100644 --- a/python/fedml/core/security/defense/geometric_median_defense.py +++ b/python/fedml/core/security/defense/geometric_median_defense.py @@ -20,7 +20,23 @@ class GeometricMedianDefense(BaseDefenseMethod): + """ + Defense method using Geometric Median for federated learning. + + Attributes: + byzantine_client_num: Number of Byzantine clients in the system. + client_num_per_round: Number of clients participating in each round. + batch_num: Number of batches used for geometric median computation. + batch_size: Size of each batch for gradient aggregation. + """ + def __init__(self, config): + """ + Initialize the GeometricMedianDefense. + + Args: + config: Configuration parameters for the defense method. + """ self.byzantine_client_num = config.byzantine_client_num self.client_num_per_round = config.client_num_per_round # 2(1 + ε )q ≤ batch_num ≤ client_num_per_round @@ -37,7 +53,19 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): - batch_grad_list = Bucket.bucketization(raw_client_grad_list, self.batch_size) + """ + Apply Geometric Median defense on gradient aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + base_aggregation_func (Callable): Base aggregation function to use (optional). + extra_auxiliary_info: Additional information required for defense (optional). + + Returns: + OrderedDict: Aggregated global model parameters. + """ + batch_grad_list = Bucket.bucketization( + raw_client_grad_list, self.batch_size) (num0, avg_params) = batch_grad_list[0] alphas = {alpha for (alpha, params) in batch_grad_list} alphas = {alpha / sum(alphas, 0.0) for alpha in alphas} @@ -45,5 +73,3 @@ def defend_on_aggregation( batch_grads = [params[k] for (alpha, params) in batch_grad_list] avg_params[k] = compute_geometric_median(alphas, batch_grads) return avg_params - - diff --git a/python/fedml/core/security/defense/krum_defense.py b/python/fedml/core/security/defense/krum_defense.py index 5201cc8c09..19d8af73cf 100755 --- a/python/fedml/core/security/defense/krum_defense.py +++ b/python/fedml/core/security/defense/krum_defense.py @@ -16,6 +16,12 @@ class KrumDefense(BaseDefenseMethod): def __init__(self, config): + """ + Initialize the KrumDefense method. + + Args: + config (object): Configuration object containing defense parameters. + """ self.config = config self.byzantine_client_num = config.byzantine_client_num @@ -29,6 +35,16 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation using the KrumDefense method. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional information (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ num_client = len(raw_client_grad_list) # in the Krum paper, it says 2 * byzantine_client_num + 2 < client # if not 2 * self.byzantine_client_num + 2 <= num_client - self.krum_param_m: @@ -48,6 +64,15 @@ def defend_before_aggregation( return [raw_client_grad_list[i] for i in score_index] def _compute_krum_score(self, vec_grad_list): + """ + Compute Krum scores for the given list of gradient vectors. + + Args: + vec_grad_list (List[torch.Tensor]): List of gradient vectors. + + Returns: + List[float]: List of Krum scores. + """ krum_scores = [] num_client = len(vec_grad_list) for i in range(0, num_client): diff --git a/python/fedml/core/security/defense/norm_diff_clipping_defense.py b/python/fedml/core/security/defense/norm_diff_clipping_defense.py index a01306e7f6..bbb064478b 100644 --- a/python/fedml/core/security/defense/norm_diff_clipping_defense.py +++ b/python/fedml/core/security/defense/norm_diff_clipping_defense.py @@ -13,20 +13,38 @@ class NormDiffClippingDefense(BaseDefenseMethod): def __init__(self, config): + """ + Initialize the NormDiffClippingDefense method. + + Args: + config (object): Configuration object containing defense parameters. + """ self.config = config - self.norm_bound = config.norm_bound # for norm diff clipping; in the paper, they set it to 0.1, 0.17, and 0.33. + # for norm diff clipping; in the paper, they set it to 0.1, 0.17, and 0.33. + self.norm_bound = config.norm_bound def defend_before_aggregation( self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation using norm difference clipping. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Global model for clipping (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ global_model = extra_auxiliary_info vec_global_w = utils.vectorize_weight(global_model) new_grad_list = [] for (sample_num, local_w) in raw_client_grad_list: vec_local_w = utils.vectorize_weight(local_w) - clipped_weight_diff = self._get_clipped_norm_diff(vec_local_w, vec_global_w) + clipped_weight_diff = self._get_clipped_norm_diff( + vec_local_w, vec_global_w) clipped_w = self._get_clipped_weights( local_w, global_model, clipped_weight_diff ) @@ -34,20 +52,44 @@ def defend_before_aggregation( return new_grad_list def _get_clipped_norm_diff(self, vec_local_w, vec_global_w): + """ + Compute the clipped norm difference between local and global weights. + + Args: + vec_local_w (torch.Tensor): Vectorized local weights. + vec_global_w (torch.Tensor): Vectorized global weights. + + Returns: + torch.Tensor: Clipped weight difference. + """ vec_diff = vec_local_w - vec_global_w weight_diff_norm = torch.norm(vec_diff).item() - clipped_weight_diff = vec_diff / max(1, weight_diff_norm / self.norm_bound) + clipped_weight_diff = vec_diff / \ + max(1, weight_diff_norm / self.norm_bound) return clipped_weight_diff @staticmethod def _get_clipped_weights(local_w, global_w, weight_diff): + """ + Compute clipped weights based on global and local weights. + + Args: + local_w (OrderedDict): Local model weights. + global_w (OrderedDict): Global model weights. + weight_diff (torch.Tensor): Clipped weight difference. + + Returns: + OrderedDict: Clipped local model weights. + """ + # rule: global_w + clipped(local_w - global_w) recons_local_w = OrderedDict() index_bias = 0 for item_index, (k, v) in enumerate(local_w.items()): if utils.is_weight_param(k): recons_local_w[k] = ( - weight_diff[index_bias: index_bias + v.numel()].view(v.size()) + weight_diff[index_bias: index_bias + + v.numel()].view(v.size()) + global_w[k] ) index_bias += v.numel() diff --git a/python/fedml/core/security/defense/outlier_detection.py b/python/fedml/core/security/defense/outlier_detection.py index e24f6c594f..793d4ffd11 100644 --- a/python/fedml/core/security/defense/outlier_detection.py +++ b/python/fedml/core/security/defense/outlier_detection.py @@ -6,7 +6,14 @@ class OutlierDetection(BaseDefenseMethod): + def __init__(self, config): + """ + Initialize the OutlierDetection method. + + Args: + config (object): Configuration object containing defense parameters. + """ self.cross_round_check = CrossRoundDefense(config) self.three_sigma_check = ThreeSigmaKrumDefense(config) @@ -15,11 +22,30 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): - raw_client_grad_list = self.cross_round_check.defend_before_aggregation(raw_client_grad_list, extra_auxiliary_info) + """ + Perform outlier detection defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional information (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ + raw_client_grad_list = self.cross_round_check.defend_before_aggregation( + raw_client_grad_list, extra_auxiliary_info) if self.cross_round_check.is_attack_existing: - self.three_sigma_check.set_potential_malicious_clients(self.cross_round_check.get_potential_poisoned_clients()) - raw_client_grad_list = self.three_sigma_check.defend_before_aggregation(raw_client_grad_list, extra_auxiliary_info) + self.three_sigma_check.set_potential_malicious_clients( + self.cross_round_check.get_potential_poisoned_clients()) + raw_client_grad_list = self.three_sigma_check.defend_before_aggregation( + raw_client_grad_list, extra_auxiliary_info) return raw_client_grad_list def get_malicious_client_idxs(self): - return self.three_sigma_check.get_malicious_client_idxs() \ No newline at end of file + """ + Get the indices of potential malicious clients. + + Returns: + List[int]: List of indices of potential malicious clients. + """ + return self.three_sigma_check.get_malicious_client_idxs() diff --git a/python/fedml/core/security/defense/residual_based_reweighting_defense.py b/python/fedml/core/security/defense/residual_based_reweighting_defense.py index 32c1c07b14..f71fabb977 100644 --- a/python/fedml/core/security/defense/residual_based_reweighting_defense.py +++ b/python/fedml/core/security/defense/residual_based_reweighting_defense.py @@ -16,6 +16,12 @@ class ResidualBasedReweightingDefense(BaseDefenseMethod): def __init__(self, config): + """ + Initialize the ResidualBasedReweightingDefense method. + + Args: + config (object): Configuration object containing defense parameters. + """ if hasattr(config, "lambda_param"): self.lambda_param = config.lambda_param else: @@ -31,16 +37,36 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation using residual-based reweighting. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional information (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ return self.IRLS_other_split_restricted(raw_client_grad_list) def IRLS_other_split_restricted(self, raw_client_grad_list): + """ + Perform the Iteratively Reweighted Least Squares (IRLS) defense with restricted mode. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ reweight_algorithm = median_reweight_algorithm_restricted if self.mode == "median": reweight_algorithm = median_reweight_algorithm_restricted elif self.mode == "theilsen": reweight_algorithm = theilsen_reweight_algorithm_restricted elif self.mode == "gaussian": - reweight_algorithm = gaussian_reweight_algorithm_restricted # in gaussian reweight algorithm, lambda is sigma + # in gaussian reweight algorithm, lambda is sigma + reweight_algorithm = gaussian_reweight_algorithm_restricted SHARD_SIZE = 2000 w = [grad for (_, grad) in raw_client_grad_list] @@ -70,13 +96,15 @@ def IRLS_other_split_restricted(self, raw_client_grad_list): else: num_shards = int(math.ceil(total_num / SHARD_SIZE)) for i in range(num_shards): - y = transposed_y_list[i * SHARD_SIZE : (i + 1) * SHARD_SIZE, ...] + y = transposed_y_list[i * + SHARD_SIZE: (i + 1) * SHARD_SIZE, ...] reweight, restricted_y = reweight_algorithm( y, self.lambda_param, self.thresh ) print(reweight.sum(dim=0)) reweight_sum += reweight.sum(dim=0) - y_result[i * SHARD_SIZE : (i + 1) * SHARD_SIZE, ...] = restricted_y + y_result[i * SHARD_SIZE: (i + 1) + * SHARD_SIZE, ...] = restricted_y # put restricted y back to w y_result = torch.t(y_result) @@ -89,13 +117,25 @@ def IRLS_other_split_restricted(self, raw_client_grad_list): def median_reweight_algorithm_restricted(y, LAMBDA, thresh): + """ + Perform reweighting using the Median Reweight Algorithm with restricted mode. + + Args: + y (torch.Tensor): Input data. + LAMBDA (float): Lambda parameter. + thresh (float): Threshold value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing reweight values and restricted data. + """ num_models = y.shape[1] total_num = y.shape[0] X_pure = y.sort()[1].sort()[1].type(torch.float) # calculate H matrix X_pure = X_pure.unsqueeze(2) - X = torch.cat((torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) + X = torch.cat( + (torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) X_X = torch.matmul(X.transpose(1, 2), X) X_X = torch.matmul(X, torch.inverse(X_X)) H = torch.matmul(X_X, X.transpose(1, 2)) @@ -121,6 +161,16 @@ def median_reweight_algorithm_restricted(y, LAMBDA, thresh): def median(input): + """ + Calculate the median of the input data. + + Args: + input (torch.Tensor): Input data. + + Returns: + torch.Tensor: Median value. + """ + shape = input.shape input = input.sort()[0] if shape[-1] % 2 != 0: @@ -133,6 +183,17 @@ def median(input): def theilsen_reweight_algorithm_restricted(y, LAMBDA, thresh): + """ + Perform reweighting using the Theil-Sen Reweight Algorithm with restricted mode. + + Args: + y (torch.Tensor): Input data. + LAMBDA (float): Lambda parameter. + thresh (float): Threshold value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing reweight values and restricted data. + """ num_models = y.shape[1] total_num = y.shape[0] slopes, intercepts = theilsen(y) @@ -140,7 +201,8 @@ def theilsen_reweight_algorithm_restricted(y, LAMBDA, thresh): # calculate H matrix X_pure = X_pure.unsqueeze(2) - X = torch.cat((torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) + X = torch.cat( + (torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) X_X = torch.matmul(X.transpose(1, 2), X) X_X = torch.matmul(X, torch.inverse(X_X)) H = torch.matmul(X_X, X.transpose(1, 2)) @@ -173,12 +235,24 @@ def theilsen_reweight_algorithm_restricted(y, LAMBDA, thresh): def gaussian_reweight_algorithm_restricted(y, sig, thresh): + """ + Perform reweighting using the Gaussian Reweight Algorithm with restricted mode. + + Args: + y (torch.Tensor): Input data. + sig (float): Sigma parameter for the Gaussian distribution. + thresh (float): Threshold value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing reweight values and restricted data. + """ num_models = y.shape[1] total_num = y.shape[0] slopes, intercepts = repeated_median(y) X_pure = y.sort()[1].sort()[1].type(torch.float) X_pure = X_pure.unsqueeze(2) - X = torch.cat((torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) + X = torch.cat( + (torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) beta = torch.cat( ( @@ -205,10 +279,29 @@ def gaussian_reweight_algorithm_restricted(y, sig, thresh): def gaussian_zero_mean(x, sig=1): + """ + Compute the Gaussian reweighting with zero mean. + + Args: + x (torch.Tensor): Input data. + sig (float, optional): Sigma parameter for the Gaussian distribution. Default is 1. + + Returns: + torch.Tensor: Reweighted data. + """ return torch.exp(-x * x / (2 * sig * sig)) def repeated_median(y): + """ + Compute the repeated median and intercepts for the Theil-Sen Reweight Algorithm. + + Args: + y (torch.Tensor): Input data. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing slopes and intercepts. + """ num_models = y.shape[1] total_num = y.shape[0] y = y.sort()[0] @@ -238,6 +331,15 @@ def repeated_median(y): def theilsen(y): + """ + Compute the Theil-Sen estimator for slopes and intercepts. + + Args: + y (torch.Tensor): Input data. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing slopes and intercepts. + """ num_models = y.shape[1] total_num = y.shape[0] y = y.sort()[0] diff --git a/python/fedml/core/security/defense/robust_learning_rate_defense.py b/python/fedml/core/security/defense/robust_learning_rate_defense.py index ccbb1c24df..bacacfcfa3 100644 --- a/python/fedml/core/security/defense/robust_learning_rate_defense.py +++ b/python/fedml/core/security/defense/robust_learning_rate_defense.py @@ -24,22 +24,66 @@ class RobustLearningRateDefense(BaseDefenseMethod): + """ + Robust Learning Rate Defense. + + This defense method adjusts the learning rates of clients based on the robust threshold. + + Args: + config: Configuration parameters. + + Attributes: + robust_threshold (int): The robust threshold used for learning rate adjustment. + server_learning_rate (int): The server's learning rate. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + Adjust the learning rates of clients based on the robust threshold. + + """ + def __init__(self, config): + """ + Initialize the RobustLearningRateDefense. + + Args: + config: Configuration parameters. + """ self.robust_threshold = config.robust_threshold # e.g., robust threshold = 4 self.server_learning_rate = 1 def run( - self, - raw_client_grad_list: List[Tuple[float, OrderedDict]], - base_aggregation_func: Callable = None, - extra_auxiliary_info: Any = None, - ): + self, + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + """ + Adjust the learning rates of clients based on the robust threshold. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + OrderedDict: + Aggregated parameters after adjusting learning rates based on the robust threshold. + """ if self.robust_threshold == 0: return base_aggregation_func(raw_client_grad_list) # avg_params total_sample_num = get_total_sample_num(raw_client_grad_list) (num0, avg_params) = raw_client_grad_list[0] for k in avg_params.keys(): - client_update_sign = [] # self._compute_robust_learning_rates(model_list) + # self._compute_robust_learning_rates(model_list) + client_update_sign = [] for i in range(0, len(raw_client_grad_list)): local_sample_number, local_model_params = raw_client_grad_list[i] client_update_sign.append(torch.sign(local_model_params[k])) @@ -53,7 +97,19 @@ def run( return avg_params def _compute_robust_learning_rates(self, client_update_sign): + """ + Compute robust learning rates based on the client update signs. + + Args: + client_update_sign (list of torch.Tensor): + List of tensors containing the sign of client updates. + + Returns: + torch.Tensor: + Adjusted learning rates for clients. + """ client_lr = torch.abs(sum(client_update_sign)) - client_lr[client_lr < self.robust_threshold] = -self.server_learning_rate + client_lr[client_lr < self.robust_threshold] = - \ + self.server_learning_rate client_lr[client_lr >= self.robust_threshold] = self.server_learning_rate return client_lr diff --git a/python/fedml/core/security/defense/slsgd_defense.py b/python/fedml/core/security/defense/slsgd_defense.py index c39b60da7c..ac8d848da9 100644 --- a/python/fedml/core/security/defense/slsgd_defense.py +++ b/python/fedml/core/security/defense/slsgd_defense.py @@ -27,7 +27,42 @@ class SLSGDDefense(BaseDefenseMethod): + """ + Stochastic Leader Selection for SGD Defense. + + This defense method performs leader selection and aggregation for federated learning. + + Args: + config: Configuration parameters. + + Attributes: + b (int): Parameter of trimmed mean. + alpha (float): Weighting factor for aggregation. + option_type (int): Type of option. + config: Configuration parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform preprocessing and leader selection on client gradients before aggregation. + + defend_on_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + Perform aggregation with leader selection based on the given configuration. + + """ def __init__(self, config): + """ + Initialize the SLSGDDefense. + + Args: + config: Configuration parameters. + """ self.b = config.trim_param_b # parameter of trimmed mean if config.alpha > 1 or config.alpha < 0: raise ValueError("the bound of alpha is [0, 1]") @@ -40,6 +75,19 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform preprocessing and leader selection on client gradients before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: + Processed and selected client gradients. + """ if self.b > math.ceil(len(raw_client_grad_list) / 2) - 1 or self.b < 0: raise ValueError( "the bound of b is [0, {}])".format( @@ -60,6 +108,21 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Perform aggregation with leader selection based on the given configuration. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + OrderedDict: + Aggregated parameters after leader selection and aggregation. + """ global_model = extra_auxiliary_info avg_params = base_aggregation_func(args=self.config, raw_grad_list=raw_client_grad_list) for k in avg_params.keys(): diff --git a/python/fedml/core/security/defense/soteria_defense.py b/python/fedml/core/security/defense/soteria_defense.py index a85203eade..e9f372737e 100644 --- a/python/fedml/core/security/defense/soteria_defense.py +++ b/python/fedml/core/security/defense/soteria_defense.py @@ -26,6 +26,29 @@ class SoteriaDefense(BaseDefenseMethod): + """ + Soteria Defense for Federated Learning. + + This defense method performs a Soteria-based defense for federated learning. + + Args: + num_class (int): Number of classes in the dataset. + model: The federated learning model. + defense_data: Defense data for the model. + defense_label (int): Defense label. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> Dict: + Perform Soteria-based defense on the federated learning model. + + label_to_onehot(target, num_classes=100) -> torch.Tensor: + Convert labels to one-hot encoding. + + """ def __init__( self, num_class, @@ -33,6 +56,15 @@ def __init__( defense_data, defense_label=84, ): + """ + Initialize the SoteriaDefense. + + Args: + num_class (int): Number of classes in the dataset. + model: The federated learning model. + defense_data: Defense data for the model. + defense_label (int): Defense label. + """ self.num_class = num_class # number of classess of the dataset self.model = model self.defense_data = defense_data @@ -46,6 +78,21 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> Dict: + """ + Perform Soteria-based defense on the federated learning model. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + Dict: + Aggregation result after Soteria-based defense. + """ # load local model self.model.load_state_dict(raw_client_grad_list, strict=True) original_dy_dx = extra_auxiliary_info # refs for local gradient diff --git a/python/fedml/core/security/defense/three_sigma_defense.py b/python/fedml/core/security/defense/three_sigma_defense.py index ebdb56ce5b..efe7719103 100644 --- a/python/fedml/core/security/defense/three_sigma_defense.py +++ b/python/fedml/core/security/defense/three_sigma_defense.py @@ -6,7 +6,7 @@ from ..common import utils from scipy import spatial -### Original paper: https://arxiv.org/pdf/2107.05252.pdf +# Original paper: https://arxiv.org/pdf/2107.05252.pdf # training: In each iteration, each client k splits its local dataset into batches of size B, # and runs for E local epochs batched-gradient descent through the local dataset # to obtain local model, and sends it to the server. @@ -41,7 +41,39 @@ class ThreeSigmaDefense(BaseDefenseMethod): + """ + Three-Sigma Defense for Federated Learning. + + This defense method performs a Three-Sigma-based defense for federated learning. + + Args: + config: Configuration object for defense parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform defense before aggregation. + + compute_gaussian_distribution() -> Tuple[float, float]: + Compute the Gaussian distribution parameters. + + compute_client_scores(raw_client_grad_list) -> List[float]: + Compute client scores. + + fools_gold_score(feature_vec_list) -> List[float]: + Compute Fool's Gold scores. + + """ + def __init__(self, config): + """ + Initialize the ThreeSigmaDefense. + + Args: + config: Configuration object for defense parameters. + """ self.memory = None self.iteration_num = 1 self.score_list = [] @@ -74,6 +106,18 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: Batched gradient list after defense. + """ # grad_list = [grad for (_, grad) in raw_client_grad_list] client_scores = self.compute_client_scores(raw_client_grad_list) if self.iteration_num < self.pretraining_round_number: @@ -96,8 +140,6 @@ def defend_before_aggregation( raw_client_grad_list.pop(i) print(f"pop -- i = {i}") - - batch_grad_list = Bucket.bucketization( raw_client_grad_list, self.bucketing_batch_size ) @@ -120,6 +162,12 @@ def defend_before_aggregation( # return avg_params def compute_gaussian_distribution(self): + """ + Compute the Gaussian distribution parameters. + + Returns: + Tuple[float, float]: Mean (mu) and standard deviation (sigma). + """ n = len(self.score_list) mu = sum(list(self.score_list)) / n temp = 0 @@ -131,8 +179,18 @@ def compute_gaussian_distribution(self): return mu, sigma def compute_client_scores(self, raw_client_grad_list): + """ + Compute client scores. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of client scores. + """ if self.score_function == "foolsgold": - importance_feature_list = self._get_importance_feature(raw_client_grad_list) + importance_feature_list = self._get_importance_feature( + raw_client_grad_list) if self.memory is None: self.memory = importance_feature_list else: # memory: potential bugs: grads in different iterations may be from different clients @@ -141,6 +199,15 @@ def compute_client_scores(self, raw_client_grad_list): return self.fools_gold_score(self.memory) def _get_importance_feature(self, raw_client_grad_list): + """ + Get importance features for Fool's Gold score computation. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of importance features. + """ # print(f"raw_client_grad_list = {raw_client_grad_list}") # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] @@ -162,6 +229,15 @@ def _get_importance_feature(self, raw_client_grad_list): @staticmethod def fools_gold_score(feature_vec_list): + """ + Compute Fool's Gold scores. + + Args: + feature_vec_list: List of importance features. + + Returns: + List[float]: List of Fool's Gold scores. + """ n_clients = len(feature_vec_list) cs = np.zeros((n_clients, n_clients)) for i in range(n_clients): @@ -183,7 +259,7 @@ def fools_gold_score(feature_vec_list): alpha[alpha <= 0.0] = 1e-15 # Rescale so that max value is alpha - # print(np.max(alpha)) + # print(np.max(alpha)) alpha = alpha / np.max(alpha) alpha[(alpha == 1.0)] = 0.999999 diff --git a/python/fedml/core/security/defense/three_sigma_geomedian_defense.py b/python/fedml/core/security/defense/three_sigma_geomedian_defense.py index 73d4ac9a05..9c78d19646 100644 --- a/python/fedml/core/security/defense/three_sigma_geomedian_defense.py +++ b/python/fedml/core/security/defense/three_sigma_geomedian_defense.py @@ -9,7 +9,42 @@ class ThreeSigmaGeoMedianDefense(BaseDefenseMethod): + """ + Three-Sigma Defense with Geometric Median for Federated Learning. + + This defense method performs a Three-Sigma-based defense with geometric median for federated learning. + + Args: + config: Configuration object for defense parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform defense before aggregation. + + compute_gaussian_distribution() -> Tuple[float, float]: + Compute the Gaussian distribution parameters. + + compute_client_scores(raw_client_grad_list) -> List[float]: + Compute client scores. + + fools_gold_score(feature_vec_list) -> List[float]: + Compute Fool's Gold scores. + + l2_scores(importance_feature_list) -> List[float]: + Compute L2 scores. + + """ + def __init__(self, config): + """ + Initialize the ThreeSigmaGeoMedianDefense. + + Args: + config: Configuration object for defense parameters. + """ self.memory = None self.iteration_num = 1 self.score_list = [] @@ -39,6 +74,18 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: Gradient list after defense. + """ # grad_list = [grad for (_, grad) in raw_client_grad_list] client_scores = self.compute_client_scores(raw_client_grad_list) print(f"client scores = {client_scores}") @@ -64,6 +111,13 @@ def defend_before_aggregation( return raw_client_grad_list def compute_gaussian_distribution(self): + """ + Compute the Gaussian distribution parameters. + + Returns: + Tuple[float, float]: Mean (mu) and standard deviation (sigma). + """ + n = len(self.score_list) mu = sum(list(self.score_list)) / n temp = 0 @@ -75,7 +129,17 @@ def compute_gaussian_distribution(self): return mu, sigma def compute_client_scores(self, raw_client_grad_list): - importance_feature_list = self._get_importance_feature(raw_client_grad_list) + """ + Compute client scores. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of client scores. + """ + importance_feature_list = self._get_importance_feature( + raw_client_grad_list) if self.score_function == "foolsgold": if self.memory is None: self.memory = importance_feature_list @@ -88,19 +152,39 @@ def compute_client_scores(self, raw_client_grad_list): # (num0, avg_params) = raw_client_grad_list[0] # alphas = {alpha for (alpha, params) in raw_client_grad_list} # alphas = {alpha / sum(alphas, 0.0) for alpha in alphas} - alphas = [1/len(raw_client_grad_list)] * len(raw_client_grad_list) - self.geo_median = compute_geometric_median(alphas, importance_feature_list) + alphas = [1/len(raw_client_grad_list)] * \ + len(raw_client_grad_list) + self.geo_median = compute_geometric_median( + alphas, importance_feature_list) return self.l2_scores(importance_feature_list) def l2_scores(self, importance_feature_list): + """ + Compute L2 scores. + + Args: + importance_feature_list: List of importance features. + + Returns: + List[float]: List of L2 scores. + """ scores = [] for feature in importance_feature_list: - score = compute_euclidean_distance(torch.Tensor(feature), self.geo_median) + score = compute_euclidean_distance( + torch.Tensor(feature), self.geo_median) scores.append(score) return scores - def _get_importance_feature(self, raw_client_grad_list): + """ + Get importance features for score computation. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of importance features. + """ # print(f"raw_client_grad_list = {raw_client_grad_list}") # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] @@ -122,6 +206,15 @@ def _get_importance_feature(self, raw_client_grad_list): @staticmethod def fools_gold_score(feature_vec_list): + """ + Compute Fool's Gold scores. + + Args: + feature_vec_list: List of importance features. + + Returns: + List[float]: List of Fool's Gold scores. + """ n_clients = len(feature_vec_list) cs = np.zeros((n_clients, n_clients)) for i in range(n_clients): @@ -143,7 +236,7 @@ def fools_gold_score(feature_vec_list): alpha[alpha <= 0.0] = 1e-15 # Rescale so that max value is alpha - # print(np.max(alpha)) + # print(np.max(alpha)) alpha = alpha / np.max(alpha) alpha[(alpha == 1.0)] = 0.999999 @@ -154,4 +247,4 @@ def fools_gold_score(feature_vec_list): print("alpha = {}".format(alpha)) - return alpha \ No newline at end of file + return alpha diff --git a/python/fedml/core/security/defense/three_sigma_krum_defense.py b/python/fedml/core/security/defense/three_sigma_krum_defense.py index 565aa0f962..8d476ad45e 100644 --- a/python/fedml/core/security/defense/three_sigma_krum_defense.py +++ b/python/fedml/core/security/defense/three_sigma_krum_defense.py @@ -14,7 +14,52 @@ class ThreeSigmaKrumDefense(BaseDefenseMethod): + """ + Three-Sigma Defense with Krum-based Malicious Client Detection for Federated Learning. + + This defense method performs a Three-Sigma-based defense with Krum-based malicious client detection for federated learning. + + Args: + config: Configuration object for defense parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform defense before aggregation. + + kick_out_poisoned_local_models( + client_scores: List[float], + raw_client_grad_list: List[Tuple[float, OrderedDict]] + ) -> Tuple[List[Tuple[float, OrderedDict]], List[float]]: + Remove poisoned local models based on client scores. + + get_malicious_client_idxs() -> List[int]: + Get indices of detected malicious clients. + + set_potential_malicious_clients(potential_malicious_client_idxs: List[int]): + Set potential malicious client indices. + + compute_avg_with_krum(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Compute an average feature with Krum-based malicious client detection. + + compute_l2_scores(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Compute L2 scores for client models. + + compute_client_cosine_scores(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Compute cosine similarity scores between client models. + + _get_importance_feature(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Get importance features from raw client gradients. + """ def __init__(self, config): + """ + Initialize the ThreeSigmaKrumDefense. + + Args: + config: Configuration object for defense parameters. + """ self.average = None self.upper_bound = 0 self.malicious_client_idxs = [] @@ -31,6 +76,18 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: Gradient list after defense. + """ if self.average is None: self.average = self.compute_avg_with_krum(raw_client_grad_list) client_scores = self.compute_l2_scores(raw_client_grad_list) @@ -46,6 +103,19 @@ def defend_before_aggregation( return new_client_models def compute_an_average_feature(self, importance_feature_list): + """ + Remove poisoned local models based on client scores. + + Args: + client_scores (List[float]): List of client scores. + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + Tuple[List[Tuple[float, OrderedDict]], List[float]]: + Tuple containing gradient list after removing poisoned models + and updated client scores. + """ alphas = [1 / len(importance_feature_list)] * len(importance_feature_list) return compute_middle_point(alphas, importance_feature_list) @@ -99,6 +169,13 @@ def compute_an_average_feature(self, importance_feature_list): # return raw_client_grad_list def kick_out_poisoned_local_models(self, client_scores, raw_client_grad_list): + """ + Get indices of detected malicious clients. + + Returns: + List[int]: List of indices of malicious clients. + """ + print(f"upper bound = {self.upper_bound}") # traverse the score list in a reversed order self.malicious_client_idxs = [] @@ -112,12 +189,38 @@ def kick_out_poisoned_local_models(self, client_scores, raw_client_grad_list): return raw_client_grad_list, client_scores def get_malicious_client_idxs(self): + """ + Set potential malicious client indices. + + Args: + potential_malicious_client_idxs: List of potential malicious client indices. + """ return self.malicious_client_idxs def set_potential_malicious_clients(self, potential_malicious_client_idxs): + """ + Compute an average feature with Krum-based malicious client detection. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List representing an average feature. + """ self.potential_malicious_client_idxs = None # potential_malicious_client_idxs todo def compute_avg_with_krum(self, raw_client_grad_list): + """ + Compute L2 scores for client models. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of L2 scores. + """ importance_feature_list = self._get_importance_feature(raw_client_grad_list) krum_scores = compute_krum_score( importance_feature_list, @@ -133,6 +236,16 @@ def compute_avg_with_krum(self, raw_client_grad_list): return self.compute_an_average_feature(honest_importance_feature_list) def compute_l2_scores(self, raw_client_grad_list): + """ + Compute L2 scores for client models. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of L2 scores. + """ importance_feature_list = self._get_importance_feature(raw_client_grad_list) scores = [] for feature in importance_feature_list: @@ -141,6 +254,16 @@ def compute_l2_scores(self, raw_client_grad_list): return scores def compute_client_cosine_scores(self, raw_client_grad_list): + """ + Compute cosine similarity scores between client models. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of cosine similarity scores. + """ importance_feature_list = self._get_importance_feature(raw_client_grad_list) cosine_scores = [] num_client = len(importance_feature_list) @@ -158,6 +281,16 @@ def compute_client_cosine_scores(self, raw_client_grad_list): return cosine_scores def _get_importance_feature(self, raw_client_grad_list): + """ + Get importance features from raw client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of importance feature vectors. + """ # print(f"raw_client_grad_list = {raw_client_grad_list}") # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] diff --git a/python/fedml/core/security/defense/wbc_defense.py b/python/fedml/core/security/defense/wbc_defense.py index e65dc4d597..e804a1f8ba 100644 --- a/python/fedml/core/security/defense/wbc_defense.py +++ b/python/fedml/core/security/defense/wbc_defense.py @@ -23,7 +23,36 @@ class WbcDefense(BaseDefenseMethod): + """ + Weight-Based Client Defense for Federated Learning. + + This defense method performs weight-based client defense for federated learning. + + Args: + args: Argument object containing client and batch indices. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> Dict: + Run the weight-based client defense. + + Attributes: + args: Argument object containing client and batch indices. + client_idx: Index of the client. + batch_idx: Index of the batch. + old_gradient: Dictionary to store old gradients for weight perturbation. + """ + def __init__(self, args): + """ + Initialize the WbcDefense. + + Args: + args: Argument object containing client and batch indices. + """ self.args = args self.client_idx = args.client_idx self.batch_idx = args.batch_idx @@ -35,6 +64,20 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> Dict: + """ + Run the weight-based client defense. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + Dict: Dictionary containing aggregated model parameters. + """ num_client = len(raw_client_grad_list) vec_local_w = [ ( @@ -53,7 +96,8 @@ def run( for (k, v) in model_param.items(): if "weight" in k: grad_tensor = ( - raw_client_grad_list[self.client_idx][1][k].cpu().numpy() + raw_client_grad_list[self.client_idx][1][k].cpu( + ).numpy() ) # for testing, simply pre-defin old gradient self.old_gradient[k] = grad_tensor * 0.2 @@ -67,7 +111,8 @@ def run( ) learning_rate = 0.1 new_model_param[k] = torch.from_numpy( - model_param[k].cpu().numpy() + pertubation * learning_rate + model_param[k].cpu().numpy() + + pertubation * learning_rate ) else: new_model_param[k] = model_param[k] @@ -82,7 +127,8 @@ def run( if i != self.client_idx or self.batch_idx == 0: param_list.append(models_param[i]) else: - param_list.append((models_param[self.client_idx][0], new_model_param)) + param_list.append( + (models_param[self.client_idx][0], new_model_param)) logging.info(f"New. param: {param_list[i]}") return base_aggregation_func(self.args, param_list) # avg_params diff --git a/python/fedml/core/security/defense/weak_dp_defense.py b/python/fedml/core/security/defense/weak_dp_defense.py index 06e465c3f4..25c3372c5e 100644 --- a/python/fedml/core/security/defense/weak_dp_defense.py +++ b/python/fedml/core/security/defense/weak_dp_defense.py @@ -9,6 +9,27 @@ class WeakDPDefense(BaseDefenseMethod): + """ + Weak Differential Privacy (DP) Defense for Federated Learning. + + This defense method adds weak differential privacy noise to client gradients to enhance privacy. + + Args: + config: Configuration object containing defense parameters. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> Dict: + Run the weak DP defense. + + Attributes: + config: Configuration object containing defense parameters. + stddev: Standard deviation for adding noise to gradients. + """ + def __init__(self, config): self.config = config self.stddev = config.stddev # for weak DP defenses @@ -19,6 +40,20 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> Dict: + """ + Run the weak DP defense by adding noise to client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + Dict: Dictionary containing aggregated model parameters with added noise. + """ new_grad_list = [] for (sample_num, local_w) in raw_client_grad_list: new_w = self._add_noise(local_w) @@ -26,6 +61,15 @@ def run( return base_aggregation_func(self.config, new_grad_list) # avg_params def _add_noise(self, param): + """ + Add Gaussian noise to the parameters. + + Args: + param (OrderedDict): Client parameters. + + Returns: + OrderedDict: Parameters with added noise. + """ dp_param = dict() for k in param.keys(): dp_param[k] = param[k] + torch.randn(param[k].size()) * self.stddev diff --git a/python/fedml/core/security/fedml_attacker.py b/python/fedml/core/security/fedml_attacker.py index 6264ccfb1c..34739231bc 100644 --- a/python/fedml/core/security/fedml_attacker.py +++ b/python/fedml/core/security/fedml_attacker.py @@ -11,6 +11,18 @@ class FedMLAttacker: + """ + Represents an attacker in a federated learning system. + + The `FedMLAttacker` class is responsible for managing different types of attacks, including model poisoning, data poisoning, + and data reconstruction attacks, within a federated learning setting. + + Attributes: + _attacker_instance (FedMLAttacker): A singleton instance of the `FedMLAttacker` class. + is_enabled (bool): Whether the attacker is enabled. + attack_type (str): The type of attack being used. + attacker (Any): The specific attacker object. + """ _attacker_instance = None @staticmethod @@ -21,11 +33,31 @@ def get_instance(): return FedMLAttacker._attacker_instance def __init__(self): + """ + Initialize a FedMLAttacker instance. + + This constructor sets up the attacker instance and initializes its properties. + + Attributes: + is_enabled (bool): Whether the attacker is enabled. + attack_type (str): The type of attack being used. + attacker (Any): The specific attacker object. + + """ self.is_enabled = False self.attack_type = None self.attacker = None def init(self, args): + """ + Initialize the attacker with provided arguments. + + This method initializes the attacker based on the provided arguments. + + Args: + args: The arguments used to configure the attacker. + + """ if hasattr(args, "enable_attack") and args.enable_attack: logging.info("------init attack..." + args.attack_type.strip()) self.is_enabled = True @@ -56,13 +88,35 @@ def init(self, args): self.is_enabled = False def is_attack_enabled(self): + """ + Check if the attacker is enabled. + + Returns: + bool: True if the attacker is enabled, False otherwise. + + """ return self.is_enabled def get_attack_types(self): + """ + Get the type of attack. + + Returns: + str: The type of attack being used. + + """ return self.attack_type # --------------- for model poisoning attacks --------------- # def is_model_attack(self): + """ + Check if the attack is a model poisoning attack. + + Returns: + bool: True if it's a model poisoning attack, False otherwise. + + """ + if self.is_attack_enabled() and self.attack_type in [ ATTACK_METHOD_BYZANTINE_ATTACK, BACKDOOR_ATTACK_MODEL_REPLACEMENT ]: @@ -70,6 +124,23 @@ def is_model_attack(self): return False def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Attack the model with poisoned gradients. + + This method is used for model poisoning attacks. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the attack. + + Returns: + Any: The poisoned client gradients. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") return self.attacker.attack_model(raw_client_grad_list, extra_auxiliary_info) @@ -77,16 +148,48 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], ex # --------------- for data poisoning attacks --------------- # def is_data_poisoning_attack(self): + """ + Check if the attack is a data poisoning attack. + + Returns: + bool: True if it's a data poisoning attack, False otherwise. + + """ if self.is_attack_enabled() and self.attack_type in [ATTACK_LABEL_FLIPPING]: return True return False def is_to_poison_data(self): + """ + Check if data should be poisoned. + + Returns: + bool: True if data should be poisoned, False otherwise. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") return self.attacker.is_to_poison_data() def poison_data(self, dataset): + """ + Poison the dataset. + + This method is used for data poisoning attacks. + + Args: + dataset: The dataset to be poisoned. + + Returns: + Any: The poisoned dataset. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") return self.attacker.poison_data(dataset) @@ -94,12 +197,34 @@ def poison_data(self, dataset): # --------------- for data reconstructing attacks --------------- # def is_data_reconstruction_attack(self): + """ + Check if the attack is a data reconstruction attack. + + Returns: + bool: True if it's a data reconstruction attack, False otherwise. + + """ if self.is_attack_enabled() and self.attack_type in [ATTACK_METHOD_DLG]: return True return False def reconstruct_data(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Reconstruct the data from gradients. + + This method is used for data reconstruction attacks. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the attack. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") - self.attacker.reconstruct_data(raw_client_grad_list, extra_auxiliary_info=extra_auxiliary_info) - # --------------- for data reconstructing attacks --------------- # \ No newline at end of file + self.attacker.reconstruct_data( + raw_client_grad_list, extra_auxiliary_info=extra_auxiliary_info) + # --------------- for data reconstructing attacks --------------- # diff --git a/python/fedml/core/security/fedml_defender.py b/python/fedml/core/security/fedml_defender.py index d88f8dfcd9..56b688317c 100644 --- a/python/fedml/core/security/fedml_defender.py +++ b/python/fedml/core/security/fedml_defender.py @@ -38,21 +38,64 @@ class FedMLDefender: + """ + A class for managing defense mechanisms in federated learning. + + This class handles the configuration and execution of defense mechanisms to enhance the robustness + of federated learning against adversarial attacks. + + Methods: + get_instance: Get an instance of the FedMLDefender class. + init: Initialize the defense mechanism based on configuration. + is_defense_enabled: Check if defense mechanisms are enabled. + defend: Defend against adversarial attacks on client gradients. + is_defense_on_aggregation: Check if defense occurs during aggregation. + is_defense_before_aggregation: Check if defense occurs before aggregation. + is_defense_after_aggregation: Check if defense occurs after aggregation. + defend_before_aggregation: Apply defense before gradient aggregation. + defend_on_aggregation: Apply defense during gradient aggregation. + defend_after_aggregation: Apply defense after gradient aggregation. + get_malicious_client_idxs: Get the indices of malicious clients. + get_benign_client_idxs: Get the indices of benign clients. + + Attributes: + None + """ + _defender_instance = None @staticmethod def get_instance(): + """ + Get an instance of the FedMLDefender class. + + Returns: + FedMLDefender: An instance of the FedMLDefender class. + """ + if FedMLDefender._defender_instance is None: FedMLDefender._defender_instance = FedMLDefender() return FedMLDefender._defender_instance def __init__(self): + """ + Initialize a FedMLDefender instance. + """ self.is_enabled = False self.defense_type = None self.defender = None def init(self, args): + """ + Initialize the defense mechanism based on configuration. + + Args: + args: The command-line arguments. + + Raises: + Exception: If the defense mechanism type is not defined. + """ if hasattr(args, "enable_defense") and args.enable_defense: self.args = args logging.info("------init defense..." + args.defense_type) @@ -114,6 +157,12 @@ def init(self, args): self.is_enabled = False def is_defense_enabled(self): + """ + Check if defense mechanisms are enabled. + + Returns: + bool: True if defense is enabled, False otherwise. + """ return self.is_enabled def defend( @@ -122,6 +171,21 @@ def defend( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Defend against adversarial attacks on client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + base_aggregation_func (Callable, optional): The base aggregation function for gradient aggregation. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the defense mechanism. + + Returns: + Any: The defended client gradients or the result of the aggregation function. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") return self.defender.run( @@ -129,9 +193,22 @@ def defend( ) def is_defense_on_aggregation(self): + """ + Check if defense occurs during gradient aggregation. + + Returns: + bool: True if defense occurs during aggregation, False otherwise. + """ return self.is_defense_enabled() and self.defense_type in [DEFENSE_SLSGD, DEFENSE_RFA, DEFENSE_WISE_MEDIAN, DEFENSE_GEO_MEDIAN] def is_defense_before_aggregation(self): + """ + Check if defense occurs before gradient aggregation. + + Returns: + bool: True if defense occurs before aggregation, False otherwise. + """ + return self.is_defense_enabled() and self.defense_type in [ DEFENSE_SLSGD, DEFENSE_FOOLSGOLD, @@ -147,6 +224,13 @@ def is_defense_before_aggregation(self): ] def is_defense_after_aggregation(self): + """ + Check if defense occurs after gradient aggregation. + + Returns: + bool: True if defense occurs after aggregation, False otherwise. + """ + return self.is_defense_enabled() and self.defense_type in [DEFENSE_CRFL, DEFENSE_CCLIP] def defend_before_aggregation( @@ -154,6 +238,20 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply defense before gradient aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the defense mechanism. + + Returns: + List[Tuple[float, OrderedDict]]: The defended client gradients. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") if self.is_defense_before_aggregation(): @@ -168,6 +266,21 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Apply defense during gradient aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + base_aggregation_func (Callable, optional): The base aggregation function for gradient aggregation. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the defense mechanism. + + Returns: + Any: The defended client gradients or the result of the aggregation function. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") if self.is_defense_on_aggregation(): @@ -177,6 +290,18 @@ def defend_on_aggregation( return base_aggregation_func(args=self.args, raw_grad_list=raw_client_grad_list) def defend_after_aggregation(self, global_model): + """ + Apply defense after gradient aggregation. + + Args: + global_model: The global model after gradient aggregation. + + Returns: + Any: The defended global model or its equivalent. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") if self.is_defense_after_aggregation(): @@ -184,7 +309,26 @@ def defend_after_aggregation(self, global_model): return global_model def get_malicious_client_idxs(self): + """ + Get the indices of malicious clients. + + Returns: + List[int]: A list of indices corresponding to malicious clients. + """ + return self.defender.get_malicious_client_idxs() def get_benign_client_idxs(self, client_idxs): + """ + Get the indices of benign clients from a list of client indices. + + Args: + client_idxs (List[int]): A list of client indices. + + Returns: + List[int]: A list of indices corresponding to benign clients. + + Notes: + This method assumes that malicious clients have been identified using defense mechanisms. + """ return [i for i in client_idxs if i not in self.defender.get_malicious_client_idxs()] diff --git a/python/fedml/cross_device/mnn_server.py b/python/fedml/cross_device/mnn_server.py index 3502929944..599c72978a 100644 --- a/python/fedml/cross_device/mnn_server.py +++ b/python/fedml/cross_device/mnn_server.py @@ -4,15 +4,50 @@ class ServerMNN: + """ + A class representing the server in federated learning using MNN (Mobile Neural Networks). + + This class is responsible for coordinating and aggregating model updates from client devices. + + Args: + args: The command-line arguments. + device: The device for computations. + test_dataloader: The DataLoader for testing data. + model: The federated learning model. + server_aggregator: The server aggregator (optional). + + Attributes: + None + + Methods: + run: Run the server for federated learning. + """ + def __init__(self, args, device, test_dataloader, model, server_aggregator=None): + """ + Initialize a ServerMNN instance. + + Args: + args: The command-line arguments. + device: The device for computations. + test_dataloader: The DataLoader for testing data. + model: The federated learning model. + server_aggregator: The server aggregator (optional). + """ if args.federated_optimizer == "FedAvg": - logging.info("test_data_global.iter_number = {}".format(test_dataloader.iter_number)) + logging.info("test_data_global.iter_number = {}".format( + test_dataloader.iter_number)) fedavg_cross_device( args, 0, args.worker_num, None, device, test_dataloader, model, server_aggregator=server_aggregator ) else: - raise Exception("Exception") + raise Exception("Unsupported federated optimizer") def run(self): + """ + Run the server for federated learning. + + This method coordinates and aggregates model updates from client devices. + """ pass diff --git a/python/fedml/cross_silo/client/client_initializer.py b/python/fedml/cross_silo/client/client_initializer.py index 54fd865710..0b6503b0d7 100644 --- a/python/fedml/cross_silo/client/client_initializer.py +++ b/python/fedml/cross_silo/client/client_initializer.py @@ -20,6 +20,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize and run a federated learning client. + + Args: + args: The command-line arguments. + device: The device to perform computations on. + comm: The communication backend. + client_rank: The rank of the client. + client_num: The total number of clients. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + + Returns: + None + """ backend = args.backend trainer_dist_adapter = get_trainer_dist_adapter( @@ -36,8 +55,8 @@ def init_client( if ( args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL or ( - args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL and - getattr(args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) + args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL and + getattr(args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) ) ): if args.proc_rank_in_silo == 0: @@ -46,13 +65,16 @@ def init_client( ) else: - client_manager = get_client_manager_salve(args, trainer_dist_adapter) + client_manager = get_client_manager_salve( + args, trainer_dist_adapter) elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL: - client_manager = get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend) + client_manager = get_client_manager_master( + args, trainer_dist_adapter, comm, client_rank, client_num, backend) else: - raise RuntimeError("we do not support {}. Please check whether this is typo.".format(args.scenario)) + raise RuntimeError( + "we do not support {}. Please check whether this is typo.".format(args.scenario)) client_manager.run() @@ -68,6 +90,23 @@ def get_trainer_dist_adapter( test_data_local_dict, model_trainer, ): + """ + Get a trainer distributed adapter. + + Args: + args: The command-line arguments. + device: The device to perform computations on. + client_rank: The rank of the client. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + + Returns: + TrainerDistAdapter: The trainer distributed adapter. + """ return TrainerDistAdapter( args, device, @@ -82,10 +121,34 @@ def get_trainer_dist_adapter( def get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend): + """ + Get the federated learning client manager for the master. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + comm: The communication backend. + client_rank: The rank of the client. + client_num: The total number of clients. + backend: The communication backend. + + Returns: + ClientMasterManager: The federated learning client manager for the master. + """ return ClientMasterManager(args, trainer_dist_adapter, comm, client_rank, client_num, backend) def get_client_manager_salve(args, trainer_dist_adapter): + """ + Get the federated learning client manager for a slave. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + + Returns: + ClientSlaveManager: The federated learning client manager for a slave. + """ from .fedml_client_slave_manager import ClientSlaveManager return ClientSlaveManager(args, trainer_dist_adapter) diff --git a/python/fedml/cross_silo/client/client_launcher.py b/python/fedml/cross_silo/client/client_launcher.py index 1a4831b11e..76ff8ee703 100644 --- a/python/fedml/cross_silo/client/client_launcher.py +++ b/python/fedml/cross_silo/client/client_launcher.py @@ -27,25 +27,73 @@ class CrossSiloLauncher: @staticmethod def launch_dist_trainers(torch_client_filename, inputs): + """ + Launch distributed trainers for cross-silo federated learning. + + Args: + torch_client_filename (str): The filename of the torch client script to run. + inputs (list): A list of input arguments to pass to the torch client script. + + Returns: + None + """ # this is only used by the client (DDP or single process), so there is no need to specify the backend. args = load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - CrossSiloLauncher._run_cross_silo_hierarchical(args, torch_client_filename, inputs) + CrossSiloLauncher._run_cross_silo_hierarchical( + args, torch_client_filename, inputs) elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL: - CrossSiloLauncher._run_cross_silo_horizontal(args, torch_client_filename, inputs) + CrossSiloLauncher._run_cross_silo_horizontal( + args, torch_client_filename, inputs) else: - raise Exception("we do not support {}, check whether this is typo in args.scenario".format(args.scenario)) + raise Exception( + "we do not support {}, check whether this is typo in args.scenario".format(args.scenario)) @staticmethod def _run_cross_silo_horizontal(args, torch_client_filename, inputs): - python_path = subprocess.run(["which", "python"], capture_output=True, text=True).stdout.strip() + """ + Run cross-silo federated learning in horizontal scenario. + + Args: + args: The command-line arguments. + torch_client_filename (str): The filename of the torch client script to run. + inputs (list): A list of input arguments to pass to the torch client script. + + Returns: + None + """ + + python_path = subprocess.run( + ["which", "python"], capture_output=True, text=True).stdout.strip() process_arguments = [python_path, torch_client_filename] + inputs subprocess.run(process_arguments) @staticmethod def _run_cross_silo_hierarchical(args, torch_client_filename, inputs): + """ + Run cross-silo federated learning in hierarchical scenario. + + Args: + args: The command-line arguments. + torch_client_filename (str): The filename of the torch client script to run. + inputs (list): A list of input arguments to pass to the torch client script. + + Returns: + None + """ + def get_torchrun_arguments(node_rank): - torchrun_path = subprocess.run(["which", "torchrun"], capture_output=True, text=True).stdout.strip() + """ + Get the torchrun command arguments for launching on each node. + + Args: + node_rank (int): The rank of the current node. + + Returns: + list: List of command arguments for torchrun. + """ + torchrun_path = subprocess.run( + ["which", "torchrun"], capture_output=True, text=True).stdout.strip() return [ torchrun_path, @@ -58,8 +106,10 @@ def get_torchrun_arguments(node_rank): torch_client_filename, ] + inputs - network_interface = None if not hasattr(args, "network_interface") else args.network_interface - print(f"Using network interface {network_interface} for process group and TRPC communication") + network_interface = None if not hasattr( + args, "network_interface") else args.network_interface + print( + f"Using network interface {network_interface} for process group and TRPC communication") env_variables = { "OMP_NUM_THREADS": "4", } @@ -78,7 +128,8 @@ def get_torchrun_arguments(node_rank): device_type = get_device_type(args) if torch.cuda.is_available() and device_type == "gpu": gpu_count = torch.cuda.device_count() - print(f"Using number of GPUs ({gpu_count}) as number of processeses.") + print( + f"Using number of GPUs ({gpu_count}) as number of processeses.") args.n_proc_per_node = gpu_count else: print(f"Using number 1 as number of processeses.") @@ -95,7 +146,8 @@ def get_torchrun_arguments(node_rank): else: print(f"Automatic Client Launcher") - which_pdsh = subprocess.run(["which", "pdsh"], capture_output=True, text=True).stdout.strip() + which_pdsh = subprocess.run( + ["which", "pdsh"], capture_output=True, text=True).stdout.strip() if not which_pdsh: raise Exception( diff --git a/python/fedml/cross_silo/client/fedml_client_master_manager.py b/python/fedml/cross_silo/client/fedml_client_master_manager.py index 938ce2d152..51da33f9e5 100644 --- a/python/fedml/cross_silo/client/fedml_client_master_manager.py +++ b/python/fedml/cross_silo/client/fedml_client_master_manager.py @@ -24,6 +24,17 @@ class ClientMasterManager(FedMLCommManager): RUN_FINISHED_STATUS_FLAG = "FINISHED" def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the ClientMasterManager. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + comm: The communication backend. + rank: The rank of the client. + size: The total number of clients. + backend: The communication backend (default is "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer_dist_adapter = trainer_dist_adapter self.args = args @@ -50,21 +61,42 @@ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backen @property def use_customized_hierarchical(self) -> bool: + """ + Check if customized hierarchical cross-silo is enabled. + + Returns: + bool: True if customized hierarchical is enabled, False otherwise. + """ return getattr(self.args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) @property def has_customized_sync_process_group(self) -> bool: + """ + Check if a customized sync process group method is available in the trainer. + + Returns: + bool: True if a customized sync process group method is available, False otherwise. + """ return check_method_override( cls_obj=self.trainer_dist_adapter.trainer.trainer, method_name="sync_process_group" ) def is_main_process(self): + """ + Check if the current process is the main process. + + Returns: + bool: True if the current process is the main process, False otherwise. + """ return getattr(self.trainer_dist_adapter, "trainer", None) is None or \ getattr(self.trainer_dist_adapter.trainer, "trainer", None) is None or \ self.trainer_dist_adapter.trainer.trainer.is_main_process() def register_message_receive_handlers(self): + """ + Register message receive handlers for various message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -73,7 +105,8 @@ def register_message_receive_handlers(self): MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status ) - self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) + self.register_message_receive_handler( + MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.handle_message_receive_model_from_server, ) @@ -83,6 +116,12 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection-ready message. + + Args: + msg_params (dict): Parameters of the message. + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -90,15 +129,28 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the check-client-status message. + + Args: + msg_params (dict): Parameters of the message. + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params (dict): Parameters of the message. + """ if self.is_inited: return self.is_inited = True - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + global_model_params = msg_params.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS) data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) logging.info("data_silo_index = %s" % str(data_silo_index)) @@ -107,10 +159,12 @@ def handle_message_init(self, msg_params): self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING) if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - global_model_params = convert_model_params_to_ddp(global_model_params) + global_model_params = convert_model_params_to_ddp( + global_model_params) self.sync_process_group(0, global_model_params, data_silo_index) elif self.use_customized_hierarchical: - self.customized_sync_process_group(0, global_model_params, data_silo_index) + self.customized_sync_process_group( + 0, global_model_params, data_silo_index) self.trainer_dist_adapter.update_dataset(int(data_silo_index)) self.trainer_dist_adapter.update_model(global_model_params) @@ -120,6 +174,12 @@ def handle_message_init(self, msg_params): self.round_idx += 1 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Parameters of the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -128,50 +188,87 @@ def handle_message_receive_model_from_server(self, msg_params): model_params = convert_model_params_to_ddp(model_params) self.sync_process_group(self.round_idx, model_params, client_index) elif self.use_customized_hierarchical: - self.customized_sync_process_group(self.round_idx, model_params, client_index) + self.customized_sync_process_group( + self.round_idx, model_params, client_index) self.trainer_dist_adapter.update_dataset(int(client_index)) - logging.info("current round index {}, total rounds {}".format(self.round_idx, self.num_rounds)) + logging.info("current round index {}, total rounds {}".format( + self.round_idx, self.num_rounds)) self.trainer_dist_adapter.update_model(model_params) if self.round_idx < self.num_rounds: self.__train() self.round_idx += 1 else: mlops.stop_sys_perf() - self.send_client_status(0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) + self.send_client_status( + 0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) if self.is_main_process(): mlops.log_training_finished_status() self.finish() def handle_message_finish(self, msg_params): + """ + Handle the finish message. + + Args: + msg_params (dict): Parameters of the message. + """ logging.info(" ====================cleanup ====================") self.cleanup() def cleanup(self): - self.send_client_status(0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) + self.send_client_status( + 0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) if self.is_main_process(): mlops.log_training_finished_status() self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model to the server. + + Args: + receive_id: The ID of the entity receiving the model. + weights: The model weights to send. + local_sample_num: The number of local training samples. + + Returns: + None + """ if self.is_main_process(): tick = time.time() - mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) - message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id) + mlops.event("comm_c2s", event_started=True, + event_value=str(self.round_idx)) + message = Message( + MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) + message.add_params( + MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) self.send_message(message) - MLOpsProfilerEvent.log_to_wandb({"Communication/Send_Total": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Communication/Send_Total": time.time() - tick}) mlops.log_client_model_info( self.round_idx + 1, self.num_rounds, model_url=message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL), ) def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): + """ + Send the client status to another entity. + + Args: + receive_id: The ID of the entity receiving the status. + status (str): The status to send (default is "ONLINE"). + + Returns: + None + """ if self.is_main_process(): logging.info("send_client_status") - logging.info("self.client_real_id = {}".format(self.client_real_id)) - message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) + logging.info("self.client_real_id = {}".format( + self.client_real_id)) + message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, + self.client_real_id, receive_id) sys_name = platform.system() if sys_name == "Darwin": sys_name = "Mac" @@ -182,11 +279,21 @@ def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, sys_name) if getattr(self.args, "using_mlops", False) and status == ClientMasterManager.RUN_FINISHED_STATUS_FLAG: - mlops.log_server_payload(self.args.run_id, self.client_real_id, json.dumps(message.get_params())) + mlops.log_server_payload( + self.args.run_id, self.client_real_id, json.dumps(message.get_params())) else: self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status: The training status to report. + + Returns: + None + """ mlops.log_training_status(status) def sync_process_group( @@ -196,12 +303,25 @@ def sync_process_group( client_index: Optional[int] = None, src: int = 0 ) -> None: + """ + Synchronize the process group for hierarchical cross-silo scenarios. + + Args: + round_idx (int): The round index. + model_params: The model parameters. + client_index (int): The client index. + src (int): The source index. + + Returns: + None + """ logging.info("sending round number to pg") round_number = [round_idx, model_params, client_index] dist.broadcast_object_list( round_number, src=src, group=self.trainer_dist_adapter.process_group_manager.get_process_group(), ) - logging.info("round number %d broadcast to process group" % round_number[0]) + logging.info("round number %d broadcast to process group" % + round_number[0]) def customized_sync_process_group( self, @@ -210,6 +330,18 @@ def customized_sync_process_group( client_index: Optional[int] = None, src: int = 0 ) -> None: + """ + Synchronize the process group using a customized method for hierarchical cross-silo scenarios. + + Args: + round_idx (int): The round index. + model_params: The model parameters. + client_index (int): The client index. + src (int): The source index. + + Returns: + None + """ trainer = self.trainer_dist_adapter.trainer.trainer trainer_class_name = trainer.__class__.__name__ @@ -222,13 +354,23 @@ def customized_sync_process_group( trainer.sync_process_group(round_idx, model_params, client_index, src) def __train(self): - logging.info("#######training########### round_id = %d" % self.round_idx) + """ + Perform the training process. - mlops.event("train", event_started=True, event_value=str(self.round_idx)) + Returns: + None + """ + logging.info("#######training########### round_id = %d" % + self.round_idx) - weights, local_sample_num = self.trainer_dist_adapter.train(self.round_idx) + mlops.event("train", event_started=True, + event_value=str(self.round_idx)) - mlops.event("train", event_started=False, event_value=str(self.round_idx)) + weights, local_sample_num = self.trainer_dist_adapter.train( + self.round_idx) + + mlops.event("train", event_started=False, + event_value=str(self.round_idx)) # the current model is still DDP-wrapped under cross-silo-hi setting if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -237,4 +379,10 @@ def __train(self): self.send_model_to_server(0, weights, local_sample_num) def run(self): + """ + Run the client manager. + + Returns: + None + """ super().run() diff --git a/python/fedml/cross_silo/client/fedml_client_slave_manager.py b/python/fedml/cross_silo/client/fedml_client_slave_manager.py index a5320fed95..401f18672c 100644 --- a/python/fedml/cross_silo/client/fedml_client_slave_manager.py +++ b/python/fedml/cross_silo/client/fedml_client_slave_manager.py @@ -8,6 +8,13 @@ class ClientSlaveManager: def __init__(self, args, trainer_dist_adapter): + """ + Initialize a federated learning client manager for a slave. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + """ self.trainer_dist_adapter = trainer_dist_adapter self.args = args self.round_idx = 0 @@ -31,10 +38,22 @@ def __init__(self, args, trainer_dist_adapter): @property def use_customized_hierarchical(self) -> bool: + """ + Determine whether customized hierarchical cross-silo is enabled. + + Returns: + bool: True if customized hierarchical cross-silo is enabled, False otherwise. + """ return getattr(self.args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) @property def has_customized_await_sync_process_group(self) -> bool: + """ + Check if the trainer has a customized "await_sync_process_group" method. + + Returns: + bool: True if the method is overridden, False otherwise. + """ return check_method_override( cls_obj=self.trainer_dist_adapter.trainer.trainer, method_name="await_sync_process_group" @@ -42,12 +61,21 @@ def has_customized_await_sync_process_group(self) -> bool: @property def has_customized_cleanup_process_group(self) -> bool: + """ + Check if the trainer has a customized "cleanup_process_group" method. + + Returns: + bool: True if the method is overridden, False otherwise. + """ return check_method_override( cls_obj=self.trainer_dist_adapter.trainer.trainer, method_name="cleanup_process_group" ) def train(self): + """ + Perform a training round for the federated learning client. + """ if self.use_customized_hierarchical: [round_idx, model_params, client_index] = self.customized_await_sync_process_group() else: @@ -67,6 +95,9 @@ def train(self): self.trainer_dist_adapter.train(self.round_idx) def finish(self): + """ + Finish the federated learning client's training process. + """ if self.use_customized_hierarchical: self.customized_cleanup_process_group() else: @@ -78,6 +109,16 @@ def finish(self): self.finished = True def await_sync_process_group(self, src: int = 0) -> list: + """ + Await synchronization of the process group. + + Args: + src (int): The source rank for synchronization. + + Returns: + list: A list containing round number, model parameters, and client index. + """ + logging.info("process %d waiting for round number" % dist.get_rank()) objects = [None, None, None] dist.broadcast_object_list( @@ -87,6 +128,15 @@ def await_sync_process_group(self, src: int = 0) -> list: return objects def customized_await_sync_process_group(self, src: int = 0) -> list: + """ + Perform a customized await synchronization of the process group. + + Args: + src (int): The source rank for synchronization. + + Returns: + list: A list containing round number, model parameters, and client index. + """ trainer = self.trainer_dist_adapter.trainer.trainer trainer_class_name = trainer.__class__.__name__ @@ -99,10 +149,16 @@ def customized_await_sync_process_group(self, src: int = 0) -> list: return trainer.await_sync_process_group(src) def customized_cleanup_process_group(self) -> None: + """ + Perform a customized cleanup of the process group. + """ trainer = self.trainer_dist_adapter.trainer.trainer if self.has_customized_cleanup_process_group: trainer.cleanup_process_group() def run(self): + """ + Run the federated learning client manager. + """ while not self.finished: self.train() diff --git a/python/fedml/cross_silo/client/fedml_trainer.py b/python/fedml/cross_silo/client/fedml_trainer.py index 827644cc42..8244d26766 100755 --- a/python/fedml/cross_silo/client/fedml_trainer.py +++ b/python/fedml/cross_silo/client/fedml_trainer.py @@ -6,6 +6,41 @@ class FedMLTrainer(object): + """ + A class representing a Federated Machine Learning Trainer. + + This class manages the training process for federated learning on a client. + + Args: + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + train_data_num: The total number of training data samples. + device: The device for computations. + args: The command-line arguments. + model_trainer: The model trainer. + + Attributes: + trainer: The model trainer. + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + all_train_data_num: The total number of training data samples. + train_local: The local training data for the client. + local_sample_number: The number of local training data samples. + test_local: The local testing data for the client. + device: The device for computations. + args: The command-line arguments. + + Methods: + update_model: Update the federated learning model with new weights. + update_dataset: Update the local dataset for training. + train: Train the federated learning model for a specified round. + test: Test the federated learning model. + """ + def __init__( self, client_index, @@ -17,12 +52,26 @@ def __init__( args, model_trainer, ): + """ + Initialize a Federated Machine Learning Trainer. + + Args: + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + train_data_num: The total number of training data samples. + device: The device for computations. + args: The command-line arguments. + model_trainer: The model trainer. + """ self.trainer = model_trainer self.client_index = client_index if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - self.train_data_local_dict = split_data_for_dist_trainers(train_data_local_dict, args.n_proc_in_silo) + self.train_data_local_dict = split_data_for_dist_trainers( + train_data_local_dict, args.n_proc_in_silo) else: self.train_data_local_dict = train_data_local_dict @@ -38,9 +87,21 @@ def __init__( self.args.device = device def update_model(self, weights): + """ + Update the federated learning model with new weights. + + Args: + weights: The new model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local dataset for training. + + Args: + client_index: The index of the client. + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -61,24 +122,45 @@ def update_dataset(self, client_index): else: self.test_local = None - self.trainer.update_dataset(self.train_local, self.test_local, self.local_sample_number) + self.trainer.update_dataset( + self.train_local, self.test_local, self.local_sample_number) def train(self, round_idx=None): + """ + Train the federated learning model for a specified round. + + Args: + round_idx: The index of the training round (optional). + + Returns: + tuple: A tuple containing weights and the number of local training data samples. + """ self.args.round_idx = round_idx tick = time.time() - self.trainer.on_before_local_training(self.train_local, self.device, self.args) + self.trainer.on_before_local_training( + self.train_local, self.device, self.args) self.trainer.train(self.train_local, self.device, self.args) - self.trainer.on_after_local_training(self.train_local, self.device, self.args) + self.trainer.on_after_local_training( + self.train_local, self.device, self.args) - MLOpsProfilerEvent.log_to_wandb({"Train/Time": time.time() - tick, "round": round_idx}) + MLOpsProfilerEvent.log_to_wandb( + {"Train/Time": time.time() - tick, "round": round_idx}) weights = self.trainer.get_model_params() # transform Tensor to list return weights, self.local_sample_number def test(self): + """ + Test the federated learning model. + + Returns: + tuple: A tuple containing training accuracy, training loss, the number of training samples, + testing accuracy, testing loss, and the number of testing samples. + """ # train data - train_metrics = self.trainer.test(self.train_local, self.device, self.args) + train_metrics = self.trainer.test( + self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( train_metrics["test_correct"], train_metrics["test_total"], @@ -86,7 +168,8 @@ def test(self): ) # test data - test_metrics = self.trainer.test(self.test_local, self.device, self.args) + test_metrics = self.trainer.test( + self.test_local, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( test_metrics["test_correct"], test_metrics["test_total"], diff --git a/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py b/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py index 60383d31cf..27ae2d441b 100644 --- a/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py +++ b/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py @@ -7,6 +7,38 @@ class TrainerDistAdapter: + """ + A class representing a Trainer Distribution Adapter for federated learning. + + This adapter facilitates training a federated learning model with distributed computing support. + + Args: + args: The command-line arguments. + device: The device for computations. + client_rank: The rank of the client. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + + Attributes: + process_group_manager: The process group manager for distributed training. + client_index: The index of the client. + client_rank: The rank of the client. + device: The device for computations. + trainer: The federated learning trainer. + args: The command-line arguments. + + Methods: + get_trainer: Get the federated learning trainer. + train: Train the federated learning model for a round. + update_model: Update the federated learning model with new parameters. + update_dataset: Update the dataset for training. + cleanup_pg: Clean up the process group for distributed training. + """ + def __init__( self, args, @@ -19,11 +51,26 @@ def __init__( test_data_local_dict, model_trainer, ): + """ + Initialize a Trainer Distribution Adapter. + + Args: + args: The command-line arguments. + device: The device for computations. + client_rank: The rank of the client. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + """ ml_engine_adapter.model_to_device(args, model, device) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - self.process_group_manager, model = ml_engine_adapter.model_ddp(args, model, device) + self.process_group_manager, model = ml_engine_adapter.model_ddp( + args, model, device) if model_trainer is None: model_trainer = create_model_trainer(model, args) @@ -62,6 +109,22 @@ def get_trainer( args, model_trainer, ): + """ + Get the federated learning trainer. + + Args: + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + train_data_num: The total number of training data samples. + device: The device for computations. + args: The command-line arguments. + model_trainer: The model trainer. + + Returns: + FedMLTrainer: The federated learning trainer. + """ return FedMLTrainer( client_index, train_data_local_dict, @@ -74,17 +137,41 @@ def get_trainer( ) def train(self, round_idx): + """ + Train the federated learning model for a round. + + Args: + round_idx: The index of the training round. + + Returns: + tuple: A tuple containing weights and local sample number. + """ weights, local_sample_num = self.trainer.train(round_idx) return weights, local_sample_num def update_model(self, model_params): + """ + Update the federated learning model with new parameters. + + Args: + model_params: The new model parameters. + """ self.trainer.update_model(model_params) def update_dataset(self, client_index=None): + """ + Update the dataset for training. + + Args: + client_index: The index of the client (optional). + """ _client_index = client_index or self.client_index self.trainer.update_dataset(int(_client_index)) def cleanup_pg(self): + """ + Clean up the process group for distributed training. + """ if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: logging.info( "Cleaningup process group for client %s in silo %s" diff --git a/python/fedml/cross_silo/client/process_group_manager.py b/python/fedml/cross_silo/client/process_group_manager.py index 92519c6cc4..571ad3c2ab 100644 --- a/python/fedml/cross_silo/client/process_group_manager.py +++ b/python/fedml/cross_silo/client/process_group_manager.py @@ -6,6 +6,23 @@ class ProcessGroupManager: + """ + A class for managing the process group for distributed training. + + This class initializes and manages the process group for distributed training using PyTorch's distributed library. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master node for coordination. + master_port (int): The port number for coordination with the master node. + only_gpu (bool): Whether to use NCCL backend for GPU-based communication. + + Methods: + cleanup: Clean up the process group and release resources. + get_process_group: Get the initialized process group. + """ + def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Start process group") logging.info( @@ -17,10 +34,13 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): os.environ["WORLD_SIZE"] = str(world_size) os.environ["RANK"] = str(rank) - env_dict = {key: os.environ[key] for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE",)} - logging.info(f"[{os.getpid()}] Initializing process group with: {env_dict}") + env_dict = {key: os.environ[key] for key in ( + "MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE",)} + logging.info( + f"[{os.getpid()}] Initializing process group with: {env_dict}") - backend = dist.Backend.NCCL if (only_gpu and torch.cuda.is_available()) else dist.Backend.GLOO + backend = dist.Backend.NCCL if ( + only_gpu and torch.cuda.is_available()) else dist.Backend.GLOO logging.info(f"Process group backend: {backend}") # initialize the process group @@ -31,7 +51,16 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Clean up the process group and release associated resources. + """ dist.destroy_process_group() def get_process_group(self): + """ + Get the initialized process group. + + Returns: + dist.ProcessGroup: The initialized process group. + """ return self.messaging_pg diff --git a/python/fedml/cross_silo/client/utils.py b/python/fedml/cross_silo/client/utils.py index 960aa5e3ac..308cc5b38e 100644 --- a/python/fedml/cross_silo/client/utils.py +++ b/python/fedml/cross_silo/client/utils.py @@ -3,25 +3,54 @@ # ref: https://discuss.pytorch.org/t/failed-to-load-model-trained-by-ddp-for-inference/84841/2?u=amir_zsh def convert_model_params_from_ddp(ddp_model_params): - model_params = OrderedDict() - for k, v in ddp_model_params.items(): - name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel - model_params[name] = v - return model_params + """ + Convert model parameters from DataParallel/DistributedDataParallel format to a regular model format. + Args: + ddp_model_params (dict): Model parameters in DataParallel/DistributedDataParallel format. -def convert_model_params_to_ddp(ddp_model_params): + Returns: + OrderedDict: Model parameters in the regular format. + """ model_params = OrderedDict() for k, v in ddp_model_params.items(): - name = f"module.{k}" # add 'module.' of DataParallel/DistributedDataParallel + name = k[7:] # Remove 'module.' of DataParallel/DistributedDataParallel model_params[name] = v return model_params +def convert_model_params_to_ddp(model_params): + """ + Convert model parameters from a regular format to DataParallel/DistributedDataParallel format. + + Args: + model_params (dict): Model parameters in the regular format. + + Returns: + OrderedDict: Model parameters in DataParallel/DistributedDataParallel format. + """ + ddp_model_params = OrderedDict() + for k, v in model_params.items(): + # Add 'module.' for DataParallel/DistributedDataParallel + name = f"module.{k}" + ddp_model_params[name] = v + return ddp_model_params + + def check_method_override(cls_obj, method_name: str) -> bool: - # check if method has been overriden by class + """ + Check if a method has been overridden by a class. + + Args: + cls_obj (object): The class object. + method_name (str): The name of the method to check for override. + + Returns: + bool: True if the method has been overridden, False otherwise. + """ + # Check if method has been overridden by class return ( - method_name in cls_obj.__class__.__dict__ and - hasattr(cls_obj, method_name) and - callable(getattr(cls_obj, method_name)) + method_name in cls_obj.__class__.__dict__ and + hasattr(cls_obj, method_name) and + callable(getattr(cls_obj, method_name)) ) From 133597755c7ac360d71e84e625a013ce61bc48c6 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 21 Sep 2023 11:45:48 +0530 Subject: [PATCH 30/70] add --- python/fedml/core/data/noniid_partition.py | 42 +- python/fedml/core/mpc/lightsecagg.py | 173 ++++++- python/fedml/core/mpc/secagg.py | 421 ++++++++++++++++- .../fedml/core/schedule/runtime_estimate.py | 56 +++ .../core/schedule/seq_train_scheduler.py | 121 ++++- .../core/security/attack/backdoor_attack.py | 86 +++- .../core/security/attack/byzantine_attack.py | 37 ++ .../fedml/core/security/attack/dlg_attack.py | 55 ++- .../attack/edge_case_backdoor_attack.py | 20 + .../security/attack/invert_gradient_attack.py | 430 +++++++++++++++--- .../security/attack/label_flipping_attack.py | 37 +- .../fedml/core/security/attack/lazy_worker.py | 73 ++- .../model_replacement_backdoor_attack.py | 26 ++ .../revealing_labels_from_gradients_attack.py | 63 +++ .../common/attack_defense_data_loader.py | 36 +- python/fedml/core/security/common/bucket.py | 13 + python/fedml/core/security/common/net.py | 22 + python/fedml/core/security/common/utils.py | 192 +++++++- .../server_mnn/fedml_aggregator.py | 175 ++++++- .../server_mnn/fedml_server_manager.py | 242 ++++++++-- .../cross_device/server_mnn/server_mnn_api.py | 53 ++- 21 files changed, 2177 insertions(+), 196 deletions(-) diff --git a/python/fedml/core/data/noniid_partition.py b/python/fedml/core/data/noniid_partition.py index 368710ddd9..065102c063 100644 --- a/python/fedml/core/data/noniid_partition.py +++ b/python/fedml/core/data/noniid_partition.py @@ -55,7 +55,8 @@ def non_iid_partition_with_dirichlet_distribution( ) else: idx_k = np.asarray( - [np.any(label_list[i] == cat) for i in range(len(label_list))] + [np.any(label_list[i] == cat) + for i in range(len(label_list))] ) # Get the indices of images that have category = c @@ -87,6 +88,26 @@ def non_iid_partition_with_dirichlet_distribution( def partition_class_samples_with_dirichlet_distribution( N, alpha, client_num, idx_batch, idx_k ): + """ + Partition class samples using the Dirichlet distribution. + + Parameters: + N (int): Total number of samples to partition. + alpha (float): Parameter for the Dirichlet distribution. + client_num (int): Number of clients. + idx_batch (list of arrays): List of arrays containing sample indices for each client. + idx_k (array): Array of sample indices to be partitioned. + + Returns: + tuple: A tuple containing the updated idx_batch and the minimum batch size. + + This function partitions class samples using the Dirichlet distribution to create unbalanced proportions + for each client. It shuffles the sample indices, calculates the proportions, and generates batch lists + for each client. The minimum batch size is also computed. + + Example: + idx_batch, min_size = partition_class_samples_with_dirichlet_distribution(N, alpha, client_num, idx_batch, idx_k) + """ np.random.shuffle(idx_k) # using dirichlet distribution to determine the unbalanced proportion for each client (client_num in total) # e.g., when client_num = 4, proportions = [0.29543505 0.38414498 0.31998781 0.00043216], sum(proportions) = 1 @@ -94,7 +115,8 @@ def partition_class_samples_with_dirichlet_distribution( # get the index in idx_k according to the dirichlet distribution proportions = np.array( - [p * (len(idx_j) < N / client_num) for p, idx_j in zip(proportions, idx_batch)] + [p * (len(idx_j) < N / client_num) + for p, idx_j in zip(proportions, idx_batch)] ) proportions = proportions / proportions.sum() proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] @@ -110,6 +132,22 @@ def partition_class_samples_with_dirichlet_distribution( def record_data_stats(y_train, net_dataidx_map, task="classification"): + """ + Record data statistics for each client. + + Parameters: + y_train (array): Labels for the entire dataset. + net_dataidx_map (dict): Mapping of client indices to their respective data indices. + task (str): Task type, either "classification" or "segmentation". + + Returns: + dict: A dictionary containing class counts for each client. + + This function records data statistics for each client, specifically the count of each class in their data. + + Example: + net_cls_counts = record_data_stats(y_train, net_dataidx_map, task="classification") + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): diff --git a/python/fedml/core/mpc/lightsecagg.py b/python/fedml/core/mpc/lightsecagg.py index bc77bdf15d..34fa72ac19 100644 --- a/python/fedml/core/mpc/lightsecagg.py +++ b/python/fedml/core/mpc/lightsecagg.py @@ -6,6 +6,16 @@ def modular_inv(a, p): + """ + Compute the modular multiplicative inverse of 'a' modulo 'p'. + + Parameters: + a (int): The integer for which to find the modular inverse. + p (int): The prime number modulo which to compute the inverse. + + Returns: + int: The modular multiplicative inverse of 'a' modulo 'p'. + """ x, y, m = 1, 0, p while a > 1: q = a // m @@ -23,14 +33,35 @@ def modular_inv(a, p): def divmod(_num, _den, _p): - # compute num / den modulo prime p + """ + Compute the result of _num / _den modulo prime _p. + + Parameters: + _num (int): The numerator. + _den (int): The denominator. + _p (int): The prime number modulo which to compute the result. + + Returns: + int: The result of (_num / _den) modulo _p. + """ + # Compute the modulus of inputs _num = np.mod(_num, _p) _den = np.mod(_den, _p) _inv = modular_inv(_den, _p) return np.mod(np.int64(_num) * np.int64(_inv), _p) -def PI(vals, p): # upper-case PI -- product of inputs +def PI(vals, p): + """ + Compute the product of values in 'vals' modulo prime 'p'. + + Parameters: + vals (list of int): List of integers to be multiplied. + p (int): The prime number modulo which to compute the product. + + Returns: + int: The product of values in 'vals' modulo 'p'. + """ accum = 1 for v in vals: tmp = np.mod(v, p) @@ -39,6 +70,18 @@ def PI(vals, p): # upper-case PI -- product of inputs def LCC_encoding_with_points(X, alpha_s, beta_s, p): + """ + Perform Lagrange-Cauchy Coding encoding of data 'X' using specified points. + + Parameters: + X (numpy.ndarray): The input data matrix. + alpha_s (list of int): List of alpha values for encoding. + beta_s (list of int): List of beta values for encoding. + p (int): The prime number modulo which to perform encoding. + + Returns: + numpy.ndarray: The encoded data matrix. + """ m, d = np.shape(X) U = gen_Lagrange_coeffs(beta_s, alpha_s, p).astype("int64") X_LCC = np.zeros((len(beta_s), d), dtype="int64") @@ -48,6 +91,18 @@ def LCC_encoding_with_points(X, alpha_s, beta_s, p): def LCC_decoding_with_points(f_eval, eval_points, target_points, p): + """ + Perform Lagrange-Cauchy Coding decoding of data 'f_eval' using specified evaluation and target points. + + Parameters: + f_eval (numpy.ndarray): The data to decode. + eval_points (list of int): List of evaluation points. + target_points (list of int): List of target points. + p (int): The prime number modulo which to perform decoding. + + Returns: + numpy.ndarray: The decoded data. + """ alpha_s_eval = eval_points beta_s = target_points U_dec = gen_Lagrange_coeffs(beta_s, alpha_s_eval, p) @@ -57,6 +112,18 @@ def LCC_decoding_with_points(f_eval, eval_points, target_points, p): def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): + """ + Generate Lagrange coefficients for encoding and decoding. + + Parameters: + alpha_s (list of int): List of alpha values. + beta_s (list of int): List of beta values. + p (int): The prime number modulo which to compute the coefficients. + is_K1 (int, optional): A flag indicating whether it's for K=1 (1 for K=1, 0 otherwise). + + Returns: + numpy.ndarray: The Lagrange coefficients. + """ if is_K1 == 1: num_alpha = 1 else: @@ -81,12 +148,24 @@ def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): def model_masking(weights_finite, dimensions, local_mask, prime_number): + """ + Apply masking to model weights. + + Parameters: + weights_finite (dict): A dictionary of model weights. + dimensions (list of int): List of dimensions corresponding to weights. + local_mask (numpy.ndarray): The masking values. + prime_number (int): The prime number modulo which to perform masking. + + Returns: + dict: The masked model weights. + """ pos = 0 for i, k in enumerate(weights_finite): tmp = weights_finite[k] cur_shape = tmp.shape d = dimensions[i] - cur_mask = local_mask[pos : pos + d, :] + cur_mask = local_mask[pos: pos + d, :] cur_mask = np.reshape(cur_mask, cur_shape) weights_finite[k] += cur_mask weights_finite[k] = np.mod(weights_finite[k], prime_number) @@ -102,6 +181,20 @@ def mask_encoding( prime_number, local_mask, ): + """ + Encode a masking scheme for privacy-preserving federated learning. + + Parameters: + total_dimension (int): Total dimension. + num_clients (int): Number of clients. + targeted_number_active_clients (int): Targeted number of active clients. + privacy_guarantee (int): Privacy guarantee parameter. + prime_number (int): The prime number modulo which to perform encoding. + local_mask (numpy.ndarray): The local mask. + + Returns: + numpy.ndarray: The encoded mask set. + """ d = total_dimension N = num_clients U = targeted_number_active_clients @@ -124,6 +217,17 @@ def mask_encoding( def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): + """ + Compute the aggregate encoded mask from a dictionary of encoded masks for active clients. + + Parameters: + encoded_mask_dict (dict): A dictionary containing encoded masks for clients. + p (int): The prime number modulo which to compute the aggregate mask. + active_clients (list): List of active client IDs. + + Returns: + list: The aggregate encoded mask as a list. + """ aggregate_encoded_mask = np.zeros((np.shape(encoded_mask_dict[0]))) for client_id in active_clients: aggregate_encoded_mask += encoded_mask_dict[client_id] @@ -133,8 +237,14 @@ def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): def aggregate_models_in_finite(weights_finite, prime_number): """ - weights_finite : array of state_dict() - prime_number : size of the finite field + Aggregate model weights in a finite field. + + Parameters: + weights_finite (list): List of model weights (state_dict) from different clients. + prime_number (int): The size of the finite field. + + Returns: + dict: The aggregated model weights in the finite field. """ w_sum = copy.deepcopy(weights_finite[0]) @@ -148,6 +258,17 @@ def aggregate_models_in_finite(weights_finite, prime_number): def my_q(X, q_bit, p): + """ + Quantize input values using fixed-point representation. + + Parameters: + X (numpy.ndarray): Input values to be quantized. + q_bit (int): Number of quantization bits. + p (int): The prime number modulo which to quantize. + + Returns: + numpy.ndarray: Quantized values. + """ X_int = np.round(X * (2**q_bit)) is_negative = (abs(np.sign(X_int)) - np.sign(X_int)) / 2 out = X_int + p * is_negative @@ -155,6 +276,17 @@ def my_q(X, q_bit, p): def my_q_inv(X_q, q_bit, p): + """ + Inverse quantize values back to their original range. + + Parameters: + X_q (numpy.ndarray): Quantized values to be de-quantized. + q_bit (int): Number of quantization bits. + p (int): The prime number modulo which to perform inverse quantization. + + Returns: + numpy.ndarray: De-quantized values. + """ flag = X_q - (p - 1) / 2 is_negative = (abs(np.sign(flag)) + np.sign(flag)) / 2 X_q = X_q - p * is_negative @@ -162,6 +294,17 @@ def my_q_inv(X_q, q_bit, p): def transform_finite_to_tensor(model_params, p, q_bits): + """ + Transform model parameters from finite field representation to tensor representation. + + Parameters: + model_params (dict): Model parameters represented in a finite field. + p (int): The prime number used for finite field representation. + q_bits (int): Number of quantization bits. + + Returns: + dict: Transformed model parameters in tensor representation. + """ for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_real = my_q_inv(tmp, q_bits, p) @@ -185,6 +328,17 @@ def transform_finite_to_tensor(model_params, p, q_bits): def transform_tensor_to_finite(model_params, p, q_bits): + """ + Transform model parameters from tensor representation to finite field representation. + + Parameters: + model_params (dict): Model parameters represented as tensors. + p (int): The prime number used for finite field representation. + q_bits (int): Number of quantization bits. + + Returns: + dict: Transformed model parameters in finite field representation. + """ for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_finite = my_q(tmp, q_bits, p) @@ -193,6 +347,15 @@ def transform_tensor_to_finite(model_params, p, q_bits): def model_dimension(weights): + """ + Compute the dimensions and total dimension of model weights. + + Parameters: + weights (dict): Model weights (state_dict). + + Returns: + tuple: A tuple containing dimensions (list) and total dimension (int). + """ logging.info("Get model dimension") dimensions = [] for k in weights.keys(): diff --git a/python/fedml/core/mpc/secagg.py b/python/fedml/core/mpc/secagg.py index 45874faba8..1660cbb27e 100644 --- a/python/fedml/core/mpc/secagg.py +++ b/python/fedml/core/mpc/secagg.py @@ -6,6 +6,16 @@ def modular_inv(a, p): + """ + Compute the modular inverse of 'a' modulo 'p' using the extended Euclidean algorithm. + + Parameters: + a (int): The number for which to find the modular inverse. + p (int): The modulus. + + Returns: + int: The modular inverse of 'a' modulo 'p'. + """ x, y, m = 1, 0, p while a > 1: q = a // m @@ -23,6 +33,18 @@ def modular_inv(a, p): def divmod(_num, _den, _p): + """ + Compute 'num' divided by 'den' modulo prime 'p'. + + Parameters: + _num (int): The numerator. + _den (int): The denominator. + _p (int): The prime modulus. + + Returns: + int: The result of 'num' / 'den' modulo 'p'. + """ + # compute num / den modulo prime p _num = np.mod(_num, _p) _den = np.mod(_den, _p) @@ -31,6 +53,16 @@ def divmod(_num, _den, _p): def PI(vals, p): # upper-case PI -- product of inputs + """ + Compute the product of a list of values modulo 'p'. + + Parameters: + vals (list): List of values. + p (int): The modulus. + + Returns: + int: The product of the values modulo 'p'. + """ accum = np.int64(1) for v in vals: tmp = np.mod(v, p) @@ -39,6 +71,18 @@ def PI(vals, p): # upper-case PI -- product of inputs def LCC_encoding_with_points(X, alpha_s, beta_s, p): + """ + Linear Code with Complementary coefficients (LCC) encoding of a matrix 'X' with given alpha and beta points. + + Parameters: + X (numpy.ndarray): Input matrix to be encoded. + alpha_s (list): List of alpha points. + beta_s (list): List of beta points. + p (int): The modulus. + + Returns: + numpy.ndarray: Encoded matrix using LCC encoding. + """ m, d = np.shape(X) U = gen_Lagrange_coeffs(beta_s, alpha_s, p).astype("int64") X_LCC = np.zeros((len(beta_s), d), dtype="int64") @@ -48,6 +92,18 @@ def LCC_encoding_with_points(X, alpha_s, beta_s, p): def LCC_decoding_with_points(f_eval, eval_points, target_points, p): + """ + Linear Code with Complementary coefficients (LCC) decoding with given evaluation and target points. + + Parameters: + f_eval (numpy.ndarray): Evaluation points. + eval_points (list): List of evaluation points. + target_points (list): List of target points. + p (int): The modulus. + + Returns: + int: Decoded result using LCC decoding. + """ alpha_s_eval = eval_points beta_s = target_points U_dec = gen_Lagrange_coeffs(beta_s, alpha_s_eval, p) @@ -57,6 +113,18 @@ def LCC_decoding_with_points(f_eval, eval_points, target_points, p): def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): + """ + Generate Lagrange coefficients for given alpha and beta points. + + Parameters: + alpha_s (list): List of alpha points. + beta_s (list): List of beta points. + p (int): The modulus. + is_K1 (int): Indicator for K1 coefficient generation. + + Returns: + numpy.ndarray: Lagrange coefficients matrix. + """ if is_K1 == 1: num_alpha = 1 else: @@ -81,6 +149,23 @@ def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): def model_masking(weights_finite, dimensions, local_mask, prime_number): + """ + Apply masking to model weights. + + Parameters: + weights_finite (dict): Dictionary of model weights. + dimensions (list): List of dimensions for each weight. + local_mask (numpy.ndarray): Local mask to be applied. + prime_number (int): The prime number for modulo operation. + + Returns: + dict: Updated model weights after masking. + + This function applies a local mask to model weights by element-wise addition and modulo operation. + + Example: + updated_weights = model_masking(weights_finite, dimensions, local_mask, prime_number) + """ pos = 0 reshaped_local_mask = local_mask.reshape((local_mask.shape[0], 1)) for i, k in enumerate(weights_finite): @@ -95,7 +180,7 @@ def model_masking(weights_finite, dimensions, local_mask, prime_number): tmp = weights_finite[k] cur_shape = tmp.shape d = dimensions[i] - cur_mask = reshaped_local_mask[pos : pos + d, :] + cur_mask = reshaped_local_mask[pos: pos + d, :] cur_mask = np.reshape(cur_mask, cur_shape) weights_finite[k] += cur_mask weights_finite[k] = np.mod(weights_finite[k], prime_number) @@ -118,6 +203,26 @@ def model_masking(weights_finite, dimensions, local_mask, prime_number): def mask_encoding( total_dimension, num_clients, targeted_number_active_clients, privacy_guarantee, prime_number, local_mask ): + """ + Encode a local mask for privacy. + + Parameters: + total_dimension (int): Total dimension. + num_clients (int): Total number of clients. + targeted_number_active_clients (int): Targeted number of active clients. + privacy_guarantee (int): Privacy guarantee parameter. + prime_number (int): The prime number for modulo operation. + local_mask (numpy.ndarray): Local mask. + + Returns: + numpy.ndarray: Encoded mask. + + This function encodes a local mask for privacy using parameters like total dimension, number of clients, etc. + + Example: + encoded_mask = mask_encoding(total_dimension, num_clients, targeted_number_active_clients, privacy_guarantee, prime_number, local_mask) + """ + d = total_dimension N = num_clients U = targeted_number_active_clients @@ -132,12 +237,30 @@ def mask_encoding( LCC_in = np.concatenate([local_mask, n_i], axis=0) LCC_in = np.reshape(LCC_in, (U, d // (U - T))) - encoded_mask_set = LCC_encoding_with_points(LCC_in, alpha_s, beta_s, p).astype("int64") + encoded_mask_set = LCC_encoding_with_points( + LCC_in, alpha_s, beta_s, p).astype("int64") return encoded_mask_set def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): + """ + Compute the aggregate encoded mask. + + Parameters: + encoded_mask_dict (dict): Dictionary of encoded masks for each client. + p (int): The prime number for modulo operation. + active_clients (list): List of active client IDs. + + Returns: + numpy.ndarray: Aggregate encoded mask. + + This function computes the aggregate encoded mask from individual client masks. + + Example: + aggregate_mask = compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients) + """ + aggregate_encoded_mask = np.zeros((np.shape(encoded_mask_dict[0]))) for client_id in active_clients: aggregate_encoded_mask += encoded_mask_dict[client_id] @@ -147,8 +270,19 @@ def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): def aggregate_models_in_finite(weights_finite, prime_number): """ - weights_finite : array of state_dict() - prime_number : size of the finite field + Aggregate model weights in a finite field. + + Parameters: + weights_finite (list of dict): List of model weight dictionaries. + prime_number (int): The prime number for modulo operation. + + Returns: + dict: Aggregated model weights. + + This function aggregates model weights in a finite field using modulo operation. + + Example: + aggregated_weights = aggregate_models_in_finite(weights_finite, prime_number) """ w_sum = copy.deepcopy(weights_finite[0]) @@ -162,6 +296,23 @@ def aggregate_models_in_finite(weights_finite, prime_number): def BGW_encoding(X, N, T, p): + """ + Encode data using BGW encoding. + + Parameters: + X (numpy.ndarray): Data to be encoded. + N (int): Number of evaluation points. + T (int): Degree of polynomial. + p (int): Prime number. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using BGW encoding scheme. + + Example: + encoded_data = BGW_encoding(X, N, T, p) + """ m = len(X) d = len(X[0]) @@ -173,11 +324,27 @@ def BGW_encoding(X, N, T, p): for i in range(N): for t in range(T + 1): - X_BGW[i, :, :] = np.mod(X_BGW[i, :, :] + R[t, :, :] * (alpha_s[i] ** t), p) + X_BGW[i, :, :] = np.mod( + X_BGW[i, :, :] + R[t, :, :] * (alpha_s[i] ** t), p) return X_BGW def gen_BGW_lambda_s(alpha_s, p): + """ + Generate lambda values for BGW encoding. + + Parameters: + alpha_s (numpy.ndarray): Array of alpha values. + p (int): Prime number. + + Returns: + numpy.ndarray: Generated lambda values. + + This function generates lambda values for BGW encoding. + + Example: + lambda_values = gen_BGW_lambda_s(alpha_s, p) + """ lambda_s = np.zeros((1, len(alpha_s)), dtype="int64") for i in range(len(alpha_s)): @@ -190,6 +357,23 @@ def gen_BGW_lambda_s(alpha_s, p): def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluation points + """ + Decode data using BGW decoding. + + Parameters: + f_eval (numpy.ndarray): Evaluated data. + worker_idx (list): List of worker indices. + p (int): Prime number. + + Returns: + numpy.ndarray: Decoded data. + + This function decodes data using BGW decoding scheme. + + Example: + decoded_data = BGW_decoding(f_eval, worker_idx, p) + """ + # f_eval : [RT X d ] # worker_idx : [ 1 X RT] # output : [ 1 X d ] @@ -211,12 +395,30 @@ def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluatio def LCC_encoding(X, N, K, T, p): + """ + Encode data using LCC encoding. + + Parameters: + X (numpy.ndarray): Data to be encoded. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + p (int): Prime number. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using LCC encoding scheme. + + Example: + encoded_data = LCC_encoding(X, N, K, T, p) + """ m = len(X) d = len(X[0]) # print(m,d,m//K) X_sub = np.zeros((K + T, m // K, d), dtype="int64") for i in range(K): - X_sub[i] = X[i * m // K : (i + 1) * m // K :] + X_sub[i] = X[i * m // K: (i + 1) * m // K:] for i in range(K, K + T): X_sub[i] = np.random.randint(p, size=(m // K, d)) @@ -232,17 +434,37 @@ def LCC_encoding(X, N, K, T, p): X_LCC = np.zeros((N, m // K, d), dtype="int64") for i in range(N): for j in range(K + T): - X_LCC[i, :, :] = np.mod(X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) + X_LCC[i, :, :] = np.mod( + X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) return X_LCC def LCC_encoding_w_Random(X, R_, N, K, T, p): + """ + Encode data using LCC encoding with random values. + + Parameters: + X (numpy.ndarray): Data to be encoded. + R_ (numpy.ndarray): Random values for encoding. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + p (int): Prime number. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using LCC encoding scheme with random values. + + Example: + encoded_data = LCC_encoding_w_Random(X, R_, N, K, T, p) + """ m = len(X) d = len(X[0]) # print(m,d,m//K) X_sub = np.zeros((K + T, m // K, d), dtype="int64") for i in range(K): - X_sub[i] = X[i * m // K : (i + 1) * m // K :] + X_sub[i] = X[i * m // K: (i + 1) * m // K:] for i in range(K, K + T): X_sub[i] = R_[i - K, :, :].astype("int64") @@ -262,17 +484,39 @@ def LCC_encoding_w_Random(X, R_, N, K, T, p): X_LCC = np.zeros((N, m // K, d), dtype="int64") for i in range(N): for j in range(K + T): - X_LCC[i, :, :] = np.mod(X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) + X_LCC[i, :, :] = np.mod( + X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) return X_LCC def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): + """ + Encode data using LCC encoding with random values for a subset of workers. + + Parameters: + X (numpy.ndarray): Data to be encoded. + R_ (numpy.ndarray): Random values for encoding. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + p (int): Prime number. + worker_idx (list): List of worker indices. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using LCC encoding scheme with random values for a subset of workers. + + Example: + encoded_data = LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx) + """ + m = len(X) d = len(X[0]) # print(m,d,m//K) X_sub = np.zeros((K + T, m // K, d), dtype="int64") for i in range(K): - X_sub[i] = X[i * m // K : (i + 1) * m // K :] + X_sub[i] = X[i * m // K: (i + 1) * m // K:] for i in range(K, K + T): X_sub[i] = R_[i - K, :, :].astype("int64") @@ -290,11 +534,33 @@ def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): X_LCC = np.zeros((N_out, m // K, d), dtype="int64") for i in range(N_out): for j in range(K + T): - X_LCC[i, :, :] = np.mod(X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) + X_LCC[i, :, :] = np.mod( + X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) return X_LCC def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): + """ + Decode the encoded data using LCC decoding. + + Parameters: + f_eval (numpy.ndarray): Encoded data to be decoded. + f_deg (int): Degree of the encoded data. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + worker_idx (list): List of worker indices. + p (int): Prime number. + + Returns: + numpy.ndarray: Decoded data. + + This function decodes the encoded data using LCC decoding scheme. + + Example: + decoded_data = LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p) + """ + RT_LCC = f_deg * (K + T - 1) + 1 n_beta = K # +T @@ -314,6 +580,23 @@ def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): def Gen_Additive_SS(d, n_out, p): + """ + Generate additive secret sharing. + + Parameters: + d (int): Dimension of the secret. + n_out (int): Number of output shares. + p (int): Prime number. + + Returns: + numpy.ndarray: Additive secret sharing matrix. + + This function generates additive secret sharing matrix. + + Example: + secret_sharing_matrix = Gen_Additive_SS(d, n_out, p) + """ + # x_model should be one dimension temp = np.random.randint(0, p, size=(n_out - 1, d)) @@ -327,6 +610,22 @@ def Gen_Additive_SS(d, n_out, p): def my_pk_gen(my_sk, p, g): + """ + Generate public key. + + Parameters: + my_sk (int): Private key. + p (int): Prime number. + g (int): Generator. + + Returns: + int: Public key. + + This function generates a public key from a private key. + + Example: + public_key = my_pk_gen(my_sk, p, g) + """ # print 'my_pk_gen option: g=',g if g == 0: return my_sk @@ -335,6 +634,23 @@ def my_pk_gen(my_sk, p, g): def my_key_agreement(my_sk, u_pk, p, g): + """ + Perform key agreement. + + Parameters: + my_sk (int): Private key. + u_pk (int): Other party's public key. + p (int): Prime number. + g (int): Generator. + + Returns: + int: Shared secret key. + + This function performs key agreement between two parties. + + Example: + shared_secret_key = my_key_agreement(my_sk, u_pk, p, g) + """ if g == 0: return np.mod(my_sk * u_pk, p) else: @@ -342,6 +658,22 @@ def my_key_agreement(my_sk, u_pk, p, g): def my_q(X, q_bit, p): + """ + Quantize data to a finite field. + + Parameters: + X (numpy.ndarray): Data to be quantized. + q_bit (int): Number of bits for quantization. + p (int): Prime number. + + Returns: + numpy.ndarray: Quantized data. + + This function quantizes data to a specific number of bits within a finite field. + + Example: + quantized_data = my_q(X, q_bit, p) + """ X_int = np.round(X * (2 ** q_bit)) is_negative = (abs(np.sign(X_int)) - np.sign(X_int)) / 2 out = X_int + p * is_negative @@ -349,6 +681,23 @@ def my_q(X, q_bit, p): def transform_tensor_to_finite(model_params, p, q_bits): + """ + Transform model tensor parameters to finite field. + + Parameters: + model_params (dict): Dictionary of model parameters. + p (int): Prime number for the finite field. + q_bits (int): Number of bits for quantization. + + Returns: + dict: Transformed model parameters in the finite field. + + This function takes a dictionary of model parameters (typically tensors) and transforms them to the specified finite field. + + Example: + finite_model_params = transform_tensor_to_finite(model_params, p, q_bits) + """ + for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_finite = my_q(tmp, q_bits, p) @@ -357,6 +706,22 @@ def transform_tensor_to_finite(model_params, p, q_bits): def my_q_inv(X_q, q_bit, p): + """ + Inverse quantize data from a finite field. + + Parameters: + X_q (numpy.ndarray): Data in the finite field to be inverse quantized. + q_bit (int): Number of bits for quantization. + p (int): Prime number. + + Returns: + numpy.ndarray: Inverse quantized data in the real field. + + This function performs inverse quantization of data from a finite field to the real field. + + Example: + real_data = my_q_inv(X_q, q_bit, p) + """ flag = X_q - (p - 1) / 2 is_negative = (abs(np.sign(flag)) + np.sign(flag)) / 2 X_q = X_q - p * is_negative @@ -364,6 +729,22 @@ def my_q_inv(X_q, q_bit, p): def transform_finite_to_tensor(model_params, p, q_bits): + """ + Transform model parameters from a finite field to tensor. + + Parameters: + model_params (dict): Dictionary of model parameters in the finite field. + p (int): Prime number for the finite field. + q_bits (int): Number of bits for quantization. + + Returns: + dict: Transformed model parameters as tensors in the real field. + + This function takes a dictionary of model parameters in the finite field and transforms them to tensors in the real field. + + Example: + tensor_model_params = transform_finite_to_tensor(model_params, p, q_bits) + """ for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_real = my_q_inv(tmp, q_bits, p) @@ -377,12 +758,28 @@ def transform_finite_to_tensor(model_params, p, q_bits): 0 - Wed, 13 Oct 2021 07:50:59 utils.py[line:33] DEBUG tmp_real = 256812209.4375 """ # logging.debug("tmp_real = {}".format(tmp_real)) - tmp_real = torch.Tensor([tmp_real]) if isinstance(tmp_real, np.floating) else torch.Tensor(tmp_real) + tmp_real = torch.Tensor([tmp_real]) if isinstance( + tmp_real, np.floating) else torch.Tensor(tmp_real) model_params[k] = tmp_real return model_params def model_dimension(weights): + """ + Get the dimension of a model. + + Parameters: + weights (dict): Dictionary of model weights. + + Returns: + list: List of dimensions of model parameters. + int: Total dimension of the model. + + This function calculates the dimensions of model parameters and the total dimension of the model. + + Example: + dimensions, total_dimension = model_dimension(weights) + """ logging.info("Get model dimension") dimensions = [] for k in weights.keys(): diff --git a/python/fedml/core/schedule/runtime_estimate.py b/python/fedml/core/schedule/runtime_estimate.py index f4984407d1..4500478e0d 100644 --- a/python/fedml/core/schedule/runtime_estimate.py +++ b/python/fedml/core/schedule/runtime_estimate.py @@ -2,6 +2,24 @@ def linear_fit(x, y): + """ + Fit a linear model to the given data. + + Parameters: + x (array-like): The independent variable data. + y (array-like): The dependent variable data. + + Returns: + z1 (array-like): Coefficients of the linear fit. + p1 (numpy.poly1d): The polynomial representing the linear fit. + yvals (array-like): Predicted values based on the linear fit. + fit_error (float): Mean absolute percentage error of the fit. + + Example: + x = [1, 2, 3, 4, 5] + y = [2, 4, 5, 4, 5] + z1, p1, yvals, fit_error = linear_fit(x, y) + """ z1 = np.polyfit(x, y, 1) p1 = np.poly1d(z1) print(p1) @@ -21,6 +39,44 @@ def t_sample_fit( 0: {0: [], 1: [], 2: []...}, 1: {0: [], 1: [], 2: []...}, } + + Fit linear models to runtime data for each worker and client combination. + + Parameters: + num_workers (int): The number of workers. + num_clients (int): The number of clients. + runtime_history (dict): A dictionary containing runtime history data. + Format: { + worker_id: { + client_id: [list of runtimes] + } + } + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for each client. + Format: { + client_id: num_samples + } + uniform_client (bool): Whether all clients have the same number of GPUs. + uniform_gpu (bool): Whether all clients have the same number of GPUs. + + Returns: + fit_params (dict): Fitted parameters (slope and intercept) of the linear models for each worker and client. + Format: { + worker_id: { + client_id: (slope, intercept) + } + } + fit_funcs (dict): Fitted linear functions for each worker and client. + Format: { + worker_id: { + client_id: p1 (linear function) + } + } + fit_errors (dict): Fit errors (mean absolute percentage error) for each worker and client. + Format: { + worker_id: { + client_id: fit_error + } + } """ fit_params = {} fit_funcs = {} diff --git a/python/fedml/core/schedule/seq_train_scheduler.py b/python/fedml/core/schedule/seq_train_scheduler.py index cd155df271..2e2b7082a7 100644 --- a/python/fedml/core/schedule/seq_train_scheduler.py +++ b/python/fedml/core/schedule/seq_train_scheduler.py @@ -7,6 +7,41 @@ class SeqTrainScheduler: + """ + Initialize the Sequential Training Scheduler. + + Parameters: + workloads (list): List of client workloads. + constraints (list): List of constraints corresponding to each resource. + memory (list): List of memory constraints for each resource. + cost_funcs (list of lists or list of functions): Cost functions for assigning workloads. + uniform_client (bool): Whether the client workloads are uniform. + uniform_gpu (bool): Whether the GPU resources are uniform. + prune_equal_sub_solution (bool): Whether to prune equal sub-solutions. + + Attributes: + workloads (list): List of client workloads. + constraints (list): List of constraints corresponding to each resource. + memory (list): List of memory constraints for each resource. + cost_funcs (list of lists or list of functions): Cost functions for assigning workloads. + uniform_client (bool): Whether the client workloads are uniform. + uniform_gpu (bool): Whether the GPU resources are uniform. + len_x (int): Number of workloads (clients). + len_y (int): Number of constraints (resources). + iter_times (int): Iteration counter. + + Example: + scheduler = SeqTrainScheduler( + workloads=[100, 200, 150], + constraints=[10, 20], + memory=[300, 400], + cost_funcs=[[cost_func1, cost_func2], [cost_func3, cost_func4]], + uniform_client=True, + uniform_gpu=False, + prune_equal_sub_solution=True, + ) + """ + def __init__( self, workloads, @@ -33,6 +68,23 @@ def __init__( self.iter_times = 0 def obtain_client_cost(self, resource_id, client_id): + """ + Calculate the cost of assigning a workload to a resource. + + Parameters: + resource_id (int): Index of the resource. + client_id (int): Index of the client. + + Returns: + float: The calculated cost. + + This method calculates the cost of assigning a workload to a resource based on the specified cost functions + and resource and client characteristics. It handles different scenarios based on the values of + `uniform_client` and `uniform_gpu`. + + Example: + cost = scheduler.obtain_client_cost(0, 1) + """ if self.uniform_client and self.uniform_gpu: # cost = self.cost_funcs[0][0](self.client_data_nums[client_id]) cost = self.cost_funcs[0][0](self.workloads[client_id]) @@ -44,12 +96,29 @@ def obtain_client_cost(self, resource_id, client_id): cost = self.cost_funcs[resource_id][0](self.workloads[client_id]) else: # cost = self.cost_funcs[resource_id][client_id](self.client_data_nums[client_id]) - cost = self.cost_funcs[resource_id][client_id](self.workloads[client_id]) + cost = self.cost_funcs[resource_id][client_id]( + self.workloads[client_id]) if cost < 0.0: cost = 0.0 return cost def assign_a_workload_serial(self, x_maps, cost_maps): + """ + Assign workloads to resources sequentially. + + Parameters: + x_maps (list): List of workload assignment maps. + cost_maps (list): List of cost maps corresponding to workload assignments. + + Returns: + tuple: A tuple containing updated x_maps and cost_maps. + + This method assigns workloads to resources sequentially while minimizing the cost. It explores various workload + assignments and prunes suboptimal solutions based on the `prune_equal_sub_solution` attribute. + + Example: + x_maps, cost_maps = scheduler.assign_a_workload_serial(x_maps, cost_maps) + """ # Find the case with the minimum cost. self.iter_times += 1 costs = [] @@ -108,6 +177,24 @@ def assign_a_workload_serial(self, x_maps, cost_maps): return self.assign_a_workload_serial(x_maps, cost_maps) def assign_a_workload(self, x_maps, cost_maps, resource_maps): + """ + Assign workloads to resources considering both parallel and serial execution. + + Parameters: + x_maps (list): List of workload assignment maps. + cost_maps (list): List of cost maps corresponding to workload assignments. + resource_maps (list): List of resource maps. + + Returns: + tuple: A tuple containing updated x_maps, cost_maps, and resource_maps. + + This method assigns workloads to resources while considering both parallel and serial execution possibilities. + It explores various workload assignments and prunes suboptimal solutions based on the `prune_equal_sub_solution` + attribute. + + Example: + x_maps, cost_maps, resource_maps = scheduler.assign_a_workload(x_maps, cost_maps, resource_maps) + """ # Find the case with the minimum cost. costs = [] for i in range(len(cost_maps)): @@ -139,7 +226,8 @@ def assign_a_workload(self, x_maps, cost_maps, resource_maps): new_maps.append(np.copy(x_map)) new_maps[-1][target_index] = i new_costs.append(np.copy(cost_map)) - new_costs[-1][i] = max((self.y[i] * self.x[target_index]), new_costs[-1][i]) + new_costs[-1][i] = max((self.y[i] * + self.x[target_index]), new_costs[-1][i]) new_resources.append(np.copy(resource_map)) new_resources[-1][i] += self.x[target_index] @@ -163,6 +251,22 @@ def assign_a_workload(self, x_maps, cost_maps, resource_maps): return self.assign_a_workload(x_maps, cost_maps, resource_maps) def DP_schedule(self, mode): + """ + Perform Dynamic Programming (DP) based scheduling. + + Parameters: + mode (int): Scheduling mode, 0 for serial, 1 for parallel. + + Returns: + tuple: A tuple containing the schedules and output_schedules. + + This method performs dynamic programming-based scheduling to assign workloads to resources while minimizing + the cost. It explores various workload assignments and prunes suboptimal solutions based on the scheduling mode. + The schedules are returned in the format of a list of dictionaries. + + Example: + schedules, output_schedules = scheduler.DP_schedule(1) + """ x_maps = [] x_maps.append(np.negative(np.ones((self.len_x)))) cost_maps = [] @@ -172,9 +276,11 @@ def DP_schedule(self, mode): if mode == 1: resource_maps = [] resource_maps.append(np.zeros((self.len_y))) - x_maps, cost_maps, resource_maps = self.assign_a_workload(x_maps, cost_maps, resource_maps) + x_maps, cost_maps, resource_maps = self.assign_a_workload( + x_maps, cost_maps, resource_maps) else: - x_maps, cost_maps = self.assign_a_workload_serial(x_maps, cost_maps) + x_maps, cost_maps = self.assign_a_workload_serial( + x_maps, cost_maps) # print(f"x_maps: {x_maps} len(x_maps): {len(x_maps)}") # print(f"cost_maps: {cost_maps} len(cost_maps): {len(cost_maps)}") @@ -195,9 +301,11 @@ def DP_schedule(self, mode): # logging.info(f"schedules: {schedules} len(schedules): {len(schedules)}") logging.info(f"self.iter_times: {self.iter_times}") logging.info( - "The optimal maximum cost: %f, assignment: %s\n" % (costs[target_index], str(x_maps[target_index])) + "The optimal maximum cost: %f, assignment: %s\n" % ( + costs[target_index], str(x_maps[target_index])) ) - logging.info(f"target_index: {target_index} cost_map: {cost_maps[target_index]}") + logging.info( + f"target_index: {target_index} cost_map: {cost_maps[target_index]}") # print(f"schedules: {schedules} len(schedules): {len(schedules)}") # print(f"self.iter_times: {self.iter_times}") @@ -239,4 +347,3 @@ def DP_schedule(self, mode): schedule[num_bunches] = jobs output_schedules.append(schedule) return schedules, output_schedules - diff --git a/python/fedml/core/security/attack/backdoor_attack.py b/python/fedml/core/security/attack/backdoor_attack.py index f0f882bdb2..ae7734097b 100644 --- a/python/fedml/core/security/attack/backdoor_attack.py +++ b/python/fedml/core/security/attack/backdoor_attack.py @@ -32,6 +32,16 @@ class BackdoorAttack(BaseAttackMethod): def __init__( self, backdoor_client_num, client_num, num_std=None, dataset=None, backdoor_type="pattern", ): + """ + Initialize the BackdoorAttack. + + Args: + backdoor_client_num (int): Number of malicious clients for the backdoor attack. + client_num (int): Total number of clients. + num_std (float): Number of standard deviations for clipping gradients (default=None). + dataset (Tuple[Tensor, Tensor] or None): Dataset for generating backdoor (default=None). + backdoor_type (str): Type of backdoor ("pattern" or "random"). + """ self.backdoor_client_num = backdoor_client_num self.client_num = client_num self.num_std = num_std @@ -52,9 +62,20 @@ def __init__( pass def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], - extra_auxiliary_info: Any = None): + extra_auxiliary_info: Any = None): + """ + Attack the model using a backdoor attack strategy. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Extra auxiliary information. + + Returns: + np.ndarray: New gradients for malicious clients. + """ # the local_w comes from local training (regular) - backdoor_idxs = self._get_malicious_client_idx(len(raw_client_grad_list)) + backdoor_idxs = self._get_malicious_client_idx( + len(raw_client_grad_list)) (num0, averaged_params) = raw_client_grad_list[0] # fake grad @@ -64,54 +85,105 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], for i in backdoor_idxs: (_, param) = raw_client_grad_list[i] # grad = np.concatenate([param.grad.data.cpu().numpy().flatten() for param in model.parameters()]) // for real net - grad = np.concatenate([param[p_name].numpy().flatten() * 0.5 for p_name in param]) + grad = np.concatenate( + [param[p_name].numpy().flatten() * 0.5 for p_name in param]) grads.append(grad) grads_mean = np.mean(grads, axis=0) grads_stdev = np.var(grads, axis=0) ** 0.5 learning_rate = 0.1 - original_params_flat = np.concatenate([averaged_params[p_name].numpy().flatten() for p_name in averaged_params]) + original_params_flat = np.concatenate( + [averaged_params[p_name].numpy().flatten() for p_name in averaged_params]) initial_params_flat = ( original_params_flat - learning_rate * grads_mean ) # the corrected param after the user optimized, because we still want the model to improve - mal_net_params = self.train_malicious_network(initial_params_flat, original_params_flat) + mal_net_params = self.train_malicious_network( + initial_params_flat, original_params_flat) # Getting from the final required mal_net_params to the gradients that needs to be applied on the parameters of the previous round. new_params = mal_net_params + learning_rate * grads_mean new_grads = (initial_params_flat - new_params) / learning_rate # authors in the paper claims to limit the range of parameters but the code limits the gradient. new_user_grads = np.clip( - new_grads, grads_mean - self.num_std * grads_stdev, grads_mean + self.num_std * grads_stdev, + new_grads, grads_mean - self.num_std * + grads_stdev, grads_mean + self.num_std * grads_stdev, ) # the returned gradient controls the local update for malicious clients return new_user_grads @staticmethod def add_pattern(img): + """ + Add a pattern to an image (currently disabled). + + Args: + img (Tensor): Input image. + + Returns: + Tensor: Image with added pattern (disabled). + """ # disable img[:, :5, :5] = 2.8 return img def train_malicious_network(self, initial_params_flat, param): + """ + Train a malicious network (currently skipped). + + Args: + initial_params_flat (np.ndarray): Initial flattened model parameters. + param (np.ndarray): Original model parameters. + + Returns: + np.ndarray: Flattened malicious model parameters. + """ # skip training process # return flatten_params(param) return param def _get_malicious_client_idx(self, client_num): + """ + Get indices of malicious clients. + + Args: + client_num (int): Total number of clients. + + Returns: + List[int]: List of indices of malicious clients. + """ return random.sample(range(client_num), self.backdoor_client_num) def flatten_params(params): + """ + Flatten model parameters. + + Args: + params (Iterable[Tensor]): Model parameters. + + Returns: + np.ndarray: Flattened parameters as a NumPy array. + """ # for real net return np.concatenate([i.data.cpu().numpy().flatten() for i in params]) def row_into_parameters(row, parameters): + """ + Map a flattened row of parameters to the original model parameters. + + Args: + row (np.ndarray): Flattened row of parameters. + parameters (Iterable[Tensor]): Model parameters to map to. + + Returns: + None + """ # for real net offset = 0 for param in parameters: new_size = functools.reduce(lambda x, y: x * y, param.shape) - current_data = row[offset : offset + new_size] + current_data = row[offset: offset + new_size] param.data[:] = torch.from_numpy(current_data.reshape(param.shape)) offset += new_size diff --git a/python/fedml/core/security/attack/byzantine_attack.py b/python/fedml/core/security/attack/byzantine_attack.py index c4f6b63257..7dfeead988 100644 --- a/python/fedml/core/security/attack/byzantine_attack.py +++ b/python/fedml/core/security/attack/byzantine_attack.py @@ -13,13 +13,30 @@ class ByzantineAttack(BaseAttackMethod): + def __init__(self, args): + """ + Initialize the ByzantineAttack. + + Args: + args (Namespace): Command-line arguments containing attack configuration. + """ self.byzantine_client_num = args.byzantine_client_num self.attack_mode = args.attack_mode # random: randomly generate a weight; zero: set the weight to 0 self.device = fedml.device.get_device(args) def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Attack the model using Byzantine clients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Extra auxiliary information (global model). + + Returns: + List[Tuple[float, OrderedDict]]: List of modified client gradients. + """ if len(raw_client_grad_list) < self.byzantine_client_num: self.byzantine_client_num = len(raw_client_grad_list) byzantine_idxs = sample_some_clients(len(raw_client_grad_list), self.byzantine_client_num) @@ -35,6 +52,16 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], return byzantine_local_w def _attack_zero_mode(self, model_list, byzantine_idxs): + """ + Perform zero-value Byzantine attack on the model gradients. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client gradients. + byzantine_idxs (List[int]): Indices of Byzantine clients. + + Returns: + List[Tuple[float, OrderedDict]]: List of modified client gradients. + """ new_model_list = [] for i in range(0, len(model_list)): if i not in byzantine_idxs: @@ -48,6 +75,16 @@ def _attack_zero_mode(self, model_list, byzantine_idxs): return new_model_list def _attack_random_mode(self, model_list, byzantine_idxs): + """ + Perform random Byzantine attack on the model gradients. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client gradients. + byzantine_idxs (List[int]): Indices of Byzantine clients. + + Returns: + List[Tuple[float, OrderedDict]]: List of modified client gradients. + """ new_model_list = [] for i in range(0, len(model_list)): diff --git a/python/fedml/core/security/attack/dlg_attack.py b/python/fedml/core/security/attack/dlg_attack.py index f9424625a3..176cd4cb7b 100644 --- a/python/fedml/core/security/attack/dlg_attack.py +++ b/python/fedml/core/security/attack/dlg_attack.py @@ -25,10 +25,15 @@ class DLGAttack(BaseAttackMethod): def __init__(self, args): + """ + Initialize the DLGAttack. + + Args: + args (Namespace): Command-line arguments containing attack configuration. + """ self.model = None self.model_type = args.model - if args.dataset in ["cifar10", "cifar100"]: self.original_data_size = torch.Size([1, 3, 32, 32]) if args.dataset == "cifar10": @@ -39,7 +44,8 @@ def __init__(self, args): self.original_data_size = torch.Size([1, 28, 28]) self.original_label_size = torch.Size([1, 10]) else: - raise Exception(f"do not support this dataset for DLG attack: {args.dataset}") + raise Exception( + f"do not support this dataset for DLG attack: {args.dataset}") self.criterion = cross_entropy_for_onehot # cifar 100: # original data size = torch.Size([1, 3, 32, 32]) @@ -60,6 +66,13 @@ def __init__(self, args): # attack the last iteration, as it contains more information def get_model(self): + """ + Get the model based on the specified model type. + + Returns: + torch.nn.Module: The model instance. + """ + if self.model_type == "LeNet": return LeNet() elif self.model_type == "resnet56": @@ -70,13 +83,36 @@ def get_model(self): raise Exception(f"do not support this model: {self.model_type}") def reconstruct_data(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Reconstruct the data using the provided client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Extra auxiliary information (global model of last round). + + Note: + This method performs data reconstruction based on specified conditions. + + """ if self.iteration_num in self.attack_iteration_idxs: for (_, local_model) in raw_client_grad_list: print(f"-----------attack---------------") - self.reconstruct_data_using_a_model(a_model=local_model, extra_auxiliary_info=extra_auxiliary_info) + self.reconstruct_data_using_a_model( + a_model=local_model, extra_auxiliary_info=extra_auxiliary_info) self.iteration_num += 1 def reconstruct_data_using_a_model(self, a_model: OrderedDict, extra_auxiliary_info: Any = None): + """ + Reconstruct data using a specific model and auxiliary information. + + Args: + a_model (OrderedDict): Client model parameters. + extra_auxiliary_info (Any): Extra auxiliary information (global model of last round). + + Returns: + torch.Tensor: Reconstructed data. + torch.Tensor: Reconstructed labels. + """ self.model = self.get_model() global_model_of_last_round = extra_auxiliary_info gradient = [] @@ -85,10 +121,13 @@ def reconstruct_data_using_a_model(self, a_model: OrderedDict, extra_auxiliary_i for k, _ in global_model_of_last_round.items(): if "weight" in k or "bias" in k: if self.protected_layers is not None and layer_counter in self.protected_layers: - gradient.append(torch.from_numpy(np.zeros(global_model_of_last_round[k].size())).float()) + gradient.append(torch.from_numpy( + np.zeros(global_model_of_last_round[k].size())).float()) # if the layer is protected, set to 0 else: - gradient.append(a_model[k] - global_model_of_last_round[k].to(self.device)) # !!!!!!!!!!!!!!!!!!todo: to double check + # !!!!!!!!!!!!!!!!!!todo: to double check + gradient.append( + a_model[k] - global_model_of_last_round[k].to(self.device)) layer_counter += 1 gradient = tuple(gradient) dummy_data = torch.randn(self.original_data_size) @@ -101,8 +140,10 @@ def reconstruct_data_using_a_model(self, a_model: OrderedDict, extra_auxiliary_i def closure(): optimizer.zero_grad() dummy_pred = self.model(dummy_data) - dummy_loss = self.criterion(dummy_pred, F.softmax(dummy_label, dim=-1)) - dummy_grad = torch.autograd.grad(dummy_loss, self.model.parameters(), create_graph=True) + dummy_loss = self.criterion( + dummy_pred, F.softmax(dummy_label, dim=-1)) + dummy_grad = torch.autograd.grad( + dummy_loss, self.model.parameters(), create_graph=True) dummy_grad = tuple( g.to(self.device) for g in dummy_grad) # extract tensor from tuple and move to device diff --git a/python/fedml/core/security/attack/edge_case_backdoor_attack.py b/python/fedml/core/security/attack/edge_case_backdoor_attack.py index 2e2a0d64ea..4b016fcda2 100644 --- a/python/fedml/core/security/attack/edge_case_backdoor_attack.py +++ b/python/fedml/core/security/attack/edge_case_backdoor_attack.py @@ -25,6 +25,16 @@ def __init__( backdoor_dataset, batch_size, ): + """ + Initialize the EdgeCaseBackdoorAttack. + + Args: + client_num (int): Total number of clients in the system. + poisoned_client_num (int): Number of clients to poison with backdoor samples. + backdoor_sample_percentage (float): Percentage of backdoor samples to insert. + backdoor_dataset (Dataset): Backdoor dataset containing poisoned samples. + batch_size (int): Batch size for data loaders. + """ self.client_num = client_num self.attack_epoch = 0 self.poisoned_client_num = poisoned_client_num @@ -34,6 +44,16 @@ def __init__( self.batch_size = batch_size def poison_data(self, dataset): + """ + Poison the training data of selected clients with backdoor samples. + + Args: + dataset (list): List containing various data related to clients and the dataset. + + Returns: + list: List of data loaders for each client, including backdoored clients. + """ + [ train_data_num, test_data_num, diff --git a/python/fedml/core/security/attack/invert_gradient_attack.py b/python/fedml/core/security/attack/invert_gradient_attack.py index 0f65fd2871..a11e0c80c7 100644 --- a/python/fedml/core/security/attack/invert_gradient_attack.py +++ b/python/fedml/core/security/attack/invert_gradient_attack.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import logging import math import time @@ -37,6 +38,16 @@ class InvertAttack(BaseAttackMethod): def __init__( self, attack_client_idx=0, trained_model=False, model=None, num_images=1, use_updates=False, ): + """ + Initialize the Invert Attack. + + Args: + attack_client_idx (int): Index of the target client for the attack. + trained_model (bool): Whether the model is already trained. + model: The model used for the attack. + num_images (int): Number of images to use for the attack. + use_updates (bool): Whether to use model updates for the attack. + """ defs = ConservativeStrategy() loss_fn = Classification() self.use_updates = use_updates @@ -48,15 +59,27 @@ def __init__( self.num_images = num_images # = batch_size in local training def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): + """ + Reconstruct the data after the attack. + + Args: + a_gradient (dict): Gradient information. + extra_auxiliary_info: Additional auxiliary information. + + Returns: + tuple: A tuple containing the reconstructed data and statistics. + """ self.ground_truth = extra_auxiliary_info[0][0] self.labels = extra_auxiliary_info[0][1] if not self.use_updates: rec_machine = GradientReconstructor( - self.model, (self.dm, self.ds), config=extra_auxiliary_info[1], num_images=self.num_images, + self.model, (self.dm, + self.ds), config=extra_auxiliary_info[1], num_images=self.num_images, ) self.input_gradient = a_gradient - output, stats = rec_machine.reconstruct(self.input_gradient, self.labels, self.img_shape) + output, stats = rec_machine.reconstruct( + self.input_gradient, self.labels, self.img_shape) else: rec_machine = FedAvgReconstructor( self.model, @@ -67,10 +90,12 @@ def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): use_updates=self.use_updates, ) self.input_parameters = a_gradient - output, stats = rec_machine.reconstruct(self.input_parameters, self.labels, self.img_shape) + output, stats = rec_machine.reconstruct( + self.input_parameters, self.labels, self.img_shape) test_mse = (output.detach() - self.ground_truth).pow(2).mean() - feat_mse = (self.model(output.detach()) - self.model(self.ground_truth)).pow(2).mean() + feat_mse = (self.model(output.detach()) - + self.model(self.ground_truth)).pow(2).mean() test_psnr = psnr(output, self.ground_truth, factor=1 / self.ds) logging.info( f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} |" @@ -83,8 +108,6 @@ def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): """Optimization setups.""" -from dataclasses import dataclass - @dataclass # class ConservativeStrategy(Strategy): @@ -108,8 +131,10 @@ def __init__(self, lr=None, epochs=None, dryrun=False): class Loss: """Abstract class, containing necessary methods. - Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model - containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations + + Abstract class to collect information about the 'higher-level' loss function, + used to train an energy-based model containing the evaluation of the loss + function, its gradients w.r.t. to first and second argument and evaluations of the actual metric that is targeted. """ @@ -181,13 +206,34 @@ def metric(self, x=None, y=None): def _label_to_onehot(target, num_classes=100): + """Convert class labels to one-hot encoded tensors. + + Args: + target (torch.Tensor): Class labels. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: One-hot encoded tensor with shape (target.size(0), num_classes). + """ target = torch.unsqueeze(target, 1) - onehot_target = torch.zeros(target.size(0), num_classes, device=target.device) + onehot_target = torch.zeros(target.size( + 0), num_classes, device=target.device) onehot_target.scatter_(1, target, 1) return onehot_target def _validate_config(config): + """Validate and fill in missing configuration values with defaults. + + Args: + config (dict): Configuration dictionary. + + Returns: + dict: Validated configuration dictionary with missing keys filled in with defaults. + + Raises: + ValueError: If deprecated keys are found in the configuration. + """ for key in DEFAULT_CONFIG.keys(): if config.get(key) is None: config[key] = DEFAULT_CONFIG[key] @@ -198,13 +244,32 @@ def _validate_config(config): class GradientReconstructor: - """Instantiate a reconstruction algorithm.""" + """ + Instantiate a reconstruction algorithm for gradients. + + Args: + model: The PyTorch model used for the reconstruction. + mean_std: Tuple of mean and standard deviation used for normalization. + config: Configuration dictionary for algorithm setup. + num_images: Number of images to use for reconstruction. + + Attributes: + config (dict): Algorithm configuration parameters. + model: The PyTorch model used for reconstruction. + setup (dict): Device and data type setup for the model. + mean_std (tuple): Mean and standard deviation used for normalization. + num_images (int): Number of images to use for reconstruction. + inception (InceptionScore): Inception score calculator (optional). + loss_fn (torch.nn.Module): Loss function used for reconstruction. + iDLG (bool): Flag indicating whether to use the iDLG trick. + """ def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1): """Initialize with algorithm setup.""" self.config = _validate_config(config) self.model = model - self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype) + self.setup = dict(device=next(model.parameters()).device, + dtype=next(model.parameters()).dtype) self.mean_std = mean_std self.num_images = num_images @@ -218,7 +283,20 @@ def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images def reconstruct( self, input_data, labels, img_shape=(3, 32, 32), dryrun=False, eval=True, tol=None, ): - """Reconstruct image from gradient.""" + """ + Reconstruct images from gradients. + + Args: + input_data (torch.Tensor): Input gradient data. + labels (torch.Tensor): Labels associated with the input data. + img_shape (tuple): Image shape (channels, height, width). + dryrun (bool): Whether to perform a dry run. + eval (bool): Whether to set the model to evaluation mode. + tol (float): Tolerance threshold for reconstruction. + + Returns: + tuple: A tuple containing the reconstructed data and statistics. + """ start_time = time.time() if eval: self.model.eval() @@ -230,7 +308,8 @@ def reconstruct( if labels is None: if self.num_images == 1 and self.iDLG: # iDLG trick: - last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1) + last_weight_min = torch.argmin( + torch.sum(input_data[-2], dim=-1), dim=-1) labels = last_weight_min.detach().reshape((1,)).requires_grad_(False) self.reconstruct_label = False else: @@ -249,7 +328,8 @@ def loss_fn(pred, labels): try: for trial in range(self.config["restarts"]): - x_trial, labels = self._run_trial(x[trial], input_data, labels, dryrun=dryrun) + x_trial, labels = self._run_trial( + x[trial], input_data, labels, dryrun=dryrun) # Finalize scores[trial] = self._score_trial(x_trial, input_data, labels) x[trial] = x_trial @@ -263,7 +343,8 @@ def loss_fn(pred, labels): # Choose optimal result: print("Choosing optimal result ...") - scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores? + # guard against NaN/-Inf scores? + scores = scores[torch.isfinite(scores)] optimal_index = torch.argmin(scores) print(f"Optimal result score: {scores[optimal_index]:2.4f}") stats["opt"] = scores[optimal_index].item() @@ -273,26 +354,50 @@ def loss_fn(pred, labels): return x_optimal.detach(), stats def _init_images(self, img_shape): + """ + Initialize images for reconstruction. + + Args: + img_shape (tuple): Image shape (channels, height, width). + + Returns: + torch.Tensor: Initialized image data. + """ if self.config["init"] == "randn": return torch.randn((self.config["restarts"], self.num_images, *img_shape)) else: raise ValueError() def _run_trial(self, x_trial, input_data, labels, dryrun=False): + """ + Run a reconstruction trial. + + Args: + x_trial (torch.Tensor): Image data for the trial. + input_data (torch.Tensor): Input gradient data. + labels (torch.Tensor): Labels associated with the input data. + dryrun (bool): Whether to perform a dry run. + + Returns: + tuple: A tuple containing the reconstructed image data and labels. + """ x_trial.requires_grad = True if self.reconstruct_label: output_test = self.model(x_trial) - labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True) + labels = torch.randn(output_test.shape[1]).to( + **self.setup).requires_grad_(True) if self.config["optim"] == "adam": - optimizer = torch.optim.Adam([x_trial, labels], lr=self.config["lr"]) + optimizer = torch.optim.Adam( + [x_trial, labels], lr=self.config["lr"]) else: raise ValueError() else: if self.config["optim"] == "adam": optimizer = torch.optim.Adam([x_trial], lr=self.config["lr"]) elif self.config["optim"] == "sgd": # actually gd - optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True) + optimizer = torch.optim.SGD( + [x_trial], lr=0.01, momentum=0.9, nesterov=True) elif self.config["optim"] == "LBFGS": optimizer = torch.optim.LBFGS([x_trial]) else: @@ -303,12 +408,14 @@ def _run_trial(self, x_trial, input_data, labels, dryrun=False): if self.config["lr_decay"]: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, - milestones=[max_iterations // 2.667, max_iterations // 1.6, max_iterations // 1.142,], + milestones=[max_iterations // 2.667, + max_iterations // 1.6, max_iterations // 1.142,], gamma=0.1, ) # 3/8 5/8 7/8 try: for iteration in range(max_iterations): - closure = self._gradient_closure(optimizer, x_trial, input_data, labels) + closure = self._gradient_closure( + optimizer, x_trial, input_data, labels) rec_loss = optimizer.step(closure) if self.config["lr_decay"]: scheduler.step() @@ -316,16 +423,19 @@ def _run_trial(self, x_trial, input_data, labels, dryrun=False): with torch.no_grad(): # Project into image space if self.config["boxed"]: - x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds) + x_trial.data = torch.max( + torch.min(x_trial, (1 - dm) / ds), -dm / ds) if (iteration + 1 == max_iterations) or iteration % 500 == 0: - print(f"It: {iteration}. Rec. loss: {rec_loss.item():2.4f}.") + print( + f"It: {iteration}. Rec. loss: {rec_loss.item():2.4f}.") if (iteration + 1) % 500 == 0: if self.config["filter"] == "none": pass elif self.config["filter"] == "median": - x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial) + x_trial.data = MedianPool2d( + kernel_size=3, stride=1, padding=1, same=False)(x_trial) else: raise ValueError() @@ -337,11 +447,24 @@ def _run_trial(self, x_trial, input_data, labels, dryrun=False): return x_trial.detach(), labels def _gradient_closure(self, optimizer, x_trial, input_gradient, label): + """ + Create a closure for gradient computation. + + Args: + optimizer: The optimizer used for reconstruction. + x_trial (torch.Tensor): Image data for the trial. + input_gradient (torch.Tensor): Input gradient data. + label (torch.Tensor): Labels associated with the input data. + + Returns: + function: A closure for gradient computation. + """ def closure(): optimizer.zero_grad() self.model.zero_grad() loss = self.loss_fn(self.model(x_trial), label) - gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) + gradient = torch.autograd.grad( + loss, self.model.parameters(), create_graph=True) rec_loss = reconstruction_costs( [gradient], input_gradient, @@ -351,7 +474,8 @@ def closure(): ) if self.config["total_variation"] > 0: - rec_loss += self.config["total_variation"] * total_variation(x_trial) + rec_loss += self.config["total_variation"] * \ + total_variation(x_trial) rec_loss.backward() if self.config["signed"]: x_trial.grad.sign_() @@ -360,11 +484,23 @@ def closure(): return closure def _score_trial(self, x_trial, input_gradient, label): + """ + Score a reconstruction trial. + + Args: + x_trial (torch.Tensor): Reconstructed image data. + input_gradient (torch.Tensor): Input gradient data. + label (torch.Tensor): Labels associated with the input data. + + Returns: + float: The score for the reconstruction trial. + """ if self.config["scoring_choice"] == "loss": self.model.zero_grad() x_trial.grad = None loss = self.loss_fn(self.model(x_trial), label) - gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False) + gradient = torch.autograd.grad( + loss, self.model.parameters(), create_graph=False) return reconstruction_costs( [gradient], input_gradient, @@ -377,7 +513,33 @@ def _score_trial(self, x_trial, input_gradient, label): class FedAvgReconstructor(GradientReconstructor): - """Reconstruct an image from weights after n gradient descent steps.""" + """ + Reconstruct an image from model weights after performing gradient descent steps. + + Args: + model: The PyTorch model used for the reconstruction. + mean_std: Tuple of mean and standard deviation used for normalization. + local_steps: Number of local gradient descent steps. + local_lr: Learning rate for local gradient descent. + config: Configuration dictionary for algorithm setup. + num_images: Number of images to use for reconstruction. + use_updates: Flag indicating whether to use weight updates. + batch_size: Batch size for local gradient descent. + + Attributes: + config (dict): Algorithm configuration parameters. + model: The PyTorch model used for reconstruction. + setup (dict): Device and data type setup for the model. + mean_std (tuple): Mean and standard deviation used for normalization. + num_images (int): Number of images to use for reconstruction. + inception (InceptionScore): Inception score calculator (optional). + loss_fn (torch.nn.Module): Loss function used for reconstruction. + iDLG (bool): Flag indicating whether to use the iDLG trick. + local_steps (int): Number of local gradient descent steps. + local_lr (float): Learning rate for local gradient descent. + use_updates (bool): Flag indicating whether to use weight updates. + batch_size (int): Batch size for local gradient descent. + """ def __init__( self, @@ -390,7 +552,19 @@ def __init__( use_updates=True, batch_size=0, ): - """Initialize with model, (mean, std) and config.""" + """ + Initialize the FedAvgReconstructor with the given parameters. + + Args: + model: The PyTorch model used for the reconstruction. + mean_std: Tuple of mean and standard deviation used for normalization. + local_steps: Number of local gradient descent steps. + local_lr: Learning rate for local gradient descent. + config: Configuration dictionary for algorithm setup. + num_images: Number of images to use for reconstruction. + use_updates: Flag indicating whether to use weight updates. + batch_size: Batch size for local gradient descent. + """ super().__init__(model, mean_std, config, num_images) self.local_steps = local_steps self.local_lr = local_lr @@ -398,6 +572,18 @@ def __init__( self.batch_size = batch_size def _gradient_closure(self, optimizer, x_trial, input_parameters, labels): + """ + Closure function for computing gradients during optimization. + + Args: + optimizer (torch.optim.Optimizer): The optimizer used for gradient descent. + x_trial (torch.Tensor): The input image to be optimized. + input_parameters (torch.Tensor): The ground truth model weights. + labels (torch.Tensor): The labels used for reconstruction. + + Returns: + Callable: A closure function for computing gradients and loss. + """ def closure(): optimizer.zero_grad() self.model.zero_grad() @@ -419,7 +605,8 @@ def closure(): weights=self.config["weights"], ) if self.config["total_variation"] > 0: - rec_loss += self.config["total_variation"] * total_variation(x_trial) + rec_loss += self.config["total_variation"] * \ + total_variation(x_trial) rec_loss.backward() if self.config["signed"]: x_trial.grad.sign_() @@ -428,6 +615,17 @@ def closure(): return closure def _score_trial(self, x_trial, input_parameters, labels): + """ + Compute the score of a trial reconstruction. + + Args: + x_trial (torch.Tensor): The reconstructed image. + input_parameters (torch.Tensor): The ground truth model weights. + labels (torch.Tensor): The labels used for reconstruction. + + Returns: + float: The score of the trial reconstruction. + """ if self.config["scoring_choice"] == "loss": self.model.zero_grad() parameters = loss_steps( @@ -451,7 +649,22 @@ def _score_trial(self, x_trial, input_parameters, labels): def loss_steps( model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0, ): - """Take a few gradient descent steps to fit the model to the given input.""" + """ + Perform gradient descent steps to fit the model to the given input data. + + Args: + model (nn.Module): The neural network model to be optimized. + inputs (torch.Tensor): The input data for optimization. + labels (torch.Tensor): The labels for the input data. + loss_fn (torch.nn.Module, optional): The loss function used for optimization. Default is CrossEntropyLoss. + lr (float, optional): The learning rate for gradient descent. Default is 1e-4. + local_steps (int, optional): The number of gradient descent steps to perform. Default is 4. + use_updates (bool, optional): Whether to use parameter updates during optimization. Default is True. + batch_size (int, optional): Batch size for mini-batch gradient descent. Default is 0 (full batch). + + Returns: + List[torch.Tensor]: A list of model parameter tensors after optimization. + """ patched_model = MetaMonkey(model) if use_updates: patched_model_origin = deepcopy(patched_model) @@ -461,8 +674,9 @@ def loss_steps( labels_ = labels else: idx = i % (inputs.shape[0] // batch_size) - outputs = patched_model(inputs[idx * batch_size : (idx + 1) * batch_size], patched_model.parameters,) - labels_ = labels[idx * batch_size : (idx + 1) * batch_size] + outputs = patched_model( + inputs[idx * batch_size: (idx + 1) * batch_size], patched_model.parameters,) + labels_ = labels[idx * batch_size: (idx + 1) * batch_size] loss = loss_fn(outputs, labels_).sum() grad = torch.autograd.grad( loss, patched_model.parameters.values(), retain_graph=True, create_graph=True, only_inputs=True, @@ -483,21 +697,38 @@ def loss_steps( def reconstruction_costs(gradients, input_gradient, cost_fn="l2", indices="def", weights="equal"): - """Input gradient is given data.""" + """ + Calculate reconstruction costs between gradients and input gradient. + + Args: + gradients (List[torch.Tensor]): List of gradients to be compared with the input gradient. + input_gradient (torch.Tensor): The input gradient (data). + cost_fn (str, optional): The reconstruction cost function to use ("l2" or "sim"). Default is "l2". + indices (Union[str, List[int]], optional): The indices of gradients to consider or method to choose them. + Default is "def" (all gradients). + weights (Union[str, List[float]], optional): The weights for each gradient during reconstruction cost calculation. + Default is "equal" (equal weights). + + Returns: + float: The total reconstruction cost averaged over the provided gradients. + """ if isinstance(indices, list): pass elif indices == "def": indices = torch.arange(len(input_gradient)) elif indices == "top10": - _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10) + _, indices = torch.topk(torch.stack( + [p.norm() for p in input_gradient], dim=0), 10) else: raise ValueError() ex = input_gradient[0] if weights == "linear": - weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient) + weights = torch.arange(len( + input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient) elif weights == "exp": - weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) + weights = torch.arange(len(input_gradient), 0, -1, + dtype=ex.dtype, device=ex.device) weights = weights.softmax(dim=0) weights = weights / weights[0] else: @@ -509,7 +740,8 @@ def reconstruction_costs(gradients, input_gradient, cost_fn="l2", indices="def", costs = 0 for i in indices: if cost_fn == "sim": - costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i] + costs -= (trial_gradient[i] * + input_gradient[i]).sum() * weights[i] pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i] pnorm[1] += input_gradient[i].pow(2).sum() * weights[i] if cost_fn == "sim": @@ -529,13 +761,27 @@ class MetaMonkey(torch.nn.Module): """ def __init__(self, net): - """Init with network.""" + """ + Initialize MetaMonkey with a neural network. + + Args: + net (torch.nn.Module): The neural network to be patched. + """ super().__init__() self.net = net self.parameters = OrderedDict(net.named_parameters()) def forward(self, inputs, parameters=None): - """Live Patch ... :> ...""" + """ + Forward pass through the network with optional live patching of modules. + + Args: + inputs (torch.Tensor): The input data. + parameters (OrderedDict, optional): Dictionary of parameters to be used for live patching. + + Returns: + torch.Tensor: The output tensor. + """ # If no parameter dictionary is given, everything is normal if parameters is None: return self.net(inputs) @@ -573,7 +819,8 @@ def forward(self, inputs, parameters=None): if module.num_batches_tracked is not None: module.num_batches_tracked += 1 if module.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float(module.num_batches_tracked) + exponential_average_factor = 1.0 / \ + float(module.num_batches_tracked) else: # use exponential moving average exponential_average_factor = module.momentum @@ -595,7 +842,8 @@ def forward(self, inputs, parameters=None): lin_weights = next(param_gen) lin_bias = next(param_gen) method_pile.append(module.forward) - module.forward = partial(F.linear, weight=lin_weights, bias=lin_bias) + module.forward = partial( + F.linear, weight=lin_weights, bias=lin_bias) elif next(module.parameters(), None) is None: # Pass over modules that do not contain parameters @@ -605,7 +853,8 @@ def forward(self, inputs, parameters=None): pass else: # Warn for other containers - warnings.warn(f"Patching for module {module.__class__} is not implemented.") + warnings.warn( + f"Patching for module {module.__class__} is not implemented.") output = self.net(inputs) @@ -622,13 +871,15 @@ def forward(self, inputs, parameters=None): class MedianPool2d(nn.Module): - """Median pool (usable as median filter when stride=1) module. - Args: - kernel_size: size of pooling kernel, int or 2-tuple - stride: pool stride, int or 2-tuple - padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad - same: override padding and enforce same padding, boolean """ + Initialize the MedianPool2d module. + + Args: + kernel_size: Size of the pooling kernel, can be an integer or a 2-tuple. + stride: Pooling stride, can be an integer or a 2-tuple. + padding: Pooling padding, can be an integer or a 4-tuple (left, right, top, bottom). + same: If True, override padding and enforce "same" padding. If False, use the specified padding. + """ def __init__(self, kernel_size=3, stride=1, padding=0, same=True): """Initialize with kernel_size, stride, padding.""" @@ -639,6 +890,15 @@ def __init__(self, kernel_size=3, stride=1, padding=0, same=True): self.same = same def _padding(self, x): + """ + Calculate the padding required based on the 'same' attribute and input size. + + Args: + x: Input tensor. + + Returns: + Tuple (pl, pr, pt, pb): Padding values for left, right, top, and bottom. + """ if self.same: ih, iw = x.size()[2:] if ih % self.stride[0] == 0: @@ -659,44 +919,88 @@ def _padding(self, x): return padding def forward(self, x): + """ + Perform median pooling on the input tensor. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Output tensor after median pooling. + """ # using existing pytorch functions and tensor ops so that we get autograd, # would likely be more efficient to implement from scratch at C/Cuda level x = F.pad(x, self._padding(x), mode="reflect") - x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.unfold(2, self.k[0], self.stride[0]).unfold( + 3, self.k[1], self.stride[1]) x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] return x class InceptionScore(torch.nn.Module): - """Class that manages and returns the inception score of images.""" + """Class that manages and returns the inception score of images. + + Args: + batch_size (int): Batch size for calculating the Inception Score. + setup (dict): A dictionary containing device and dtype setup for the model. + + Attributes: + preprocessing (torch.nn.Module): Preprocessing module to resize images to (299, 299). + model (torch.nn.Module): Inception V3 model used for scoring. + batch_size (int): Batch size for scoring. + + Note: + The input image batch should have dimensions BCHW and should be normalized. + B should be divisible by self.batch_size. + """ def __init__(self, batch_size=32, setup=dict(device=torch.device("cpu"), dtype=torch.float)): """Initialize with setup and target inception batch size.""" super().__init__() - self.preprocessing = torch.nn.Upsample(size=(299, 299), mode="bilinear", align_corners=False) - self.model = torchvision.models.inception_v3(pretrained=True).to(**setup) + self.preprocessing = torch.nn.Upsample( + size=(299, 299), mode="bilinear", align_corners=False) + self.model = torchvision.models.inception_v3( + pretrained=True).to(**setup) self.model.eval() self.batch_size = batch_size def forward(self, image_batch): - """Image batch should have dimensions BCHW and should be normalized. - B should be divisible by self.batch_size. + """Calculate the Inception Score for an image batch. + + Args: + image_batch (torch.Tensor): Input image batch with dimensions BCHW. + + Returns: + torch.Tensor: Inception Score. """ B, C, H, W = image_batch.shape batches = B // self.batch_size scores = [] for batch in range(batches): - input = self.preprocessing(image_batch[batch * self.batch_size : (batch + 1) * self.batch_size]) + input = self.preprocessing( + image_batch[batch * self.batch_size: (batch + 1) * self.batch_size]) scores.append(self.model(input)) # pylint: disable=E1102 prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1) - entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx)) + entropy = torch.where(prob_yx > 0, -prob_yx * + prob_yx.log(), torch.zeros_like(prob_yx)) return entropy.sum() def psnr(img_batch, ref_batch, batched=False, factor=1.0): - """Standard PSNR.""" + """Calculate the Peak Signal-to-Noise Ratio (PSNR) between two image batches. + + Args: + img_batch (torch.Tensor): Input image batch. + ref_batch (torch.Tensor): Reference image batch. + batched (bool): If True, compute PSNR for the entire batch. If False, compute individual PSNRs. + factor (float): Scaling factor for PSNR computation. + + Returns: + float or torch.Tensor: PSNR value(s). + """ def get_psnr(img_in, img_ref): + mse = ((img_in - img_ref) ** 2).mean() if mse > 0 and torch.isfinite(mse): return 10 * torch.log10(factor ** 2 / mse) @@ -711,14 +1015,22 @@ def get_psnr(img_in, img_ref): [B, C, m, n] = img_batch.shape psnrs = [] for sample in range(B): - psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) + psnrs.append(get_psnr(img_batch.detach()[ + sample, :, :, :], ref_batch[sample, :, :, :])) psnr = torch.stack(psnrs, dim=0).mean() return psnr.item() def total_variation(x): - """Anisotropic TV.""" + """"Calculate the Anisotropic Total Variation (TV) of an image. + + Args: + x (torch.Tensor): Input image. + + Returns: + torch.Tensor: Total Variation value. + """ dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) return dx + dy diff --git a/python/fedml/core/security/attack/label_flipping_attack.py b/python/fedml/core/security/attack/label_flipping_attack.py index 690f0f0748..d604da29ca 100644 --- a/python/fedml/core/security/attack/label_flipping_attack.py +++ b/python/fedml/core/security/attack/label_flipping_attack.py @@ -19,6 +19,12 @@ class LabelFlippingAttack(BaseAttackMethod): def __init__(self, args): + """ + Initialize the Label Flipping Attack. + + Args: + args: An object containing attack configuration parameters. + """ self.original_class_list = args.original_class_list self.target_class_list = args.target_class_list self.batch_size = args.batch_size @@ -43,9 +49,21 @@ def __init__(self, args): self.counter = 0 def get_ite_num(self): + """ + Get the current iteration number. + + Returns: + int: The current iteration number. + """ return math.floor(self.counter / self.client_num_per_round) # ite num starts from 0 def is_to_poison_data(self): + """ + Check if data poisoning should be performed for the current iteration. + + Returns: + bool: True if data poisoning should be performed, False otherwise. + """ self.counter += 1 if self.get_ite_num() < self.poison_start_round_id or self.get_ite_num() > self.poison_end_round_id: return False @@ -55,11 +73,26 @@ def is_to_poison_data(self): return rand < self.ratio_of_poisoned_client def print_dataset(self, dataset): + """ + Print information about the given dataset. + + Args: + dataset: The dataset to print information about. + """ print("---------------print dataset------------") for batch_idx, (data, target) in enumerate(dataset): print(f"{batch_idx} ----- {target}") def poison_data(self, local_dataset): + """ + Poison the local dataset by flipping labels. + + Args: + local_dataset: The local dataset to poison. + + Returns: + DataLoader: The poisoned data loader. + """ get_client_data_stat(local_dataset) # print("=======================1 end ") # self.print_dataset(local_dataset) @@ -83,7 +116,7 @@ def poison_data(self, local_dataset): total_counter += item[1] # print(f"total counter = {total_counter}") - ####################### below are correct ###############################3 + # below are correct ###############################3 tmp_y = replace_original_class_with_target_class( data_labels=tmp_local_dataset_y, @@ -94,4 +127,4 @@ def poison_data(self, local_dataset): poisoned_data = DataLoader(dataset, batch_size=self.batch_size) get_client_data_stat(poisoned_data) - return poisoned_data \ No newline at end of file + return poisoned_data diff --git a/python/fedml/core/security/attack/lazy_worker.py b/python/fedml/core/security/attack/lazy_worker.py index 6aa32a4fce..9a80b5caa4 100644 --- a/python/fedml/core/security/attack/lazy_worker.py +++ b/python/fedml/core/security/attack/lazy_worker.py @@ -12,6 +12,12 @@ class LazyWorkerAttack(BaseAttackMethod): def __init__(self, config): + """ + Initialize the Lazy Worker Attack. + + Args: + config: An object containing attack configuration parameters. + """ self.lazy_worker_num = config.lazy_worker_num self.attack_mode = ( config.attack_mode @@ -39,6 +45,16 @@ def attack_model( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform the Lazy Worker Attack on the global model. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional auxiliary information. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client gradients with the attack. + """ if self.round == 1: self.client_cache = [grad for (_, grad) in raw_client_grad_list] return raw_client_grad_list @@ -74,10 +90,23 @@ def attack_model( else: # client previous_model_params = self.client_cache[i] previous_model_params = mask_func(previous_model_params) - new_model_list.append((local_sample_number, previous_model_params)) + new_model_list.append( + (local_sample_number, previous_model_params)) return new_model_list def _add_a_mask_on_clients(self, model_list, lazy_worker_idxs, mask_func: Callable): + """ + Perform the Lazy Worker Attack on the client models. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client models. + lazy_worker_idxs (List[int]): List of lazy worker indices. + mask_func (Callable): Masking function. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client models with the attack. + + """ new_model_list = [] for i in range(0, len(model_list)): if i not in lazy_worker_idxs: @@ -88,10 +117,23 @@ def _add_a_mask_on_clients(self, model_list, lazy_worker_idxs, mask_func: Callab local_sample_number, _ = model_list[i] previous_model_params = self.client_cache[i] previous_model_params = mask_func(previous_model_params) - new_model_list.append((local_sample_number, previous_model_params)) + new_model_list.append( + (local_sample_number, previous_model_params)) return new_model_list def _add_a_mask_on_global(self, model_list, lazy_worker_idxs, mask_func: Callable): + """ + Perform the Lazy Worker Attack on the global model. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client models. + lazy_worker_idxs (List[int]): List of lazy worker indices. + mask_func (Callable): Masking function. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client models with the attack. + + """ new_model_list = [] for i in range(0, len(model_list)): if i not in lazy_worker_idxs: @@ -100,10 +142,17 @@ def _add_a_mask_on_global(self, model_list, lazy_worker_idxs, mask_func: Callabl local_sample_number, _ = model_list[i] previous_model_params = self.client_cache[i] previous_model_params = mask_func(previous_model_params) - new_model_list.append((local_sample_number, previous_model_params)) + new_model_list.append( + (local_sample_number, previous_model_params)) return new_model_list def random_mask(self, previous_model_params): + """ + Add a random mask in [-1, 1]. + + Args: + previous_model_params (OrderedDict): Previous model parameters. + """ # add a random mask in [-1, 1] for k in previous_model_params.keys(): if is_weight_param(k): @@ -120,6 +169,12 @@ def random_mask(self, previous_model_params): return previous_model_params def gaussian_mask(self, previous_model_params): + """ + Add a gaussian mask. + + Args: + previous_model_params (OrderedDict): Previous model parameters. + """ # add a gaussian mask for k in previous_model_params.keys(): if is_weight_param(k): @@ -131,6 +186,12 @@ def gaussian_mask(self, previous_model_params): return previous_model_params def uniform_mask(self, previous_model): + """ + Randomly generate a uniform mask. + + Args: + previous_model (OrderedDict): Previous model parameters. + """ # randomly generate a uniform mask unif_param = random.uniform(-1, 1) print(f"unif_mode_param = {unif_param}") @@ -147,5 +208,11 @@ def uniform_mask(self, previous_model): return previous_model def no_mask(self, previous_model_params): + """ + Directly return the model in the last round. + + Args: + previous_model_params (OrderedDict): Previous model parameters. + """ # directly return the model in the last round return previous_model_params diff --git a/python/fedml/core/security/attack/model_replacement_backdoor_attack.py b/python/fedml/core/security/attack/model_replacement_backdoor_attack.py index be4495d188..5c84bb3401 100644 --- a/python/fedml/core/security/attack/model_replacement_backdoor_attack.py +++ b/python/fedml/core/security/attack/model_replacement_backdoor_attack.py @@ -24,6 +24,12 @@ class ModelReplacementBackdoorAttack(BaseAttackMethod): def __init__(self, args): + """ + Initialize the Model Replacement Backdoor Attack. + + Args: + args: An object containing attack parameters. + """ if hasattr(args, "malicious_client_id") and isinstance(args.malicious_client_id, int): # assume only 1 malicious client self.malicious_client_id = args.malicious_client_id @@ -46,6 +52,16 @@ def attack_model( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Attack the global model by replacing the model of a selected malicious client. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional auxiliary information. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client gradients with the model replacement attack. + """ participant_num = len(raw_client_grad_list) if self.attack_training_rounds is not None and self.training_round not in self.attack_training_rounds: return raw_client_grad_list @@ -71,6 +87,16 @@ def attack_model( return raw_client_grad_list def compute_gamma(self, global_model, original_client_model): + """ + Compute the scaling factor gamma for model replacement. + + Args: + global_model (OrderedDict): Global model parameters. + original_client_model (OrderedDict): Model parameters of the malicious client. + + Returns: + float: Scaling factor gamma. + """ # total_client_num / η, η: global learning rate; # when η = total_client_num/participant_num, the model is fully replaced by the average of the local models malicious_client_model_vec = vectorize_weight(original_client_model) diff --git a/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py b/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py index 4594fc99e1..95caa3ec44 100644 --- a/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py +++ b/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py @@ -22,10 +22,27 @@ class RevealingLabelsFromGradientsAttack(BaseAttackMethod): def __init__(self, batch_size, model_type): + """ + Initialize the Revealing Labels from Gradients Attack. + + Args: + batch_size (int): Batch size for the attack. + model_type (str): The type of the target model (e.g., "ResNet50"). + """ self.batch_size = batch_size self.model_type = model_type def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): + """ + Reconstruct data labels using gradients information. + + Args: + a_gradient (dict): A dictionary containing gradients information. + extra_auxiliary_info (Any): Additional auxiliary information (e.g., ground truth labels). + + Returns: + None + """ vec_local_weight = utils.vectorize_weight(a_gradient) print(vec_local_weight) @@ -37,12 +54,33 @@ def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): return def _attack_on_gradients(self, gt_labels, v): + """ + Attack on gradients to infer labels. + + Args: + gt_labels (set): Ground truth labels. + v: Gradients information. + + Returns: + None + """ grads = np.sign(v) _, pred_labels = self._infer_labels(grads, gt_k=self.batch_size, epsilon=1e-10) print("In gt, not in pr:", [i for i in gt_labels if i not in pred_labels]) print("In pr, not in gt:", [i for i in pred_labels if i not in gt_labels]) def _infer_labels(self, grads, gt_k=None, epsilon=1e-8): + """ + Infer labels from gradients. + + Args: + grads: Gradients information. + gt_k: Number of ground truth labels to consider. + epsilon: A small value to avoid numerical instability. + + Returns: + Tuple[int, list]: Tuple containing the number of predicted labels and the list of inferred labels. + """ m, n = np.shape(grads) B, s, C = np.linalg.svd(grads, full_matrices=False) pred_k = np.linalg.matrix_rank(grads) @@ -91,6 +129,20 @@ def _infer_labels(self, grads, gt_k=None, epsilon=1e-8): @staticmethod def _solve_perceptron(X, y, fit_intercept=True, max_iter=1000, tol=1e-3, eta0=1.0): + """ + Solve the perceptron problem. + + Args: + X: Input data. + y: Target labels. + fit_intercept: Whether to fit an intercept. + max_iter: Maximum number of iterations. + tol: Tolerance for stopping criterion. + eta0: Learning rate. + + Returns: + bool: True if the perceptron problem is successfully solved, False otherwise. + """ from sklearn.linear_model import Perceptron clf = Perceptron( @@ -105,6 +157,17 @@ def _solve_perceptron(X, y, fit_intercept=True, max_iter=1000, tol=1e-3, eta0=1. @staticmethod def solve_lp(grads, b, c): + """ + Solve a linear programming problem. + + Args: + grads: Gradients information. + b: Target vector. + c: Coefficients matrix. + + Returns: + bool: True if the linear programming problem is successfully solved, False otherwise. + """ # from cvxopt import matrix, solvers np.solvers.options["show_progress"] = False diff --git a/python/fedml/core/security/common/attack_defense_data_loader.py b/python/fedml/core/security/common/attack_defense_data_loader.py index c01328b748..31e188617c 100644 --- a/python/fedml/core/security/common/attack_defense_data_loader.py +++ b/python/fedml/core/security/common/attack_defense_data_loader.py @@ -10,6 +10,19 @@ class AttackDefenseDataLoader: def load_cifar10_data( cls, client_num, batch_size, data_dir="../../../../../data/cifar10", partition_method="homo", partition_alpha=None ): + """ + Load CIFAR-10 dataset and partition it among clients. + + Args: + client_num (int): The number of clients to partition the dataset for. + batch_size (int): The batch size for DataLoader objects. + data_dir (str): The directory where the CIFAR-10 dataset is located. + partition_method (str): The method for partitioning the dataset among clients. + partition_alpha (float): The alpha parameter for partitioning (used when partition_method is "hetero"). + + Returns: + dict: A dictionary containing DataLoader objects for each client. + """ return load_partition_data_cifar10( "cifar10", data_dir=data_dir, @@ -24,13 +37,14 @@ def get_data_loader_from_data(cls, batch_size, X, Y, **kwargs): """ Get a data loader created from a given set of data. - :param batch_size: batch size of data loader - :type batch_size: int - :param X: data features - :type X: numpy.Array() - :param Y: data labels - :type Y: numpy.Array() - :return: torch.utils.data.DataLoader + Args: + batch_size (int): Batch size of the DataLoader. + X (numpy.ndarray): Data features. + Y (numpy.ndarray): Data labels. + **kwargs: Additional arguments for DataLoader. + + Returns: + torch.utils.data.DataLoader: DataLoader object for the provided data. """ X_torch = torch.from_numpy(X).float() @@ -48,9 +62,13 @@ def get_data_loader_from_data(cls, batch_size, X, Y, **kwargs): @classmethod def load_data_loader_from_file(cls, filename): """ - Loads DataLoader object from a file if available. + Load a DataLoader object from a file. + + Args: + filename (str): The name of the file containing the DataLoader object. - :param filename: string + Returns: + torch.utils.data.DataLoader: Loaded DataLoader object. """ print("Loading data loader from file: {}".format(filename)) diff --git a/python/fedml/core/security/common/bucket.py b/python/fedml/core/security/common/bucket.py index ac07019aeb..0f400887a3 100644 --- a/python/fedml/core/security/common/bucket.py +++ b/python/fedml/core/security/common/bucket.py @@ -5,6 +5,19 @@ class Bucket: @classmethod def bucketization(cls, client_grad_list, batch_size): + """ + Perform bucketization of client gradients. + + Args: + client_grad_list (list): A list of tuples containing client gradients, where each tuple consists of + the number of samples and a dictionary of gradient values. + batch_size (int): The desired batch size for bucketization. + + Returns: + list: A list of batched client gradients, where each batch is represented as a tuple containing + the total number of samples and a dictionary of batched gradient values. + + """ (num0, averaged_params) = client_grad_list[0] batch_grad_list = [] for batch_idx in range(0, math.ceil(len(client_grad_list) / batch_size)): diff --git a/python/fedml/core/security/common/net.py b/python/fedml/core/security/common/net.py index 4023ede220..5c8f7e554b 100644 --- a/python/fedml/core/security/common/net.py +++ b/python/fedml/core/security/common/net.py @@ -1,6 +1,18 @@ import torch.nn as nn class LeNet(nn.Module): + """ + LeNet-5 is a convolutional neural network architecture that was designed for handwritten and machine-printed character + recognition tasks. This implementation includes four convolutional layers and one fully connected layer. + + Args: + None + + Attributes: + body (nn.Sequential): The convolutional layers of the LeNet model. + fc (nn.Sequential): The fully connected layer of the LeNet model. + + """ def __init__(self): super(LeNet, self).__init__() act = nn.Sigmoid @@ -17,6 +29,16 @@ def __init__(self): self.fc = nn.Sequential(nn.Linear(768, 10)) def forward(self, x): + """ + Forward pass of the LeNet model. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, num_classes). + + """ out = self.body(x) out = out.view(out.size(0), -1) out = self.fc(out) diff --git a/python/fedml/core/security/common/utils.py b/python/fedml/core/security/common/utils.py index 9e12a17fdc..f753be2c35 100644 --- a/python/fedml/core/security/common/utils.py +++ b/python/fedml/core/security/common/utils.py @@ -6,6 +6,15 @@ def vectorize_weight(state_dict): + """ + Vectorizes the weight tensors in the given state_dict. + + Args: + state_dict (OrderedDict): The state_dict containing model weights. + + Returns: + torch.Tensor: A concatenated tensor of flattened weights. + """ weight_list = [] for (k, v) in state_dict.items(): if is_weight_param(k): @@ -14,27 +23,62 @@ def vectorize_weight(state_dict): def is_weight_param(k): + """ + Checks if a parameter key is a weight parameter. + + Args: + k (str): The parameter key. + + Returns: + bool: True if the key corresponds to a weight parameter, False otherwise. + """ return ( - "running_mean" not in k - and "running_var" not in k - and "num_batches_tracked" not in k + "running_mean" not in k + and "running_var" not in k + and "num_batches_tracked" not in k ) def compute_euclidean_distance(v1, v2, device='cpu'): + """ + Computes the Euclidean distance between two tensors. + + Args: + v1 (torch.Tensor): The first tensor. + v2 (torch.Tensor): The second tensor. + device (str): The device for computation (default is 'cpu'). + + Returns: + torch.Tensor: The Euclidean distance between the two tensors. + """ v1 = v1.to(device) v2 = v2.to(device) return (v1 - v2).norm() def compute_model_norm(model): + """ + Computes the norm of a model's weights. + + Args: + model: The model. + + Returns: + torch.Tensor: The norm of the model's weights. + """ return vectorize_weight(model).norm() def compute_middle_point(alphas, model_list): """ - alphas: weights of model_dict - model_dict: a model submitted by a user + Computes the weighted sum of model weights. + + Args: + alphas (list): List of weights. + model_list (list): List of model weights. + + Returns: + numpy.ndarray: The weighted sum of model weights. """ sum_batch = torch.zeros(model_list[0].shape) for a, a_batch_w in zip(alphas, model_list): @@ -88,6 +132,15 @@ def compute_geometric_median(weights, client_grads): def get_total_sample_num(model_list): + """ + Calculates the total number of samples across multiple clients. + + Args: + model_list (list): List of tuples containing local sample numbers and model parameters. + + Returns: + int: Total number of samples. + """ sample_num = 0 for i in range(len(model_list)): local_sample_num, local_model_params = model_list[i] @@ -96,6 +149,17 @@ def get_total_sample_num(model_list): def get_malicious_client_id_list(random_seed, client_num, malicious_client_num): + """ + Generates a list of malicious client IDs. + + Args: + random_seed (int): Random seed for reproducibility. + client_num (int): Total number of clients. + malicious_client_num (int): Number of malicious clients to generate. + + Returns: + list: List of malicious client IDs. + """ if client_num == malicious_client_num: client_indexes = [client_index for client_index in range(client_num)] else: @@ -103,7 +167,8 @@ def get_malicious_client_id_list(random_seed, client_num, malicious_client_num): np.random.seed( random_seed ) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num), num_clients, replace=False) + client_indexes = np.random.choice( + range(client_num), num_clients, replace=False) print("malicious client_indexes = %s" % str(client_indexes)) return client_indexes @@ -112,9 +177,15 @@ def replace_original_class_with_target_class( data_labels, original_class_list=None, target_class_list=None ): """ - :param targets: Target class IDs - :type targets: list - :return: new class IDs + Replaces original class labels in data_labels with corresponding target class labels. + + Args: + data_labels (list): List of class labels. + original_class_list (list): List of original class labels to be replaced. + target_class_list (list): List of target class labels to replace with. + + Returns: + list: Updated list of class labels. """ if ( len(original_class_list) == 0 @@ -141,12 +212,11 @@ def replace_original_class_with_target_class( def log_client_data_statistics(poisoned_client_ids, train_data_local_dict): """ - Logs all client data statistics. + Logs data distribution statistics for each client in the dataset. - :param poisoned_client_ids: list of malicious clients - :type poisoned_client_ids: list - :param train_data_local_dict: distributed dataset - :type train_data_local_dict: list(tuple) + Args: + poisoned_client_ids (list): List of malicious client IDs. + train_data_local_dict (list): Distributed dataset. """ for client_idx in range(len(train_data_local_dict)): if client_idx in poisoned_client_ids: @@ -163,6 +233,13 @@ def log_client_data_statistics(poisoned_client_ids, train_data_local_dict): def get_client_data_stat(local_dataset): + """ + Prints data distribution statistics for a local dataset. + + Args: + local_dataset (Iterable): Local dataset. + + """ print("-==========================") targets_set = {} for batch_idx, (data, targets) in enumerate(local_dataset): @@ -200,17 +277,51 @@ def get_client_data_stat(local_dataset): def cross_entropy_for_onehot(pred, target): + """ + Computes the cross-entropy loss between predicted and target one-hot encoded vectors. + + Args: + pred (torch.Tensor): Predicted logit values. + target (torch.Tensor): Target one-hot encoded vectors. + + Returns: + torch.Tensor: Cross-entropy loss. + + """ return torch.mean(torch.sum(-target * F.log_softmax(pred, dim=-1), 1)) def label_to_onehot(target, num_classes=100): + """ + Converts class labels to one-hot encoded vectors. + + Args: + target (torch.Tensor): Class labels. + num_classes (int, optional): Number of classes. Defaults to 100. + + Returns: + torch.Tensor: One-hot encoded vectors. + + """ target = torch.unsqueeze(target, 1) - onehot_target = torch.zeros(target.size(0), num_classes, device=target.device) + onehot_target = torch.zeros(target.size( + 0), num_classes, device=target.device) onehot_target.scatter_(1, target, 1) return onehot_target def trimmed_mean(model_list, trimmed_num): + """ + Trims the list of models by removing a specified number of models from both ends. + + Args: + model_list (list): List of model tuples containing local sample numbers and gradients. + trimmed_num (int): Number of models to trim from each end. + + Returns: + list: Trimmed list of models. + + """ temp_model_list = [] for i in range(0, len(model_list)): local_sample_num, client_grad = model_list[i] @@ -221,18 +332,42 @@ def trimmed_mean(model_list, trimmed_num): compute_a_score(local_sample_num), ) ) - temp_model_list.sort(key=lambda grad: grad[2]) # sort by coordinate-wise scores - temp_model_list = temp_model_list[trimmed_num: len(model_list) - trimmed_num] + # sort by coordinate-wise scores + temp_model_list.sort(key=lambda grad: grad[2]) + temp_model_list = temp_model_list[trimmed_num: len( + model_list) - trimmed_num] model_list = [(t[0], t[1]) for t in temp_model_list] return model_list def compute_a_score(local_sample_number): + """ + Compute a score for a client based on its local sample number. + + Args: + local_sample_number (int): Number of local samples for a client. + + Returns: + int: A score for the client. + + """ # todo: change to coordinate-wise score return local_sample_number def compute_krum_score(vec_grad_list, client_num_after_trim, p=2): + """ + Compute Krum scores for clients based on their gradients. + + Args: + vec_grad_list (list): List of gradient vectors for each client. + client_num_after_trim (int): Number of clients to consider. + p (int, optional): Power parameter for distance calculation. Defaults to 2. + + Returns: + list: List of Krum scores for each client. + + """ krum_scores = [] num_client = len(vec_grad_list) for i in range(0, num_client): @@ -252,6 +387,16 @@ def compute_krum_score(vec_grad_list, client_num_after_trim, p=2): def compute_gaussian_distribution(score_list): + """ + Compute the mean (mu) and standard deviation (sigma) of a list of scores. + + Args: + score_list (list): List of scores. + + Returns: + Tuple[float, float]: Mean (mu) and standard deviation (sigma). + + """ n = len(score_list) mu = sum(list(score_list)) / n temp = 0 @@ -263,4 +408,15 @@ def compute_gaussian_distribution(score_list): def sample_some_clients(client_num, sampled_client_num): - return random.sample(range(client_num), sampled_client_num) \ No newline at end of file + """ + Sample a specified number of clients from the total number of clients. + + Args: + client_num (int): Total number of clients. + sampled_client_num (int): Number of clients to sample. + + Returns: + list: List of sampled client indices. + + """ + return random.sample(range(client_num), sampled_client_num) diff --git a/python/fedml/cross_device/server_mnn/fedml_aggregator.py b/python/fedml/cross_device/server_mnn/fedml_aggregator.py index cf6c2c23c1..aafa5539b0 100644 --- a/python/fedml/cross_device/server_mnn/fedml_aggregator.py +++ b/python/fedml/cross_device/server_mnn/fedml_aggregator.py @@ -18,6 +18,19 @@ class FedMLAggregator(object): def __init__( self, test_dataloader, worker_num, device, args, aggregator, ): + """ + Initialize the FedMLAggregator. + + Args: + test_dataloader: DataLoader for the test dataset. + worker_num: Number of worker nodes (clients). + device: The device (e.g., CPU or GPU) to use for computations. + args: Arguments for configuration. + aggregator: The aggregator used for federated learning aggregation. + + Returns: + None + """ self.aggregator = aggregator self.args = args @@ -32,23 +45,67 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Test the global model on the server using the MNN (Mobile Neural Network) file format. + + Args: + mnn_file_path: The path to the MNN file containing the global model. + round_idx: The current round index. + report_metrics: A boolean indicating whether to report metrics (default is True). + + Returns: + None + """ return self.aggregator.get_model_params() # TODO: refactor MNN-related file processing def get_global_model_params_file(self): + """ + Get the file path of the global model parameters. + + Returns: + str: File path of the global model parameters. + """ return self.aggregator.get_model_params_file() def set_global_model_params(self, model_parameters): - logging.info("FedDebug. model_parameters = {}".format(model_parameters)) + """ + Set the global model parameters. + + Args: + model_parameters: Parameters of the global model. + + Returns: + None + """ + logging.info( + "FedDebug. model_parameters = {}".format(model_parameters)) self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the results of local model training for aggregation. + + Args: + index (int): Index of the local client. + model_params: Parameters of the locally trained model. + sample_num (int): Number of samples used for training. + + Returns: + None + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.info("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -58,27 +115,48 @@ def check_whether_all_receive(self): return True def _test_individual_model_perf_before_agg(self, model_file_path, round_idx): - self.test_on_server_for_all_clients_mnn(model_file_path, round_idx, report_metrics=False) + """ + Test the performance of an individual model before aggregation. + + Args: + model_file_path (str): File path of the individual model. + round_idx (int): Index of the current federated learning round. + + Returns: + None + """ + self.test_on_server_for_all_clients_mnn( + model_file_path, round_idx, report_metrics=False) def aggregate(self): + """ + Aggregate local model updates to obtain the global model. + + Returns: + averaged_params: Averaged global model parameters. + """ logging.info("FedMLDebug. Individual model performance:") for idx in range(self.worker_num): - logging.info("self.model_dict[idx] = {}".format(self.model_dict[idx])) + logging.info("self.model_dict[idx] = {}".format( + self.model_dict[idx])) mnn_file_path = self.model_dict[idx] - self._test_individual_model_perf_before_agg(mnn_file_path, self.args.round_idx) + self._test_individual_model_perf_before_agg( + mnn_file_path, self.args.round_idx) start_time = time.time() model_list = [] training_num = 0 for idx in range(self.worker_num): - logging.info("self.model_dict[idx] = {}".format(self.model_dict[idx])) + logging.info("self.model_dict[idx] = {}".format( + self.model_dict[idx])) mnn_file_path = self.model_dict[idx] tensor_params_dict = read_mnn_as_tensor_dict(mnn_file_path) model_list.append((self.sample_num_dict[idx], tensor_params_dict)) training_num += self.sample_num_dict[idx] logging.info("training_num = {}".format(training_num)) - logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) + logging.info( + "len of self.model_dict[idx] = " + str(len(self.model_dict))) # logging.info("################aggregate: %d" % len(model_list)) averaged_params = self.aggregator.aggregate(model_list) @@ -102,14 +180,17 @@ def data_silo_selection(self, round_idx, data_silo_num_in_total, client_num_in_t """ logging.info( - "data_silo_num_in_total = %d, client_num_in_total = %d" % (data_silo_num_in_total, client_num_in_total) + "data_silo_num_in_total = %d, client_num_in_total = %d" % ( + data_silo_num_in_total, client_num_in_total) ) assert data_silo_num_in_total >= client_num_in_total if client_num_in_total == data_silo_num_in_total: return [i for i in range(data_silo_num_in_total)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(data_silo_num_in_total), client_num_in_total, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(data_silo_num_in_total), client_num_in_total, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): @@ -126,21 +207,49 @@ def client_selection(self, round_idx, client_id_list_in_total, client_num_per_ro """ if client_num_per_round == len(client_id_list_in_total) or len(client_id_list_in_total) == 1: # for debugging return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Description of the client_sampling method. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + client_indexes: List of selected client indexes. + """ + if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _test(self, test_data, device, args): + """ + Description of the _test method. + + Args: + test_data: Test data. + device: Device on which to perform testing. + args: Additional arguments. + + Returns: + metrics: Dictionary containing test metrics. + """ model = self.model model.to(device) @@ -166,6 +275,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Description of the test method. + + Args: + test_data: Test data. + device: Device on which to perform testing. + args: Additional arguments. + + Returns: + Tuple containing test accuracy, test loss, and additional metrics. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -199,7 +319,8 @@ def test(self, test_data, device, args): def test_on_server_for_all_clients(self, round_idx, global_model_file=None): if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -209,10 +330,13 @@ def test_on_server_for_all_clients(self, round_idx, global_model_file=None): if round_idx == self.args.comm_round - 1: # we allow to return four metrics, such as accuracy, AUC, loss, etc. - metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + metric_result_in_current_round = self.aggregator.test( + self.test_global, self.device, self.args) else: - metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) - logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) + metric_result_in_current_round = self.aggregator.test( + self.val_global, self.device, self.args) + logging.info("metric_result_in_current_round = {}".format( + metric_result_in_current_round)) if round_idx == self.args.comm_round - 1: mlops.log({"round_idx": round_idx}) @@ -237,8 +361,10 @@ def test_on_server_for_all_clients_mnn(self, mnn_file_path, round_idx, report_me example = self.test_global.next() input_data = example[0] output_target = example[1] - data = input_data[0] # which input, model may have more than one inputs - label = output_target[0] # also, model may have more than one outputs + # which input, model may have more than one inputs + data = input_data[0] + # also, model may have more than one outputs + label = output_target[0] result = module.forward(data) predict = F.argmax(result, 1) @@ -250,13 +376,15 @@ def test_on_server_for_all_clients_mnn(self, mnn_file_path, round_idx, report_me target = F.one_hot(F.cast(label, F.int), 10, 1, 0) loss = nn.loss.cross_entropy(result, target) - logging.info(f"correct = {correct}, self.test_global.size = {self.test_global.size}") + logging.info( + f"correct = {correct}, self.test_global.size = {self.test_global.size}") test_accuracy = correct / self.test_global.size test_loss = loss.read() if report_metrics: logging.info("test acc = {}".format(test_accuracy)) - logging.info("test loss = {}, round loss {}".format(test_loss, round(float(np.round(test_loss, 4)), 4))) + logging.info("test loss = {}, round loss {}".format( + test_loss, round(float(np.round(test_loss, 4)), 4))) mlops.log( { @@ -268,5 +396,6 @@ def test_on_server_for_all_clients_mnn(self, mnn_file_path, round_idx, report_me if self.args.enable_wandb: wandb.log( - {"round idx": round_idx, "test acc": test_accuracy, "test loss": test_loss, } + {"round idx": round_idx, "test acc": test_accuracy, + "test loss": test_loss, } ) diff --git a/python/fedml/cross_device/server_mnn/fedml_server_manager.py b/python/fedml/cross_device/server_mnn/fedml_server_manager.py index 12b49ae68c..99ddf07579 100644 --- a/python/fedml/cross_device/server_mnn/fedml_server_manager.py +++ b/python/fedml/cross_device/server_mnn/fedml_server_manager.py @@ -12,6 +12,21 @@ class FedMLServerManager(FedMLCommManager): + """ + Federated Learning Server Manager. + + This class manages the server-side operations of federated learning. + + Args: + args: Arguments for the federated learning process. + aggregator: Server aggregator for aggregating model updates. + comm: Communication backend for distributed training (default: None). + rank (int): The rank of the current worker (default: 0). + size (int): The total number of workers (default: 0). + backend (str): The communication backend (default: "MPI"). + is_preprocessed (bool): Flag indicating if data is preprocessed (default: False). + preprocessed_client_lists: List of preprocessed client data (default: None). + """ ONLINE_STATUS_FLAG = "ONLINE" RUN_FINISHED_STATUS_FLAG = "FINISHED" @@ -37,10 +52,12 @@ def __init__( self.global_model_file_path = self.args.global_model_file_path self.model_file_cache_folder = self.args.model_file_cache_folder logging.info( - "self.global_model_file_path = {}".format(self.global_model_file_path) + "self.global_model_file_path = {}".format( + self.global_model_file_path) ) logging.info( - "self.model_file_cache_folder = {}".format(self.model_file_cache_folder) + "self.model_file_cache_folder = {}".format( + self.model_file_cache_folder) ) self.client_online_mapping = {} @@ -56,6 +73,14 @@ def run(self): super().run() def start_train(self): + """ + Start the federated training process. + + This method initiates federated training by sending start training messages to all clients. + + Returns: + None + """ start_train_json = { "edges": [ { @@ -148,7 +173,8 @@ def start_train(self): "timestamp": "1651635148138", } for client_id in self.client_real_ids: - logging.info("com_manager_status - client_id = {}".format(client_id)) + logging.info( + "com_manager_status - client_id = {}".format(client_id)) self.send_message_json( "flserver_agent/" + str(client_id) + "/start_train", json.dumps(start_train_json), @@ -162,6 +188,13 @@ def send_init_msg(self): MNN (file) -> numpy -> pytorch -> aggregation -> numpy -> MNN (the same file) S2C - send the model to clients send MNN file + + Initialize and send model to clients. + + This method sends the initial model to clients to start the federated learning process. + + Returns: + """ global_model_url = None global_model_key = None @@ -174,16 +207,26 @@ def send_init_msg(self): self.data_silo_index_list[client_idx_in_this_round], global_model_url, global_model_key ) - logging.info(f"global_model_url = {global_model_url}, global_model_key = {global_model_key}") + logging.info( + f"global_model_url = {global_model_url}, global_model_key = {global_model_key}") client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) # Todo: for serving the cross-device model, # how to transform it to pytorch and upload the model network to ModelOps # mlops.log_training_model_net_info(self.aggregator.aggregator.model) def register_message_receive_handlers(self): + """ + Register message receive handlers. + + This method registers message handlers for processing incoming messages from clients. + + Returns: + None + """ print("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, @@ -199,9 +242,20 @@ def register_message_receive_handlers(self): ) def process_online_status(self, client_status, msg_params): + """ + Process online status message from clients. + + Args: + client_status (str): The status message from clients. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_online_mapping[str(msg_params.get_sender_id())] = True - logging.info("self.client_online_mapping = {}".format(self.client_online_mapping)) + logging.info("self.client_online_mapping = {}".format( + self.client_online_mapping)) all_client_is_online = True for client_id in self.client_id_list_in_this_round: @@ -210,17 +264,29 @@ def process_online_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: - mlops.log_aggregation_status(MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) + mlops.log_aggregation_status( + MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) # send initialization message to all clients to start training self.send_init_msg() self.is_initialized = True def process_finished_status(self, client_status, msg_params): + """ + Process finished status message from clients. + + Args: + client_status (str): The status message from clients. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_finished_mapping[str(msg_params.get_sender_id())] = True all_client_is_finished = True @@ -230,7 +296,8 @@ def process_finished_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_finished = %s" % (msg_params.get_sender_id(), str(all_client_is_finished)) + "sender_id = %d, all_client_is_finished = %s" % ( + msg_params.get_sender_id(), str(all_client_is_finished)) ) if all_client_is_finished: @@ -239,6 +306,15 @@ def process_finished_status(self, client_status, msg_params): self.finish() def handle_message_client_status_update(self, msg_params): + """ + Handle client status update message. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == FedMLServerManager.ONLINE_STATUS_FLAG: self.process_online_status(client_status, msg_params) @@ -246,6 +322,15 @@ def handle_message_client_status_update(self, msg_params): self.process_finished_status(client_status, msg_params) def handle_message_connection_ready(self, msg_params): + """ + Handle connection ready message. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( self.args.round_idx, self.client_real_ids, self.args.client_num_per_round @@ -270,22 +355,34 @@ def handle_message_connection_ready(self, msg_params): self.send_message_check_client_status( client_id, self.data_silo_index_list[client_idx_in_this_round], ) - logging.info("Connection ready for client: " + str(client_id)) + logging.info( + "Connection ready for client: " + str(client_id)) except Exception as e: logging.info("Connection not ready for client: {}".format( str(client_id), traceback.format_exc())) client_idx_in_this_round += 1 def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model from client. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) + mlops.event("comm_c2s", event_started=False, event_value=str( + self.args.round_idx), event_edge_id=sender_id) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) b_all_received = self.aggregator.check_whether_all_receive() logging.info("b_all_received = %s " % str(b_all_received)) @@ -298,23 +395,26 @@ def handle_message_receive_model_from_client(self, msg_params): ) logging.info("=================================================") - mlops.event("server.wait", event_started=False, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=False, + event_value=str(self.args.round_idx)) - mlops.event("server.agg_and_eval", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=True, + event_value=str(self.args.round_idx)) global_model_params = self.aggregator.aggregate() - + # self.aggregator.test_on_server_for_all_clients( # self.args.round_idx, self.global_model_file_path # ) - - write_tensor_dict_to_mnn(self.global_model_file_path, global_model_params) + + write_tensor_dict_to_mnn( + self.global_model_file_path, global_model_params) self.aggregator.test_on_server_for_all_clients_mnn( self.global_model_file_path, self.args.round_idx ) - - mlops.event("server.agg_and_eval", event_started=False, event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=False, + event_value=str(self.args.round_idx)) # send round info to the MQTT backend mlops.log_round_info(self.round_num, self.args.round_idx) @@ -333,8 +433,8 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_key = None logging.info("round idx {}, client_num_in_total {}, data_silo_index_list length {}," "client_id_list_in_this_round length {}.".format( - self.args.round_idx, self.args.client_num_in_total, - len(self.data_silo_index_list), len(self.client_id_list_in_this_round))) + self.args.round_idx, self.args.client_num_in_total, + len(self.data_silo_index_list), len(self.client_id_list_in_this_round))) for receiver_id in self.client_id_list_in_this_round: global_model_url, global_model_key = self.send_message_sync_model_to_client( receiver_id, @@ -350,11 +450,21 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx, model_url=global_model_url, ) - logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) + logging.info( + "\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) if self.args.round_idx < self.round_num: - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) def cleanup(self): + """ + Clean up and send finish message to clients. + + This method sends a finish message to all clients to indicate the completion of the federated learning round. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -363,25 +473,55 @@ def cleanup(self): client_idx_in_this_round += 1 def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send finish message to a client. + + Args: + receive_id: The ID of the client to receive the finish message. + datasilo_index: The data silo index associated with the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) logging.info( "finish from send id {} to receive id {}.".format(message.get_sender_id(), message.get_receiver_id())) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_init_config(self, receive_id, global_model_params, client_index, global_model_url, global_model_key): + """ + Send initialization configuration message to a client. + + Args: + receive_id: The ID of the client to receive the message. + global_model_params: The global model parameters to be sent. + client_index: The client's index. + global_model_url: URL for global model parameters (if available). + global_model_key: Key for global model parameters (if available). + + Returns: + Tuple: A tuple containing the global model URL and key after sending the message. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) logging.info("global_model_params = {}".format(global_model_params)) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "AndroidClient") self.send_message(message) @@ -390,28 +530,58 @@ def send_message_init_config(self, receive_id, global_model_params, client_index return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): + """ + Send message to check client status. + + Args: + receive_id: The ID of the client to receive the message. + datasilo_index: The data silo index associated with the client. + + Returns: + None + """ + message = Message( MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id ) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_sync_model_to_client( self, receive_id, global_model_params, data_silo_index, global_model_url=None, global_model_key=None ): - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) + """ + Send model synchronization message to a client. + + Args: + receive_id: The ID of the client to receive the model synchronization message. + global_model_params: The global model parameters to be synchronized. + data_silo_index: The data silo index associated with the client. + global_model_url: URL for global model parameters (if available). + global_model_key: Key for global model parameters (if available). + + Returns: + Tuple: A tuple containing the global model URL and key after sending the message. + """ + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(data_silo_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(data_silo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "AndroidClient") self.send_message(message) diff --git a/python/fedml/cross_device/server_mnn/server_mnn_api.py b/python/fedml/cross_device/server_mnn/server_mnn_api.py index e0ca0786ee..ce33707b4d 100644 --- a/python/fedml/cross_device/server_mnn/server_mnn_api.py +++ b/python/fedml/cross_device/server_mnn/server_mnn_api.py @@ -6,27 +6,68 @@ def fedavg_cross_device(args, process_id, worker_number, comm, device, test_dataloader, model, server_aggregator=None): - logging.info("test_data_global.iter_number = {}".format(test_dataloader.iter_number)) + """ + Federated Averaging across Multiple Devices (Cross-Device Aggregation). + + This function performs federated averaging across multiple devices using cross-device aggregation. + + Args: + args: Arguments for the federated learning process. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + comm: Communication backend for distributed training. + device: The device (e.g., CPU or GPU) to perform computations. + test_dataloader: DataLoader for the test dataset. + model: The federated learning model. + server_aggregator: Server aggregator for aggregating model updates (default: None). + + Returns: + None + """ + logging.info("test_data_global.iter_number = {}".format( + test_dataloader.iter_number)) if process_id == 0: - init_server(args, device, comm, process_id, worker_number, model, test_dataloader, server_aggregator) + init_server(args, device, comm, process_id, worker_number, + model, test_dataloader, server_aggregator) def init_server(args, device, comm, rank, size, model, test_dataloader, aggregator): + """ + Initialize the Federated Learning Server. + + This function initializes the federated learning server for aggregation. + + Args: + args: Arguments for the federated learning process. + device: The device (e.g., CPU or GPU) to perform computations. + comm: Communication backend for distributed training. + rank (int): The rank of the current worker. + size (int): The total number of workers. + model: The federated learning model. + test_dataloader: DataLoader for the test dataset. + aggregator: Server aggregator for aggregating model updates. + + Returns: + None + """ if aggregator is None: aggregator = create_server_aggregator(model, args) aggregator.set_id(-1) td_id = id(test_dataloader) logging.info("test_dataloader = {}".format(td_id)) - logging.info("test_data_global.iter_number = {}".format(test_dataloader.iter_number)) + logging.info("test_data_global.iter_number = {}".format( + test_dataloader.iter_number)) worker_num = size - aggregator = FedMLAggregator(test_dataloader, worker_num, device, args, aggregator) + aggregator = FedMLAggregator( + test_dataloader, worker_num, device, args, aggregator) - # start the distributed training + # Start the distributed training backend = args.backend - server_manager = FedMLServerManager(args, aggregator, comm, rank, size, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, rank, size, backend) if not args.using_mlops: server_manager.start_train() server_manager.run() From 19b64e5bd4e65d85601bc6467397621ccd235474 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 22 Sep 2023 18:10:24 +0530 Subject: [PATCH 31/70] push --- python/fedml/core/dp/common/utils.py | 62 ++++ .../core/dp/fedml_differential_privacy.py | 43 ++- python/fedml/core/dp/frames/NbAFL.py | 75 ++++- .../fedml/core/dp/frames/base_dp_solution.py | 92 +++++- python/fedml/core/dp/frames/cdp.py | 41 ++- python/fedml/core/dp/frames/dp_clip.py | 102 +++++- python/fedml/core/dp/frames/ldp.py | 29 +- .../fedml/core/dp/mechanisms/dp_mechanism.py | 73 ++++- python/fedml/core/dp/mechanisms/gaussian.py | 54 +++- python/fedml/core/dp/mechanisms/laplace.py | 38 ++- .../dp/test/test_fed_privacy_mechanism.py | 18 ++ python/fedml/core/mlops/mlops_configs.py | 100 +++++- python/fedml/core/mlops/mlops_device_perfs.py | 78 ++++- python/fedml/core/mlops/mlops_job_perfs.py | 60 +++- python/fedml/core/mlops/mlops_metrics.py | 305 ++++++++++++++++-- .../fedml/core/mlops/mlops_profiler_event.py | 55 ++++ python/fedml/core/mlops/mlops_runtime_log.py | 60 +++- .../core/mlops/mlops_runtime_log_daemon.py | 208 ++++++++++-- python/fedml/core/mlops/mlops_status.py | 85 +++++ python/fedml/core/mlops/mlops_utils.py | 39 ++- python/fedml/core/mlops/stats_impl.py | 88 ++++- python/fedml/core/mlops/system_stats.py | 25 ++ 22 files changed, 1596 insertions(+), 134 deletions(-) diff --git a/python/fedml/core/dp/common/utils.py b/python/fedml/core/dp/common/utils.py index 91558cb6ea..46b37f6048 100644 --- a/python/fedml/core/dp/common/utils.py +++ b/python/fedml/core/dp/common/utils.py @@ -7,6 +7,20 @@ def check_bounds(lower, upper): + """ + Check if the provided lower and upper bounds are valid. + + Args: + lower (Real): The lower bound. + upper (Real): The upper bound. + + Returns: + Tuple[Real, Real]: A tuple containing the validated lower and upper bounds. + + Raises: + TypeError: If lower or upper is not a numeric type. + ValueError: If the lower bound is greater than the upper bound. + """ if not isinstance(lower, Real) or not isinstance(upper, Real): raise TypeError("Bounds must be numeric") if lower > upper: @@ -15,18 +29,54 @@ def check_bounds(lower, upper): def check_numeric_value(value): + """ + Check if the provided value is a numeric type. + + Args: + value (Real): The value to be checked. + + Returns: + bool: True if the value is numeric, False otherwise. + + Raises: + TypeError: If the value is not a numeric type. + """ if not isinstance(value, Real): raise TypeError("Value to be randomised must be a number") return True def check_integer_value(value): + """ + Check if the provided value is an integer. + + Args: + value (Integral): The value to be checked. + + Returns: + bool: True if the value is an integer, False otherwise. + + Raises: + TypeError: If the value is not an integer. + """ if not isinstance(value, Integral): raise TypeError("Value to be randomised must be an integer") return True def check_epsilon_delta(epsilon, delta, allow_zero=False): + """ + Check if the provided epsilon and delta values are valid for differential privacy. + + Args: + epsilon (Real): Epsilon value. + delta (Real): Delta value. + allow_zero (bool, optional): Whether to allow epsilon and delta to be zero. Default is False. + + Raises: + TypeError: If epsilon or delta is not a numeric type. + ValueError: If epsilon is negative, delta is outside [0, 1] range, or both epsilon and delta are zero. + """ if not isinstance(epsilon, Real) or not isinstance(delta, Real): raise TypeError("Epsilon and delta must be numeric") if epsilon < 0: @@ -38,6 +88,18 @@ def check_epsilon_delta(epsilon, delta, allow_zero=False): def check_params(epsilon, delta, sensitivity): + """ + Check the validity of epsilon, delta, and sensitivity parameters for differential privacy. + + Args: + epsilon (Real): Epsilon value. + delta (Real): Delta value. + sensitivity (Real): Sensitivity value. + + Raises: + TypeError: If epsilon, delta, or sensitivity is not a numeric type. + ValueError: If epsilon is negative, delta is outside [0, 1] range, or sensitivity is negative. + """ check_epsilon_delta(epsilon, delta, allow_zero=False) if not isinstance(sensitivity, Real): raise TypeError("Sensitivity must be numeric") diff --git a/python/fedml/core/dp/fedml_differential_privacy.py b/python/fedml/core/dp/fedml_differential_privacy.py index de76817a48..5b81b5470a 100644 --- a/python/fedml/core/dp/fedml_differential_privacy.py +++ b/python/fedml/core/dp/fedml_differential_privacy.py @@ -11,6 +11,33 @@ class FedMLDifferentialPrivacy: + """ + A class for managing Differential Privacy in Federated Learning. + + Attributes: + enable_rdp_accountant (bool): Flag indicating if RDP accountant is enabled. + max_grad_norm (float): Maximum gradient norm for clipping. + dp_solution_type (str): Type of differential privacy solution (e.g., 'gaussian', 'laplace'). + dp_solution: An instance of the differential privacy solution. + dp_accountant: An instance of the differential privacy accountant. + is_enabled (bool): Flag indicating if differential privacy is enabled. + privacy_engine: The privacy engine used for differential privacy. + current_round (int): Current federated learning round. + accountant: An accountant for tracking privacy budget consumption. + delta (float): Delta value for differential privacy. + + Methods: + init(args): Initialize the differential privacy settings based on command-line arguments. + is_dp_enabled(): Check if differential privacy is enabled. + is_local_dp_enabled(): Check if local differential privacy is enabled. + is_global_dp_enabled(): Check if global differential privacy is enabled. + is_clipping(): Check if gradient clipping is enabled. + to_compute_params_in_aggregation_enabled(): Check if computing parameters in aggregation is enabled. + global_clip(raw_client_model_or_grad_list): Apply global gradient clipping. + add_local_noise(local_grad): Add local noise to gradients. + add_global_noise(global_model): Add global noise to the global model. + set_params_for_dp(raw_client_model_or_grad_list): Set parameters for differential privacy. + """ _dp_instance = None @staticmethod @@ -20,6 +47,12 @@ def get_instance(): return FedMLDifferentialPrivacy._dp_instance def __init__(self): + """ + Initialize differential privacy settings based on command-line arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + """ self.enable_rdp_accountant = False self.max_grad_norm = None self.dp_solution_type = None @@ -33,7 +66,8 @@ def __init__(self): def init(self, args): if hasattr(args, "enable_dp") and args.enable_dp: - logging.info(".......init dp......." + args.dp_solution_type + "-" + args.dp_solution_type) + logging.info(".......init dp......." + + args.dp_solution_type + "-" + args.dp_solution_type) self.is_enabled = True self.dp_solution_type = args.dp_solution_type.strip() if hasattr(args, "max_grad_norm"): @@ -67,6 +101,12 @@ def init(self, args): self.is_enabled = False def is_dp_enabled(self): + """ + Check if differential privacy is enabled. + + Returns: + bool: True if differential privacy is enabled, False otherwise. + """ return self.is_enabled def is_local_dp_enabled(self): @@ -101,4 +141,3 @@ def set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, Ord if self.dp_solution is None: raise Exception("dp solution is not initialized!") self.dp_solution.set_params_for_dp(raw_client_model_or_grad_list) - diff --git a/python/fedml/core/dp/frames/NbAFL.py b/python/fedml/core/dp/frames/NbAFL.py index 8c7d3c3f33..83c01eb816 100644 --- a/python/fedml/core/dp/frames/NbAFL.py +++ b/python/fedml/core/dp/frames/NbAFL.py @@ -12,7 +12,33 @@ class NbAFL_DP(BaseDPFrame): + """ + Non-Blocking Asynchronous Federated Learning with Differential Privacy Mechanism. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + big_C_clipping (float): A clipping threshold for bounding model weights. + total_round_num (int): The total number of communication rounds. + small_c_constant (float): A constant used in the mechanism. + client_num_per_round (int): The number of clients participating in each round. + client_num_in_total (int): The total number of clients. + epsilon (float): The privacy parameter epsilon. + m (int): The minimum size of local datasets. + + Methods: + __init__(self, args): Initialize the NbAFL_DP mechanism. + add_local_noise(self, local_grad: OrderedDict): Add local noise to the gradients. + add_global_noise(self, global_model: OrderedDict): Add global noise to the global model. + set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): Set parameters for DP. + """ + def __init__(self, args): + """ + Initialize the NbAFL_DP mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) self.set_ldp( DPMechanism( @@ -22,39 +48,64 @@ def __init__(self, args): ) ) """ - In the experiments, the authors choosed C by taking the median of the norms of the unclipped parameters. - This is not practical in reality. The server can not obtain unclipped plaintext parameters. It can only - get noised clipped parameters. So here we set C as a parameter that indicated by users. + In the experiments, the authors chose C by taking the median of the norms of the unclipped parameters. + This is not practical in reality. The server cannot obtain unclipped plaintext parameters. It can only + get noised clipped parameters. So here we set C as a parameter indicated by users. """ self.big_C_clipping = args.C # C: a clipping threshold for bounding w_i self.total_round_num = args.comm_round # T in the paper self.small_c_constant = np.sqrt( - 2 * math.log(1.25 / args.delta)) # the author indicated c>= sqrt(2ln(1.25/delta) - self.client_num_per_round = args.client_num_per_round # L in the paper - self.client_num_in_total = args.client_num_in_total # N in the paper + 2 * math.log(1.25 / args.delta)) # the author indicated c >= sqrt(2ln(1.25/delta) + self.client_num_per_round = args.client_num_per_round # L in the paper + self.client_num_in_total = args.client_num_in_total # N in the paper self.epsilon = args.epsilon # 0 < epsilon < 1 - """ The author said ''m is the minimum size of the local datasets''. + """ The author said ''m is the minimum size of the local datasets''. In their paper, clients did not sample local data for training; In our setting, we set m to the minimum sample num of each round.""" self.m = 0 # the minimum size of the local datasets def add_local_noise(self, local_grad: OrderedDict): + """ + Add local noise to the gradients. + + Args: + local_grad (OrderedDict): Local gradients. + + Returns: + OrderedDict: Local gradients with added noise. + """ for k in local_grad.keys(): # Clip weight local_grad[k] = local_grad[k] / torch.max(torch.ones(size=local_grad[k].shape), torch.abs(local_grad[k]) / self.big_C_clipping) return super().add_local_noise(local_grad=local_grad) def add_global_noise(self, global_model: OrderedDict): + """ + Add global noise to the global model. + + Args: + global_model (OrderedDict): Global model parameters. + + Returns: + OrderedDict: Global model parameters with added noise. + """ if self.total_round_num > np.sqrt(self.client_num_in_total) * self.client_num_per_round: - scale_d = 2 * self.small_c_constant * self.big_C_clipping * np.sqrt(np.power(self.total_round_num, 2) - - np.power(self.client_num_per_round, - 2) * self.client_num_in_total) / ( - self.m * self.client_num_in_total * self.epsilon) + scale_d = 2 * self.small_c_constant * self.big_C_clipping * np.sqrt( + np.power(self.total_round_num, 2) - + np.power(self.client_num_per_round, 2) * self.client_num_in_total) / ( + self.m * self.client_num_in_total * self.epsilon) for k in global_model.keys(): - global_model[k] = Gaussian.compute_noise_using_sigma(scale_d, global_model[k].shape) + global_model[k] = Gaussian.compute_noise_using_sigma( + scale_d, global_model[k].shape) return global_model def set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): + """ + Set parameters for Differential Privacy. + + Args: + raw_client_model_or_grad_list (List[Tuple[float, OrderedDict]]): List of tuples containing sample numbers and gradients/models. + """ smallest_sample_num, _ = raw_client_model_or_grad_list[0] for (sample_num, _) in raw_client_model_or_grad_list: if smallest_sample_num > sample_num: diff --git a/python/fedml/core/dp/frames/base_dp_solution.py b/python/fedml/core/dp/frames/base_dp_solution.py index e0647eb903..de3f91d097 100644 --- a/python/fedml/core/dp/frames/base_dp_solution.py +++ b/python/fedml/core/dp/frames/base_dp_solution.py @@ -6,7 +6,34 @@ class BaseDPFrame(ABC): + """ + Abstract base class for Differential Privacy mechanisms. + + Attributes: + cdp: A DPMechanism instance for global differential privacy. + ldp: A DPMechanism instance for local differential privacy. + args: A namespace containing the configuration arguments for the mechanism. + is_rdp_accountant_enabled: A boolean indicating whether RDP accountant is enabled. + max_grad_norm: Maximum gradient norm for gradient clipping. + + Methods: + __init__(self, args=None): Initialize the BaseDPFrame instance. + set_cdp(self, dp_mechanism: DPMechanism): Set the global differential privacy mechanism. + set_ldp(self, dp_mechanism: DPMechanism): Set the local differential privacy mechanism. + add_local_noise(self, local_grad: OrderedDict): Add local noise to local gradients. + add_global_noise(self, global_model: OrderedDict): Add global noise to global model parameters. + set_params_for_dp(self, raw_client_model_or_grad_list): Set parameters for differential privacy mechanism. + get_rdp_accountant_val(self): Get the differential privacy parameter for RDP accountant. + global_clip(self, raw_client_model_or_grad_list): Apply gradient clipping to global gradients. + """ + def __init__(self, args=None): + """ + Initialize the BaseDPFrame instance. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ self.cdp = None self.ldp = None self.args = args @@ -17,21 +44,66 @@ def __init__(self, args=None): self.max_grad_norm = None def set_cdp(self, dp_mechanism: DPMechanism): + """ + Set the global differential privacy mechanism. + + Args: + dp_mechanism (DPMechanism): A DPMechanism instance for global differential privacy. + """ self.cdp = dp_mechanism def set_ldp(self, dp_mechanism: DPMechanism): + """ + Set the local differential privacy mechanism. + + Args: + dp_mechanism (DPMechanism): A DPMechanism instance for local differential privacy. + """ self.ldp = dp_mechanism + @abstractmethod def add_local_noise(self, local_grad: OrderedDict): - return self.ldp.add_noise(grad=local_grad) + """ + Add local noise to local gradients. + + Args: + local_grad (OrderedDict): Local gradients. + Returns: + OrderedDict: Local gradients with added noise. + """ + pass + + @abstractmethod def add_global_noise(self, global_model: OrderedDict): - return self.cdp.add_noise(grad=global_model) + """ + Add global noise to global model parameters. + + Args: + global_model (OrderedDict): Global model parameters. - def set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): + Returns: + OrderedDict: Global model parameters with added noise. + """ + pass + + @abstractmethod + def set_params_for_dp(self, raw_client_model_or_grad_list): + """ + Set parameters for differential privacy mechanism. + + Args: + raw_client_model_or_grad_list: List of raw client models or gradients. + """ pass def get_rdp_accountant_val(self): + """ + Get the differential privacy parameter for RDP accountant. + + Returns: + float: Differential privacy parameter. + """ if self.cdp is not None: dp_param = self.cdp.get_rdp_scale() elif self.ldp is not None: @@ -40,7 +112,16 @@ def get_rdp_accountant_val(self): raise Exception("can not create rdp accountant") return dp_param - def global_clip(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): + def global_clip(self, raw_client_model_or_grad_list): + """ + Apply gradient clipping to global gradients. + + Args: + raw_client_model_or_grad_list: List of raw client models or gradients. + + Returns: + List: List of clipped client models or gradients. + """ if self.max_grad_norm is None: return raw_client_model_or_grad_list new_grad_list = [] @@ -54,6 +135,3 @@ def global_clip(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDi local_grad[k].mul_(clip_coef_clamped) new_grad_list.append((num, local_grad)) return new_grad_list - - - diff --git a/python/fedml/core/dp/frames/cdp.py b/python/fedml/core/dp/frames/cdp.py index ccc65b5822..58bfa7ca18 100644 --- a/python/fedml/core/dp/frames/cdp.py +++ b/python/fedml/core/dp/frames/cdp.py @@ -6,16 +6,49 @@ class GlobalDP(BaseDPFrame): + """ + Differential Privacy mechanism with global noise. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + enable_rdp_accountant: A boolean indicating whether RDP accountant is enabled. + is_rdp_accountant_enabled: A boolean indicating whether RDP accountant is enabled. + sample_rate: Sample rate for RDP accountant. + accountant: RDP accountant for privacy analysis. + + Methods: + __init__(self, args): Initialize the GlobalDP mechanism. + add_global_noise(self, global_model: OrderedDict): Add global noise to the global model parameters. + """ + def __init__(self, args): + """ + Initialize the GlobalDP mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) - self.set_cdp(DPMechanism(args.mechanism_type, args.epsilon, args.delta, args.sensitivity)) + self.set_cdp(DPMechanism(args.mechanism_type, + args.epsilon, args.delta, args.sensitivity)) self.enable_rdp_accountant = False if hasattr(args, "enable_rdp_accountant") and args.enable_rdp_accountant: self.is_rdp_accountant_enabled = True self.sample_rate = args.client_num_per_round / args.client_num_in_total - self.accountant = RDP_Accountant(alpha=args.rdp_alpha, dp_mechanism=args.mechanism_type, args=args) + self.accountant = RDP_Accountant( + alpha=args.rdp_alpha, dp_mechanism=args.mechanism_type, args=args) def add_global_noise(self, global_model: OrderedDict): + """ + Add global noise to the global model parameters. + + Args: + global_model (OrderedDict): Global model parameters. + + Returns: + OrderedDict: Global model parameters with added global noise. + """ if self.is_rdp_accountant_enabled: - self.accountant.step(noise_multiplier=self.cdp.get_rdp_scale(), sample_rate=self.sample_rate) # todo: ask??? - return super().add_global_noise(global_model=global_model) \ No newline at end of file + self.accountant.step( + noise_multiplier=self.cdp.get_rdp_scale(), sample_rate=self.sample_rate) + return super().add_global_noise(global_model=global_model) diff --git a/python/fedml/core/dp/frames/dp_clip.py b/python/fedml/core/dp/frames/dp_clip.py index b6e7d02b65..eb44037e53 100644 --- a/python/fedml/core/dp/frames/dp_clip.py +++ b/python/fedml/core/dp/frames/dp_clip.py @@ -13,49 +13,119 @@ """ class DP_Clip(BaseDPFrame): + """ + Differential Privacy mechanism with gradient clipping. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + + Methods: + __init__(self, args): Initialize the DP_Clip mechanism. + clip_local_update(self, local_grad, norm_type: float = 2.0): Clip local gradients. + add_local_noise(self, local_grad: OrderedDict, extra_auxiliary_info: Any = None): Add local noise to gradients. + add_global_noise(self, global_model: OrderedDict): Add global noise to the global model parameters. + get_global_params(self): Get global parameters. + compute_noise(self, size, qw): Compute noise. + add_noise(self, w_global, qw): Add noise to global parameters. + """ + def __init__(self, args): + """ + Initialize the DP_Clip mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) self.clipping_norm = args.clipping_norm self.train_data_num_in_total = args.train_data_num_in_total self._scale = args.clipping_norm * args.noise_multiplier def clip_local_update(self, local_grad, norm_type: float = 2.0): - total_norm = torch.norm(torch.stack([torch.norm(local_grad[k], norm_type) for k in local_grad.keys()]), norm_type) + """ + Clip local gradients. + + Args: + local_grad (OrderedDict): Local gradients. + norm_type (float): Type of norm to compute (default is 2.0). + + Returns: + OrderedDict: Clipped local gradients. + """ + total_norm = torch.norm(torch.stack( + [torch.norm(local_grad[k], norm_type) for k in local_grad.keys()]), norm_type) clip_coef = self.clipping_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) for k in local_grad.keys(): local_grad[k].mul_(clip_coef_clamped) return local_grad - def add_local_noise(self, local_grad: OrderedDict, extra_auxiliary_info: Any = None,): + def add_local_noise(self, local_grad: OrderedDict, extra_auxiliary_info: Any = None): + """ + Add local noise to gradients. + + Args: + local_grad (OrderedDict): Local gradients. + extra_auxiliary_info (Any): Extra auxiliary information (not used). + + Returns: + OrderedDict: Local gradients with added noise. + """ global_model_params = extra_auxiliary_info for k in global_model_params.keys(): local_grad[k] = local_grad[k] - global_model_params[k] return self.clip_local_update(local_grad, self.clipping_norm) def add_global_noise(self, global_model: OrderedDict): - qw = self.train_data_num_in_total * (self.args.client_num_per_round / self.args.client_num_in_total) - for k in global_model.keys(): - global_model[k] = global_model[k] / qw - w_global = self.add_noise( - global_model, qw - ) - for k in w_global.keys(): - w_global[k] = w_global[k] + global_model[k] + """ + Add global noise to the global model parameters (not implemented). + + Args: + global_model (OrderedDict): Global model parameters. + + Raises: + NotImplementedError: This method is not implemented. + """ + raise NotImplementedError( + "add_global_noise method is not implemented.") def get_global_params(self): - pass + """ + Get global parameters (not implemented). + + Raises: + NotImplementedError: This method is not implemented. + """ + raise NotImplementedError( + "get_global_params method is not implemented.") def compute_noise(self, size, qw): + """ + Compute noise for differential privacy. + + Args: + size: Size of the noise. + qw: Noise scaling factor. + + Returns: + torch.Tensor: Noise tensor. + """ self._scale = self._scale / qw return torch.normal(mean=0, std=self._scale, size=size) def add_noise(self, w_global, qw): + """ + Add noise to global parameters for differential privacy. + + Args: + w_global (OrderedDict): Global model parameters. + qw: Noise scaling factor. + + Returns: + OrderedDict: Global model parameters with added noise. + """ new_params = OrderedDict() for k in w_global.keys(): - new_params[k] = self.compute_noise(w_global[k].shape, qw) + w_global[k] + new_params[k] = self.compute_noise( + w_global[k].shape, qw) + w_global[k] return new_params - - - - diff --git a/python/fedml/core/dp/frames/ldp.py b/python/fedml/core/dp/frames/ldp.py index 94f4443431..c26cb8e131 100644 --- a/python/fedml/core/dp/frames/ldp.py +++ b/python/fedml/core/dp/frames/ldp.py @@ -5,9 +5,36 @@ class LocalDP(BaseDPFrame): + """ + Local Differential Privacy mechanism. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + + Methods: + __init__(self, args): Initialize the LocalDP mechanism. + add_local_noise(self, local_grad: OrderedDict): Add local noise to the gradients. + """ + def __init__(self, args): + """ + Initialize the LocalDP mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) - self.set_ldp(DPMechanism(args.mechanism_type, args.epsilon, args.delta, args.sensitivity)) + self.set_ldp(DPMechanism(args.mechanism_type, + args.epsilon, args.delta, args.sensitivity)) def add_local_noise(self, local_grad: OrderedDict): + """ + Add local noise to the gradients. + + Args: + local_grad (OrderedDict): Local gradients. + + Returns: + OrderedDict: Local gradients with added noise. + """ return super().add_local_noise(local_grad=local_grad) diff --git a/python/fedml/core/dp/mechanisms/dp_mechanism.py b/python/fedml/core/dp/mechanisms/dp_mechanism.py index ba64fe3eb6..b3cbfcc366 100644 --- a/python/fedml/core/dp/mechanisms/dp_mechanism.py +++ b/python/fedml/core/dp/mechanisms/dp_mechanism.py @@ -1,3 +1,5 @@ +from .gaussian import Gaussian +from .laplace import Laplace from fedml.core.dp.mechanisms import Gaussian, Laplace import torch from typing import Union, Iterable @@ -9,7 +11,33 @@ class DPMechanism: + """ + A class representing a Differential Privacy Mechanism. + + Attributes: + mechanism_type (str): The type of differential privacy mechanism ('laplace' or 'gaussian'). + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta. + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + + Methods: + __init__(self, mechanism_type, epsilon, delta, sensitivity=1): Initialize the DP mechanism. + add_noise(self, grad): Add noise to a gradient. + _compute_new_grad(self, grad): Compute a new gradient by adding noise. + add_a_noise_to_local_data(self, local_data): Add noise to local data. + get_rdp_scale(self): Get the RDP (Rényi Differential Privacy) scale of the mechanism. + """ + def __init__(self, mechanism_type, epsilon, delta, sensitivity=1): + """ + Initialize the Differential Privacy Mechanism. + + Args: + mechanism_type (str): The type of differential privacy mechanism ('laplace' or 'gaussian'). + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta. + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + """ mechanism_type = mechanism_type.lower() if mechanism_type == "laplace": self.dp = Laplace( @@ -21,28 +49,57 @@ def __init__(self, mechanism_type, epsilon, delta, sensitivity=1): raise NotImplementedError("DP mechanism not implemented!") def add_noise(self, grad): + """ + Add noise to a gradient. + + Args: + grad (OrderedDict): The gradient to which noise will be added. + + Returns: + OrderedDict: A new gradient with added noise. + """ new_grad = OrderedDict() for k in grad.keys(): new_grad[k] = self._compute_new_grad(grad[k]) return new_grad def _compute_new_grad(self, grad): + """ + Compute a new gradient by adding noise. + + Args: + grad (torch.Tensor): The gradient tensor. + + Returns: + torch.Tensor: A new gradient tensor with added noise. + """ noise = self.dp.compute_noise(grad.shape) return noise + grad def add_a_noise_to_local_data(self, local_data): + """ + Add noise to local data. + + Args: + local_data (list of tuples): Local data where each tuple represents a data point. + + Returns: + list of tuples: Local data with added noise. + """ new_data = [] for i in range(len(local_data)): - list = [] + data_tuple = [] for x in local_data[i]: - y = self._compute_new_grad(x) - list.append(y) - new_data.append(tuple(list)) + noisy_data = self._compute_new_grad(x) + data_tuple.append(noisy_data) + new_data.append(tuple(data_tuple)) return new_data def get_rdp_scale(self): - return self.dp.get_rdp_scale() - - - + """ + Get the RDP (Rényi Differential Privacy) scale of the mechanism. + Returns: + float: The RDP scale of the mechanism. + """ + return self.dp.get_rdp_scale() diff --git a/python/fedml/core/dp/mechanisms/gaussian.py b/python/fedml/core/dp/mechanisms/gaussian.py index 93b0ad56d5..3074ec6070 100644 --- a/python/fedml/core/dp/mechanisms/gaussian.py +++ b/python/fedml/core/dp/mechanisms/gaussian.py @@ -5,7 +5,30 @@ class Gaussian(BaseDPMechanism): + """ + The Gaussian mechanism in differential privacy. + + Attributes: + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta (default is 0.0). + sensitivity (float): The sensitivity of the mechanism (default is 1). + + Methods: + __init__(self, epsilon, delta=0.0, sensitivity=1): Initialize the Gaussian mechanism. + compute_noise(self, size): Generate Gaussian noise. + compute_noise_using_sigma(cls, sigma, size): Generate Gaussian noise with a given standard deviation. + get_rdp_scale(self): Get the RDP (Rényi Differential Privacy) scale of the mechanism. + """ + def __init__(self, epsilon, delta=0.0, sensitivity=1): + """ + Initialize the Gaussian mechanism. + + Args: + epsilon (float): The privacy parameter epsilon. + delta (float, optional): The privacy parameter delta (default is 0.0). + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + """ check_params(epsilon, delta, sensitivity) if epsilon == 0 or delta == 0: raise ValueError("Neither Epsilon nor Delta can be zero") @@ -15,19 +38,44 @@ def __init__(self, epsilon, delta=0.0, sensitivity=1): ) self.scale = ( - np.sqrt(2 * np.log(1.25 / float(delta))) - * float(sensitivity) - / float(epsilon) + np.sqrt(2 * np.log(1.25 / float(delta))) + * float(sensitivity) + / float(epsilon) ) @classmethod def compute_noise_using_sigma(cls, sigma, size): + """ + Generate Gaussian noise with a given standard deviation. + + Args: + sigma (float): The standard deviation of the Gaussian noise. + size (int or tuple): The size of the noise vector. + + Returns: + torch.Tensor: A tensor containing Gaussian noise. + """ if not isinstance(sigma, float): raise ValueError("sigma should be a float") return torch.normal(mean=0, std=sigma, size=size) def compute_noise(self, size): + """ + Generate Gaussian noise. + + Args: + size (int or tuple): The size of the noise vector. + + Returns: + torch.Tensor: A tensor containing Gaussian noise. + """ return torch.normal(mean=0, std=self.scale, size=size) def get_rdp_scale(self): + """ + Get the RDP (Rényi Differential Privacy) scale of the mechanism. + + Returns: + float: The RDP scale of the mechanism. + """ return self.scale diff --git a/python/fedml/core/dp/mechanisms/laplace.py b/python/fedml/core/dp/mechanisms/laplace.py index cc4fabd95f..4b304eab0e 100644 --- a/python/fedml/core/dp/mechanisms/laplace.py +++ b/python/fedml/core/dp/mechanisms/laplace.py @@ -7,15 +7,49 @@ class Laplace(BaseDPMechanism): """ The classical Laplace mechanism in differential privacy. + + Attributes: + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta (default is 0.0). + sensitivity (float): The sensitivity of the mechanism (default is 1). + + Methods: + __init__(self, epsilon, delta=0.0, sensitivity=1): Initialize the Laplace mechanism. + compute_noise(self, size): Generate Laplace noise. + get_rdp_scale(self): Get the RDP (Rényi Differential Privacy) scale of the mechanism. """ def __init__(self, epsilon, delta=0.0, sensitivity=1): + """ + Initialize the Laplace mechanism. + + Args: + epsilon (float): The privacy parameter epsilon. + delta (float, optional): The privacy parameter delta (default is 0.0). + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + """ check_params(epsilon, delta, sensitivity) - self.scale = float(sensitivity) / (float(epsilon) - np.log(1 - float(delta))) + self.scale = float(sensitivity) / \ + (float(epsilon) - np.log(1 - float(delta))) self.sensitivity = sensitivity def compute_noise(self, size): + """ + Generate Laplace noise. + + Args: + size (int or tuple): The size of the noise vector. + + Returns: + torch.Tensor: A tensor containing Laplace noise. + """ return torch.tensor(np.random.laplace(loc=0.0, scale=self.scale, size=size)) def get_rdp_scale(self): - return self.scale/self.sensitivity + """ + Get the RDP (Rényi Differential Privacy) scale of the mechanism. + + Returns: + float: The RDP scale of the mechanism. + """ + return self.scale / self.sensitivity diff --git a/python/fedml/core/dp/test/test_fed_privacy_mechanism.py b/python/fedml/core/dp/test/test_fed_privacy_mechanism.py index 99d252dffb..19aa39697e 100644 --- a/python/fedml/core/dp/test/test_fed_privacy_mechanism.py +++ b/python/fedml/core/dp/test/test_fed_privacy_mechanism.py @@ -13,6 +13,12 @@ def add_gaussian_args(): + """ + Define and parse command-line arguments for Gaussian differential privacy mechanism. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ parser = argparse.ArgumentParser(description="FedML") parser.add_argument( "--yaml_config_file", @@ -34,6 +40,12 @@ def add_gaussian_args(): def add_laplace_args(): + """ + Define and parse command-line arguments for Laplace differential privacy mechanism. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ parser = argparse.ArgumentParser(description="FedML") parser.add_argument( "--yaml_config_file", @@ -53,6 +65,9 @@ def add_laplace_args(): def test_FedMLDifferentialPrivacy_gaussian(): + """ + Test the FedMLDifferentialPrivacy class with the Gaussian mechanism. + """ print("----------- test_FedMLDifferentialPrivacy - gaussian mechanism -----------") FedMLDifferentialPrivacy.get_instance().init(add_gaussian_args()) print(f"grad = {a_local_w}") @@ -60,6 +75,9 @@ def test_FedMLDifferentialPrivacy_gaussian(): def test_FedMLDifferentialPrivacy_laplace(): + """ + Test the FedMLDifferentialPrivacy class with the Laplace mechanism. + """ print("----------- test_FedMLDifferentialPrivacy - laplace mechanism -----------") FedMLDifferentialPrivacy.get_instance().init(add_laplace_args()) print(f"grad = {a_local_w}") diff --git a/python/fedml/core/mlops/mlops_configs.py b/python/fedml/core/mlops/mlops_configs.py index b0ea899ab1..a43851689e 100644 --- a/python/fedml/core/mlops/mlops_configs.py +++ b/python/fedml/core/mlops/mlops_configs.py @@ -39,6 +39,13 @@ def get_instance(args): return MLOpsConfigs._config_instance def get_request_params(self): + """ + Get the request parameters for fetching configurations. + + Returns: + str: The URL for configuration retrieval. + str: The path to the certificate file, if applicable. + """ url = "https://open.fedml.ai/fedmlOpsServer/configs/fetch" config_version = "release" if ( @@ -55,7 +62,8 @@ def get_request_params(self): url = "https://open-dev.fedml.ai/fedmlOpsServer/configs/fetch" elif self.args.config_version == "local": if hasattr(self.args, "local_server") and self.args.local_server is not None: - url = "http://{}:9000/fedmlOpsServer/configs/fetch".format(self.args.local_server) + url = "http://{}:9000/fedmlOpsServer/configs/fetch".format( + self.args.local_server) else: url = "http://localhost:9000/fedmlOpsServer/configs/fetch" @@ -78,7 +86,8 @@ def get_request_params_with_version(self, version): url = "https://open-dev.fedml.ai/fedmlOpsServer/configs/fetch" elif version == "local": if hasattr(self.args, "local_server") and self.args.local_server is not None: - url = "http://{}:9000/fedmlOpsServer/configs/fetch".format(self.args.local_server) + url = "http://{}:9000/fedmlOpsServer/configs/fetch".format( + self.args.local_server) else: url = "http://localhost:9000/fedmlOpsServer/configs/fetch" @@ -93,6 +102,12 @@ def get_request_params_with_version(self, version): @staticmethod def get_root_ca_path(): + """ + Get the file path to the root CA certificate. + + Returns: + str: The file path to the root CA certificate. + """ cur_source_dir = os.path.dirname(__file__) cert_path = os.path.join( cur_source_dir, "ssl", "open-root-ca.crt" @@ -101,6 +116,14 @@ def get_root_ca_path(): @staticmethod def install_root_ca_file(): + """ + Install the root CA certificate file. + + This method appends the root CA certificate to the CA file used by the requests library. + + Raises: + FileNotFoundError: If the root CA certificate file is not found. + """ ca_file = certifi.where() open_root_ca_path = MLOpsConfigs.get_root_ca_path() with open(open_root_ca_path, 'rb') as infile: @@ -109,6 +132,16 @@ def install_root_ca_file(): outfile.write(open_root_ca_file) def fetch_configs(self): + """ + Fetch device configurations. + + Returns: + dict: MQTT configuration. + dict: S3 configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url, cert_path = self.get_request_params() json_params = {"config_name": ["mqtt_config", "s3_config", "ml_ops_config"], "device_send_time": int(time.time() * 1000)} @@ -126,7 +159,8 @@ def fetch_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -140,6 +174,16 @@ def fetch_configs(self): return mqtt_config, s3_config def fetch_web3_configs(self): + """ + Fetch MQTT, Web3, and ML Ops configurations. + + Returns: + dict: MQTT configuration. + dict: Web3 configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url, cert_path = self.get_request_params() json_params = {"config_name": ["mqtt_config", "web3_config", "ml_ops_config"], "device_send_time": int(time.time() * 1000)} @@ -157,7 +201,8 @@ def fetch_web3_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -171,6 +216,17 @@ def fetch_web3_configs(self): return mqtt_config, web3_config def fetch_thetastore_configs(self): + """ + Fetch MQTT, ThetaStore, and ML Ops configurations. + + Returns: + dict: MQTT configuration. + dict: ThetaStore configuration. + + Raises: + Exception: If fetching device configurations fails. + """ + url, cert_path = self.get_request_params() json_params = {"config_name": ["mqtt_config", "thetastore_config", "ml_ops_config"], "device_send_time": int(time.time() * 1000)} @@ -188,7 +244,8 @@ def fetch_thetastore_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -202,6 +259,18 @@ def fetch_thetastore_configs(self): return mqtt_config, thetastore_config def fetch_all_configs(self): + """ + Fetch all configurations including MQTT, S3, ML Ops, and Docker configurations. + + Returns: + dict: MQTT configuration. + dict: S3 configuration. + dict: ML Ops configuration. + dict: Docker configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url, cert_path = self.get_request_params() json_params = { "config_name": ["mqtt_config", "s3_config", "ml_ops_config", "docker_config"], @@ -221,7 +290,8 @@ def fetch_all_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -238,6 +308,21 @@ def fetch_all_configs(self): @staticmethod def fetch_all_configs_with_version(version): + """ + Fetch all configurations with a specific version. + + Args: + version (str): The version to fetch configurations for. + + Returns: + dict: MQTT configuration. + dict: S3 configuration. + dict: ML Ops configuration. + dict: Docker configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url = "https://open{}.fedml.ai/fedmlOpsServer/configs/fetch".format( "" if version == "release" else "-"+version) cert_path = None @@ -265,7 +350,8 @@ def fetch_all_configs_with_version(version): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") diff --git a/python/fedml/core/mlops/mlops_device_perfs.py b/python/fedml/core/mlops/mlops_device_perfs.py index de8c8e9b69..5394afe233 100644 --- a/python/fedml/core/mlops/mlops_device_perfs.py +++ b/python/fedml/core/mlops/mlops_device_perfs.py @@ -16,6 +16,12 @@ class MLOpsDevicePerfStats(object): + """ + Class for reporting device performance statistics to MLOps. + + This class handles the reporting of device performance statistics to MLOps using MQTT. + """ + def __init__(self): self.device_realtime_stats_process = None self.device_realtime_stats_event = None @@ -25,23 +31,61 @@ def __init__(self): self.edge_id = None def report_device_realtime_stats(self, sys_args): + """ + Report device real-time statistics to MLOps. + + Args: + sys_args: The system arguments passed to the device. + + This method sets up and starts a process to report real-time device statistics to MLOps. + + Returns: + None + """ + self.setup_realtime_stats_process(sys_args) def stop_device_realtime_stats(self): + """ + Stop reporting device real-time statistics. + + This method sets the event to stop reporting device real-time statistics. + + Returns: + None + """ if self.device_realtime_stats_event is not None: self.device_realtime_stats_event.set() def should_stop_device_realtime_stats(self): + """ + Check if reporting of device real-time statistics should stop. + + Returns: + bool: True if reporting should stop, False otherwise. + """ if self.device_realtime_stats_event is not None and self.device_realtime_stats_event.is_set(): return True return False def setup_realtime_stats_process(self, sys_args): + """ + Set up the process for reporting real-time device statistics. + + Args: + sys_args: The system arguments passed to the device. + + This method sets up the process for reporting real-time device statistics to MLOps. + + Returns: + None + """ perf_stats = MLOpsDevicePerfStats() perf_stats.args = sys_args perf_stats.edge_id = getattr(sys_args, "edge_id", None) - perf_stats.edge_id = getattr(sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id + perf_stats.edge_id = getattr( + sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.edge_id = 0 if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.device_id = getattr(sys_args, "device_id", 0) perf_stats.run_id = getattr(sys_args, "run_id", 0) @@ -56,6 +100,17 @@ def setup_realtime_stats_process(self, sys_args): self.device_realtime_stats_process.start() def report_device_realtime_stats_entry(self, sys_event): + """ + Entry point for reporting real-time device statistics. + + Args: + sys_event: The system event used to control reporting. + + This method is the entry point for reporting real-time device statistics to MLOps. + + Returns: + None + """ self.device_realtime_stats_event = sys_event mqtt_mgr = MqttManager( self.args.mqtt_config_path["BROKER_HOST"], @@ -63,7 +118,8 @@ def report_device_realtime_stats_entry(self, sys_event): self.args.mqtt_config_path["MQTT_USER"], self.args.mqtt_config_path["MQTT_PWD"], 180, - "FedML_Metrics_DevicePerf_{}_{}_{}".format(str(self.args.device_id), str(self.edge_id), str(uuid.uuid4())) + "FedML_Metrics_DevicePerf_{}_{}_{}".format( + str(self.args.device_id), str(self.edge_id), str(uuid.uuid4())) ) mqtt_mgr.connect() mqtt_mgr.loop_start() @@ -74,9 +130,11 @@ def report_device_realtime_stats_entry(self, sys_event): # Notify MLOps with system information. while not self.should_stop_device_realtime_stats(): try: - MLOpsDevicePerfStats.report_gpu_device_info(self.edge_id, mqtt_mgr=mqtt_mgr) + MLOpsDevicePerfStats.report_gpu_device_info( + self.edge_id, mqtt_mgr=mqtt_mgr) except Exception as e: - logging.debug("exception when reporting device pref: {}.".format(traceback.format_exc())) + logging.debug("exception when reporting device pref: {}.".format( + traceback.format_exc())) pass time.sleep(10) @@ -87,6 +145,18 @@ def report_device_realtime_stats_entry(self, sys_event): @staticmethod def report_gpu_device_info(edge_id, mqtt_mgr=None): + """ + Report GPU device information to MLOps. + + Args: + edge_id: The ID of the edge device. + mqtt_mgr: The MQTT manager for communication. + + This method reports GPU device information to MLOps using MQTT. + + Returns: + None + """ total_mem, free_mem, total_disk_size, free_disk_size, cup_utilization, cpu_cores, gpu_cores_total, \ gpu_cores_available, sent_bytes, recv_bytes, gpu_available_ids = sys_utils.get_sys_realtime_stats() diff --git a/python/fedml/core/mlops/mlops_job_perfs.py b/python/fedml/core/mlops/mlops_job_perfs.py index 2511e1ea9a..6eb3a9c659 100644 --- a/python/fedml/core/mlops/mlops_job_perfs.py +++ b/python/fedml/core/mlops/mlops_job_perfs.py @@ -15,6 +15,9 @@ class MLOpsJobPerfStats(object): def __init__(self): + """ + Initialize MLOpsJobPerfStats object. + """ self.job_stats_process = None self.job_stats_event = None self.args = None @@ -25,11 +28,28 @@ def __init__(self): self.job_stats_obj_map = dict() def add_job(self, job_id, process_id): + """ + Add a job to be tracked for performance statistics. + + Args: + job_id (str): The ID of the job. + process_id (int): The process ID of the job. + """ self.job_process_id_map[job_id] = process_id @staticmethod def report_system_metric(run_id, edge_id, metric_json=None, mqtt_mgr=None, sys_stats_obj=None): + """ + Report system performance metrics to MLOps. + + Args: + run_id (int): The ID of the run. + edge_id (int): The ID of the edge device. + metric_json (dict, optional): The system performance metrics in JSON format. + mqtt_mgr (MqttManager, optional): The MQTT manager for communication. + sys_stats_obj (SysStats, optional): The SysStats object for collecting system stats. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_client/mlops/system_performance" @@ -91,23 +111,40 @@ def report_system_metric(run_id, edge_id, metric_json=None, mqtt_mgr.send_message_json(topic_name, message_json) def stop_job_stats(self): + """ + Stop tracking job performance statistics. + """ + if self.job_stats_event is not None: self.job_stats_event.set() def should_stop_job_stats(self): + """ + Check if job performance statistics tracking should be stopped. + + Returns: + bool: True if job performance statistics tracking should be stopped, otherwise False. + """ if self.job_stats_event is not None and self.job_stats_event.is_set(): return True return False def setup_job_stats_process(self, sys_args): + """ + Set up the process for tracking job performance statistics. + + Args: + sys_args (object): The system arguments. + """ if self.job_stats_process is not None and psutil.pid_exists(self.job_stats_process.pid): return perf_stats = MLOpsJobPerfStats() perf_stats.args = sys_args perf_stats.edge_id = getattr(sys_args, "edge_id", None) - perf_stats.edge_id = getattr(sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id + perf_stats.edge_id = getattr( + sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.edge_id = 0 if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.device_id = getattr(sys_args, "device_id", 0) perf_stats.run_id = getattr(sys_args, "run_id", 0) @@ -122,9 +159,21 @@ def setup_job_stats_process(self, sys_args): self.job_stats_process.start() def report_job_stats(self, sys_args): + """ + Report job performance statistics. + + Args: + sys_args (object): The system arguments. + """ self.setup_job_stats_process(sys_args) def report_job_stats_entry(self, sys_event): + """ + Report job performance statistics entry point. + + Args: + sys_event (multiprocessing.Event): The system event for signaling the process. + """ self.job_stats_event = sys_event mqtt_mgr = MqttManager( self.args.mqtt_config_path["BROKER_HOST"], @@ -132,7 +181,8 @@ def report_job_stats_entry(self, sys_event): self.args.mqtt_config_path["MQTT_USER"], self.args.mqtt_config_path["MQTT_PWD"], 180, - "FedML_Metrics_JobPerf_{}_{}_{}".format(str(self.device_id), str(self.edge_id), str(uuid.uuid4())) + "FedML_Metrics_JobPerf_{}_{}_{}".format( + str(self.device_id), str(self.edge_id), str(uuid.uuid4())) ) mqtt_mgr.connect() mqtt_mgr.loop_start() @@ -142,13 +192,15 @@ def report_job_stats_entry(self, sys_event): for job_id, process_id in self.job_process_id_map.items(): try: if self.job_stats_obj_map.get(job_id, None) is None: - self.job_stats_obj_map[job_id] = SysStats(process_id=process_id) + self.job_stats_obj_map[job_id] = SysStats( + process_id=process_id) MLOpsJobPerfStats.report_system_metric(job_id, self.edge_id, mqtt_mgr=mqtt_mgr, sys_stats_obj=self.job_stats_obj_map[job_id]) except Exception as e: - logging.debug("exception when reporting job pref: {}.".format(traceback.format_exc())) + logging.debug("exception when reporting job pref: {}.".format( + traceback.format_exc())) pass time.sleep(10) diff --git a/python/fedml/core/mlops/mlops_metrics.py b/python/fedml/core/mlops/mlops_metrics.py index 2e4109a823..528a92498f 100644 --- a/python/fedml/core/mlops/mlops_metrics.py +++ b/python/fedml/core/mlops/mlops_metrics.py @@ -13,6 +13,17 @@ class MLOpsMetrics(object): def __new__(cls, *args, **kw): + """ + Create a singleton instance of MLOpsMetrics. + + Args: + cls: The class. + *args: Variable-length argument list. + **kw: Keyword arguments. + + Returns: + MLOpsMetrics: The MLOpsMetrics instance. + """ if not hasattr(cls, "_instance"): orig = super(MLOpsMetrics, cls) cls._instance = orig.__new__(cls, *args, **kw) @@ -20,9 +31,16 @@ def __new__(cls, *args, **kw): return cls._instance def __init__(self): + """ + Initialize the MLOpsMetrics object. + """ + pass def init(self): + """ + Initialize the MLOpsMetrics object attributes. + """ self.messenger = None self.args = None self.run_id = None @@ -35,6 +53,13 @@ def init(self): self.device_perfs = MLOpsDevicePerfStats() def set_messenger(self, msg_messenger, args=None): + """ + Set the messenger for communication. + + Args: + msg_messenger: The message messenger. + args: The system arguments. + """ self.messenger = msg_messenger if args is not None: self.args = args @@ -62,6 +87,12 @@ def set_messenger(self, msg_messenger, args=None): self.server_agent_id = self.edge_id def comm_sanity_check(self): + """ + Check if communication is set up properly. + + Returns: + bool: True if communication is set up, otherwise False. + """ if self.messenger is None: logging.info("self.messenger is Null") return False @@ -69,6 +100,16 @@ def comm_sanity_check(self): return True def report_client_training_status(self, edge_id, status, running_json=None, is_from_model=False, in_run_id=None): + """ + Report client training status to various components. + + Args: + edge_id: The ID of the edge device. + status: The status of the training. + running_json: The running JSON information. + is_from_model: Whether the report is from the model. + in_run_id: The run ID. + """ run_id = 0 if self.run_id is not None: run_id = self.run_id @@ -84,14 +125,20 @@ def report_client_training_status(self, edge_id, status, running_json=None, is_f if is_from_model: from ...computing.scheduler.model_scheduler.device_client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) else: from ...computing.scheduler.slave.client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) def report_client_device_status_to_web_ui(self, edge_id, status): """ - this is used for notifying the client device status to MLOps Frontend + Report the client device status to MLOps Frontend. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. """ if status == ClientConstants.MSG_MLOPS_CLIENT_STATUS_IDLE: return @@ -100,9 +147,11 @@ def report_client_device_status_to_web_ui(self, edge_id, status): if self.run_id is not None: run_id = self.run_id topic_name = "fl_client/mlops/status" - msg = {"edge_id": edge_id, "run_id": run_id, "status": status, "version": "v1.0"} + msg = {"edge_id": edge_id, "run_id": run_id, + "status": status, "version": "v1.0"} message_json = json.dumps(msg) - logging.info("report_client_device_status. message_json = %s" % message_json) + logging.info( + "report_client_device_status. message_json = %s" % message_json) MLOpsStatus.get_instance().set_client_status(edge_id, status) self.messenger.send_message_json(topic_name, message_json) @@ -111,7 +160,11 @@ def common_report_client_training_status(self, edge_id, status): # logging.info("comm_sanity_check at report_client_training_status.") # return """ - this is used for notifying the client status to MLOps (both FedML CLI and backend can consume it) + Common method for reporting client training status to MLOps. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. """ run_id = 0 if self.run_id is not None: @@ -127,7 +180,12 @@ def broadcast_client_training_status(self, edge_id, status, is_from_model=False) # if not self.comm_sanity_check(): # return """ - this is used for broadcasting the client status to MLOps (backend can consume it) + Broadcast client training status to MLOps. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. + is_from_model (bool): Whether the report is from the model. """ run_id = 0 if self.run_id is not None: @@ -147,7 +205,11 @@ def common_broadcast_client_training_status(self, edge_id, status): # if not self.comm_sanity_check(): # return """ - this is used for broadcasting the client status to MLOps (backend can consume it) + Common method for broadcasting client training status to MLOps. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. """ run_id = 0 if self.run_id is not None: @@ -155,22 +217,43 @@ def common_broadcast_client_training_status(self, edge_id, status): topic_name = "fl_run/fl_client/mlops/status" msg = {"edge_id": edge_id, "run_id": run_id, "status": status} message_json = json.dumps(msg) - logging.info("report_client_training_status. message_json = %s" % message_json) + logging.info( + "report_client_training_status. message_json = %s" % message_json) self.messenger.send_message_json(topic_name, message_json) def client_send_exit_train_msg(self, run_id, edge_id, status, msg=None): - topic_exit_train_with_exception = "flserver_agent/" + str(run_id) + "/client_exit_train_with_exception" - msg = {"run_id": run_id, "edge_id": edge_id, "status": status, "msg": msg if msg is not None else ""} + """ + Send an exit train message for a client. + + Args: + run_id (int): The ID of the training run. + edge_id (int): The ID of the edge device. + status (str): The status of the client. + msg (str, optional): Additional message (default is None). + """ + topic_exit_train_with_exception = "flserver_agent/" + \ + str(run_id) + "/client_exit_train_with_exception" + msg = {"run_id": run_id, "edge_id": edge_id, + "status": status, "msg": msg if msg is not None else ""} message_json = json.dumps(msg) logging.info("client_send_exit_train_msg.") - self.messenger.send_message_json(topic_exit_train_with_exception, message_json) + self.messenger.send_message_json( + topic_exit_train_with_exception, message_json) def report_client_id_status(self, run_id, edge_id, status, running_json=None, is_from_model=False, server_id="0"): # if not self.comm_sanity_check(): # return """ - this is used for communication between client agent (FedML cli module) and client + Report client ID status to MLOps. + + Args: + run_id (int): The ID of the training run. + edge_id (int): The ID of the edge device. + status (str): The status of the client. + running_json: JSON information about the running state (default is None). + is_from_model (bool): Whether the report is from the model (default is False). + server_id (str): The ID of the server (default is "0"). """ self.common_report_client_id_status(run_id, edge_id, status, server_id) @@ -178,24 +261,43 @@ def report_client_id_status(self, run_id, edge_id, status, running_json=None, if is_from_model: from ...computing.scheduler.model_scheduler.device_client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) else: from ...computing.scheduler.slave.client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) def common_report_client_id_status(self, run_id, edge_id, status, server_id="0"): # if not self.comm_sanity_check(): # return """ - this is used for communication between client agent (FedML cli module) and client + Common method for reporting client ID status to MLOps. + + Args: + run_id (int): The ID of the training run. + edge_id (int): The ID of the edge device. + status (str): The status of the client device. + server_id (str): The ID of the server (default is "0"). """ topic_name = "fl_client/flclient_agent_" + str(edge_id) + "/status" - msg = {"run_id": run_id, "edge_id": edge_id, "status": status, "server_id": server_id} + msg = {"run_id": run_id, "edge_id": edge_id, + "status": status, "server_id": server_id} message_json = json.dumps(msg) # logging.info("report_client_id_status. message_json = %s" % message_json) self.messenger.send_message_json(topic_name, message_json) def report_server_training_status(self, run_id, status, role=None, running_json=None, is_from_model=False): + """ + Report server training status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + role (str, optional): The role of the server (default is None). + running_json: JSON information about the running state (default is None). + is_from_model (bool): Whether the report is from the model (default is False). + """ # if not self.comm_sanity_check(): # return self.common_report_server_training_status(run_id, status, role) @@ -204,14 +306,21 @@ def report_server_training_status(self, run_id, status, role=None, running_json= if is_from_model: from ...computing.scheduler.model_scheduler.device_server_data_interface import FedMLServerDataInterface - FedMLServerDataInterface.get_instance().save_job(run_id, self.edge_id, status, running_json) + FedMLServerDataInterface.get_instance().save_job( + run_id, self.edge_id, status, running_json) else: from ...computing.scheduler.master.server_data_interface import FedMLServerDataInterface - FedMLServerDataInterface.get_instance().save_job(run_id, self.edge_id, status, running_json) + FedMLServerDataInterface.get_instance().save_job( + run_id, self.edge_id, status, running_json) def report_server_device_status_to_web_ui(self, run_id, status, role=None): """ - this is used for notifying the server device status to MLOps Frontend + Report server device status to MLOps Frontend. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server device. + role (str, optional): The role of the server (default is None). """ if status == ServerConstants.MSG_MLOPS_DEVICE_STATUS_IDLE: return @@ -232,6 +341,14 @@ def report_server_device_status_to_web_ui(self, run_id, status, role=None): self.messenger.send_message_json(topic_name, message_json) def common_report_server_training_status(self, run_id, status, role=None): + """ + Common method for reporting server training status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + role (str, optional): The role of the server (default is None). + """ # if not self.comm_sanity_check(): # return topic_name = "fl_run/fl_server/mlops/status" @@ -250,6 +367,16 @@ def common_report_server_training_status(self, run_id, status, role=None): self.report_server_id_status(run_id, status) def broadcast_server_training_status(self, run_id, status, role=None, is_from_model=False, edge_id=None): + """ + Broadcast server training status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + role (str, optional): The role of the server (default is None). + is_from_model (bool): Whether the report is from the model (default is False). + edge_id (int, optional): The ID of the edge device (default is None). + """ if self.messenger is None: return topic_name = "fl_run/fl_server/mlops/status" @@ -275,37 +402,71 @@ def broadcast_server_training_status(self, run_id, status, role=None, is_from_mo FedMLServerDataInterface.get_instance().save_job(run_id, self.edge_id, status) def report_server_id_status(self, run_id, status, edge_id=None, server_id=None, server_agent_id=None): + """ + Report server ID status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + edge_id (int, optional): The ID of the edge device (default is None). + server_id (str, optional): The ID of the server (default is None). + server_agent_id (int, optional): The ID of the server agent (default is None). + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/flserver_agent_" + str(server_agent_id if server_agent_id is not None else self.server_agent_id) + "/status" - msg = {"run_id": run_id, "edge_id": edge_id if edge_id is not None else self.edge_id, "status": status} + msg = {"run_id": run_id, + "edge_id": edge_id if edge_id is not None else self.edge_id, "status": status} if server_id is not None: msg["server_id"] = server_id message_json = json.dumps(msg) # logging.info("report_server_id_status server id {}".format(server_agent_id)) - logging.info("report_server_id_status. message_json = %s" % message_json) + logging.info("report_server_id_status. message_json = %s" % + message_json) self.messenger.send_message_json(topic_name, message_json) self.report_server_device_status_to_web_ui(run_id, status) def report_client_training_metric(self, metric_json): + """ + Report client training metrics to MLOps. + + Args: + metric_json (dict): JSON containing client training metrics. + """ + # if not self.comm_sanity_check(): # return topic_name = "fl_client/mlops/training_metrics" - logging.info("report_client_training_metric. message_json = %s" % metric_json) + logging.info( + "report_client_training_metric. message_json = %s" % metric_json) message_json = json.dumps(metric_json) self.messenger.send_message_json(topic_name, message_json) def report_server_training_metric(self, metric_json): + """ + Report server training metrics to MLOps. + + Args: + metric_json (dict): JSON containing server training metrics. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/training_progress_and_eval" - logging.info("report_server_training_metric. message_json = %s" % metric_json) + logging.info( + "report_server_training_metric. message_json = %s" % metric_json) message_json = json.dumps(metric_json) self.messenger.send_message_json(topic_name, message_json) def report_server_training_round_info(self, round_info): + """ + Report server training round information to MLOps. + + Args: + round_info (dict): JSON containing server training round information. + """ + # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/training_roundx" @@ -313,6 +474,12 @@ def report_server_training_round_info(self, round_info): self.messenger.send_message_json(topic_name, message_json) def report_client_model_info(self, model_info_json): + """ + Report client model information to MLOps. + + Args: + model_info_json (dict): JSON containing client model information. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/client_model" @@ -320,6 +487,13 @@ def report_client_model_info(self, model_info_json): self.messenger.send_message_json(topic_name, message_json) def report_aggregated_model_info(self, model_info_json): + """ + Report aggregated model information to MLOps. + + Args: + model_info_json (dict): JSON containing aggregated model information. + """ + # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/global_aggregated_model" @@ -327,6 +501,12 @@ def report_aggregated_model_info(self, model_info_json): self.messenger.send_message_json(topic_name, message_json) def report_training_model_net_info(self, model_net_info_json): + """ + Report training model network information to MLOps. + + Args: + model_net_info_json (dict): JSON containing training model network information. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/training_model_net" @@ -334,6 +514,12 @@ def report_training_model_net_info(self, model_net_info_json): self.messenger.send_message_json(topic_name, message_json) def report_llm_record(self, metric_json): + """ + Report low-latency model (LLM) input-output record to MLOps. + + Args: + metric_json (dict): JSON containing low-latency model input-output record. + """ # if not self.comm_sanity_check(): # return topic_name = "model_serving/mlops/llm_input_output_record" @@ -345,8 +531,17 @@ def report_edge_job_computing_cost(self, job_id, edge_id, computing_started_time, computing_ended_time, user_id, api_key): """ - this is used for reporting the computing cost of a job running on an edge to MLOps + Report the computing cost of a job running on an edge to MLOps. + + Args: + job_id (str): The ID of the job. + edge_id (str): The ID of the edge device. + computing_started_time (float): The timestamp when computing started. + computing_ended_time (float): The timestamp when computing ended. + user_id (str): The user ID. + api_key (str): The API key. """ + topic_name = "ml_client/mlops/job_computing_cost" duration = computing_ended_time - computing_started_time if duration < 0: @@ -360,6 +555,12 @@ def report_edge_job_computing_cost(self, job_id, edge_id, # logging.info("report_job_computing_cost. message_json = %s" % message_json) def report_logs_updated(self, run_id): + """ + Report that runtime logs have been updated to MLOps. + + Args: + run_id (int): The ID of the training run. + """ # if not self.comm_sanity_check(): # return topic_name = "mlops/runtime_logs/" + str(run_id) @@ -372,6 +573,20 @@ def report_artifact_info(self, job_id, edge_id, artifact_name, artifact_type, artifact_local_path, artifact_url, artifact_ext_info, artifact_desc, timestamp): + """ + Report artifact information to MLOps. + + Args: + job_id (str): The ID of the job associated with the artifact. + edge_id (str): The ID of the edge device where the artifact is generated. + artifact_name (str): The name of the artifact. + artifact_type (str): The type of the artifact. + artifact_local_path (str): The local path to the artifact. + artifact_url (str): The URL of the artifact. + artifact_ext_info (dict): Additional information about the artifact. + artifact_desc (str): A description of the artifact. + timestamp (float): The timestamp when the artifact was generated. + """ topic_name = "launch_device/mlops/artifacts" artifact_info_json = { "job_id": job_id, @@ -388,30 +603,68 @@ def report_artifact_info(self, job_id, edge_id, artifact_name, artifact_type, self.messenger.send_message_json(topic_name, message_json) def report_sys_perf(self, sys_args, mqtt_config): + """ + Report system performance metrics to MLOps. + + Args: + sys_args (object): System arguments object containing performance metrics. + mqtt_config (str): Path to the MQTT configuration. + """ setattr(sys_args, "mqtt_config_path", mqtt_config) run_id = getattr(sys_args, "run_id", 0) self.fl_job_perf.add_job(run_id, os.getpid()) self.fl_job_perf.report_job_stats(sys_args) def stop_sys_perf(self): + """ + Stop reporting system performance metrics to MLOps. + """ self.fl_job_perf.stop_job_stats() def report_job_perf(self, sys_args, mqtt_config, job_process_id): + """ + Report job performance metrics to MLOps. + + Args: + sys_args (object): System arguments object containing job performance metrics. + mqtt_config (str): Path to the MQTT configuration. + job_process_id (int): The process ID of the job. + """ setattr(sys_args, "mqtt_config_path", mqtt_config) run_id = getattr(sys_args, "run_id", 0) self.job_perfs.add_job(run_id, job_process_id) self.job_perfs.report_job_stats(sys_args) def stop_job_perf(self): + """ + Stop reporting job performance metrics to MLOps. + """ self.job_perfs.stop_job_stats() def report_device_realtime_perf(self, sys_args, mqtt_config): + """ + Report real-time device performance metrics to MLOps. + + Args: + sys_args (object): System arguments object containing real-time device performance metrics. + mqtt_config (str): Path to the MQTT configuration. + """ setattr(sys_args, "mqtt_config_path", mqtt_config) self.device_perfs.report_device_realtime_stats(sys_args) def stop_device_realtime_perf(self): + """ + Stop reporting real-time device performance metrics to MLOps. + """ + self.device_perfs.stop_device_realtime_stats() def report_json_message(self, topic, payload): - self.messenger.send_message_json(topic, payload) + """ + Report a JSON message to a specified topic. + Args: + topic (str): The MQTT topic to publish the message to. + payload (dict): The JSON payload to be sent. + """ + self.messenger.send_message_json(topic, payload) diff --git a/python/fedml/core/mlops/mlops_profiler_event.py b/python/fedml/core/mlops/mlops_profiler_event.py index 73aa151054..bafdafa7be 100644 --- a/python/fedml/core/mlops/mlops_profiler_event.py +++ b/python/fedml/core/mlops/mlops_profiler_event.py @@ -22,6 +22,12 @@ def __new__(cls, *args, **kwargs): return MLOpsProfilerEvent._instance def __init__(self, args): + """ + Initialize the MLOpsProfilerEvent. + + Args: + args: The system arguments containing configuration settings. + """ self.args = args if args is not None and hasattr(args, "enable_wandb") and args.enable_wandb is not None: self.enable_wandb = args.enable_wandb @@ -37,6 +43,13 @@ def __init__(self, args): self.run_id = 0 def set_messenger(self, msg_messenger, args=None): + """ + Set the messenger for communication. + + Args: + msg_messenger: The messenger for communication. + args: The system arguments containing configuration settings. + """ self.com_manager = msg_messenger if args is None: return @@ -59,19 +72,39 @@ def set_messenger(self, msg_messenger, args=None): @classmethod def enable_wandb_tracking(cls): + """ + Enable W&B (Weights and Biases) tracking. + """ cls._enable_wandb = True @classmethod def enable_sys_perf_profiling(cls): + """ + Enable system performance profiling. + """ cls._sys_perf_profiling = True @classmethod def log_to_wandb(cls, metric): + """ + Log a metric to W&B (Weights and Biases). + + Args: + metric: The metric to log. + """ if cls._enable_wandb: import wandb wandb.log(metric) def log_event_started(self, event_name, event_value=None, event_edge_id=None): + """ + Log the start of an event. + + Args: + event_name: The name of the event. + event_value: The value associated with the event. + event_edge_id: The ID of the edge device associated with the event. + """ if event_value is None: event_value_passed = "" else: @@ -95,6 +128,14 @@ def log_event_started(self, event_name, event_value=None, event_edge_id=None): self.com_manager.send_message_json(event_topic, event_msg_str) def log_event_ended(self, event_name, event_value=None, event_edge_id=None): + """ + Log the end of an event. + + Args: + event_name: The name of the event. + event_value: The value associated with the event. + event_edge_id: The ID of the edge device associated with the event. + """ if event_value is None: event_value_passed = "" else: @@ -120,6 +161,20 @@ def log_event_ended(self, event_name, event_value=None, event_edge_id=None): @staticmethod def __build_event_mqtt_msg(run_id, edge_id, event_type, event_name, event_value): + """ + Build an MQTT message for an event. + + Args: + run_id: The ID of the run. + edge_id: The ID of the edge device. + event_type: The type of the event (started or ended). + event_name: The name of the event. + event_value: The value associated with the event. + + Returns: + event_topic: The MQTT topic for the event. + event_msg: The MQTT message for the event. + """ event_topic = "mlops/events" event_msg = {} if event_type == MLOpsProfilerEvent.EVENT_TYPE_STARTED: diff --git a/python/fedml/core/mlops/mlops_runtime_log.py b/python/fedml/core/mlops/mlops_runtime_log.py index 7ebcc43ade..93b2cea126 100644 --- a/python/fedml/core/mlops/mlops_runtime_log.py +++ b/python/fedml/core/mlops/mlops_runtime_log.py @@ -33,7 +33,8 @@ def handle_exception(exc_type, exc_value, exc_traceback): sys.__excepthook__(exc_type, exc_value, exc_traceback) return - logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) + logging.error("Uncaught exception", exc_info=( + exc_type, exc_value, exc_traceback)) if MLOpsRuntimeLog._log_sdk_instance is not None and \ hasattr(MLOpsRuntimeLog._log_sdk_instance, "args") and \ @@ -48,6 +49,23 @@ def handle_exception(exc_type, exc_value, exc_traceback): mlops.send_exit_train_msg() def __init__(self, args): + """ + Initialize the MLOpsRuntimeLog. + + Args: + args: Input arguments. + + Attributes: + logger: Logger instance for logging. + args: Input arguments. + should_write_log_file: Boolean indicating whether log files should be written. + log_file_dir: Directory where log files are stored. + log_file: File handle for the log file. + run_id: The ID of the current run. + edge_id: The ID of the edge device (server or client). + origin_log_file_path: Path to the original log file. + + """ self.logger = None self.args = args if hasattr(args, "using_mlops"): @@ -92,13 +110,31 @@ def __init__(self, args): @staticmethod def get_instance(args): + """ + Get an instance of the MLOpsRuntimeLog. + + Args: + args: Input arguments. + + Returns: + MLOpsRuntimeLog: An instance of the log handler. + + """ if MLOpsRuntimeLog._log_sdk_instance is None: MLOpsRuntimeLog._log_sdk_instance = MLOpsRuntimeLog(args) return MLOpsRuntimeLog._log_sdk_instance def init_logs(self, show_stdout_log=True): - log_file_path, program_prefix = MLOpsRuntimeLog.build_log_file_path(self.args) + """ + Initialize logging. + + Args: + show_stdout_log (bool): Flag to control whether to show log messages on stdout. + + """ + log_file_path, program_prefix = MLOpsRuntimeLog.build_log_file_path( + self.args) logging.raiseExceptions = True self.logger = logging.getLogger(log_file_path) @@ -118,7 +154,8 @@ def formatTime(self, record, datefmt=None): if self.ntp_offset is None: self.ntp_offset = 0.0 - log_ntp_time = int((log_time * 1000 + self.ntp_offset) / 1000.0) + log_ntp_time = int( + (log_time * 1000 + self.ntp_offset) / 1000.0) ct = self.converter(log_ntp_time) if datefmt: s = ct.strftime(datefmt) @@ -156,6 +193,17 @@ def formatTime(self, record, datefmt=None): @staticmethod def build_log_file_path(in_args): + """ + Build the log file path based on input arguments. + + Args: + in_args: Input arguments. + + Returns: + str: Log file path. + str: Program prefix. + + """ if in_args.role == "server": if hasattr(in_args, "server_id"): edge_id = in_args.server_id @@ -182,7 +230,8 @@ def build_log_file_path(in_args): edge_id = in_args.edge_id else: edge_id = 0 - program_prefix = "FedML-Client @device-id-{edge}".format(edge=edge_id) + program_prefix = "FedML-Client @device-id-{edge}".format( + edge=edge_id) if not os.path.exists(in_args.log_file_dir): os.makedirs(in_args.log_file_dir, exist_ok=True) @@ -196,7 +245,8 @@ def build_log_file_path(in_args): if __name__ == "__main__": - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--log_file_dir", "-log", help="log file dir") parser.add_argument("--run_id", "-ri", type=str, help='run id') diff --git a/python/fedml/core/mlops/mlops_runtime_log_daemon.py b/python/fedml/core/mlops/mlops_runtime_log_daemon.py index 905c5287b0..dc4fa1c9ea 100644 --- a/python/fedml/core/mlops/mlops_runtime_log_daemon.py +++ b/python/fedml/core/mlops/mlops_runtime_log_daemon.py @@ -19,13 +19,25 @@ class MLOpsRuntimeLogProcessor: FEDML_RUN_LOG_STATUS_DIR = "run_log_status" def __init__(self, using_mlops, log_run_id, log_device_id, log_file_dir, log_server_url, in_args=None): + """ + Initialize the MLOpsRuntimeLogProcessor. + + Args: + using_mlops: Whether MLOps is being used. + log_run_id: The ID of the log run. + log_device_id: The ID of the log device. + log_file_dir: The directory where log files are stored. + log_server_url: The URL of the log server. + in_args: Input arguments (system configuration). + """ self.args = in_args self.is_log_reporting = False self.log_reporting_status_file = os.path.join(log_file_dir, MLOpsRuntimeLogProcessor.FEDML_RUN_LOG_STATUS_DIR, MLOpsRuntimeLogProcessor.FEDML_LOG_REPORTING_STATUS_FILE_NAME + "-" + str(log_run_id) + ".conf") - os.makedirs(os.path.join(log_file_dir, MLOpsRuntimeLogProcessor.FEDML_RUN_LOG_STATUS_DIR), exist_ok=True) + os.makedirs(os.path.join( + log_file_dir, MLOpsRuntimeLogProcessor.FEDML_RUN_LOG_STATUS_DIR), exist_ok=True) self.logger = None self.should_upload_log_file = using_mlops self.log_file_dir = log_file_dir @@ -53,12 +65,28 @@ def __init__(self, using_mlops, log_run_id, log_device_id, log_file_dir, log_ser self.log_process_event = None def set_log_source(self, source): + """ + Set the source of the log. + + Args: + source: The source of the log. + """ self.log_source = source if source is not None: self.log_source = str(self.log_source).replace(' ', '') @staticmethod def build_log_file_path(in_args): + """ + Build the log file path based on input arguments. + + Args: + in_args: Input arguments (system configuration). + + Returns: + log_file_path: The path to the log file. + program_prefix: The prefix for the program's log. + """ if in_args.rank == 0: if hasattr(in_args, "server_id"): log_device_id = in_args.server_id @@ -67,7 +95,8 @@ def build_log_file_path(in_args): log_device_id = in_args.edge_id else: log_device_id = 0 - program_prefix = "FedML-Server({}) @device-id-{}".format(in_args.rank, log_device_id) + program_prefix = "FedML-Server({}) @device-id-{}".format( + in_args.rank, log_device_id) else: if hasattr(in_args, "client_id"): log_device_id = in_args.client_id @@ -82,7 +111,8 @@ def build_log_file_path(in_args): log_device_id = in_args.edge_id else: log_device_id = 0 - program_prefix = "FedML-Client({}) @device-id-{}".format(in_args.rank, log_device_id) + program_prefix = "FedML-Client({}) @device-id-{}".format( + in_args.rank, log_device_id) if not os.path.exists(in_args.log_file_dir): os.makedirs(in_args.log_file_dir, exist_ok=True) @@ -95,6 +125,13 @@ def build_log_file_path(in_args): return log_file_path, program_prefix def log_upload(self, run_id, device_id): + """ + Upload logs to the log server. + + Args: + run_id: The ID of the run. + device_id: The ID of the device. + """ # read log data from local log file log_lines = self.log_read() if log_lines is None or len(log_lines) <= 0: @@ -131,7 +168,8 @@ def log_upload(self, run_id, device_id): prev_line_prefix_list[2]) if not str(log_lines[index]).startswith('[FedML-'): - log_line = "{} {}".format(prev_line_prefix, log_lines[index]) + log_line = "{} {}".format( + prev_line_prefix, log_lines[index]) log_lines[index] = log_line index += 1 @@ -146,7 +184,8 @@ def log_upload(self, run_id, device_id): for log_index in range(len(upload_lines)): log_line = str(upload_lines[log_index]) if log_line.find(' [ERROR] ') != -1: - err_line_dict = {"errLine": self.log_uploaded_line_index + log_index, "errMsg": log_line} + err_line_dict = { + "errLine": self.log_uploaded_line_index + log_index, "errMsg": log_line} err_list.append(err_line_dict) log_upload_request = { @@ -165,10 +204,12 @@ def log_upload(self, run_id, device_id): if self.log_source is not None and self.log_source != "": log_upload_request["source"] = self.log_source - log_headers = {'Content-Type': 'application/json', 'Connection': 'close'} + log_headers = {'Content-Type': 'application/json', + 'Connection': 'close'} # send log data to the log server - _, cert_path = MLOpsConfigs.get_instance(self.args).get_request_params() + _, cert_path = MLOpsConfigs.get_instance( + self.args).get_request_params() if cert_path is not None: try: requests.session().verify = cert_path @@ -187,7 +228,8 @@ def log_upload(self, run_id, device_id): # logging.info(f"FedMLDebug POST log to server run_id {run_id}, device_id {device_id}. response.status_code: {response.status_code}") else: # logging.info(f"FedMLDebug POST log to server. run_id {run_id}, device_id {device_id}") - response = requests.post(self.log_server_url, headers=log_headers, json=log_upload_request) + response = requests.post( + self.log_server_url, headers=log_headers, json=log_upload_request) # logging.info(f"FedMLDebug POST log to server. run_id {run_id}, device_id {device_id}. response.status_code: {response.status_code}") if response.status_code != 200: pass @@ -201,6 +243,15 @@ def log_upload(self, run_id, device_id): @staticmethod def should_ignore_log_line(log_line): + """ + Determine whether to ignore a log line. + + Args: + log_line: The log line to check. + + Returns: + True if the log line should be ignored, False otherwise. + """ # if str is empty, then continue, will move it later if str(log_line) == '' or str(log_line) == '\n': return True @@ -215,6 +266,12 @@ def should_ignore_log_line(log_line): return False def log_process(self, process_event): + """ + Continuously upload log data to the log server. + + Args: + process_event: Event object to control the log processing loop. + """ self.log_process_event = process_event while not self.should_stop(): @@ -228,6 +285,9 @@ def log_process(self, process_event): print("FedDebug log_process STOPPED") def log_relocation(self): + """ + Relocate the log file pointer to the last uploaded log line. + """ log_line_count = self.log_line_index self.log_uploaded_line_index = self.log_line_index while log_line_count > 0: @@ -244,6 +304,9 @@ def log_relocation(self): self.log_line_index = 0 def log_open(self): + """ + Open the log file for reading. + """ try: shutil.copyfile(self.origin_log_file_path, self.log_file_path) if self.log_file is None: @@ -253,6 +316,13 @@ def log_open(self): pass def log_read(self): + """ + Read log data from the log file. + + Returns: + log_lines: A list of log lines read from the file. + """ + self.log_open() if self.log_file is None: @@ -272,6 +342,13 @@ def log_read(self): @staticmethod def __generate_yaml_doc(log_config_object, yaml_file): + """ + Generate a YAML document from a configuration object and save it to a file. + + Args: + log_config_object: The configuration object to serialize. + yaml_file: The path to the YAML file to save. + """ try: file = open(yaml_file, "w", encoding="utf-8") yaml.dump(log_config_object, file) @@ -281,7 +358,15 @@ def __generate_yaml_doc(log_config_object, yaml_file): @staticmethod def __load_yaml_config(yaml_path): - """Helper function to load a yaml config file""" + """ + Load a YAML configuration file. + + Args: + yaml_path: The path to the YAML configuration file. + + Returns: + config_data: The loaded configuration data. + """ with open(yaml_path, "r") as stream: try: return yaml.safe_load(stream) @@ -289,23 +374,50 @@ def __load_yaml_config(yaml_path): raise ValueError("Yaml error - check yaml file") def save_log_config(self): + """ + Save the log configuration to a YAML file, including the log line index. + + This method saves the log line index to the log configuration YAML file + for resuming log processing where it left off. + + Raises: + Exception: If there is an error while saving the configuration. + """ try: - log_config_key = "log_config_{}_{}".format(self.run_id, self.device_id) + log_config_key = "log_config_{}_{}".format( + self.run_id, self.device_id) self.log_config[log_config_key] = dict() self.log_config[log_config_key]["log_line_index"] = self.log_line_index - MLOpsRuntimeLogProcessor.__generate_yaml_doc(self.log_config, self.log_config_file) + MLOpsRuntimeLogProcessor.__generate_yaml_doc( + self.log_config, self.log_config_file) except Exception as e: pass def load_log_config(self): + """ + Load the log configuration from a YAML file. + + This method loads the log configuration, including the log line index, + from the log configuration YAML file. + + Raises: + Exception: If there is an error while loading the configuration. + """ try: - log_config_key = "log_config_{}_{}".format(self.run_id, self.device_id) + log_config_key = "log_config_{}_{}".format( + self.run_id, self.device_id) self.log_config = self.__load_yaml_config(self.log_config_file) self.log_line_index = self.log_config[log_config_key]["log_line_index"] except Exception as e: pass def should_stop(self): + """ + Check if the log processing should stop. + + Returns: + bool: True if the log processing should stop; False otherwise. + """ if self.log_process_event is not None and self.log_process_event.is_set(): return True @@ -324,6 +436,22 @@ def __new__(cls, *args, **kwargs): return MLOpsRuntimeLogDaemon._instance def __init__(self, in_args): + """ + Initialize the MLOpsRuntimeLogDaemon. + + Args: + in_args: Input arguments passed to the daemon. + + Attributes: + args: Input arguments. + edge_id: The ID of the edge device (server or client). + log_server_url: The URL for the log server. + log_file_dir: Directory where log files are stored. + log_child_process_list: List to keep track of child log processing processes. + log_child_process: Reference to the child log processing process. + log_process_event: Event to control log processing. + + """ self.args = in_args if in_args.role == "server": @@ -364,16 +492,43 @@ def __init__(self, in_args): @staticmethod def get_instance(args): + """ + Get an instance of the MLOpsRuntimeLogDaemon. + + Args: + args: Input arguments. + + Returns: + MLOpsRuntimeLogDaemon: An instance of the log daemon. + + """ if MLOpsRuntimeLogDaemon._log_sdk_instance is None: - MLOpsRuntimeLogDaemon._log_sdk_instance = MLOpsRuntimeLogDaemon(args) + MLOpsRuntimeLogDaemon._log_sdk_instance = MLOpsRuntimeLogDaemon( + args) MLOpsRuntimeLogDaemon._log_sdk_instance.log_source = None return MLOpsRuntimeLogDaemon._log_sdk_instance def set_log_source(self, source): + """ + Set the source of log messages. + + Args: + source (str): The source of log messages. + + """ self.log_source = source def start_log_processor(self, log_run_id, log_device_id): + """ + Start a log processor for a specific run and device. + + Args: + log_run_id: The ID of the log run. + log_device_id: The ID of the log device. + + """ + log_processor = MLOpsRuntimeLogProcessor(self.args.using_mlops, log_run_id, log_device_id, self.log_file_dir, self.log_server_url, @@ -389,11 +544,21 @@ def start_log_processor(self, log_run_id, log_device_id): if self.log_child_process is not None: self.log_child_process.start() try: - self.log_child_process_list.index((self.log_child_process, log_run_id, log_device_id)) + self.log_child_process_list.index( + (self.log_child_process, log_run_id, log_device_id)) except ValueError as ex: - self.log_child_process_list.append((self.log_child_process, log_run_id, log_device_id)) + self.log_child_process_list.append( + (self.log_child_process, log_run_id, log_device_id)) def stop_log_processor(self, log_run_id, log_device_id): + """ + Stop a log processor for a specific run and device. + + Args: + log_run_id: The ID of the log run. + log_device_id: The ID of the log device. + + """ if log_run_id is None or log_device_id is None: return @@ -407,6 +572,10 @@ def stop_log_processor(self, log_run_id, log_device_id): break def stop_all_log_processor(self): + """ + Stop all running log processors. + + """ for (log_child_process, _, _) in self.log_child_process_list: if self.log_process_event is not None: self.log_process_event.set() @@ -415,11 +584,13 @@ def stop_all_log_processor(self): if __name__ == "__main__": - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--log_file_dir", "-log", help="log file dir") parser.add_argument("--rank", "-r", type=str, default="1") parser.add_argument("--client_id_list", "-cil", type=str, default="[]") - parser.add_argument("--log_server_url", "-lsu", type=str, default="http://") + parser.add_argument("--log_server_url", "-lsu", + type=str, default="http://") args = parser.parse_args() setattr(args, "using_mlops", True) @@ -427,7 +598,8 @@ def stop_all_log_processor(self): run_id = 9998 device_id = 1 - MLOpsRuntimeLogDaemon.get_instance(args).start_log_processor(run_id, device_id) + MLOpsRuntimeLogDaemon.get_instance( + args).start_log_processor(run_id, device_id) while True: time.sleep(1) diff --git a/python/fedml/core/mlops/mlops_status.py b/python/fedml/core/mlops/mlops_status.py index 1b166aca91..0146da384c 100644 --- a/python/fedml/core/mlops/mlops_status.py +++ b/python/fedml/core/mlops/mlops_status.py @@ -5,6 +5,21 @@ class MLOpsStatus(Singleton): _status_instance = None def __init__(self): + """ + Initialize an instance of MLOpsStatus. + + This class is a Singleton and should not be instantiated directly. + Use the `get_instance` method to obtain the Singleton instance. + + Attributes: + messenger: Messenger object for communication. + run_id: The ID of the current run. + edge_id: The ID of the edge device. + client_agent_status: A dictionary to store client agent status. + server_agent_status: A dictionary to store server agent status. + client_status: A dictionary to store client status. + server_status: A dictionary to store server status. + """ self.messenger = None self.run_id = None self.edge_id = None @@ -15,31 +30,101 @@ def __init__(self): @staticmethod def get_instance(): + """ + Get the Singleton instance of MLOpsStatus. + + Returns: + MLOpsStatus: The Singleton instance of MLOpsStatus. + """ if MLOpsStatus._status_instance is None: MLOpsStatus._status_instance = MLOpsStatus() return MLOpsStatus._status_instance def set_client_agent_status(self, edge_id, status): + """ + Set the status of a client agent. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client agent. + """ self.client_agent_status[edge_id] = status def set_server_agent_status(self, edge_id, status): + """ + Set the status of a server agent. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the server agent. + """ self.server_agent_status[edge_id] = status def set_client_status(self, edge_id, status): + """ + Set the status of a client. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client. + """ self.client_status[edge_id] = status def set_server_status(self, edge_id, status): + """ + Set the status of a server. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the server. + """ self.server_status[edge_id] = status def get_client_agent_status(self, edge_id): + """ + Get the status of a client agent. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the client agent, or None if not found. + """ return self.client_agent_status.get(edge_id, None) def get_server_agent_status(self, edge_id): + """ + Get the status of a server agent. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the server agent, or None if not found. + """ return self.server_agent_status.get(edge_id, None) def get_client_status(self, edge_id): + """ + Get the status of a client. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the client, or None if not found. + """ return self.client_status.get(edge_id, None) def get_server_status(self, edge_id): + """ + Get the status of a server. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the server, or None if not found. + """ return self.server_status.get(edge_id, None) diff --git a/python/fedml/core/mlops/mlops_utils.py b/python/fedml/core/mlops/mlops_utils.py index e8d63088bf..a59d39aa00 100644 --- a/python/fedml/core/mlops/mlops_utils.py +++ b/python/fedml/core/mlops/mlops_utils.py @@ -4,11 +4,23 @@ class MLOpsUtils: + """ + Class for MLOps utilities. + """ _ntp_offset = None BYTES_TO_GB = 1 / (1024 * 1024 * 1024) @staticmethod def calc_ntp_from_config(mlops_config): + """ + Calculate NTP time offset from MLOps configuration. + + Args: + mlops_config (dict): MLOps configuration containing NTP response data. + + Returns: + None: If the necessary NTP response data is missing or invalid. + """ if mlops_config is None: return @@ -25,7 +37,8 @@ def calc_ntp_from_config(mlops_config): return # calculate the time offset(int) - ntp_time = (server_recv_time + server_send_time + device_recv_time - device_send_time) // 2 + ntp_time = (server_recv_time + server_send_time + + device_recv_time - device_send_time) // 2 ntp_offset = ntp_time - device_recv_time # set the time offset @@ -33,20 +46,44 @@ def calc_ntp_from_config(mlops_config): @staticmethod def set_ntp_offset(ntp_offset): + """ + Set the NTP time offset. + + Args: + ntp_offset (int): The NTP time offset. + """ MLOpsUtils._ntp_offset = ntp_offset @staticmethod def get_ntp_time(): + """ + Get the current time adjusted by the NTP offset. + + Returns: + int: The NTP-adjusted current time in milliseconds. + """ if MLOpsUtils._ntp_offset is not None: return int(time.time() * 1000) + MLOpsUtils._ntp_offset return int(time.time() * 1000) @staticmethod def get_ntp_offset(): + """ + Get the current NTP time offset. + + Returns: + int: The NTP time offset. + """ return MLOpsUtils._ntp_offset @staticmethod def write_log_trace(log_trace): + """ + Write a log trace to a file in the "fedml_log" directory. + + Args: + log_trace (str): The log trace to write. + """ log_trace_dir = os.path.join(expanduser("~"), "fedml_log") if not os.path.exists(log_trace_dir): os.makedirs(log_trace_dir, exist_ok=True) diff --git a/python/fedml/core/mlops/stats_impl.py b/python/fedml/core/mlops/stats_impl.py index 51e59e48b9..f1ab974609 100644 --- a/python/fedml/core/mlops/stats_impl.py +++ b/python/fedml/core/mlops/stats_impl.py @@ -28,6 +28,16 @@ def gpu_in_use_by_this_process(gpu_handle: GPUHandle, pid: int) -> bool: + """ + Check if a GPU is in use by a specified process. + + Args: + gpu_handle (GPUHandle): Handle to the GPU to check. + pid (int): The process ID of the target process. + + Returns: + bool: True if the GPU is in use by the specified process; False otherwise. + """ if not psutil: return False @@ -67,6 +77,16 @@ class WandbSystemStats: gpu_count: int def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None: + """ + Initialize the WandbSystemStats instance. + + Args: + settings (SettingsStatic): Settings for system stats tracking. + interface (InterfaceQueue): Interface for publishing stats. + + Raises: + Exception: An exception is raised if GPU initialization fails. + """ try: pynvml.nvmlInit() self.gpu_count = pynvml.nvmlDeviceGetCount() @@ -82,7 +102,8 @@ def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None: self._telem = telemetry.TelemetryRecord() if psutil: net = psutil.net_io_counters() - self.network_init = {"sent": net.bytes_sent, "recv": net.bytes_recv} + self.network_init = { + "sent": net.bytes_sent, "recv": net.bytes_recv} else: wandb.termlog( "psutil not installed, only GPU stats will be reported. Install with pip install psutil" @@ -105,6 +126,9 @@ def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None: wandb.termlog("Error initializing IPUProfiler: " + str(e)) def start(self) -> None: + """ + Start the system stats tracking thread. + """ if self._thread is None: self._shutdown = False self._thread = threading.Thread(target=self._thread_body) @@ -117,23 +141,42 @@ def start(self) -> None: @property def proc(self) -> psutil.Process: + """ + Get the process associated with the current PID. + + Returns: + psutil.Process: A process object for the current PID. + """ return psutil.Process(pid=self._pid) @property def sample_rate_seconds(self) -> float: - """Sample system stats every this many seconds, defaults to 2, min is 0.5""" + """ + Get the system stats sampling rate in seconds. + + Returns: + float: The system stats sampling rate in seconds. + """ sample_rate = self._settings._stats_sample_rate_seconds # TODO: handle self._api.dynamic_settings["system_sample_seconds"] return max(0.5, sample_rate) @property def samples_to_average(self) -> int: - """The number of samples to average before pushing, defaults to 15 valid range (2:30)""" + """ + Get the number of samples to average before pushing. + + Returns: + int: The number of samples to average. + """ samples = self._settings._stats_samples_to_average # TODO: handle self._api.dynamic_settings["system_samples"] return min(30, max(2, samples)) def _thread_body(self) -> None: + """ + Body of the system stats tracking thread. + """ while True: stats = self.stats() for stat, value in stats.items(): @@ -154,6 +197,9 @@ def _thread_body(self) -> None: return def shutdown(self) -> None: + """ + Shutdown the system stats tracking thread. + """ self._shutdown = True try: if self._thread is not None: @@ -164,6 +210,9 @@ def shutdown(self) -> None: self._tpu_profiler.stop() def flush(self) -> None: + """ + Flush and publish system stats. + """ stats = self.stats() for stat, value in stats.items(): # TODO: a bit hacky, we assume all numbers should be averaged. If you want @@ -189,7 +238,8 @@ def stats(self) -> StatsDict: temp = pynvml.nvmlDeviceGetTemperature( handle, pynvml.NVML_TEMPERATURE_GPU ) - in_use_by_us = gpu_in_use_by_this_process(handle, pid=self._pid) + in_use_by_us = gpu_in_use_by_this_process( + handle, pid=self._pid) stats["gpu.{}.{}".format(i, "gpu")] = utilz.gpu stats["gpu.{}.{}".format(i, "memory")] = utilz.memory @@ -200,7 +250,8 @@ def stats(self) -> StatsDict: if in_use_by_us: stats["gpu.process.{}.{}".format(i, "gpu")] = utilz.gpu - stats["gpu.process.{}.{}".format(i, "memory")] = utilz.memory + stats["gpu.process.{}.{}".format( + i, "memory")] = utilz.memory stats["gpu.process.{}.{}".format(i, "memoryAllocated")] = ( memory.used / float(memory.total) ) * 100 @@ -208,17 +259,23 @@ def stats(self) -> StatsDict: # Some GPUs don't provide information about power usage try: - power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 + power_watts = pynvml.nvmlDeviceGetPowerUsage( + handle) / 1000.0 power_capacity_watts = ( - pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) / 1000.0 + pynvml.nvmlDeviceGetEnforcedPowerLimit( + handle) / 1000.0 ) - power_usage = (power_watts / power_capacity_watts) * 100 + power_usage = ( + power_watts / power_capacity_watts) * 100 - stats["gpu.{}.{}".format(i, "powerWatts")] = power_watts - stats["gpu.{}.{}".format(i, "powerPercent")] = power_usage + stats["gpu.{}.{}".format( + i, "powerWatts")] = power_watts + stats["gpu.{}.{}".format( + i, "powerPercent")] = power_usage if in_use_by_us: - stats["gpu.process.{}.{}".format(i, "powerWatts")] = power_watts + stats["gpu.process.{}.{}".format( + i, "powerWatts")] = power_watts stats[ "gpu.process.{}.{}".format(i, "powerPercent") ] = power_usage @@ -238,9 +295,11 @@ def stats(self) -> StatsDict: and self.gpu_count == 0 ): try: - out = subprocess.check_output([util.apple_gpu_stats_binary(), "--json"]) + out = subprocess.check_output( + [util.apple_gpu_stats_binary(), "--json"]) m1_stats = json.loads(out.split(b"\n")[0]) - stats["gpu.0.memory"] = m1_stats["mem_used"] / float(m1_stats["utilization"]/100) + stats["gpu.0.memory"] = m1_stats["mem_used"] / \ + float(m1_stats["utilization"]/100) stats["gpu.0.gpu"] = m1_stats["utilization"] stats["gpu.0.memoryAllocated"] = m1_stats["mem_used"] stats["gpu.0.temp"] = m1_stats["temperature"] @@ -274,7 +333,8 @@ def stats(self) -> StatsDict: stats["disk"] = psutil.disk_usage("/").percent stats["proc.memory.availableMB"] = sysmem.available / 1048576.0 try: - stats["proc.memory.rssMB"] = self.proc.memory_info().rss / 1048576.0 + stats["proc.memory.rssMB"] = self.proc.memory_info().rss / \ + 1048576.0 stats["proc.memory.percent"] = self.proc.memory_percent() stats["proc.cpu.threads"] = self.proc.num_threads() except psutil.NoSuchProcess: diff --git a/python/fedml/core/mlops/system_stats.py b/python/fedml/core/mlops/system_stats.py index bdbd9e7f55..8e82182b7f 100755 --- a/python/fedml/core/mlops/system_stats.py +++ b/python/fedml/core/mlops/system_stats.py @@ -6,6 +6,28 @@ class SysStats: def __init__(self, process_id=None): + """ + Initialize the SysStats object. + + Args: + process_id (int): Optional process ID. Defaults to None. + + Attributes: + sys_stats_impl (WandbSystemStats): Instance of WandbSystemStats for collecting system statistics. + gpu_time_spent_accessing_memory (float): GPU time spent accessing memory. + gpu_power_usage (float): GPU power usage. + gpu_temp (float): GPU temperature. + gpu_memory_allocated (float): GPU memory allocated. + gpu_utilization (float): GPU utilization. + network_traffic (float): Network traffic. + disk_utilization (float): Disk utilization. + process_cpu_threads_in_use (int): Number of CPU threads in use by the process. + process_memory_available (float): Available process memory. + process_memory_in_use (float): Process memory in use. + process_memory_in_use_size (float): Process memory in use (size). + system_memory_utilization (float): System memory utilization. + cpu_utilization (float): CPU utilization. + """ settings = SettingsStatic(d={"_stats_pid": os.getpid() if process_id is None else process_id}) self.sys_stats_impl = WandbSystemStats(settings=settings, interface=None) self.gpu_time_spent_accessing_memory = 0.0 @@ -23,6 +45,9 @@ def __init__(self, process_id=None): self.cpu_utilization = 0.0 def produce_info(self): + """ + Collect system statistics and update attributes. + """ stats = self.sys_stats_impl.stats() self.cpu_utilization = stats.get("cpu", 0.0) From 50a4b9b817a71036fa82f44bfaf840156226cf3c Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 23 Sep 2023 11:57:33 +0530 Subject: [PATCH 32/70] add docstrins --- .../communication/grpc/grpc_server.py | 32 +++ .../communication/grpc/ip_config_utils.py | 9 + .../core/distributed/communication/message.py | 100 +++++++ .../mqtt_thetastore_comm_manager.py | 197 +++++++++++-- .../communication/s3/remote_storage.py | 263 ++++++++++++++---- .../communication/s3/remote_storage_mnn.py | 90 +++++- .../distributed/communication/s3/utils.py | 65 ++++- .../communication/trpc/trpc_comm_manager.py | 100 ++++++- .../communication/trpc/trpc_server.py | 32 +++ .../distributed/communication/trpc/utils.py | 22 +- .../core/distributed/communication/utils.py | 32 +++ .../core/distributed/crypto/crypto_api.py | 37 ++- .../theta_storage/theta_storage.py | 77 +++-- .../web3_storage/web3_storage.py | 35 ++- .../core/distributed/fedml_comm_manager.py | 150 +++++++++- .../core/distributed/flow/fedml_executor.py | 112 +++++++- .../fedml/core/distributed/flow/fedml_flow.py | 198 +++++++++++++ .../core/distributed/flow/test_fedml_flow.py | 72 +++++ .../topology/asymmetric_topology_manager.py | 54 ++++ .../topology/symmetric_topology_manager.py | 53 ++++ 20 files changed, 1589 insertions(+), 141 deletions(-) diff --git a/python/fedml/core/distributed/communication/grpc/grpc_server.py b/python/fedml/core/distributed/communication/grpc/grpc_server.py index de169295aa..67d182cc29 100644 --- a/python/fedml/core/distributed/communication/grpc/grpc_server.py +++ b/python/fedml/core/distributed/communication/grpc/grpc_server.py @@ -10,6 +10,18 @@ class GRPCCOMMServicer(grpc_comm_manager_pb2_grpc.gRPCCommManagerServicer): def __init__(self, host, port, client_num, client_id): + """ + Initializes the gRPC Communication Servicer. + + Args: + host (str): The IP address of the server. + port (int): The port number. + client_num (int): The number of clients. + client_id (int): The client ID. + + Returns: + None + """ # host is the ip address of server self.host = host self.port = port @@ -24,6 +36,16 @@ def __init__(self, host, port, client_num, client_id): self.message_q = queue.Queue() def sendMessage(self, request, context): + """ + Handles the gRPC sendMessage request. + + Args: + request (grpc_comm_manager_pb2.CommRequest): The request message. + context (grpc.ServicerContext): The context of the request. + + Returns: + grpc_comm_manager_pb2.CommResponse: The response message. + """ context_ip = context.peer().split(":")[1] logging.info( "client_{} got something from client_{} from ip address {}".format( @@ -39,4 +61,14 @@ def sendMessage(self, request, context): return response def handleReceiveMessage(self, request, context): + """ + Handles the gRPC handleReceiveMessage request. + + Args: + request (grpc_comm_manager_pb2.CommRequest): The request message. + context (grpc.ServicerContext): The context of the request. + + Returns: + None + """ pass diff --git a/python/fedml/core/distributed/communication/grpc/ip_config_utils.py b/python/fedml/core/distributed/communication/grpc/ip_config_utils.py index 1ebedfd73a..77df94701a 100644 --- a/python/fedml/core/distributed/communication/grpc/ip_config_utils.py +++ b/python/fedml/core/distributed/communication/grpc/ip_config_utils.py @@ -2,6 +2,15 @@ def build_ip_table(path): + """ + Builds an IP table from a CSV file. + + Args: + path (str): The path to the CSV file containing receiver IDs and IP addresses. + + Returns: + dict: A dictionary mapping receiver IDs to IP addresses. + """ ip_config = dict() with open(path, newline="") as csv_file: csv_reader = csv.reader(csv_file) diff --git a/python/fedml/core/distributed/communication/message.py b/python/fedml/core/distributed/communication/message.py index 7d465461e5..df2c2a66a0 100644 --- a/python/fedml/core/distributed/communication/message.py +++ b/python/fedml/core/distributed/communication/message.py @@ -3,6 +3,9 @@ class Message(object): + """ + A class for representing and working with messages in a communication system. + """ MSG_ARG_KEY_OPERATION = "operation" MSG_ARG_KEY_TYPE = "msg_type" @@ -19,6 +22,14 @@ class Message(object): MSG_ARG_KEY_MODEL_PARAMS_KEY = "model_params_key" def __init__(self, type="default", sender_id=0, receiver_id=0): + """ + Initialize a Message instance. + + Args: + type (str): The type of the message. + sender_id (int): The ID of the sender. + receiver_id (int): The ID of the receiver. + """ self.type = str(type) self.sender_id = sender_id self.receiver_id = receiver_id @@ -28,56 +39,145 @@ def __init__(self, type="default", sender_id=0, receiver_id=0): self.msg_params[Message.MSG_ARG_KEY_RECEIVER] = receiver_id def init(self, msg_params): + """ + Initialize the message with the provided message parameters. + + Args: + msg_params (dict): A dictionary of message parameters. + """ self.msg_params = msg_params def init_from_json_string(self, json_string): + """ + Initialize the message from a JSON string. + + Args: + json_string (str): A JSON string representing the message. + """ self.msg_params = json.loads(json_string) self.type = self.msg_params[Message.MSG_ARG_KEY_TYPE] self.sender_id = self.msg_params[Message.MSG_ARG_KEY_SENDER] self.receiver_id = self.msg_params[Message.MSG_ARG_KEY_RECEIVER] def init_from_json_object(self, json_object): + """ + Initialize the message from a JSON object. + + Args: + json_object (dict): A JSON object representing the message. + """ self.msg_params = json_object self.type = self.msg_params[Message.MSG_ARG_KEY_TYPE] self.sender_id = self.msg_params[Message.MSG_ARG_KEY_SENDER] self.receiver_id = self.msg_params[Message.MSG_ARG_KEY_RECEIVER] def get_sender_id(self): + """ + Get the ID of the sender. + + Returns: + int: The sender's ID. + """ return self.sender_id def get_receiver_id(self): + """ + Get the ID of the receiver. + + Returns: + int: The receiver's ID. + """ return self.receiver_id def add_params(self, key, value): + """ + Add a parameter to the message. + + Args: + key (str): The key of the parameter. + value (any): The value of the parameter. + """ self.msg_params[key] = value def get_params(self): + """ + Get all the parameters of the message. + + Returns: + dict: A dictionary of message parameters. + """ return self.msg_params def add(self, key, value): + """ + Add a parameter to the message (alias for add_params). + + Args: + key (str): The key of the parameter. + value (any): The value of the parameter. + """ self.msg_params[key] = value def get(self, key): + """ + Get the value of a parameter by its key. + + Args: + key (str): The key of the parameter. + + Returns: + any: The value of the parameter or None if not found. + """ if key not in self.msg_params.keys(): return None return self.msg_params[key] def get_type(self): + """ + Get the type of the message. + + Returns: + str: The type of the message. + """ return self.msg_params[Message.MSG_ARG_KEY_TYPE] def to_string(self): + """ + Convert the message to a string representation. + + Returns: + dict: A dictionary representing the message. + """ return self.msg_params def to_json(self): + """ + Serialize the message to a JSON string. + + Returns: + str: A JSON string representing the message. + """ json_string = json.dumps(self.msg_params) print("json string size = " + str(sys.getsizeof(json_string))) return json_string def get_content(self): + """ + Get a human-readable representation of the message. + + Returns: + str: A string representing the message content. + """ print_dict = self.msg_params.copy() msg_str = str(self.__to_msg_type_string()) + ": " + str(print_dict) return msg_str def __to_msg_type_string(self): + """ + Get a string representation of the message type. + + Returns: + str: A string representing the message type. + """ type = self.msg_params[Message.MSG_ARG_KEY_TYPE] return type diff --git a/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py b/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py index b1e2cfa3a4..d21a92afda 100755 --- a/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py +++ b/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py @@ -28,6 +28,20 @@ def __init__( client_num=0, args=None ): + """ + Initializes an MQTT-based ThetaStore Communication Manager. + + Args: + config_path (str): The path to the MQTT configuration file. + thetastore_config_path (str): The path to the ThetaStore configuration file. + topic (str, optional): The MQTT topic. Defaults to "fedml". + client_rank (int, optional): The client rank. Defaults to 0. + client_num (int, optional): The number of clients. Defaults to 0. + args (object, optional): Additional arguments. + + Returns: + None + """ self.broker_port = None self.broker_host = None self.mqtt_user = None @@ -44,7 +58,8 @@ def __init__( self.client_real_ids = [] if args.client_id_list is not None: logging.info( - "MqttThetastoreCommManager args client_id_list: " + str(args.client_id_list) + "MqttThetastoreCommManager args client_id_list: " + + str(args.client_id_list) ) self.client_real_ids = json.loads(args.client_id_list) @@ -91,7 +106,8 @@ def __init__( if args.rank == 0: self.top_active_msg = CommunicationConstants.SERVER_TOP_ACTIVE_MSG self.topic_last_will_msg = CommunicationConstants.SERVER_TOP_LAST_WILL_MSG - self.last_will_msg = json.dumps({"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) + self.last_will_msg = json.dumps( + {"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) self.mqtt_mgr = MqttManager(self.broker_host, self.broker_port, self.mqtt_user, self.mqtt_pwd, self.keepalive_time, self._client_id, self.topic_last_will_msg, @@ -104,6 +120,12 @@ def __init__( @property def client_id(self): + """ + Runs the MQTT message loop forever. + + Returns: + None + """ return self._client_id @property @@ -115,6 +137,14 @@ def run_loop_forever(self): def on_connected(self, mqtt_client_object): """ + Callback function when MQTT client is connected. + + Args: + mqtt_client_object (MqttManager): The MQTT client object. + + Returns: + None + [server] sending message topic (publish): serverID_clientID receiving message topic (subscribe): clientID @@ -135,7 +165,8 @@ def on_connected(self, mqtt_client_object): # logging.info("self.client_real_ids = {}".format(self.client_real_ids)) for client_rank in range(0, self.client_num): - real_topic = self._topic + str(self.client_real_ids[client_rank]) + real_topic = self._topic + \ + str(self.client_real_ids[client_rank]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) # logging.info( @@ -146,7 +177,8 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() else: # client - real_topic = self._topic + str(self.server_id) + "_" + str(self.client_real_ids[0]) + real_topic = self._topic + \ + str(self.server_id) + "_" + str(self.client_real_ids[0]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) self._notify_connection_ready() @@ -158,12 +190,39 @@ def on_connected(self, mqtt_client_object): self.is_connected = True def on_disconnected(self, mqtt_client_object): + """ + Callback function when MQTT client is disconnected. + + Args: + mqtt_client_object (MqttManager): The MQTT client object. + + Returns: + None + """ self.is_connected = False def add_observer(self, observer: Observer): + """ + Adds an observer to the communication manager. + + Args: + observer (Observer): The observer to be added. + + Returns: + None + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Removes an observer from the communication manager. + + Args: + observer (Observer): The observer to be removed. + + Returns: + None + """ self._observers.remove(observer) def _notify_connection_ready(self): @@ -185,7 +244,8 @@ def _on_message_impl(self, msg): payload_obj = json.loads(json_payload) sender_id = payload_obj.get(Message.MSG_ARG_KEY_SENDER, "") receiver_id = payload_obj.get(Message.MSG_ARG_KEY_RECEIVER, "") - thetastore_key_str = payload_obj.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + thetastore_key_str = payload_obj.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") thetastore_key_str = str(thetastore_key_str).strip(" ") if thetastore_key_str != "": @@ -195,10 +255,12 @@ def _on_message_impl(self, msg): model_params = self.theta_storage.read_model(thetastore_key_str) Context().add("received_model_cid", thetastore_key_str) - logging.info("Received model cid {}".format(Context().get("received_model_cid"))) + logging.info("Received model cid {}".format( + Context().get("received_model_cid"))) logging.info( - "mqtt_thetastore.on_message: model params length %d" % len(model_params) + "mqtt_thetastore.on_message: model params length %d" % len( + model_params) ) # replace the thetastore object key with raw model params @@ -213,6 +275,14 @@ def _on_message(self, msg): def send_message(self, msg: Message): """ + Sends a message using MQTT. + + Args: + msg (Message): The message to be sent. + + Returns: + None + [server] sending message topic (publish): fedml_runid_serverID_clientID receiving message topic (subscribe): fedml_runid_clientID @@ -227,16 +297,20 @@ def send_message(self, msg: Message): if self.client_id == 0: # topic = "fedml" + "_" + "run_id" + "_0" + "_" + "client_id" topic = self._topic + str(self.server_id) + "_" + str(receiver_id) - logging.info("mqtt_thetastore.send_message: msg topic = %s" % str(topic)) + logging.info( + "mqtt_thetastore.send_message: msg topic = %s" % str(topic)) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") if model_params_obj != "": # thetastore logging.info("mqtt_thetastore.send_message: to python client.") - message_key = model_url = self.theta_storage.write_model(model_params_obj) + message_key = model_url = self.theta_storage.write_model( + model_params_obj) Context().add("sent_model_cid", model_url) - logging.info("Sent model cid {}".format(Context().get("sent_model_cid"))) + logging.info("Sent model cid {}".format( + Context().get("sent_model_cid"))) logging.info( "mqtt_thetastore.send_message: thetastore+MQTT msg sent, thetastore message key = %s" % message_key @@ -261,12 +335,15 @@ def send_message(self, msg: Message): topic = self._topic + str(msg.get_sender_id()) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") if model_params_obj != "": # thetastore - message_key = model_url = self.theta_storage.write_model(model_params_obj) + message_key = model_url = self.theta_storage.write_model( + model_params_obj) Context().add("sent_model_cid", model_url) - logging.info("Sent model cid {}".format(Context().get("sent_model_cid"))) + logging.info("Sent model cid {}".format( + Context().get("sent_model_cid"))) logging.info( "mqtt_thetastore.send_message: thetastore+MQTT msg sent, message_key = %s" % message_key @@ -286,20 +363,52 @@ def send_message(self, msg: Message): self.mqtt_mgr.send_message(topic, json.dumps(payload)) def send_message_json(self, topic_name, json_message): + """ + Sends a JSON message using MQTT. + + Args: + topic_name (str): The MQTT topic name. + json_message (str): The JSON message to be sent. + + Returns: + None + """ self.mqtt_mgr.send_message_json(topic_name, json_message) def handle_receive_message(self): + """ + Handles the reception of messages. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) self.run_loop_forever() - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) def stop_receive_message(self): + """ + Stops the reception of messages and disconnects the MQTT client. + + Returns: + None + """ logging.info("mqtt_thetastore.stop_receive_message: stopping...") self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() def set_config_from_file(self, config_file_path): + """ + Sets the MQTT configuration from a file. + + Args: + config_file_path (str): The path to the MQTT configuration file. + + Returns: + None + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -315,6 +424,15 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, mqtt_config): + """ + Sets the MQTT configuration from an object. + + Args: + mqtt_config (dict): The MQTT configuration. + + Returns: + None + """ self.broker_host = mqtt_config["BROKER_HOST"] self.broker_port = mqtt_config["BROKER_PORT"] self.mqtt_user = None @@ -325,21 +443,49 @@ def set_config_from_objects(self, mqtt_config): self.mqtt_pwd = mqtt_config["MQTT_PWD"] def callback_client_last_will_msg(self, topic, payload): + """ + Callback function for processing client last will messages. + + Args: + topic (str): The MQTT topic. + payload (str): The message payload. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) if edge_id is not None and status == CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE: if self.client_active_list.get(edge_id, None) is not None: self.client_active_list.pop(edge_id) def callback_client_active_msg(self, topic, payload): + """ + Callback function for processing client active status messages. + + Args: + topic (str): The MQTT topic. + payload (str): The message payload. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) if edge_id is not None: self.client_active_list[edge_id] = status def subscribe_client_status_message(self): + """ + Subscribes to client status messages. + + Returns: + None + """ # Setup MQTT message listener to the last will message form the client. self.mqtt_mgr.add_message_listener(CommunicationConstants.CLIENT_TOP_LAST_WILL_MSG, self.callback_client_last_will_msg) @@ -349,11 +495,26 @@ def subscribe_client_status_message(self): self.callback_client_active_msg) def get_client_status(self, client_id): + """ + Gets the status of a specific client. + + Args: + client_id (int): The client ID. + + Returns: + str: The status of the client. + """ return self.client_active_list.get(client_id, CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) def get_client_list_status(self): + """ + Gets the status of all clients. + + Returns: + dict: A dictionary of client statuses. + """ return self.client_active_list if __name__ == "__main__": - pass \ No newline at end of file + pass diff --git a/python/fedml/core/distributed/communication/s3/remote_storage.py b/python/fedml/core/distributed/communication/s3/remote_storage.py index f9e3416b34..22fc82b780 100644 --- a/python/fedml/core/distributed/communication/s3/remote_storage.py +++ b/python/fedml/core/distributed/communication/s3/remote_storage.py @@ -26,6 +26,15 @@ class S3Storage: def __init__(self, s3_config_path): + """ + Initializes an S3MNNStorage instance with S3 configuration. + + Args: + s3_config_path (str): The path to the S3 configuration file. + + Returns: + None + """ self.bucket_name = None self.cn_region_name = None self.cn_s3_sak = None @@ -49,6 +58,16 @@ def __init__(self, s3_config_path): ) def write_model(self, message_key, model): + """ + Writes a machine learning model to S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The machine learning model to be stored. + + Returns: + str: The URL of the stored model in S3. + """ global aws_s3_client pickle_dump_start_time = time.time() MLOpsProfilerEvent.log_to_wandb( @@ -62,19 +81,23 @@ def write_model(self, message_key, model): model_file_size = len(model_to_send) model_file_transfered = 0 prev_progress = 0 + def upload_model_progress(bytes_transferred): nonlocal model_file_transfered nonlocal model_file_size - nonlocal prev_progress # since the callback is stateless, we need to keep the previous progress + # since the callback is stateless, we need to keep the previous progress + nonlocal prev_progress model_file_transfered += bytes_transferred uploaded_kb = format(model_file_transfered / 1024, '.2f') - progress = (model_file_transfered / model_file_size * 100) if model_file_size != 0 else 0 + progress = (model_file_transfered / model_file_size * + 100) if model_file_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model uploaded to S3 size {} KB, progress {}%".format(uploaded_kb, progress_format_int)) + logging.info("model uploaded to S3 size {} KB, progress {}%".format( + uploaded_kb, progress_format_int)) prev_progress = progress_format_int - + aws_s3_client.upload_fileobj( Fileobj=io.BytesIO(model_to_send), Bucket=self.bucket_name, Key=message_key, Callback=upload_model_progress, @@ -90,6 +113,16 @@ def upload_model_progress(bytes_transferred): return model_url def write_model_net(self, message_key, model, dummy_input_tensor, local_model_cache_path): + """ + Writes a machine learning model to S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The machine learning model to be stored. + + Returns: + str: The URL of the stored model in S3. + """ global aws_s3_client pickle_dump_start_time = time.time() MLOpsProfilerEvent.log_to_wandb( @@ -117,21 +150,25 @@ def write_model_net(self, message_key, model, dummy_input_tensor, local_model_ca model_to_send.seek(0, 0) net_file_transfered = 0 prev_progress = 0 + def upload_model_net_progress(bytes_transferred): nonlocal net_file_transfered nonlocal net_file_size - nonlocal prev_progress # since the callback is stateless, we need to keep the previous progress + # since the callback is stateless, we need to keep the previous progress + nonlocal prev_progress net_file_transfered += bytes_transferred uploaded_kb = format(net_file_transfered / 1024, '.2f') - progress = (net_file_transfered / net_file_size * 100) if net_file_size != 0 else 0 + progress = (net_file_transfered / net_file_size * + 100) if net_file_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model net uploaded to S3 size {} KB, progress {}%".format(uploaded_kb, progress_format_int)) + logging.info("model net uploaded to S3 size {} KB, progress {}%".format( + uploaded_kb, progress_format_int)) prev_progress = progress_format_int aws_s3_client.upload_fileobj( Fileobj=model_to_send, Bucket=self.bucket_name, Key=message_key, - Callback= upload_model_net_progress, + Callback=upload_model_net_progress, ) MLOpsProfilerEvent.log_to_wandb( {"Comm/send_delay": time.time() - s3_upload_start_time} @@ -144,6 +181,18 @@ def upload_model_net_progress(bytes_transferred): return model_url def write_model_input(self, message_key, input_size, input_type, local_model_cache_path): + """ + Writes model input information to S3 storage. + + Args: + message_key (str): The key to identify the stored input information in S3. + input_size: The size of the model input. + input_type: The type of the model input. + local_model_cache_path (str): The local cache path for input information storage. + + Returns: + str: The URL of the stored input information in S3. + """ global aws_s3_client if not os.path.exists(local_model_cache_path): @@ -157,7 +206,8 @@ def write_model_input(self, message_key, input_size, input_type, local_model_cac json.dump(model_input_dict, f) with open(model_input_path, 'rb') as f: - aws_s3_client.upload_fileobj(f, Bucket=self.bucket_name, Key=message_key) + aws_s3_client.upload_fileobj( + f, Bucket=self.bucket_name, Key=message_key) model_input_url = aws_s3_client.generate_presigned_url("get_object", ExpiresIn=60 * 60 * 24 * 5, @@ -165,6 +215,16 @@ def write_model_input(self, message_key, input_size, input_type, local_model_cac return model_input_url def write_model_web(self, message_key, model): + """ + Writes a machine learning model to S3 storage in web format. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The machine learning model to be stored. + + Returns: + str: The URL of the stored model in S3. + """ global aws_s3_client pickle_dump_start_time = time.time() MLOpsProfilerEvent.log_to_wandb( @@ -189,6 +249,15 @@ def write_model_web(self, message_key, model): return model_url def read_model(self, message_key): + """ + Reads a machine learning model from S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + + Returns: + model: The machine learning model retrieved from S3. + """ global aws_s3_client message_handler_start_time = time.time() @@ -200,7 +269,8 @@ def read_model(self, message_key): os.makedirs(cache_dir) except Exception as e: pass - temp_base_file_path = os.path.join(cache_dir, str(os.getpid()) + "@" + str(uuid.uuid4())) + temp_base_file_path = os.path.join( + cache_dir, str(os.getpid()) + "@" + str(uuid.uuid4())) if not os.path.exists(temp_base_file_path): try: os.makedirs(temp_base_file_path) @@ -211,22 +281,25 @@ def read_model(self, message_key): logging.info("temp_file_path = {}".format(temp_file_path)) model_file_transfered = 0 prev_progress = 0 + def read_model_progress(bytes_transferred): nonlocal model_file_transfered nonlocal object_size nonlocal prev_progress model_file_transfered += bytes_transferred readed_kb = format(model_file_transfered / 1024, '.2f') - progress = (model_file_transfered / object_size * 100) if object_size != 0 else 0 + progress = (model_file_transfered / object_size * + 100) if object_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model readed from S3 size {} KB, progress {}%".format(readed_kb, progress_format_int)) + logging.info("model readed from S3 size {} KB, progress {}%".format( + readed_kb, progress_format_int)) prev_progress = progress_format_int with open(temp_file_path, 'wb') as f: aws_s3_client.download_fileobj(Bucket=self.bucket_name, Key=message_key, Fileobj=f, - Callback=read_model_progress) + Callback=read_model_progress) MLOpsProfilerEvent.log_to_wandb( {"Comm/recieve_delay_s3": time.time() - message_handler_start_time} ) @@ -242,6 +315,16 @@ def read_model_progress(bytes_transferred): return model def read_model_net(self, message_key, local_model_cache_path): + """ + Reads a machine learning model in net format from S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + local_model_cache_path (str): The local cache path for model storage. + + Returns: + model: The machine learning model retrieved from S3. + """ global aws_s3_client message_handler_start_time = time.time() @@ -259,21 +342,25 @@ def read_model_net(self, message_key, local_model_cache_path): logging.info("temp_file_path = {}".format(temp_file_path)) model_file_transfered = 0 prev_progress = 0 + def read_model_net_progress(bytes_transferred): nonlocal model_file_transfered nonlocal object_size - nonlocal prev_progress # since the callback is stateless, we need to keep the previous progress + # since the callback is stateless, we need to keep the previous progress + nonlocal prev_progress model_file_transfered += bytes_transferred readed_kb = format(model_file_transfered / 1024, '.2f') - progress = (model_file_transfered / object_size * 100) if object_size != 0 else 0 + progress = (model_file_transfered / object_size * + 100) if object_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model net readed from S3 size {} KB, progress {}%".format(readed_kb, progress_format_int)) + logging.info("model net readed from S3 size {} KB, progress {}%".format( + readed_kb, progress_format_int)) prev_progress = progress_format_int with open(temp_file_path, 'wb') as f: aws_s3_client.download_fileobj(Bucket=self.bucket_name, Key=message_key, Fileobj=f, - Callback=read_model_net_progress) + Callback=read_model_net_progress) MLOpsProfilerEvent.log_to_wandb( {"Comm/recieve_delay_s3": time.time() - message_handler_start_time} ) @@ -291,6 +378,17 @@ def read_model_net_progress(bytes_transferred): return model def read_model_input(self, message_key, local_model_cache_path): + """ + Reads model input information from S3 storage. + + Args: + message_key (str): The key to identify the stored input information in S3. + local_model_cache_path (str): The local cache path for input information storage. + + Returns: + input_size: The size of the model input. + input_type: The type of the model input. + """ global aws_s3_client temp_base_file_path = local_model_cache_path @@ -304,7 +402,8 @@ def read_model_input(self, message_key, local_model_cache_path): os.remove(temp_file_path) logging.info("temp_file_path = {}".format(temp_file_path)) with open(temp_file_path, 'wb') as f: - aws_s3_client.download_fileobj(Bucket=self.bucket_name, Key=message_key, Fileobj=f) + aws_s3_client.download_fileobj( + Bucket=self.bucket_name, Key=message_key, Fileobj=f) with open(temp_file_path, 'r') as f: model_input_dict = json.load(f) @@ -316,9 +415,21 @@ def read_model_input(self, message_key, local_model_cache_path): # TODO: added python torch model to align the Tensorflow parameters from browser def read_model_web(self, message_key, py_model: nn.Module): + """ + Reads a machine learning model in web format from S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + py_model (nn.Module): The PyTorch model to align Tensorflow parameters from the browser. + + Returns: + model: The machine learning model retrieved from S3. + """ + global aws_s3_client message_handler_start_time = time.time() - obj = aws_s3_client.get_object(Bucket=self.bucket_name, Key=message_key) + obj = aws_s3_client.get_object( + Bucket=self.bucket_name, Key=message_key) model_json = obj["Body"].read() if type(model_json) == list: model = load_params_from_tf(py_model, model_json) @@ -368,16 +479,21 @@ def read_model_web(self, message_key, py_model: nn.Module): def upload_file(self, src_local_path, message_key): """ - upload file - :param src_local_path: - :param message_key: - :return: + Uploads a file to S3 storage. + + Args: + src_local_path (str): The local path to the file to be uploaded. + message_key (str): The key to identify the stored file in S3. + + Returns: + str: The URL of the uploaded file. """ try: with open(src_local_path, "rb") as f: global aws_s3_client aws_s3_client.upload_fileobj( - f, self.bucket_name, message_key, ExtraArgs={"ACL": "public-read"} + f, self.bucket_name, message_key, ExtraArgs={ + "ACL": "public-read"} ) model_url = aws_s3_client.generate_presigned_url( @@ -398,10 +514,14 @@ def upload_file(self, src_local_path, message_key): def download_file(self, message_key, path_local): """ - download file - :param message_key: s3 key - :param path_local: local path - :return: + Downloads a file from S3 storage to the local filesystem. + + Args: + message_key (str): The key to identify the file in S3. + path_local (str): The local path where the file should be saved. + + Returns: + None """ retry = 0 while retry < 3: @@ -410,7 +530,8 @@ def download_file(self, message_key, path_local): ) try: global aws_s3_client - aws_s3_client.download_file(self.bucket_name, message_key, path_local) + aws_s3_client.download_file( + self.bucket_name, message_key, path_local) file_size = os.path.getsize(path_local) logging.info( f"Downloading completed. | size: {round(file_size / 1048576, 2)} MB" @@ -425,12 +546,16 @@ def download_file(self, message_key, path_local): def upload_file_with_progress(self, src_local_path, dest_s3_path, out_progress_to_err=True, progress_desc=None): """ - upload file - :param out_progress_to_err: - :param progress_desc: - :param src_local_path: - :param dest_s3_path: - :return: + Uploads a file to S3 storage with progress tracking. + + Args: + src_local_path (str): The local path to the file to be uploaded. + dest_s3_path (str): The key to identify the stored file in S3. + out_progress_to_err (bool): Whether to output progress to stderr. + progress_desc (str): A description for the progress tracking. + + Returns: + str: The URL of the uploaded file. """ file_uploaded_url = "" progress_desc_text = "Uploading Package to AWS S3" @@ -447,8 +572,10 @@ def upload_file_with_progress(self, src_local_path, dest_s3_path, file=sys.stderr if out_progress_to_err else sys.stdout, desc=progress_desc_text) as pbar: aws_s3_client.upload_fileobj( - f, self.bucket_name, dest_s3_path, ExtraArgs={"ACL": "public-read"}, - Callback=lambda bytes_transferred: pbar.update(bytes_transferred), + f, self.bucket_name, dest_s3_path, ExtraArgs={ + "ACL": "public-read"}, + Callback=lambda bytes_transferred: pbar.update( + bytes_transferred), ) file_uploaded_url = aws_s3_client.generate_presigned_url( @@ -469,12 +596,16 @@ def upload_file_with_progress(self, src_local_path, dest_s3_path, def download_file_with_progress(self, path_s3, path_local, out_progress_to_err=True, progress_desc=None): """ - download file - :param out_progress_to_err: - :param progress_desc: - :param path_s3: s3 key - :param path_local: local path - :return: + Downloads a file from S3 storage to the local filesystem with progress tracking. + + Args: + path_s3 (str): The key to identify the file in S3. + path_local (str): The local path where the file should be saved. + out_progress_to_err (bool): Whether to output progress to stderr. + progress_desc (str): A description for the progress tracking. + + Returns: + None """ retry = 0 progress_desc_text = "Downloading Package from AWS S3" @@ -487,7 +618,8 @@ def download_file_with_progress(self, path_s3, path_local, try: global aws_s3_client kwargs = {"Bucket": self.bucket_name, "Key": path_s3} - object_size = aws_s3_client.head_object(**kwargs)["ContentLength"] + object_size = aws_s3_client.head_object( + **kwargs)["ContentLength"] with tqdm.tqdm(total=object_size, unit="B", unit_scale=True, file=sys.stderr if out_progress_to_err else sys.stdout, desc=progress_desc_text) as pbar: @@ -504,10 +636,14 @@ def download_file_with_progress(self, path_s3, path_local, def test_s3_base_cmds(self, message_key, message_body): """ - test_s3_base_cmds - :param file_key: s3 message key - :param file_key: s3 message body - :return: + Tests basic S3 commands by uploading and downloading a message. + + Args: + message_key (str): The key to identify the stored message in S3. + message_body: The message body to be stored and retrieved. + + Returns: + bool: True if the test is successful, False otherwise. """ retry = 0 while retry < 3: @@ -517,7 +653,8 @@ def test_s3_base_cmds(self, message_key, message_body): aws_s3_client.put_object( Body=message_pkl, Bucket=self.bucket_name, Key=message_key, ACL="public-read", ) - obj = aws_s3_client.get_object(Bucket=self.bucket_name, Key=message_key) + obj = aws_s3_client.get_object( + Bucket=self.bucket_name, Key=message_key) message_pkl_downloaded = obj["Body"].read() message_downloaded = pickle.loads(message_pkl_downloaded) if str(message_body) == str(message_downloaded): @@ -534,15 +671,28 @@ def test_s3_base_cmds(self, message_key, message_body): def delete_s3_zip(self, path_s3): """ - delete s3 object - :param path_s3: s3 key - :return: + Deletes an object from S3 storage. + + Args: + path_s3 (str): The key to identify the object in S3. + + Returns: + None """ global aws_s3_client aws_s3_client.delete_object(Bucket=self.bucket_name, Key=path_s3) logging.info(f"Delete s3 file Successful. | path_s3 = {path_s3}") def set_config_from_file(self, config_file_path): + """ + Sets the S3 configuration from a file. + + Args: + config_file_path (str): The path to the configuration file. + + Returns: + None + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -554,6 +704,15 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, s3_config): + """ + Sets the S3 configuration from a dictionary of S3 configuration values. + + Args: + s3_config (dict): A dictionary containing S3 configuration values. + + Returns: + None + """ self.cn_s3_aki = s3_config["CN_S3_AKI"] self.cn_s3_sak = s3_config["CN_S3_SAK"] self.cn_region_name = s3_config["CN_REGION_NAME"] diff --git a/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py b/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py index f6b0b17a9f..9f4068fe68 100644 --- a/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py +++ b/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py @@ -13,39 +13,113 @@ def __init__(self, s3_config_path): def upload_model_file(self, message_key, model_file_path): """ - this is used for Mobile Platform (MNN) - :param message_key: - :param model_file_path: - :return: + Uploads a model file to S3 storage for Mobile Platform (MNN). + + Args: + message_key (str): The key to identify the uploaded model in S3. + model_file_path (str): The local file path of the model to be uploaded. + + Returns: + bool: True if the upload was successful, False otherwise. """ return self.s3_storage.upload_file(model_file_path, message_key) def download_model_file(self, message_key, model_file_path): """ - this is used for Mobile Platform (MNN) - :param message_key: - :param model_file_path: - :return: + Downloads a model file from S3 storage for Mobile Platform (MNN). + + Args: + message_key (str): The key identifying the model to be downloaded from S3. + model_file_path (str): The local file path where the downloaded model will be saved. + + Returns: + None """ self.s3_storage.download_file(message_key, model_file_path) def write_model(self, message_key, model): + """ + Writes a model object to S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The model object to be stored. + + Returns: + None + """ self.s3_storage.write_model(message_key, model) def read_model(self, message_key): + """ + Reads a model object from S3 storage. + + Args: + message_key (str): The key identifying the model to be read from S3. + + Returns: + object: The model object read from S3. + """ return self.s3_storage.read_model(message_key) def upload_file(self, src_local_path, dest_s3_path): + """ + Uploads a file from the local system to S3 storage. + + Args: + src_local_path (str): The local file path of the file to be uploaded. + dest_s3_path (str): The S3 destination path for the uploaded file. + + Returns: + bool: True if the upload was successful, False otherwise. + """ return self.s3_storage.upload_file(src_local_path, dest_s3_path) def download_file(self, path_s3, path_local): + """ + Downloads a file from S3 storage to the local system. + + Args: + path_s3 (str): The S3 path of the file to be downloaded. + path_local (str): The local file path where the downloaded file will be saved. + + Returns: + None + """ self.s3_storage.download_file(path_s3, path_local) def delete_s3_zip(self, path_s3): + """ + Deletes a ZIP file from S3 storage. + + Args: + path_s3 (str): The S3 path of the ZIP file to be deleted. + + Returns: + None + """ self.s3_storage.delete_s3_zip(path_s3) def set_config_from_file(self, config_file_path): + """ + Sets the S3 configuration from a configuration file. + + Args: + config_file_path (str): The path to the S3 configuration file. + + Returns: + None + """ self.s3_storage.set_config_from_file(config_file_path) def set_config_from_objects(self, s3_config): + """ + Sets the S3 configuration from configuration objects. + + Args: + s3_config: Configuration objects for S3 storage. + + Returns: + None + """ self.s3_storage.set_config_from_objects(s3_config) diff --git a/python/fedml/core/distributed/communication/s3/utils.py b/python/fedml/core/distributed/communication/s3/utils.py index a92d5b8aaa..00be45480b 100644 --- a/python/fedml/core/distributed/communication/s3/utils.py +++ b/python/fedml/core/distributed/communication/s3/utils.py @@ -5,19 +5,23 @@ def load_params_from_tf(py_model:nn.Module, tf_model:list): """ - Load and update the parameters from tensorflow.js to pytorch nn.Module + Load and update the parameters from TensorFlow.js to PyTorch nn.Module. Args: - py_model: An nn.Moudule network structure from pytorch - tf_module: A list read from JSON file which stored the meta data of tensorflow.js model - (length is number of layers, and has two keys in each layer, 'model' and 'params' respectively) + py_model (nn.Module): A PyTorch neural network structure. + tf_model (list): A list read from a JSON file containing metadata for the TensorFlow.js model. Returns: - An updated nn.Module network structure + nn.Module: An updated PyTorch neural network structure. Raises: - Exception: Certain layer structure is not aligned - KeyError: Model layer is not aligned + Exception: If certain layer structures do not align between PyTorch and TensorFlow.js. + KeyError: If a model layer is not aligned. + + This function loads and updates the parameters from a TensorFlow.js model to a PyTorch nn.Module. + It compares layer names between the two models and assigns the TensorFlow.js parameters to the + corresponding layers in the PyTorch model. + """ state_dict = py_model.state_dict() py_layers = list(state_dict.keys()) @@ -41,6 +45,22 @@ def load_params_from_tf(py_model:nn.Module, tf_model:list): raise TypeError("The model structure of pytorch and tensorflow.js is not aligned! Cannot transfer parameters accordingly.") def process_state_dict(state_dict): + """ + Process a PyTorch state dictionary to convert it into a Python dictionary. + + Args: + state_dict (dict): A PyTorch state dictionary containing model parameters. + + Returns: + dict: A Python dictionary where keys are parameter names and values are + NumPy arrays representing the parameter values. + + This function takes a PyTorch state dictionary, which typically contains the + parameters of a neural network model, and converts it into a Python dictionary. + Each key in the resulting dictionary corresponds to a parameter's name, and the + corresponding value is a NumPy array containing the parameter's values. + + """ lr_py = {} for key, value in state_dict.items(): lr_py[key] = value.cpu().detach().numpy().tolist() @@ -48,12 +68,31 @@ def process_state_dict(state_dict): class LogisticRegression(torch.nn.Module): - def __init__(self, input_dim, output_dim): - super(LogisticRegression, self).__init__() - self.linear = torch.nn.Linear(input_dim, output_dim) - def forward(self, x): - outputs = torch.sigmoid(self.linear(x)) - return outputs + def __init__(self, input_dim, output_dim): + """ + Initialize a logistic regression model. + + Args: + input_dim (int): The input dimension. + output_dim (int): The output dimension. + + """ + super(LogisticRegression, self).__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x): + """ + Forward pass of the logistic regression model. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor after applying the sigmoid function. + + """ + outputs = torch.sigmoid(self.linear(x)) + return outputs class CNN_WEB(nn.Module): diff --git a/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py b/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py index 0cbe0bccf1..4f91223213 100644 --- a/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py +++ b/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py @@ -20,6 +20,19 @@ class TRPCCommManager(BaseCommunicationManager): def __init__(self, trpc_master_config_path, process_id=0, world_size=0, args=None): + """ + Initialize a TRPC communication manager. + + Args: + trpc_master_config_path (str): Path to the TRPC master configuration file. + process_id (int): The ID of the current process. + world_size (int): The total number of processes in the world. + args (Optional): Additional arguments. + + Returns: + None + """ + logging.info("using TRPC backend") with open(trpc_master_config_path, newline="") as csv_file: csv_reader = csv.reader(csv_file) @@ -40,19 +53,33 @@ def __init__(self, trpc_master_config_path, process_id=0, world_size=0, args=Non logging.info(f"Worker rank {process_id} initializing RPC") - self.trpc_servicer = TRPCCOMMServicer(master_address, master_port, self.world_size, process_id) + self.trpc_servicer = TRPCCOMMServicer( + master_address, master_port, self.world_size, process_id) logging.info(os.getcwd()) os.environ["MASTER_ADDR"] = self.master_address os.environ["MASTER_PORT"] = self.master_port - self._init_torch_rpc_tp(master_address, master_port, process_id, self.world_size) + self._init_torch_rpc_tp( + master_address, master_port, process_id, self.world_size) self.is_running = True logging.info("server started. master address: " + str(master_address)) def _init_torch_rpc_tp( self, master_addr, master_port, worker_idx, worker_num, ): + """ + Initialize the Torch RPC using TensorPipe backend. + + Args: + master_addr (str): The address of the RPC master. + master_port (str): The port of the RPC master. + worker_idx (int): The index of the current worker. + worker_num (int): The total number of workers. + + Returns: + None + """ # https://github.com/pytorch/pytorch/issues/55615 # [BC-Breaking][RFC] Retire ProcessGroup Backend for RPC #55615 str_init_method = "tcp://" + str(master_addr) + ":" + str(master_port) @@ -73,6 +100,15 @@ def _init_torch_rpc_tp( logging.info("_init_torch_rpc_tp finished.") def send_message(self, msg: Message): + """ + Send a message to the specified receiver. + + Args: + msg (Message): The message to be sent. + + Returns: + None + """ receiver_id = msg.get_receiver_id() logging.info("sending message to {}".format(receiver_id)) @@ -82,21 +118,52 @@ def send_message(self, msg: Message): rpc.rpc_sync( WORKER_NAME.format(receiver_id), TRPCCOMMServicer.sendMessage, args=(self.process_id, msg), ) - MLOpsProfilerEvent.log_to_wandb({"Comm/send_delay": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Comm/send_delay": time.time() - tick}) logging.debug("sent") def add_observer(self, observer: Observer): + """ + Add an observer to the communication manager. + + Args: + observer (Observer): The observer to be added. + + Returns: + None + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer from the communication manager. + + Args: + observer (Observer): The observer to be removed. + + Returns: + None + """ self._observers.remove(observer) def handle_receive_message(self): + """ + Handle receiving messages in a separate thread. + + Returns: + None + """ thread = threading.Thread(target=self.message_handling_subroutine) thread.start() self._notify_connection_ready() def message_handling_subroutine(self): + """ + Subroutine for handling received messages. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) while self.is_running: @@ -105,21 +172,44 @@ def message_handling_subroutine(self): message_handler_start_time = time.time() msg = self.trpc_servicer.message_q.get() self.notify(msg) - MLOpsProfilerEvent.log_to_wandb({"BusyTime": time.time() - message_handler_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"BusyTime": time.time() - message_handler_start_time}) lock.release() - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) return def stop_receive_message(self): + """ + Stop receiving messages and shutdown the communication manager. + + Returns: + None + """ rpc.shutdown() self.is_running = False def notify(self, message: Message): + """ + Notify observers about a received message. + + Args: + message (Message): The received message. + + Returns: + None + """ msg_type = message.get_type() for observer in self._observers: observer.receive_message(msg_type, message) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + + Returns: + None + """ msg_params = Message() msg_params.sender_id = self.rank msg_params.receiver_id = self.rank diff --git a/python/fedml/core/distributed/communication/trpc/trpc_server.py b/python/fedml/core/distributed/communication/trpc/trpc_server.py index 96d0969a38..ef41649f18 100644 --- a/python/fedml/core/distributed/communication/trpc/trpc_server.py +++ b/python/fedml/core/distributed/communication/trpc/trpc_server.py @@ -9,6 +9,18 @@ class TRPCCOMMServicer: _instance = None def __new__(cls, master_address, master_port, client_num, client_id): + """ + Create a new instance of the TRPCCOMMServicer class if it does not exist, otherwise return the existing instance. + + Args: + master_address (str): The address of the RPC master. + master_port (str): The port of the RPC master. + client_num (int): The total number of clients. + client_id (int): The ID of the current client. + + Returns: + TRPCCOMMServicer: An instance of the TRPCCOMMServicer class. + """ cls.master_address = None cls.master_port = None cls.client_num = None @@ -31,6 +43,16 @@ def __new__(cls, master_address, master_port, client_num, client_id): return cls._instance def receiveMessage(self, client_id, message): + """ + Receive a message from another client. + + Args: + client_id (int): The ID of the client sending the message. + message (Message): The received message. + + Returns: + str: A response indicating that the message was received. + """ logging.info( "client_{} got something from client_{}".format( self.client_id, @@ -51,4 +73,14 @@ def receiveMessage(self, client_id, message): @classmethod def sendMessage(cls, clint_id, message): + """ + Send a message to another client. + + Args: + clint_id (int): The ID of the target client. + message (Message): The message to be sent. + + Returns: + None + """ cls._instance.receiveMessage(clint_id, message) \ No newline at end of file diff --git a/python/fedml/core/distributed/communication/trpc/utils.py b/python/fedml/core/distributed/communication/trpc/utils.py index 636750edf2..f8c0b86c85 100644 --- a/python/fedml/core/distributed/communication/trpc/utils.py +++ b/python/fedml/core/distributed/communication/trpc/utils.py @@ -5,8 +5,26 @@ # Generate Device Map for Cuda RPC def set_device_map(options, worker_idx, device_list): + """ + Set the device mapping for PyTorch RPC communication between workers. + + Args: + options (rpc.TensorPipeRpcBackendOptions): The RPC backend options to configure. + worker_idx (int): The index of the current worker. + device_list (list of str): A list of device identifiers for all workers. + + Example: + Suppose you have two workers with GPUs, and `device_list` is ['cuda:0', 'cuda:1']. + If `worker_idx` is 0, this function will set the device mapping for worker 0 as follows: + {WORKER_NAME.format(1): 'cuda:1'} to communicate with worker 1 using 'cuda:1'. + + Returns: + None + """ local_device = device_list[worker_idx] for index, remote_device in enumerate(device_list): - logging.warn(f"Setting device map for client {index} as {remote_device}") + logging.warn( + f"Setting device map for client {index} as {remote_device}") if index != worker_idx: - options.set_device_map(WORKER_NAME.format(index), {local_device: remote_device}) \ No newline at end of file + options.set_device_map(WORKER_NAME.format( + index), {local_device: remote_device}) diff --git a/python/fedml/core/distributed/communication/utils.py b/python/fedml/core/distributed/communication/utils.py index 8bc610309b..521b5cff1d 100755 --- a/python/fedml/core/distributed/communication/utils.py +++ b/python/fedml/core/distributed/communication/utils.py @@ -3,6 +3,14 @@ def log_communication_tick(sender, receiver, timestamp=None): + """ + Log a benchmark tick event from sender to receiver. + + Args: + sender (str): Sender's identifier. + receiver (str): Receiver's identifier. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark tick from {} to {} at {}".format( sender, receiver, timestamp or time() @@ -11,6 +19,14 @@ def log_communication_tick(sender, receiver, timestamp=None): def log_communication_tock(sender, receiver, timestamp=None): + """ + Log a benchmark tock event from sender to receiver. + + Args: + sender (str): Sender's identifier. + receiver (str): Receiver's identifier. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark tock from {} to {} at {}".format( sender, receiver, timestamp or time() @@ -19,6 +35,14 @@ def log_communication_tock(sender, receiver, timestamp=None): def log_round_start(client_idx, round_number, timestamp=None): + """ + Log the start of a benchmark round for a client. + + Args: + client_idx (int): Client's index or identifier. + round_number (int): Round number. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark start round {} for {} at {}".format( round_number, client_idx, timestamp or time() @@ -27,6 +51,14 @@ def log_round_start(client_idx, round_number, timestamp=None): def log_round_end(client_idx, round_number, timestamp=None): + """ + Log the end of a benchmark round for a client. + + Args: + client_idx (int): Client's index or identifier. + round_number (int): Round number. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark end round {} for {} at {}".format( round_number, client_idx, timestamp or time() diff --git a/python/fedml/core/distributed/crypto/crypto_api.py b/python/fedml/core/distributed/crypto/crypto_api.py index e7283aca57..284a39921b 100644 --- a/python/fedml/core/distributed/crypto/crypto_api.py +++ b/python/fedml/core/distributed/crypto/crypto_api.py @@ -6,31 +6,42 @@ def export_public_key(private_key_hex: str) -> bytes: - """Export public key for contract join request. + """ + Export the public key for a contract join request. Args: - private_key: hex string representing private key + private_key_hex (str): Hex string representing the private key. Returns: - 32 bytes representing public key + bytes: 32 bytes representing the public key. """ def _hex_to_bytes(hex: str) -> bytes: + """ + Convert a hex string to bytes. + + Args: + hex (str): Hex string. + + Returns: + bytes: Bytes representation of the hex string. + """ return bytes.fromhex(hex[2:] if hex[:2] == "0x" else hex) return bytes(PrivateKey(_hex_to_bytes(private_key_hex)).public_key) def encrypt_nacl(public_key: bytes, data: bytes) -> bytes: - """Encryption function using NaCl box compatible with MetaMask + """ + Encrypt data using NaCl box compatible with MetaMask. For implementation used in MetaMask look into: https://github.com/MetaMask/eth-sig-util Args: - public_key: public key of recipient - data: message data + public_key (bytes): Public key of the recipient. + data (bytes): Message data to be encrypted. Returns: - encrypted data + bytes: Encrypted data. """ emph_key = PrivateKey.generate() enc_box = Box(emph_key, PublicKey(public_key)) @@ -57,7 +68,17 @@ def decrypt_nacl(private_key: bytes, data: bytes) -> bytes: def get_current_secret(secret: bytes, entry_key_turn: int, key_turn: int) -> bytes: - """Calculate shared secret at current state.""" + """ + Calculate the shared secret at the current state. + + Args: + secret (bytes): Initial secret. + entry_key_turn (int): Entry key turn. + key_turn (int): Key turn. + + Returns: + bytes: The calculated shared secret. + """ for _ in range(entry_key_turn, key_turn): secret = hashlib.sha256(secret).digest() return secret diff --git a/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py b/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py index 2add92cd7c..ff894a9e9e 100644 --- a/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py +++ b/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py @@ -14,19 +14,45 @@ class ThetaStorage: - def __init__( - self, thetasotre_config): + def __init__(self, thetasotre_config): + """ + Initialize a ThetaStorage instance. + + Args: + thetasotre_config (dict): Configuration parameters for ThetaStore. + + Attributes: + ipfs_config (dict): ThetaStore configuration dictionary. + store_home_dir (str): Home directory for ThetaStore. + ipfs_upload_uri (str): URI for uploading files to ThetaStore. + ipfs_download_uri (str): URI for downloading files from ThetaStore. + + """ self.ipfs_config = thetasotre_config - self.store_home_dir = thetasotre_config.get("store_home_dir", "~/edge-store-playground") + self.store_home_dir = thetasotre_config.get( + "store_home_dir", "~/edge-store-playground") if str(self.store_home_dir).startswith("~"): home_dir = expanduser("~") - new_store_dir = str(self.store_home_dir).replace('\\', os.sep).replace('/', os.sep) + new_store_dir = str(self.store_home_dir).replace( + '\\', os.sep).replace('/', os.sep) strip_dir = new_store_dir.lstrip('~').lstrip(os.sep) self.store_home_dir = os.path.join(home_dir, strip_dir) - self.ipfs_upload_uri = thetasotre_config.get("upload_uri", "http://localhost:19888/rpc") - self.ipfs_download_uri = thetasotre_config.get("download_uri", "http://localhost:19888/rpc") + self.ipfs_upload_uri = thetasotre_config.get( + "upload_uri", "http://localhost:19888/rpc") + self.ipfs_download_uri = thetasotre_config.get( + "download_uri", "http://localhost:19888/rpc") def write_model(self, model): + """ + Serialize and upload a machine learning model to ThetaStore. + + Args: + model: The machine learning model to be uploaded. + + Returns: + str: The IPFS key where the model is stored. + + """ pickle_dump_start_time = time.time() model_pkl = pickle.dumps(model) secret_key = Context().get("ipfs_secret_key") @@ -43,7 +69,17 @@ def write_model(self, model): ) return model_url - def read_model(self, message_key): + def read_model(self, message_key): + """ + Download and deserialize a machine learning model from ThetaStore. + + Args: + message_key: The ThetaStore key of the model to be retrieved. + + Returns: + model: The deserialized machine learning model. + + """ message_handler_start_time = time.time() model_pkl, _ = self.storage_ipfs_download_file(message_key) secret_key = Context().get("ipfs_secret_key") @@ -61,13 +97,15 @@ def read_model(self, message_key): return model def storage_ipfs_upload_file(self, file_obj): - """Upload file to IPFS using web3.storage. + """ + Upload a file to ThetaStore using Theta's RPC. Args: - file_obj: file-like object in byte mode + file_obj: A file-like object in byte mode. Returns: - Response: (Successful, cid or error message) + tuple: A tuple containing a boolean indicating success, and either the ThetaStore key or an error message. + """ # Request: upload a file # curl -X POST -H 'Content-Type: application/json' --data '{"jsonrpc":"2.0","method":"edgestore.PutFile","params":[{"path": "theta-edge-store-demos/demos/image/data/smiley_explorer.png"}],"id":1}' http://localhost:19888/rpc @@ -89,10 +127,10 @@ def storage_ipfs_upload_file(self, file_obj): with open(file_path, "wb") as file_handle: file_handle.write(file_obj) - request_data = {"jsonrpc":"2.0", - "method":"edgestore.PutFile", - "params":[{"path": file_path}], - "id":1} + request_data = {"jsonrpc": "2.0", + "method": "edgestore.PutFile", + "params": [{"path": file_path}], + "id": 1} res = httpx.post( self.ipfs_upload_uri, headers={"Content-Type": "application/json"}, @@ -133,10 +171,10 @@ def storage_ipfs_download_file(self, ipfs_cid, output_path=None): # } # } - request_data = {"jsonrpc":"2.0", - "method":"edgestore.GetFile", - "params":[{"key": ipfs_cid}], - "id":1} + request_data = {"jsonrpc": "2.0", + "method": "edgestore.GetFile", + "params": [{"key": ipfs_cid}], + "id": 1} res = httpx.post( self.ipfs_download_uri, headers={"Content-Type": "application/json"}, @@ -154,7 +192,8 @@ def storage_ipfs_download_file(self, ipfs_cid, output_path=None): if download_path is None: return False, "Failed to download file(path is none)." else: - download_path = os.path.join(self.store_home_dir, download_path) + download_path = os.path.join( + self.store_home_dir, download_path) output_file_obj = None file_content = None diff --git a/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py b/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py index f5f5b1a299..a86d622fd6 100644 --- a/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py +++ b/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py @@ -10,13 +10,34 @@ class Web3Storage: - def __init__( - self, ipfs_config): + def __init__(self, ipfs_config): + """ + Initialize a Web3Storage instance. + + Args: + ipfs_config (dict): Configuration parameters for IPFS. + + Attributes: + ipfs_config (dict): IPFS configuration dictionary. + ipfs_upload_uri (str): URI for uploading files to IPFS. + ipfs_download_uri (str): URI for downloading files from IPFS. + + """ self.ipfs_config = ipfs_config self.ipfs_upload_uri = ipfs_config.get("upload_uri", "https://api.web3.storage/upload") self.ipfs_download_uri = ipfs_config.get("download_uri", "ipfs.w3s.link2") def write_model(self, model): + """ + Serialize and upload a machine learning model to IPFS. + + Args: + model: The machine learning model to be uploaded. + + Returns: + str: The IPFS URL where the model is stored. + + """ pickle_dump_start_time = time.time() model_pkl = pickle.dumps(model) secret_key = Context().get("ipfs_secret_key") @@ -34,6 +55,16 @@ def write_model(self, model): return model_url def read_model(self, message_key): + """ + Download and deserialize a machine learning model from IPFS. + + Args: + message_key: The IPFS key of the model to be retrieved. + + Returns: + model: The deserialized machine learning model. + + """ message_handler_start_time = time.time() model_pkl, _ = self.storage_ipfs_download_file(message_key) secret_key = Context().get("ipfs_secret_key") diff --git a/python/fedml/core/distributed/fedml_comm_manager.py b/python/fedml/core/distributed/fedml_comm_manager.py index 5959f175ac..9a40f7398b 100644 --- a/python/fedml/core/distributed/fedml_comm_manager.py +++ b/python/fedml/core/distributed/fedml_comm_manager.py @@ -9,7 +9,55 @@ class FedMLCommManager(Observer): + """ + Communication manager for Federated Machine Learning (FedML). + + Args: + args: Command-line arguments. + comm: The communication backend. + rank: The rank of the current node. + size: The total number of nodes in the communication group. + backend: The communication backend used (e.g., "MPI", "MQTT", "MQTT_S3"). + + Attributes: + args: Command-line arguments. + size: The total number of nodes in the communication group. + rank: The rank of the current node. + backend: The communication backend used. + comm: The communication object. + com_manager: The communication manager. + message_handler_dict: A dictionary to register message handlers. + + Methods: + register_comm_manager(comm_manager): Register a communication manager. + run(): Start the communication manager. + get_sender_id(): Get the sender's ID. + receive_message(msg_type, msg_params): Receive a message and handle it. + send_message(message): Send a message. + send_message_json(topic_name, json_message): Send a JSON message. + register_message_receive_handlers(): Register message receive handlers. + register_message_receive_handler(msg_type, handler_callback_func): Register a message receive handler. + finish(): Finish the communication manager. + get_training_mqtt_s3_config(): Get MQTT and S3 configurations for training. + get_training_mqtt_web3_config(): Get MQTT and Web3 configurations for training. + get_training_mqtt_thetastore_config(): Get MQTT and Thetastore configurations for training. + _init_manager(): Initialize the communication manager based on the selected backend. + """ + def __init__(self, args, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the FedMLCommManager. + + Args: + args: Command-line arguments. + comm: The communication backend. + rank: The rank of the current node. + size: The total number of nodes in the communication group. + backend: The communication backend used (e.g., "MPI", "MQTT", "MQTT_S3"). + + Returns: + None + """ self.args = args self.size = size self.rank = int(rank) @@ -20,21 +68,54 @@ def __init__(self, args, comm=None, rank=0, size=0, backend="MPI"): self._init_manager() def register_comm_manager(self, comm_manager: BaseCommunicationManager): + """ + Register a communication manager. + + Args: + comm_manager (BaseCommunicationManager): The communication manager to register. + + Returns: + None + """ self.com_manager = comm_manager def run(self): + """ + Start the communication manager. + + Returns: + None + """ self.register_message_receive_handlers() logging.info("running") self.com_manager.handle_receive_message() logging.info("finished...") def get_sender_id(self): + """ + Get the sender's ID. + + Returns: + int: The sender's ID (rank). + + """ return self.rank def receive_message(self, msg_type, msg_params) -> None: + """ + Receive a message and handle it. + + Args: + msg_type (str): The type of the received message. + msg_params: Parameters associated with the received message. + + Returns: + None + """ if msg_params.get_sender_id() == msg_params.get_receiver_id(): - logging.info("communication backend is alive (loop_forever, sender 0 to receiver 0)") + logging.info( + "communication backend is alive (loop_forever, sender 0 to receiver 0)") else: logging.info( "receive_message. msg_type = %s, sender_id = %d, receiver_id = %d" @@ -51,19 +132,64 @@ def receive_message(self, msg_type, msg_params) -> None: ) def send_message(self, message): + """ + Send a message. + + Args: + message: The message to send. + + Returns: + None + """ self.com_manager.send_message(message) def send_message_json(self, topic_name, json_message): + """ + Send a JSON message. + + Args: + topic_name (str): The name of the message topic. + json_message: The JSON message to send. + + Returns: + None + """ self.com_manager.send_message_json(topic_name, json_message) @abstractmethod def register_message_receive_handlers(self) -> None: + """ + Register message receive handlers. + + This method should be implemented in derived classes. + + Returns: + None + """ pass def register_message_receive_handler(self, msg_type, handler_callback_func): + """ + Register a message receive handler. + + Args: + msg_type (str): The type of the message to handle. + handler_callback_func: The callback function to handle the message. + + Returns: + None + """ self.message_handler_dict[msg_type] = handler_callback_func def finish(self): + """ + Finish the communication manager. + + Depending on the backend used, this method may perform specific actions to terminate the communication. + + Returns: + None + """ logging.info("__finish") if self.backend == "MPI": from mpi4py import MPI @@ -81,6 +207,13 @@ def finish(self): self.com_manager.stop_receive_message() def get_training_mqtt_s3_config(self): + """ + Get MQTT and S3 configurations for training. + + Returns: + tuple: A tuple containing MQTT configuration and S3 configuration. + + """ mqtt_config = None s3_config = None if hasattr(self.args, "customized_training_mqtt_config") and self.args.customized_training_mqtt_config != "": @@ -88,7 +221,8 @@ def get_training_mqtt_s3_config(self): if hasattr(self.args, "customized_training_s3_config") and self.args.customized_training_s3_config != "": s3_config = self.args.customized_training_s3_config if mqtt_config is None or s3_config is None: - mqtt_config_from_cloud, s3_config_from_cloud = MLOpsConfigs.get_instance(self.args).fetch_configs() + mqtt_config_from_cloud, s3_config_from_cloud = MLOpsConfigs.get_instance( + self.args).fetch_configs() if mqtt_config is None: mqtt_config = mqtt_config_from_cloud if s3_config is None: @@ -104,7 +238,8 @@ def get_training_mqtt_web3_config(self): if hasattr(self.args, "customized_training_web3_config") and self.args.customized_training_web3_config != "": web3_config = self.args.customized_training_web3_config if mqtt_config is None or web3_config is None: - mqtt_config_from_cloud, web3_config_from_cloud = MLOpsConfigs.get_instance(self.args).fetch_web3_configs() + mqtt_config_from_cloud, web3_config_from_cloud = MLOpsConfigs.get_instance( + self.args).fetch_web3_configs() if mqtt_config is None: mqtt_config = mqtt_config_from_cloud if web3_config is None: @@ -120,7 +255,8 @@ def get_training_mqtt_thetastore_config(self): if hasattr(self.args, "customized_training_thetastore_config") and self.args.customized_training_thetastore_config != "": thetastore_config = self.args.customized_training_thetastore_config if mqtt_config is None or thetastore_config is None: - mqtt_config_from_cloud, thetastore_config_from_cloud = MLOpsConfigs.get_instance(self.args).fetch_thetastore_configs() + mqtt_config_from_cloud, thetastore_config_from_cloud = MLOpsConfigs.get_instance( + self.args).fetch_thetastore_configs() if mqtt_config is None: mqtt_config = mqtt_config_from_cloud if thetastore_config is None: @@ -133,7 +269,8 @@ def _init_manager(self): if self.backend == "MPI": from .communication.mpi.com_manager import MpiCommunicationManager - self.com_manager = MpiCommunicationManager(self.comm, self.rank, self.size) + self.com_manager = MpiCommunicationManager( + self.comm, self.rank, self.size) elif self.backend == "MQTT_S3": from .communication.mqtt_s3.mqtt_s3_multi_clients_comm_manager import MqttS3MultiClientsCommManager @@ -202,7 +339,8 @@ def _init_manager(self): ) else: if self.com_manager is None: - raise Exception("no such backend: {}. Please check the comm_backend spelling.".format(self.backend)) + raise Exception( + "no such backend: {}. Please check the comm_backend spelling.".format(self.backend)) else: logging.info("using self-defined communication backend") diff --git a/python/fedml/core/distributed/flow/fedml_executor.py b/python/fedml/core/distributed/flow/fedml_executor.py index 36b44cb4e2..9b7f45c34b 100644 --- a/python/fedml/core/distributed/flow/fedml_executor.py +++ b/python/fedml/core/distributed/flow/fedml_executor.py @@ -2,32 +2,128 @@ class FedMLExecutor(abc.ABC): + """ + Abstract base class for Federated Machine Learning Executors. + + This class defines the basic structure and methods for a FedML executor. + + Args: + id (str): Identifier for the executor. + neighbor_id_list (List[str]): List of neighbor executor IDs. + + Attributes: + id (str): Identifier for the executor. + neighbor_id_list (List[str]): List of neighbor executor IDs. + params (Any): Parameters associated with the executor. + context (Any): Context or environment information. + + Methods: + get_context() -> Any: + Get the context or environment information associated with the executor. + + set_context(context: Any) -> None: + Set the context or environment information for the executor. + + get_params() -> Any: + Get the parameters associated with the executor. + + set_params(params: Any) -> None: + Set the parameters for the executor. + + set_id(id: str) -> None: + Set the identifier for the executor. + + set_neighbor_id_list(neighbor_id_list: List[str]) -> None: + Set the list of neighbor executor IDs. + + get_id() -> str: + Get the identifier of the executor. + + get_neighbor_id_list() -> List[str]: + Get the list of neighbor executor IDs. + """ + def __init__(self, id, neighbor_id_list): + """ + Initialize a FedMLExecutor. + + Args: + id (str): Identifier for the executor. + neighbor_id_list (List[str]): List of neighbor executor IDs. + """ self.id = id self.neighbor_id_list = neighbor_id_list self.params = None self.context = None - def get_context(self): + def get_context(self) -> Any: + """ + Get the context or environment information associated with the executor. + + Returns: + Any: The context or environment information. + """ return self.context - def set_context(self, context): + def set_context(self, context: Any) -> None: + """ + Set the context or environment information for the executor. + + Args: + context (Any): The context or environment information. + """ self.context = context - def get_params(self): + def get_params(self) -> Any: + """ + Get the parameters associated with the executor. + + Returns: + Any: The parameters. + """ return self.params - def set_params(self, params): + def set_params(self, params: Any) -> None: + """ + Set the parameters for the executor. + + Args: + params (Any): The parameters. + """ self.params = params - def set_id(self, id): + def set_id(self, id: str) -> None: + """ + Set the identifier for the executor. + + Args: + id (str): The identifier. + """ self.id = id - def set_neighbor_id_list(self, neighbor_id_list): + def set_neighbor_id_list(self, neighbor_id_list: List[str]) -> None: + """ + Set the list of neighbor executor IDs. + + Args: + neighbor_id_list (List[str]): List of neighbor executor IDs. + """ self.neighbor_id_list = neighbor_id_list - def get_id(self): + def get_id(self) -> str: + """ + Get the identifier of the executor. + + Returns: + str: The identifier. + """ return self.id - def get_neighbor_id_list(self): + def get_neighbor_id_list(self) -> List[str]: + """ + Get the list of neighbor executor IDs. + + Returns: + List[str]: List of neighbor executor IDs. + """ return self.neighbor_id_list diff --git a/python/fedml/core/distributed/flow/fedml_flow.py b/python/fedml/core/distributed/flow/fedml_flow.py index 0bf7dabb5f..1ab2ba5e5f 100644 --- a/python/fedml/core/distributed/flow/fedml_flow.py +++ b/python/fedml/core/distributed/flow/fedml_flow.py @@ -18,6 +18,46 @@ class FedMLAlgorithmFlow(FedMLCommManager): + """ + Base class for defining the flow of a federated machine learning algorithm. + + Args: + args: Arguments for initializing the algorithm flow. + executor (FedMLExecutor): An instance of a FedMLExecutor class to execute tasks within the flow. + + Attributes: + ONCE (str): Flow tag indicating that the flow should run once. + FINISH (str): Flow tag indicating the end of the flow. + executor (FedMLExecutor): An instance of a FedMLExecutor class. + executor_cls_name (str): Name of the executor class. + flow_index (int): Index to keep track of flow sequences. + flow_sequence_original (list): List to store the original flow sequence. + flow_sequence_current_map (dict): Mapping of current flow sequences. + flow_sequence_next_map (dict): Mapping of next flow sequences. + flow_sequence_executed (list): List to store executed flow sequences. + neighbor_node_online_map (dict): Mapping of neighbor node online status. + is_all_neighbor_connected (bool): Flag to indicate if all neighbor nodes are connected. + + Methods: + register_message_receive_handlers(): Register message receive handlers for different message types. + add_flow(flow_name, executor_task, flow_tag=ONCE): Add a flow to the algorithm. + run(): Start running the algorithm flow. + build(): Build the flow sequence and prepare for execution. + _on_ready_to_run_flow(): Handle when the algorithm is ready to run. + _handle_message_received(msg_params): Handle received messages within the flow. + _execute_flow(flow_params, flow_name, executor_task, executor_task_cls_name, flow_tag): Execute a flow task. + __direct_to_next_flow(flow_name, flow_tag): Get the details of the next flow in the sequence. + _send_msg(flow_name, params): Send a message to other nodes. + _handle_flow_finish(msg_params): Handle the finish of the algorithm flow. + __shutdown(): Shutdown the algorithm flow. + _pass_message_locally(flow_name, params): Pass a message to a locally executed flow. + _handle_connection_ready(msg_params): Handle the readiness of the algorithm to run. + _handle_neighbor_report_node_status(msg_params): Handle neighbor nodes reporting their online status. + _handle_neighbor_check_node_status(msg_params): Handle checking of neighbor node status. + _send_message_to_check_neighbor_node_status(receiver_id): Send a message to check neighbor node status. + _send_message_to_report_node_status(receiver_id): Send a message to report node status. + _get_class_that_defined_method(meth): Get the class that defined a method. + """ ONCE = "FLOW_TAG_ONCE" FINISH = "FLOW_TAG_FINISH" @@ -39,6 +79,14 @@ def __init__(self, args, executor: FedMLExecutor): self.is_all_neighbor_connected = False def register_message_receive_handlers(self) -> None: + """ + Register message receive handlers for various message types. + + This method registers message handlers for messages related to the algorithm flow. + + Returns: + None + """ self.register_message_receive_handler(MSG_TYPE_CONNECTION_IS_READY, self._handle_connection_ready) self.register_message_receive_handler( MSG_TYPE_NEIGHBOR_CHECK_NODE_STATUS, self._handle_neighbor_check_node_status, @@ -64,6 +112,17 @@ def register_message_receive_handlers(self) -> None: self.register_message_receive_handler(flow_name, self._handle_message_received) def add_flow(self, flow_name, executor_task: Callable, flow_tag=ONCE): + """ + Add a flow to the algorithm's flow sequence. + + Args: + flow_name (str): Name of the flow. + executor_task (Callable): Callable function representing the task to be executed in the flow. + flow_tag (str): Tag indicating the type of flow (ONCE or FINISH). + + Returns: + None + """ logging.info("flow_name = {}, executor_task = {}".format(flow_name, executor_task)) executor_task_cls_name = self._get_class_that_defined_method(executor_task) @@ -72,9 +131,21 @@ def add_flow(self, flow_name, executor_task: Callable, flow_tag=ONCE): self.flow_index += 1 def run(self): + """ + Start running the algorithm flow. + + Returns: + None + """ super().run() def build(self): + """ + Build the flow sequence and prepare for execution. + + Returns: + None + """ logging.info("self.flow_sequence = {}".format(self.flow_sequence_original)) (flow_name, executor_task, executor_task_cls_name, flow_tag,) = self.flow_sequence_original[ len(self.flow_sequence_original) - 1 @@ -114,6 +185,12 @@ def build(self): logging.info("self.flow_sequence_next_map = {}".format(self.flow_sequence_next_map)) def _on_ready_to_run_flow(self): + """ + Handle when the algorithm is ready to run. + + Returns: + None + """ logging.info("#######_on_ready_to_run_flow#######") ( flow_name_current, @@ -127,6 +204,15 @@ def _on_ready_to_run_flow(self): ) def _handle_message_received(self, msg_params): + """ + Handle received messages within the flow. + + Args: + msg_params (Params): Parameters received in the message. + + Returns: + None + """ flow_name = msg_params.get_type() flow_params = Params() @@ -141,6 +227,19 @@ def _handle_message_received(self, msg_params): self._execute_flow(flow_params, flow_name_next, executor_task_next, executor_task_cls_name_next, flow_tag_next) def _execute_flow(self, flow_params, flow_name, executor_task, executor_task_cls_name, flow_tag): + """ + Execute a flow task. + + Args: + flow_params (Params): Parameters for the flow. + flow_name (str): Name of the flow. + executor_task (Callable): Callable function representing the task to be executed. + executor_task_cls_name (str): Name of the executor task's class. + flow_tag (str): Tag indicating the type of flow (ONCE or FINISH). + + Returns: + None + """ logging.info( "\n\n###########_execute_flow (START). flow_name = {}, executor_task name = {}() #######".format( flow_name, executor_task.__name__ @@ -183,6 +282,16 @@ def _execute_flow(self, flow_params, flow_name, executor_task, executor_task_cls self._send_msg(flow_name, params) def __direct_to_next_flow(self, flow_name, flow_tag): + """ + Determine the next flow to execute based on the current flow. + + Args: + flow_name (str): Name of the current flow. + flow_tag (str): Tag indicating the type of flow (ONCE or FINISH). + + Returns: + Tuple: A tuple containing the name, executor task, executor task class name, and flow tag of the next flow. + """ ( flow_name_next, executor_task_next, @@ -197,6 +306,16 @@ def __direct_to_next_flow(self, flow_name, flow_tag): ) def _send_msg(self, flow_name, params: Params): + """ + Send a message to one or more receivers. + + Args: + flow_name (str): Name of the flow. + params (Params): Parameters to be included in the message. + + Returns: + None + """ sender_id = params.get(PARAMS_KEY_SENDER_ID) receiver_id = params.get(PARAMS_KEY_RECEIVER_ID) logging.info("sender_id = {}, receiver_id = {}".format(sender_id, receiver_id)) @@ -211,9 +330,24 @@ def _send_msg(self, flow_name, params: Params): self.send_message(message) def _handle_flow_finish(self, msg_params): + """ + Handle the completion of the algorithm flow. + + Args: + msg_params (Params): Parameters received in the completion message. + + Returns: + None + """ self.__shutdown() def __shutdown(self): + """ + Shutdown the algorithm flow and terminate communication. + + Returns: + None + """ for rid in self.executor.get_neighbor_id_list(): message = Message(MSG_TYPE_FLOW_FINISH, self.executor.get_id(), rid) self.send_message(message) @@ -221,6 +355,16 @@ def __shutdown(self): self.finish() def _pass_message_locally(self, flow_name, params: Params): + """ + Pass a message locally to be handled within the algorithm. + + Args: + flow_name (str): Name of the flow. + params (Params): Parameters to be included in the message. + + Returns: + None + """ sender_id = params.get(PARAMS_KEY_SENDER_ID) receiver_id = params.get(PARAMS_KEY_RECEIVER_ID) logging.info("sender_id = {}, receiver_id = {}".format(sender_id, receiver_id)) @@ -235,6 +379,15 @@ def _pass_message_locally(self, flow_name, params: Params): self._handle_message_received(message) def _handle_connection_ready(self, msg_params): + """ + Handle the readiness of connections with neighbors. + + Args: + msg_params (Params): Parameters received indicating connection readiness. + + Returns: + None + """ if self.is_all_neighbor_connected: return logging.info("_handle_connection_ready") @@ -243,6 +396,15 @@ def _handle_connection_ready(self, msg_params): self._send_message_to_report_node_status(receiver_id) def _handle_neighbor_report_node_status(self, msg_params): + """ + Handle the reporting of neighbor node statuses. + + Args: + msg_params (Params): Parameters received with neighbor node status information. + + Returns: + None + """ sender_id = msg_params.get_sender_id() logging.info( "_handle_neighbor_report_node_status. node_id = {}, neighbor_id = {} is online".format( @@ -262,10 +424,28 @@ def _handle_neighbor_report_node_status(self, msg_params): self._on_ready_to_run_flow() def _handle_neighbor_check_node_status(self, msg_params): + """ + Handle a message to check the status of a neighbor node. + + Args: + msg_params (Params): Parameters received in the check node status message. + + Returns: + None + """ sender_id = msg_params.get_sender_id() self._send_message_to_report_node_status(sender_id) def _send_message_to_check_neighbor_node_status(self, receiver_id): + """ + Send a message to check the status of a neighbor node. + + Args: + receiver_id (int): ID of the receiver neighbor node. + + Returns: + None + """ message = Message(MSG_TYPE_NEIGHBOR_CHECK_NODE_STATUS, self.executor.get_id(), receiver_id) logging.info( "_send_message_to_check_neighbor_node_status. node_id = {}, neighbor_id = {} is online".format( @@ -275,10 +455,28 @@ def _send_message_to_check_neighbor_node_status(self, receiver_id): self.send_message(message) def _send_message_to_report_node_status(self, receiver_id): + """ + Send a message to report the node status to a neighbor node. + + Args: + receiver_id (int): ID of the receiver neighbor node. + + Returns: + None + """ message = Message(MSG_TYPE_NEIGHBOR_REPORT_NODE_STATUS, self.executor.get_id(), receiver_id) self.send_message(message) def _get_class_that_defined_method(self, meth): + """ + Get the name of the class that defines a method. + + Args: + meth (method/function): The method or function to determine the defining class. + + Returns: + str: The name of the defining class. + """ if inspect.ismethod(meth): for cls in inspect.getmro(meth.__self__.__class__): if cls.__dict__.get(meth.__name__) is meth: diff --git a/python/fedml/core/distributed/flow/test_fedml_flow.py b/python/fedml/core/distributed/flow/test_fedml_flow.py index f1e0ea86de..8e777a24e9 100644 --- a/python/fedml/core/distributed/flow/test_fedml_flow.py +++ b/python/fedml/core/distributed/flow/test_fedml_flow.py @@ -7,6 +7,16 @@ class Client(FedMLExecutor): def __init__(self, args): + """ + Initialize the Client object. + + Args: + args: Command-line arguments or configuration settings. + + Returns: + None + """ + self.args = args id = args.rank neighbor_id_list = [0] @@ -17,17 +27,40 @@ def __init__(self, args): self.model = None def init(self, device, dataset, model): + """ + Initialize the client with device, dataset, and model. + + Args: + device: The device (e.g., CPU or GPU) for training. + dataset: The dataset used for training. + model: The machine learning model used for training. + + Returns: + None + """ self.device = device self.dataset = dataset self.model = model def local_training(self): + """ + Perform local training on the client. + + Returns: + Params: Parameters containing model updates or other relevant information. + """ logging.info("local_training start") params = self.get_params() model_params = params.get(Params.KEY_MODEL_PARAMS) return params def handle_init_global_model(self): + """ + Handle the initialization of the global model on the client. + + Returns: + Params: Parameters containing the model parameters. + """ received_params = self.get_params() model_params = received_params.get(Params.KEY_MODEL_PARAMS) @@ -38,6 +71,15 @@ def handle_init_global_model(self): class Server(FedMLExecutor): def __init__(self, args): + """ + Initialize the Server object. + + Args: + args: Command-line arguments or configuration settings. + + Returns: + None + """ self.args = args id = args.rank neighbor_id_list = [1, 2] @@ -53,17 +95,41 @@ def __init__(self, args): self.client_num = 2 def init(self, device, dataset, model): + """ + Initialize the server with device, dataset, and model. + + Args: + device: The device (e.g., CPU or GPU) for server operations. + dataset: The dataset used for server operations. + model: The machine learning model used for server operations. + + Returns: + None + """ + self.device = device self.dataset = dataset self.model = model def init_global_model(self): + """ + Initialize the global model on the server. + + Returns: + Params: Parameters containing the initial model parameters. + """ logging.info("init_global_model") params = Params() params.add(Params.KEY_MODEL_PARAMS, self.model.state_dict()) return params def server_aggregate(self): + """ + Perform server-side aggregation of client updates. + + Returns: + Params: Parameters containing the aggregated model updates. + """ logging.info("server_aggregate") params = self.get_params() model_params = params.get(Params.KEY_MODEL_PARAMS) @@ -77,6 +143,12 @@ def server_aggregate(self): return params def final_eval(self): + """ + Perform final evaluation or operations on the server. + + Returns: + None + """ logging.info("final_eval") diff --git a/python/fedml/core/distributed/topology/asymmetric_topology_manager.py b/python/fedml/core/distributed/topology/asymmetric_topology_manager.py index c85737608a..2f0abcfab3 100644 --- a/python/fedml/core/distributed/topology/asymmetric_topology_manager.py +++ b/python/fedml/core/distributed/topology/asymmetric_topology_manager.py @@ -15,12 +15,29 @@ class AsymmetricTopologyManager(BaseTopologyManager): """ def __init__(self, n, undirected_neighbor_num=3, out_directed_neighbor=3): + """ + Initialize the AsymmetricTopologyManager. + + Args: + n (int): Number of nodes in the topology. + undirected_neighbor_num (int): Number of undirected (symmetric) neighbors for each node. + out_directed_neighbor (int): Number of out (asymmetric) neighbors for each node. + + Returns: + None + """ self.n = n self.undirected_neighbor_num = undirected_neighbor_num self.out_directed_neighbor = out_directed_neighbor self.topology = [] def generate_topology(self): + """ + Generate the topology based on the specified parameters. + + Returns: + None + """ # randomly add some links for each node (symmetric) k = self.undirected_neighbor_num # print("neighbors = " + str(k)) @@ -81,6 +98,15 @@ def generate_topology(self): self.topology = topology_ring def get_in_neighbor_weights(self, node_index): + """ + Get the weights of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for incoming neighbors. + """ if node_index >= self.n: return [] in_neighbor_weights = [] @@ -89,11 +115,29 @@ def get_in_neighbor_weights(self, node_index): return in_neighbor_weights def get_out_neighbor_weights(self, node_index): + """ + Get the weights of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for outgoing neighbors. + """ if node_index >= self.n: return [] return self.topology[node_index] def get_in_neighbor_idx_list(self, node_index): + """ + Get the indices of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for incoming neighbors. + """ neighbor_in_idx_list = [] neighbor_weights = self.get_in_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): @@ -102,6 +146,16 @@ def get_in_neighbor_idx_list(self, node_index): return neighbor_in_idx_list def get_out_neighbor_idx_list(self, node_index): + """ + Get the indices of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for outgoing neighbors. + """ + neighbor_out_idx_list = [] neighbor_weights = self.get_out_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): diff --git a/python/fedml/core/distributed/topology/symmetric_topology_manager.py b/python/fedml/core/distributed/topology/symmetric_topology_manager.py index 07d90525e4..9f5326ab08 100644 --- a/python/fedml/core/distributed/topology/symmetric_topology_manager.py +++ b/python/fedml/core/distributed/topology/symmetric_topology_manager.py @@ -14,11 +14,28 @@ class SymmetricTopologyManager(BaseTopologyManager): """ def __init__(self, n, neighbor_num=2): + """ + Initialize the SymmetricTopologyManager. + + Args: + n (int): Number of nodes in the topology. + neighbor_num (int): Number of neighbors for each node. + + Returns: + None + """ self.n = n self.neighbor_num = neighbor_num self.topology = [] def generate_topology(self): + """ + Generate the symmetric topology based on the specified parameters. + + Returns: + None + """ + # first generate a ring topology topology_ring = np.array( nx.to_numpy_matrix(nx.watts_strogatz_graph(self.n, 2, 0)), dtype=np.float32 @@ -56,16 +73,43 @@ def generate_topology(self): self.topology = topology_symmetric def get_in_neighbor_weights(self, node_index): + """ + Get the weights of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for incoming neighbors. + """ if node_index >= self.n: return [] return self.topology[node_index] def get_out_neighbor_weights(self, node_index): + """ + Get the weights of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for outgoing neighbors. + """ if node_index >= self.n: return [] return self.topology[node_index] def get_in_neighbor_idx_list(self, node_index): + """ + Get the indices of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for incoming neighbors. + """ neighbor_in_idx_list = [] neighbor_weights = self.get_in_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): @@ -74,6 +118,15 @@ def get_in_neighbor_idx_list(self, node_index): return neighbor_in_idx_list def get_out_neighbor_idx_list(self, node_index): + """ + Get the indices of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for outgoing neighbors. + """ neighbor_out_idx_list = [] neighbor_weights = self.get_out_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): From c80d0d89ce22c12efe84b06ad71a87ec7b0c4622 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sun, 24 Sep 2023 09:30:36 +0530 Subject: [PATCH 33/70] thread --- .../communication/grpc/grpc_comm_manager.py | 133 ++++++-- .../grpc/grpc_comm_manager_pb2_grpc.py | 90 +++++- .../communication/mpi/com_manager.py | 64 +++- .../communication/mpi/mpi_receive_thread.py | 43 ++- .../communication/mpi/mpi_send_thread.py | 45 ++- .../communication/mqtt/mqtt_manager.py | 36 +++ .../mqtt_s3_multi_clients_comm_manager.py | 288 ++++++++++++++++-- .../mqtt_s3_mnn/mqtt_s3_comm_manager.py | 252 +++++++++++++-- 8 files changed, 869 insertions(+), 82 deletions(-) diff --git a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py index 6eb9fe613e..3931120f49 100644 --- a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py +++ b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py @@ -1,3 +1,12 @@ +import csv +import logging +from ...communication.grpc.grpc_server import GRPCCOMMServicer +import time +from fedml.core.mlops.mlops_profiler_event import MLOpsProfilerEvent +from ..constants import CommunicationConstants +from ...communication.observer import Observer +from ...communication.message import Message +from ...communication.base_com_manager import BaseCommunicationManager import os import pickle import threading @@ -10,21 +19,8 @@ lock = threading.Lock() -from ...communication.base_com_manager import BaseCommunicationManager -from ...communication.message import Message -from ...communication.observer import Observer -from ..constants import CommunicationConstants - -from fedml.core.mlops.mlops_profiler_event import MLOpsProfilerEvent - -import time # Check Service or serve? -from ...communication.grpc.grpc_server import GRPCCOMMServicer - -import logging - -import csv class GRPCCommManager(BaseCommunicationManager): @@ -37,6 +33,17 @@ def __init__( client_id=0, client_num=0, ): + """ + Initialize the GRPCCommManager. + + Args: + host (str): The IP address of the server. + port (int): The port number to listen on. + ip_config_path (str): The path to the IP configuration file. + topic (str, optional): The communication topic. Default is "fedml". + client_id (int, optional): The client's ID. Default is 0. + client_num (int, optional): The number of clients. Default is 0. + """ # host is the ip address of server self.host = host self.port = str(port) @@ -61,7 +68,8 @@ def __init__( futures.ThreadPoolExecutor(max_workers=client_num), options=self.opts, ) - self.grpc_servicer = GRPCCOMMServicer(host, port, client_num, client_id) + self.grpc_servicer = GRPCCOMMServicer( + host, port, client_num, client_id) grpc_comm_manager_pb2_grpc.add_gRPCCommManagerServicer_to_server( self.grpc_servicer, self.grpc_server ) @@ -76,13 +84,23 @@ def __init__( logging.info("grpc server started. Listening on port " + str(port)) def send_message(self, msg: Message): + """ + Send a message using gRPC to a specified receiver. + + Args: + msg (Message): The message to send. + + Returns: + None + """ logging.info("msg = {}".format(msg)) # payload = msg.to_json() logging.info("pickle.dumps(msg) START") pickle_dump_start_time = time.time() msg_pkl = pickle.dumps(msg) - MLOpsProfilerEvent.log_to_wandb({"PickleDumpsTime": time.time() - pickle_dump_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"PickleDumpsTime": time.time() - pickle_dump_start_time}) logging.info("pickle.dumps(msg) END") receiver_id = msg.get_receiver_id() @@ -103,27 +121,62 @@ def send_message(self, msg: Message): tick = time.time() stub.sendMessage(request) - MLOpsProfilerEvent.log_to_wandb({"Comm/send_delay": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Comm/send_delay": time.time() - tick}) logging.debug("sent successfully") channel.close() def add_observer(self, observer: Observer): + """ + Add an observer to the communication manager. + + Args: + observer (Observer): The observer to add. + + Returns: + None + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer from the communication manager. + + Args: + observer (Observer): The observer to remove. + + Returns: + None + """ self._observers.remove(observer) def handle_receive_message(self): + """ + Start handling received messages. + + This method initiates the process of receiving and handling messages. + + Returns: + None + """ self._notify_connection_ready() self.message_handling_subroutine() # Cannont run message_handling_subroutine in new thread # Related https://stackoverflow.com/a/70705165 - + # thread = threading.Thread(target=self.message_handling_subroutine) # thread.start() def message_handling_subroutine(self): + """ + Message handling subroutine. + + This method continuously processes received messages. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) while self.is_running: @@ -134,29 +187,58 @@ def message_handling_subroutine(self): logging.info("unpickle START") unpickle_start_time = time.time() msg = pickle.loads(msg_pkl) - MLOpsProfilerEvent.log_to_wandb({"UnpickleTime": time.time() - unpickle_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"UnpickleTime": time.time() - unpickle_start_time}) logging.info("unpickle END") msg_type = msg.get_type() for observer in self._observers: _message_handler_start_time = time.time() observer.receive_message(msg_type, msg) - MLOpsProfilerEvent.log_to_wandb({"MessageHandlerTime": time.time() - _message_handler_start_time}) - MLOpsProfilerEvent.log_to_wandb({"BusyTime": time.time() - busy_time_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"MessageHandlerTime": time.time() - _message_handler_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"BusyTime": time.time() - busy_time_start_time}) lock.release() time.sleep(0.0001) - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) return def stop_receive_message(self): + """ + Stop receiving and processing messages. + + This method stops the communication manager. + + Returns: + None + """ self.grpc_server.stop(None) self.is_running = False def notify(self, message: Message): + """ + Notify observers with a message. + + Args: + message (Message): The message to notify observers with. + + Returns: + None + """ msg_type = message.get_type() for observer in self._observers: observer.receive_message(msg_type, message) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + + This method notifies observers that the communication connection is ready. + + Returns: + None + """ msg_params = Message() msg_params.sender_id = self.rank msg_params.receiver_id = self.rank @@ -165,6 +247,15 @@ def _notify_connection_ready(self): observer.receive_message(msg_type, msg_params) def _build_ip_table(self, path): + """ + Build an IP configuration table from a CSV file. + + Args: + path (str): The path to the CSV file containing IP configuration data. + + Returns: + dict: A dictionary mapping receiver IDs to their corresponding IP addresses. + """ ip_config = dict() with open(path, newline="") as csv_file: csv_reader = csv.reader(csv_file) diff --git a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py index ec24e39df6..063167a020 100644 --- a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py +++ b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py @@ -6,10 +6,15 @@ class gRPCCommManagerStub(object): - """Missing associated documentation comment in .proto file.""" + """ + gRPC Communication Manager Stub. + + This class provides a client-side stub for interacting with the gRPC communication manager service. + """ def __init__(self, channel): - """Constructor. + """ + Initialize the gRPCCommManagerStub. Args: channel: A grpc.Channel. @@ -27,22 +32,53 @@ def __init__(self, channel): class gRPCCommManagerServicer(object): - """Missing associated documentation comment in .proto file.""" + """ + gRPC Communication Manager Servicer. + + This class defines the gRPC service methods for the communication manager. + """ def sendMessage(self, request, context): - """Missing associated documentation comment in .proto file.""" + """ + Handle the sendMessage gRPC service method. + + Args: + request: The request message. + context: The gRPC context. + + Raises: + NotImplementedError: This method is not implemented. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def handleReceiveMessage(self, request, context): - """Missing associated documentation comment in .proto file.""" + """ + Handle the handleReceiveMessage gRPC service method. + + Args: + request: The request message. + context: The gRPC context. + + Raises: + NotImplementedError: This method is not implemented. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def add_gRPCCommManagerServicer_to_server(servicer, server): + """ + Add a gRPC Communication Manager Servicer to a gRPC server. + + This function registers the gRPC service methods provided by the servicer to the gRPC server. + + Args: + servicer: The gRPC Communication Manager Servicer instance. + server: The gRPC server instance to which the servicer will be added. + """ rpc_method_handlers = { "sendMessage": grpc.unary_unary_rpc_method_handler( servicer.sendMessage, @@ -63,7 +99,13 @@ def add_gRPCCommManagerServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. class gRPCCommManager(object): - """Missing associated documentation comment in .proto file.""" + """ + gRPC Communication Manager. + + This class provides static methods for making gRPC calls to the Communication Manager service. + + Note: This class is part of an experimental API. + """ @staticmethod def sendMessage( @@ -78,6 +120,24 @@ def sendMessage( timeout=None, metadata=None, ): + """ + Send a gRPC sendMessage request. + + Args: + request: The request message. + target: The target server to send the request. + options: Additional gRPC options. + channel_credentials: Channel credentials. + call_credentials: Call credentials. + insecure: Whether to use an insecure channel. + compression: Compression method to use. + wait_for_ready: Wait for the server to become ready. + timeout: Request timeout. + metadata: Request metadata. + + Returns: + grpc.Call: A gRPC call instance. + """ return grpc.experimental.unary_unary( request, target, @@ -107,6 +167,24 @@ def handleReceiveMessage( timeout=None, metadata=None, ): + """ + Send a gRPC handleReceiveMessage request. + + Args: + request: The request message. + target: The target server to send the request. + options: Additional gRPC options. + channel_credentials: Channel credentials. + call_credentials: Call credentials. + insecure: Whether to use an insecure channel. + compression: Compression method to use. + wait_for_ready: Wait for the server to become ready. + timeout: Request timeout. + metadata: Request metadata. + + Returns: + grpc.Call: A gRPC call instance. + """ return grpc.experimental.unary_unary( request, target, diff --git a/python/fedml/core/distributed/communication/mpi/com_manager.py b/python/fedml/core/distributed/communication/mpi/com_manager.py index 030b8793ad..e02666ea65 100644 --- a/python/fedml/core/distributed/communication/mpi/com_manager.py +++ b/python/fedml/core/distributed/communication/mpi/com_manager.py @@ -12,7 +12,21 @@ class MpiCommunicationManager(BaseCommunicationManager): + """ + MPI Communication Manager. + + This class manages communication using MPI (Message Passing Interface) for federated learning. + """ + def __init__(self, comm, rank, size): + """ + Initialize the MPI Communication Manager. + + Args: + comm: The MPI communicator. + rank: The rank of the current process. + size: The total number of processes in the communicator. + """ self.comm = comm self.rank = rank self.size = size @@ -39,6 +53,12 @@ def __init__(self, comm, rank, size): # assert False def init_server_communication(self): + """ + Initialize server-side communication components. + + Returns: + Tuple: A tuple containing server send and receive queues. + """ server_send_queue = queue.Queue(0) # self.server_send_thread = MPISendThread( # self.comm, self.rank, self.size, "ServerSendThread", server_send_queue @@ -54,6 +74,12 @@ def init_server_communication(self): return server_send_queue, server_receive_queue def init_client_communication(self): + """ + Initialize client-side communication components. + + Returns: + Tuple: A tuple containing client send and receive queues. + """ # SEND client_send_queue = queue.Queue(0) # self.client_send_thread = MPISendThread( @@ -75,19 +101,43 @@ def init_client_communication(self): # self.q_sender.put(msg) def send_message(self, msg: Message): + """ + Send a message using MPI. + + Args: + msg: The message to be sent. + """ # self.q_sender.put(msg) dest_id = msg.get(Message.MSG_ARG_KEY_RECEIVER) tick = time.time() self.comm.send(msg, dest=dest_id) - MLOpsProfilerEvent.log_to_wandb({"Comm/send_delay": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Comm/send_delay": time.time() - tick}) def add_observer(self, observer: Observer): + """ + Add an observer to the list of observers. + + Args: + observer: The observer to be added. + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer from the list of observers. + + Args: + observer: The observer to be removed. + """ self._observers.remove(observer) def handle_receive_message(self): + """ + Handle receiving messages using MPI. + + This function continuously listens for incoming messages and notifies observers when a message is received. + """ self.is_running = True # the first message after connection, aligned the protocol with MQTT + S3 self._notify_connection_ready() @@ -108,6 +158,9 @@ def handle_receive_message(self): logging.info("!!!!!!handle_receive_message stopped!!!") def stop_receive_message(self): + """ + Stop receiving messages and threads. + """ self.is_running = False # self.__stop_thread(self.server_send_thread) self.__stop_thread(self.server_receive_thread) @@ -117,11 +170,20 @@ def stop_receive_message(self): self.__stop_thread(self.client_collective_thread) def notify(self, msg_params): + """ + Notify observers with the received message. + + Args: + msg_params: The received message. + """ msg_type = msg_params.get_type() for observer in self._observers: observer.receive_message(msg_type, msg_params) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + """ msg_params = Message() msg_params.sender_id = self.rank msg_params.receiver_id = self.rank diff --git a/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py b/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py index b10f4d52ff..63ca19fbe0 100644 --- a/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py +++ b/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py @@ -7,7 +7,23 @@ class MPIReceiveThread(threading.Thread): + """ + MPI Receive Thread. + + This thread is responsible for receiving messages using MPI. + """ + def __init__(self, comm, rank, size, name, q): + """ + Initialize the MPI Receive Thread. + + Args: + comm: The MPI communicator. + rank: The rank of the current process. + size: The total number of processes in the communicator. + name: The name of the thread. + q: The message queue to store received messages. + """ super(MPIReceiveThread, self).__init__() self._stop_event = threading.Event() self.comm = comm @@ -17,28 +33,44 @@ def __init__(self, comm, rank, size, name, q): self.q = q def run(self): + """ + Run the MPI Receive Thread. + + This method continuously listens for incoming messages and puts them into the message queue. + """ logging.debug( "Starting Thread:" + self.name + ". Process ID = " + str(self.rank) ) while True: try: msg = self.comm.recv() - # Ugly delete comments - # msg_str = self.comm.recv() - # msg = Message() - # msg.init(msg_str) self.q.put(msg) except Exception: traceback.print_exc() raise Exception("MPI failed!") def stop(self): + """ + Stop the MPI Receive Thread. + """ self._stop_event.set() def stopped(self): + """ + Check if the MPI Receive Thread is stopped. + + Returns: + bool: True if the thread is stopped, False otherwise. + """ return self._stop_event.is_set() def get_id(self): + """ + Get the ID of the thread. + + Returns: + int: The ID of the thread. + """ # returns id of the respective thread if hasattr(self, "_thread_id"): return self._thread_id @@ -47,6 +79,9 @@ def get_id(self): return id def raise_exception(self): + """ + Raise an exception in the MPI Receive Thread to stop it. + """ thread_id = self.get_id() res = ctypes.pythonapi.PyThreadState_SetAsyncExc( thread_id, ctypes.py_object(SystemExit) diff --git a/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py b/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py index 39ebb5599a..6d67938fd4 100644 --- a/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py +++ b/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py @@ -1,6 +1,3 @@ -# Ugly delete file - - import ctypes import logging import threading @@ -11,7 +8,23 @@ class MPISendThread(threading.Thread): + """ + MPI Send Thread. + + This thread is responsible for sending messages using MPI. + """ + def __init__(self, comm, rank, size, name, q): + """ + Initialize the MPI Send Thread. + + Args: + comm: The MPI communicator. + rank: The rank of the current process. + size: The total number of processes in the communicator. + name: The name of the thread. + q: The message queue to get messages to send. + """ super(MPISendThread, self).__init__() self._stop_event = threading.Event() self.comm = comm @@ -21,7 +34,13 @@ def __init__(self, comm, rank, size, name, q): self.q = q def run(self): - logging.debug("Starting " + self.name + ". Process ID = " + str(self.rank)) + """ + Run the MPI Send Thread. + + This method continuously checks the message queue and sends messages to the specified destination. + """ + logging.debug("Starting " + self.name + + ". Process ID = " + str(self.rank)) while True: try: if not self.q.empty(): @@ -35,12 +54,27 @@ def run(self): raise Exception("MPI failed!") def stop(self): + """ + Stop the MPI Send Thread. + """ self._stop_event.set() def stopped(self): + """ + Check if the MPI Send Thread is stopped. + + Returns: + bool: True if the thread is stopped, False otherwise. + """ return self._stop_event.is_set() def get_id(self): + """ + Get the ID of the thread. + + Returns: + int: The ID of the thread. + """ # returns id of the respective thread if hasattr(self, "_thread_id"): return self._thread_id @@ -49,6 +83,9 @@ def get_id(self): return id def raise_exception(self): + """ + Raise an exception in the MPI Send Thread to stop it. + """ thread_id = self.get_id() res = ctypes.pythonapi.PyThreadState_SetAsyncExc( thread_id, ctypes.py_object(SystemExit) diff --git a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py index 2cf9f11b3e..18c1064165 100644 --- a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py +++ b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py @@ -13,6 +13,19 @@ class MqttManager(object): def __init__(self, host, port, user, pwd, keepalive_time, client_id, last_will_topic=None, last_will_msg=None): + """ + MQTT Manager for handling MQTT connections, sending, and receiving messages. + + Args: + host (str): MQTT broker host. + port (int): MQTT broker port. + user (str): MQTT username. + pwd (str): MQTT password. + keepalive_time (int): Keepalive time for the MQTT connection. + client_id (str): Client ID for the MQTT client. + last_will_topic (str, optional): Last will topic for the MQTT client. + last_will_msg (str, optional): Last will message for the MQTT client. + """ self._client = None self.mqtt_connection_id = None self._host = host @@ -44,6 +57,7 @@ def __del__(self): self._client = None def init_connect(self): + self.mqtt_connection_id = "{}_{}".format(self._client_id, "ID") self._client = mqtt.Client(client_id=self.mqtt_connection_id, clean_session=False) self._client.connected_flag = False @@ -90,6 +104,17 @@ def loop_forever(self): self._client.loop_forever(retry_first_connection=True) def send_message(self, topic, message, publish_single_message=False): + """ + Send an MQTT message. + + Args: + topic (str): The MQTT topic to which the message will be sent. + message (str): The message to send. + publish_single_message (bool, optional): If True, publish as a single message; otherwise, use MQTT publish. + + Returns: + bool: True if the message was successfully sent, False otherwise. + """ logging.info( f"FedMLDebug - Send: topic ({topic}), message ({message})" ) @@ -110,6 +135,17 @@ def send_message(self, topic, message, publish_single_message=False): return True def send_message_json(self, topic, message, publish_single_message=False): + """ + Send an MQTT message as JSON. + + Args: + topic (str): The MQTT topic to which the message will be sent. + message (str): The message to send as JSON. + publish_single_message (bool, optional): If True, publish as a single message; otherwise, use MQTT publish. + + Returns: + bool: True if the message was successfully sent, False otherwise. + """ logging.info( f"FedMLDebug - Send: topic ({topic}), message ({message})" ) diff --git a/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py b/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py index 50f0908ad2..454fef71c8 100755 --- a/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py +++ b/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py @@ -18,6 +18,44 @@ class MqttS3MultiClientsCommManager(BaseCommunicationManager): + """ + MQTT communication manager for multi-client federated learning. + + This class provides an MQTT-based communication manager for multi-client federated learning scenarios. + It supports communication between a central server and multiple client devices. + + Args: + config_path (str): Path to the MQTT configuration file. + s3_config_path (str): Path to the S3 storage configuration file. + topic (str): The MQTT topic prefix. + client_rank (int): The rank or ID of the client. + client_num (int): The total number of clients. + args (object): Additional configuration arguments. + + Attributes: + client_id (str): The unique ID of the MQTT client. + topic (str): The MQTT topic. + is_connected (bool): Indicates if the MQTT client is connected to the broker. + client_active_list (dict): A dictionary to store the status of connected clients. + + Methods: + run_loop_forever(): Run the MQTT loop forever to handle incoming messages. + on_connected(mqtt_client_object): MQTT on_connected callback. + on_disconnected(mqtt_client_object): MQTT on_disconnected callback. + add_observer(observer): Add an observer to receive messages. + remove_observer(observer): Remove an observer. + send_message(msg, wait_for_publish=False): Send a message using MQTT. + send_message_json(topic_name, json_message): Send a JSON message using MQTT. + handle_receive_message(): Start handling received messages by running the MQTT loop. + stop_receive_message(): Stop receiving messages and disconnect from MQTT. + set_config_from_file(config_file_path): Load MQTT configuration from a file. + set_config_from_objects(mqtt_config): Set MQTT configuration from objects. + callback_client_last_will_msg(topic, payload): Callback for client last will message. + callback_client_active_msg(topic, payload): Callback for client active message. + subscribe_client_status_message(): Subscribe to client status messages. + get_client_status(client_id): Get the status of a specific client. + get_client_list_status(): Get the status of all clients. + """ def __init__( self, @@ -28,6 +66,17 @@ def __init__( client_num=0, args=None ): + """ + Initialize the MQTT communication manager. + + Args: + config_path (str): Path to the MQTT configuration file. + s3_config_path (str): Path to the S3 storage configuration file. + topic (str): The MQTT topic prefix. + client_rank (int): The rank or ID of the client. + client_num (int): The total number of clients. + args (object): Additional configuration arguments. + """ self.args = args self.broker_port = None self.broker_host = None @@ -51,7 +100,8 @@ def __init__( self.client_real_ids = [] if args.client_id_list is not None: logging.info( - "MqttS3CommManager args client_id_list: " + str(args.client_id_list) + "MqttS3CommManager args client_id_list: " + + str(args.client_id_list) ) self.client_real_ids = json.loads(args.client_id_list) @@ -82,7 +132,8 @@ def __init__( self._observers: List[Observer] = [] - self._client_id = "FedML_CS_{}_{}_{}".format(str(args.run_id), str(self.edge_id), str(uuid.uuid4())) + self._client_id = "FedML_CS_{}_{}_{}".format( + str(args.run_id), str(self.edge_id), str(uuid.uuid4())) self.client_num = client_num logging.info("mqtt_s3.init: client_num = %d" % client_num) @@ -95,7 +146,8 @@ def __init__( if args.rank == 0: self.top_active_msg = CommunicationConstants.SERVER_TOP_ACTIVE_MSG self.topic_last_will_msg = CommunicationConstants.SERVER_TOP_LAST_WILL_MSG - self.last_will_msg = json.dumps({"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) + self.last_will_msg = json.dumps( + {"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) self.mqtt_mgr = MqttManager( config_path["BROKER_HOST"], config_path["BROKER_PORT"], @@ -114,25 +166,39 @@ def __init__( @property def client_id(self): + """ + Get the client ID. + + Returns: + str: The client ID. + """ return self._client_id @property def topic(self): + """ + Get the MQTT topic. + + Returns: + str: The MQTT topic. + """ return self._topic def run_loop_forever(self): + """ + Run the MQTT loop forever to handle incoming messages. + """ self.mqtt_mgr.loop_forever() def on_connected(self, mqtt_client_object): """ - [server] - sending message topic (publish): serverID_clientID - receiving message topic (subscribe): clientID + MQTT on_connected callback. - [client] - sending message topic (publish): clientID - receiving message topic (subscribe): serverID_clientID + This method is called when the MQTT client is connected to the broker. It handles + subscription to topics based on whether the current instance is a server or client. + Args: + mqtt_client_object: The MQTT client object. """ self.mqtt_mgr.add_message_passthrough_listener(self._on_message) @@ -143,7 +209,8 @@ def on_connected(self, mqtt_client_object): # logging.info("self.client_real_ids = {}".format(self.client_real_ids)) for client_rank in range(0, self.client_num): - real_topic = self._topic + str(self.client_real_ids[client_rank]) + real_topic = self._topic + \ + str(self.client_real_ids[client_rank]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) # logging.info( @@ -155,7 +222,8 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() else: # client - real_topic = self._topic + str(self.server_id) + "_" + str(self.client_real_ids[0]) + real_topic = self._topic + \ + str(self.server_id) + "_" + str(self.client_real_ids[0]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) self._notify_connection_ready() @@ -167,21 +235,56 @@ def on_connected(self, mqtt_client_object): self.is_connected = True def on_disconnected(self, mqtt_client_object): + """ + MQTT on_disconnected callback. + + This method is called when the MQTT client is disconnected from the broker. + + Args: + mqtt_client_object: The MQTT client object. + """ self.is_connected = False def add_observer(self, observer: Observer): + """ + Add an observer to receive messages. + + Args: + observer (Observer): The observer to add. + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer. + + Args: + observer (Observer): The observer to remove. + """ self._observers.remove(observer) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + """ msg_params = Message() msg_type = CommunicationConstants.MSG_TYPE_CONNECTION_IS_READY for observer in self._observers: observer.receive_message(msg_type, msg_params) def _notify(self, msg_obj): + """ + Notify registered observers with a received message. + + This method parses the incoming message, extracts its type, and notifies all registered + observers with the message type and parameters. + + Args: + msg_obj (dict): The received message object. + + Returns: + None + """ msg_params = Message() msg_params.init_from_json_object(msg_obj) msg_type = msg_params.get_type() @@ -190,6 +293,18 @@ def _notify(self, msg_obj): observer.receive_message(msg_type, msg_params) def _on_message_impl(self, msg): + """ + Handle incoming MQTT messages. + + This method is called when an MQTT message is received. It parses the message payload, + processes it, and notifies observers with the received message. + + Args: + msg (paho.mqtt.client.MQTTMessage): The received MQTT message. + + Returns: + None + """ json_payload = str(msg.payload, encoding="utf-8") payload_obj = json.loads(json_payload) logging.info( @@ -218,16 +333,19 @@ def _on_message_impl(self, msg): elif self.dataSetType == 'cifar10': py_model = CNN_WEB() - model_params = self.s3_storage.read_model_web(s3_key_str, py_model) + model_params = self.s3_storage.read_model_web( + s3_key_str, py_model) else: model_params = self.s3_storage.read_model(s3_key_str) if not hasattr(self.args, "fa_task"): logging.info( - "mqtt_s3.on_message: model params length %d" % len(model_params) + "mqtt_s3.on_message: model params length %d" % len( + model_params) ) - model_url = payload_obj.get(Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") + model_url = payload_obj.get( + Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") logging.info("mqtt_s3.on_message: model url {}".format(model_url)) # replace the S3 object key with raw model params @@ -239,6 +357,19 @@ def _on_message_impl(self, msg): self._notify(payload_obj) def _on_message(self, msg): + """ + Send a message using MQTT. + + This method sends a message to the specified recipient using MQTT. The topic for publishing + the message is determined based on whether the current instance is a server or client. + + Args: + msg (Message): The message to be sent. + wait_for_publish (bool): Whether to wait for the message to be published. + + Returns: + bool: True if the message was sent successfully, False otherwise. + """ self._on_message_impl(msg) def send_message(self, msg: Message, wait_for_publish=False): @@ -259,7 +390,8 @@ def send_message(self, msg: Message, wait_for_publish=False): logging.info("mqtt_s3.send_message: msg topic = %s" % str(topic)) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") model_url = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") model_key = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_KEY, "") if model_params_obj != "": @@ -267,9 +399,11 @@ def send_message(self, msg: Message, wait_for_publish=False): if model_url == "": model_key = topic + "_" + str(uuid.uuid4()) if self.isBrowser: - model_url = self.s3_storage.write_model_web(model_key, model_params_obj) + model_url = self.s3_storage.write_model_web( + model_key, model_params_obj) else: - model_url = self.s3_storage.write_model(model_key, model_params_obj) + model_url = self.s3_storage.write_model( + model_key, model_params_obj) logging.info( "mqtt_s3.send_message: S3+MQTT msg sent, s3 message key = %s" @@ -290,14 +424,16 @@ def send_message(self, msg: Message, wait_for_publish=False): message_key = topic + "_" + str(uuid.uuid4()) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") if model_params_obj != "": # S3 logging.info( "mqtt_s3.send_message: S3+MQTT msg sent, message_key = %s" % message_key ) - model_url = self.s3_storage.write_model(message_key, model_params_obj) + model_url = self.s3_storage.write_model( + message_key, model_params_obj) model_params_key_url = { "key": message_key, "url": model_url, @@ -319,20 +455,60 @@ def send_message(self, msg: Message, wait_for_publish=False): return True def send_message_json(self, topic_name, json_message): + """ + Send a JSON message to a specified MQTT topic. + + Args: + topic_name (str): The MQTT topic to which the message will be sent. + json_message (str): The JSON-formatted message to send. + + Returns: + bool: True if the message was sent successfully, False otherwise. + """ return self.mqtt_mgr.send_message_json(topic_name, json_message) def handle_receive_message(self): + """ + Start listening for incoming MQTT messages and handle them. + + This method initiates the process of receiving and handling MQTT messages. + It runs a loop to continuously listen for messages and processes them until stopped. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) self.run_loop_forever() - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) def stop_receive_message(self): + """ + Stop listening for incoming MQTT messages and disconnect from the MQTT broker. + + This method stops the MQTT message listening loop and disconnects from the MQTT broker. + + Returns: + None + """ logging.info("mqtt_s3.stop_receive_message: stopping...") self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() def set_config_from_file(self, config_file_path): + """ + Load MQTT configuration settings from a YAML file. + + This method reads MQTT configuration settings, including the broker host, port, username, and + password, from a YAML file and updates the instance variables accordingly. + + Args: + config_file_path (str): The path to the YAML configuration file. + + Returns: + None + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -348,6 +524,18 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, mqtt_config): + """ + Set MQTT configuration settings from a dictionary. + + This method sets the MQTT configuration settings, including the broker host, port, username, + and password, from a dictionary object. + + Args: + mqtt_config (dict): A dictionary containing MQTT configuration settings. + + Returns: + None + """ self.broker_host = mqtt_config["BROKER_HOST"] self.broker_port = mqtt_config["BROKER_PORT"] self.mqtt_user = None @@ -358,21 +546,58 @@ def set_config_from_objects(self, mqtt_config): self.mqtt_pwd = mqtt_config["MQTT_PWD"] def callback_client_last_will_msg(self, topic, payload): + """ + Handle the last will message from a client. + + This method processes the last will message received from a client and updates the client's + status accordingly. + + Args: + topic (str): The MQTT topic on which the last will message was received. + payload (str): The payload of the last will message. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) if edge_id is not None and status == CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE: if self.client_active_list.get(edge_id, None) is not None: self.client_active_list.pop(edge_id) def callback_client_active_msg(self, topic, payload): + """ + Handle the active status message from a client. + + This method processes the active status message received from a client and updates the client's + status in the active list. + + Args: + topic (str): The MQTT topic on which the active status message was received. + payload (str): The payload of the active status message. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) if edge_id is not None: self.client_active_list[edge_id] = status def subscribe_client_status_message(self): + """ + Subscribe to client status messages. + + This method sets up MQTT message listeners to handle both last will messages and active status + messages from clients. + + Returns: + None + """ # Setup MQTT message listener to the last will message form the client. self.mqtt_mgr.add_message_listener(self.topic_last_will_msg, self.callback_client_last_will_msg) @@ -382,7 +607,26 @@ def subscribe_client_status_message(self): self.callback_client_active_msg) def get_client_status(self, client_id): + """ + Get the status of a specific client. + + This method retrieves the status of a client based on its ID from the client active list. + + Args: + client_id (str): The ID of the client. + + Returns: + str: The status of the client, e.g., 'offline' or 'idle'. + """ return self.client_active_list.get(client_id, CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) def get_client_list_status(self): + """ + Get the status of all connected clients. + + This method returns the entire client active list, containing the statuses of all connected clients. + + Returns: + dict: A dictionary mapping client IDs to their statuses. + """ return self.client_active_list diff --git a/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py b/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py index 9361bb2bf5..6c5aa9b188 100755 --- a/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py +++ b/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py @@ -17,6 +17,67 @@ class MqttS3MNNCommManager(BaseCommunicationManager): + """ + MQTT-S3-based Communication Manager for Federated Learning. + + This communication manager uses MQTT-S3 for message communication and S3 for model storage. + + Args: + config_path (str): Path to the configuration file. + s3_config_path (str): Path to the S3 configuration file. + topic (str, optional): MQTT topic. Default is "fedml". + client_id (int, optional): Client ID. Default is 0. + client_num (int, optional): Number of clients. Default is 0. + args (Namespace, optional): Command-line arguments. + bind_port (int, optional): Port to bind. Default is 0. + + Attributes: + mqtt_pwd (str): MQTT password. + mqtt_user (str): MQTT username. + broker_port (int): MQTT broker port. + broker_host (str): MQTT broker host. + keepalive_time (int): MQTT keepalive time. + args (Namespace): Command-line arguments. + rank (int): Client rank. + _topic (str): MQTT topic. + s3_storage (S3MNNStorage): S3 storage. + client_real_ids (list): List of real client IDs. + group_server_id_list (str): Group server ID list. + edge_id (int): Edge ID. + server_id (int): Server ID. + _observers (list): List of observers. + _client_id (str): Client ID. + client_num (int): Number of clients. + client_active_list (dict): Dictionary to track client activity status. + top_active_msg (str): Top-level active message topic. + topic_last_will_msg (str): Topic for last will message. + last_will_msg (str): Last will message. + mqtt_mgr (MqttManager): MQTT manager. + + Methods: + run_loop_forever(self): Run the MQTT loop indefinitely. + __del__(self): Destructor to stop the MQTT loop and disconnect. + on_connected(self, mqtt_client_object): MQTT on_connected callback. + on_disconnected(self, mqtt_client_object): MQTT on_disconnected callback. + add_observer(self, observer: Observer): Add an observer to receive messages. + remove_observer(self, observer: Observer): Remove an observer. + _notify(self, msg_obj): Notify observers with a message. + _on_message_impl(self, msg): Handle incoming MQTT messages. + _on_message(self, msg): Wrapper for handling incoming MQTT messages. + send_message(self, msg: Message): Send a message using MQTT. + send_message_json(self, topic_name, json_message): Send a JSON message using MQTT. + handle_receive_message(self): Start handling received messages. + stop_receive_message(self): Stop receiving messages and disconnect from MQTT. + set_config_from_file(self, config_file_path): Load MQTT configuration from a file. + set_config_from_objects(self, mqtt_config): Set MQTT configuration from objects. + _notify_connection_ready(self): Notify observers that the connection is ready. + callback_client_last_will_msg(self, topic, payload): Callback for client last will message. + callback_client_active_msg(self, topic, payload): Callback for client active message. + subscribe_client_status_message(self): Subscribe to client status messages. + get_client_status(self, client_id): Get the status of a specific client. + get_client_list_status(self): Get the status of all clients. + """ + def __init__( self, config_path, @@ -27,6 +88,18 @@ def __init__( args=None, bind_port=0, ): + """ + Initialize the MqttS3MNNCommManager. + + Args: + config_path (str): Path to the configuration file. + s3_config_path (str): Path to the S3 configuration file. + topic (str, optional): MQTT topic. Default is "fedml". + client_id (int, optional): Client ID. Default is 0. + client_num (int, optional): Number of clients. Default is 0. + args (Namespace, optional): Command-line arguments. + bind_port (int, optional): Port to bind. Default is 0. + """ self.mqtt_pwd = None self.mqtt_user = None self.broker_port = None @@ -39,7 +112,8 @@ def __init__( self.s3_storage = S3MNNStorage(s3_config_path) self.client_real_ids = [] logging.info( - "MqttS3CommManager args client_id_list: " + str(args.client_id_list) + "MqttS3CommManager args client_id_list: " + + str(args.client_id_list) ) if args.client_id_list is not None: self.client_real_ids = json.loads(args.client_id_list) @@ -70,7 +144,8 @@ def __init__( self.edge_id = 0 self._observers: List[Observer] = [] - self._client_id = "FedML_CS_{}_{}_{}".format(str(args.run_id), str(self.edge_id), str(uuid.uuid4())) + self._client_id = "FedML_CS_{}_{}_{}".format( + str(args.run_id), str(self.edge_id), str(uuid.uuid4())) self.client_num = client_num logging.info("mqtt_s3.init: client_num = %d" % client_num) @@ -83,7 +158,8 @@ def __init__( if args.rank == 0: self.top_active_msg = CommunicationConstants.SERVER_TOP_ACTIVE_MSG self.topic_last_will_msg = CommunicationConstants.SERVER_TOP_LAST_WILL_MSG - self.last_will_msg = json.dumps({"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) + self.last_will_msg = json.dumps( + {"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) self.mqtt_mgr = MqttManager( config_path["BROKER_HOST"], config_path["BROKER_PORT"], @@ -99,30 +175,47 @@ def __init__( self.mqtt_mgr.connect() def run_loop_forever(self): + """ + Run the MQTT loop forever to handle incoming messages. + """ self.mqtt_mgr.loop_forever() def __del__(self): + """ + Destructor to stop the MQTT loop and disconnect from the broker. + """ self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() @property def client_id(self): + """ + Get the client ID. + + Returns: + str: The client ID. + """ return self._client_id @property def topic(self): + """ + Get the MQTT topic. + + Returns: + str: The MQTT topic. + """ return self._topic def on_connected(self, mqtt_client_object): """ - [server] - sending message topic (publish): serverID_clientID - receiving message topic (subscribe): clientID + MQTT on_connected callback. - [client] - sending message topic (publish): clientID - receiving message topic (subscribe): serverID_clientID + This method is called when the MQTT client is connected to the broker. It handles + subscription to topics based on whether the current instance is a server or client. + Args: + mqtt_client_object: The MQTT client object. """ self.mqtt_mgr.add_message_passthrough_listener(self._on_message) @@ -132,7 +225,8 @@ def on_connected(self, mqtt_client_object): self.subscribe_client_status_message() for client_ID in range(1, self.client_num + 1): - real_topic = self._topic + str(self.client_real_ids[client_ID - 1]) + real_topic = self._topic + \ + str(self.client_real_ids[client_ID - 1]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) logging.info( @@ -143,7 +237,8 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() else: # client - real_topic = self._topic + str(self.server_id) + "_" + str(self.client_real_ids[0]) + real_topic = self._topic + \ + str(self.server_id) + "_" + str(self.client_real_ids[0]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) logging.info( @@ -153,15 +248,42 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() def on_disconnected(self, mqtt_client_object): + """ + MQTT on_connected callback. + + This method is called when the MQTT client is connected to the broker. It handles + subscription to topics based on whether the current instance is a server or client. + + Args: + mqtt_client_object: The MQTT client object. + """ pass def add_observer(self, observer: Observer): + """ + Add an observer to receive messages. + + Args: + observer (Observer): The observer to add. + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer. + + Args: + observer (Observer): The observer to remove. + """ self._observers.remove(observer) def _notify(self, msg_obj): + """ + Notify observers with a message object. + + Args: + msg_obj: The message object to notify observers with. + """ msg_params = Message() msg_params.init_from_json_object(msg_obj) msg_type = msg_params.get_type() @@ -170,6 +292,15 @@ def _notify(self, msg_obj): observer.receive_message(msg_type, msg_params) def _on_message_impl(self, msg): + """ + Handle incoming MQTT messages. + + This method processes incoming MQTT messages, including downloading model files from S3 + if needed. + + Args: + msg: The incoming MQTT message. + """ json_payload = str(msg.payload, encoding="utf-8") payload_obj = json.loads(json_payload) logging.info("mqtt_s3.on_message: payload_obj %s" % payload_obj) @@ -182,7 +313,8 @@ def _on_message_impl(self, msg): model_file_path = self.args.model_file_cache_folder + "/" + s3_key_str self.s3_storage.download_model_file(s3_key_str, model_file_path) - logging.info("mqtt_s3.on_message: downloaded model file {}".format(model_file_path)) + logging.info( + "mqtt_s3.on_message: downloaded model file {}".format(model_file_path)) # replace the S3 object key with raw model params payload_obj[Message.MSG_ARG_KEY_MODEL_PARAMS] = model_file_path @@ -193,22 +325,30 @@ def _on_message_impl(self, msg): self._notify(payload_obj) def _on_message(self, msg): + """ + Wrapper for handling incoming MQTT messages. + + This method wraps the _on_message_impl method and handles exceptions. + + Args: + msg: The incoming MQTT message. + """ try: self._on_message_impl(msg) except Exception as e: - logging.error("mqtt_s3.on_message exception: {}".format(traceback.format_exc())) + logging.error("mqtt_s3.on_message exception: {}".format( + traceback.format_exc())) def send_message(self, msg: Message): """ - [server] - sending message topic (publish): fedml_runid_serverID_clientID - receiving message topic (subscribe): fedml_runid_clientID + Send a message using MQTT. - [client] - sending message topic (publish): fedml_runid_clientID - receiving message topic (subscribe): fedml_runid_serverID_clientID + This method sends a message using MQTT, including handling S3 storage if required. + Args: + msg (Message): The message to send. """ + if self.rank == 0: # server receiver_id = msg.get_receiver_id() @@ -218,14 +358,16 @@ def send_message(self, msg: Message): logging.info("mqtt_s3.send_message: msg topic = %s" % str(topic)) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") model_url = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") model_key = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_KEY, "") if model_params_obj != "": # S3 if model_url == "": model_key = topic + "_" + str(uuid.uuid4()) - model_url = self.s3_storage.upload_model_file(model_key, model_params_obj) + model_url = self.s3_storage.upload_model_file( + model_key, model_params_obj) logging.info( "mqtt_s3.send_message: S3+MQTT msg sent, s3 message key = %s" @@ -244,17 +386,36 @@ def send_message(self, msg: Message): raise Exception("This is only used for the server") def send_message_json(self, topic_name, json_message): + """ + Send a JSON message using MQTT. + + Args: + topic_name (str): The topic to send the message to. + json_message: The JSON message to send. + """ self.mqtt_mgr.send_message_json(topic_name, json_message) def handle_receive_message(self): + """ + Start handling received messages by running the MQTT loop. + """ self.run_loop_forever() def stop_receive_message(self): + """ + Stop receiving messages and disconnect from MQTT. + """ logging.info("mqtt_s3.stop_receive_message: stopping...") self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() def set_config_from_file(self, config_file_path): + """ + Load MQTT configuration from a file. + + Args: + config_file_path (str): Path to the configuration file. + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -270,6 +431,12 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, mqtt_config): + """ + Set MQTT configuration from objects. + + Args: + mqtt_config: MQTT configuration object. + """ self.broker_host = mqtt_config["BROKER_HOST"] self.broker_port = mqtt_config["BROKER_PORT"] self.mqtt_user = None @@ -280,27 +447,49 @@ def set_config_from_objects(self, mqtt_config): self.mqtt_pwd = mqtt_config["MQTT_PWD"] def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + """ msg_params = Message() msg_type = CommunicationConstants.MSG_TYPE_CONNECTION_IS_READY for observer in self._observers: observer.receive_message(msg_type, msg_params) def callback_client_last_will_msg(self, topic, payload): + """ + Callback for client last will message. + + Args: + topic (str): MQTT topic. + payload: The payload of the message. + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) if edge_id is not None and status == CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE: if self.client_active_list.get(edge_id, None) is not None: self.client_active_list.pop(edge_id) def callback_client_active_msg(self, topic, payload): + """ + Callback for client active message. + + Args: + topic (str): MQTT topic. + payload: The payload of the message. + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) if edge_id is not None: self.client_active_list[edge_id] = status def subscribe_client_status_message(self): + """ + Subscribe to client status messages. + """ # Setup MQTT message listener to the last will message form the client. self.mqtt_mgr.add_message_listener(CommunicationConstants.CLIENT_TOP_LAST_WILL_MSG, self.callback_client_last_will_msg) @@ -310,7 +499,22 @@ def subscribe_client_status_message(self): self.callback_client_active_msg) def get_client_status(self, client_id): + """ + Get the status of a specific client. + + Args: + client_id: The ID of the client. + + Returns: + str: The status of the client. + """ return self.client_active_list.get(client_id, CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) def get_client_list_status(self): - return self.client_active_list \ No newline at end of file + """ + Get the status of all clients. + + Returns: + dict: A dictionary containing the status of all clients. + """ + return self.client_active_list From fe18ff077970fa18f0519d94f578c22914b7e789 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Tue, 26 Sep 2023 22:03:23 +0530 Subject: [PATCH 34/70] Update mqtt_manager.py --- .../communication/mqtt/mqtt_manager.py | 188 ++++++++++++++++++ 1 file changed, 188 insertions(+) diff --git a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py index 18c1064165..c4e1492e95 100644 --- a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py +++ b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py @@ -164,6 +164,15 @@ def send_message_json(self, topic, message, publish_single_message=False): return True def on_connect(self, client, userdata, flags, rc): + """ + Callback function for the MQTT on_connect event. + + Args: + client: The MQTT client instance. + userdata: User data. + flags: Connection flags. + rc: Return code from the MQTT broker. + """ if rc == 0: client.connected_flag = True client.bad_conn_flag = False @@ -200,16 +209,43 @@ def on_connect(self, client, userdata, flags, rc): self.mqtt_connection_id, rc)) def is_connected(self): + """ + Check if the MQTT client is connected. + + Returns: + bool: True if the client is connected, False otherwise. + """ return self._client.is_connected() def subscribe_will_set_msg(self, client): + """ + Subscribe to the last will message topic and set a callback. + + Args: + client: The MQTT client instance. + """ self.add_message_listener(self.last_will_topic, self.callback_will_set_msg) client.subscribe(self.last_will_topic, qos=2) def callback_will_set_msg(self, topic, payload): + """ + Callback function for handling the last will message. + + Args: + topic (str): The MQTT topic. + payload (str): The message payload. + """ logging.info(f"MQTT client will be disconnected, id: {self._client_id}, topic: {topic}, payload: {payload}") def on_message(self, client, userdata, msg): + """ + Callback function for the MQTT on_message event. + + Args: + client: The MQTT client instance. + userdata: User data. + msg: The received MQTT message. + """ # logging.info("on_message: msg.topic {}, msg.retain {}".format(msg.topic, msg.retain)) if msg.retain: @@ -230,98 +266,250 @@ def on_message(self, client, userdata, msg): MLOpsProfilerEvent.log_to_wandb({"BusyTime": time.time() - message_handler_start_time}) def on_publish(self, client, obj, mid): + """ + Callback function for the MQTT on_publish event. + + Args: + client: The MQTT client instance. + obj: Object. + mid: Message ID. + """ self.callback_published_listener(client) def on_disconnect(self, client, userdata, rc): + """ + Callback function for the MQTT on_disconnect event. + + Args: + client: The MQTT client instance. + userdata: User data. + rc: Return code from the MQTT broker. + """ client.connected_flag = False client.bad_conn_flag = True self.callback_disconnected_listener(client) def _on_subscribe(self, client, userdata, mid, granted_qos): + """ + Callback function for the MQTT on_subscribe event. + + Args: + client: The MQTT client instance. + userdata: User data. + mid: Message ID. + granted_qos: Granted QoS levels. + """ self.callback_subscribed_listener(client) def _on_log(self, client, userdata, level, buf): + """ + Callback function for MQTT logging. + + Args: + client: The MQTT client instance. + userdata: User data. + level: Logging level. + buf: Log message buffer. + """ logging.info("mqtt log {}, client id {}.".format(buf, self.mqtt_connection_id)) def add_message_listener(self, topic, listener): + """ + Add a message listener to handle messages received on a specific topic. + + Args: + topic (str): The MQTT topic to listen to. + listener (callable): The callback function to handle the received messages. + """ self._listeners[topic] = listener def remove_message_listener(self, topic): + """ + Remove a message listener for a specific topic. + + Args: + topic (str): The MQTT topic to remove the listener from. + """ try: del self._listeners[topic] except Exception as e: pass def add_message_passthrough_listener(self, listener): + """ + Add a message passthrough listener to handle all incoming messages. + + Args: + listener (callable): The callback function to handle incoming messages. + """ + # if not callable(listener): + # raise Exception("listener must be callable!") + # self.__message_passthrough_listener = listener self.remove_message_passthrough_listener(listener) self._passthrough_listeners.append(listener) def remove_message_passthrough_listener(self, listener): + """ + Remove a message passthrough listener. + + Args: + listener (callable): The passthrough listener to remove. + """ + # if hasattr(self,'__message_passthrough_listener') and \ + # self.__message_passthrough_listener is not None: + # self.__message_passthrough_listener = None + # if isinstance(listener,(list)): + # for l in listener: + # self._passthrough_listeners.remove(l) + # else: + # self._passthrough_listeners.remove(listener) try: self._passthrough_listeners.remove(listener) except Exception as e: pass def add_connected_listener(self, listener): + """ + Add a listener to handle the MQTT client's connection event. + + Args: + listener (callable): The callback function to handle the connection event. + """ self._connected_listeners.append(listener) def remove_connected_listener(self, listener): + """ + Remove a connected listener. + + Args: + listener (callable): The connected listener to remove. + """ try: self._connected_listeners.remove(listener) except Exception as e: pass def callback_connected_listener(self, client): + """ + Callback function for handling connected listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._connected_listeners: if listener is not None and callable(listener): listener(client) def add_disconnected_listener(self, listener): + """ + Add a listener to handle the MQTT client's disconnection event. + + Args: + listener (callable): The callback function to handle the disconnection event. + """ self._disconnected_listeners.append(listener) def remove_disconnected_listener(self, listener): + """ + Remove a disconnected listener. + + Args: + listener (callable): The disconnected listener to remove. + """ try: self._disconnected_listeners.remove(listener) except Exception as e: pass def callback_disconnected_listener(self, client): + """ + Callback function for handling disconnected listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._disconnected_listeners: if listener is not None and callable(listener): listener(client) def add_subscribed_listener(self, listener): + """ + Add a listener to handle the MQTT client's subscription event. + + Args: + listener (callable): The callback function to handle the subscription event. + """ self._subscribed_listeners.append(listener) def remove_subscribed_listener(self, listener): + """ + Remove a subscribed listener. + + Args: + listener (callable): The subscribed listener to remove. + """ try: self._subscribed_listeners.remove(listener) except Exception as e: pass def callback_subscribed_listener(self, client): + """ + Callback function for handling subscribed listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._subscribed_listeners: if listener is not None and callable(listener): listener(client) def add_published_listener(self, listener): + """ + Add a listener to handle the MQTT client's message publishing event. + + Args: + listener (callable): The callback function to handle the publishing event. + """ self._published_listeners.append(listener) def remove_published_listener(self, listener): + """ + Remove a published listener. + + Args: + listener (callable): The published listener to remove. + """ try: self._published_listeners.remove(listener) except Exception as e: pass def callback_published_listener(self, client): + """ + Callback function for handling published listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._published_listeners: if listener is not None and callable(listener): listener(client) def subscribe_msg(self, topic): + """ + Subscribe to an MQTT topic with a QoS level of 2. + + Args: + topic (str): The MQTT topic to subscribe to. + """ self._client.subscribe(topic, qos=2) def check_connection(self): + """ + Check the MQTT client's connection status and wait for a connection if not connected. + Raises an exception if the connection fails. + """ count = 0 while not self._client.connected_flag and self._client.bad_conn_flag: if count >= 30: From 541703f633287933df77cbf969d2f83bb88de1ff Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 27 Sep 2023 12:01:31 +0530 Subject: [PATCH 35/70] Update mqtt_manager.py --- .../distributed/communication/mqtt/mqtt_manager.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py index c4e1492e95..d092389bd4 100644 --- a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py +++ b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py @@ -530,12 +530,26 @@ def check_connection(self): def test_msg_callback(topic, payload): + """ + Callback function to handle received MQTT messages for testing purposes. + + Args: + topic (str): The MQTT topic on which the message was received. + payload (str): The payload of the received message. + """ global received_msg_count received_msg_count += 1 logging.info("Received the topic: {}, message: {}, count {}.".format(topic, payload, received_msg_count)) def test_last_will_callback(topic, payload): + """ + Callback function to handle last will messages received for testing purposes. + + Args: + topic (str): The MQTT topic on which the last will message was received. + payload (str): The payload of the received last will message. + """ logging.info("Received the topic: {}, message: {}.".format(topic, payload)) From 803f21e2f3778ad8e317839e61e4023cfb2337a8 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 1 Sep 2023 12:58:44 +0530 Subject: [PATCH 36/70] Update data_loader.py --- research/SpreadGNN/data/data_loader.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/research/SpreadGNN/data/data_loader.py b/research/SpreadGNN/data/data_loader.py index efbc5ba995..b85a812fec 100644 --- a/research/SpreadGNN/data/data_loader.py +++ b/research/SpreadGNN/data/data_loader.py @@ -102,6 +102,22 @@ def create_non_uniform_split(args, idxs, client_number, is_train=True): def partition_data_by_sample_size( args, path, client_number, uniform=True, compact=True ): + """ + Partition dataset into multiple clients based on sample size. + + Args: + args (list): Arguments. + path (str): Path to the dataset. + client_number (int): Number of clients to partition the dataset into. + uniform (bool, optional): If True, create uniform partitions. If False, create non-uniform partitions. + compact (bool, optional): Whether to use compact representation. + + Returns: + tuple: A tuple containing global_data_dict and partition_dicts. + + global_data_dict (dict): A dictionary containing global datasets (train, val, test). + partition_dicts (list): A list of dictionaries containing partitioned datasets for each client. + """ ( train_adj_matrices, train_feature_matrices, From f1ce7861c98238df90edac5ec8d51e5562621fe4 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 12:15:05 +0530 Subject: [PATCH 37/70] code and docstring update --- research/SpreadGNN/data/data_loader.py | 163 ++++++++++++++++++ research/SpreadGNN/data/datasets.py | 58 ++++++- research/SpreadGNN/data/utils.py | 50 ++++++ research/SpreadGNN/model/gat_readout.py | 111 +++++++++++- research/SpreadGNN/model/sage_readout.py | 62 ++++++- .../SpreadGNN/trainer/gat_readout_trainer.py | 55 ++++++ .../trainer/gat_readout_trainer_regression.py | 65 +++++++ .../SpreadGNN/trainer/sage_readout_trainer.py | 54 ++++++ .../sage_readout_trainer_regression.py | 53 ++++++ 9 files changed, 662 insertions(+), 9 deletions(-) diff --git a/research/SpreadGNN/data/data_loader.py b/research/SpreadGNN/data/data_loader.py index b85a812fec..889da69c37 100644 --- a/research/SpreadGNN/data/data_loader.py +++ b/research/SpreadGNN/data/data_loader.py @@ -10,6 +10,21 @@ def get_data(path): + """ + Load data from the specified path. + + Args: + path (str): The path to the directory containing data files. + + Returns: + tuple: A tuple containing the following elements: + - adj_matrices (list): A list of adjacency matrices. + - feature_matrices (list): A list of feature matrices. + - labels (numpy.ndarray): An array of labels. + + Raises: + FileNotFoundError: If any of the required data files are not found. + """ with open(path + "/adjacency_matrices.pkl", "rb") as f: adj_matrices = pickle.load(f) @@ -21,6 +36,27 @@ def get_data(path): return adj_matrices, feature_matrices, labels def create_random_split(path): + """ + Create a random 80/10/10 split of data from the specified path. + + Args: + path (str): The path to the directory containing data files. + + Returns: + tuple: A tuple containing the following elements for training, validation, and testing sets: + - train_adj_matrices (list): A list of adjacency matrices for training. + - train_feature_matrices (list): A list of feature matrices for training. + - train_labels (list): A list of labels for training. + - val_adj_matrices (list): A list of adjacency matrices for validation. + - val_feature_matrices (list): A list of feature matrices for validation. + - val_labels (list): A list of labels for validation. + - test_adj_matrices (list): A list of adjacency matrices for testing. + - test_feature_matrices (list): A list of feature matrices for testing. + - test_labels (list): A list of labels for testing. + + Raises: + FileNotFoundError: If any of the required data files are not found. + """ adj_matrices, feature_matrices, labels = get_data(path) # Random 80/10/10 split as suggested in the MoleculeNet whitepaper @@ -74,6 +110,27 @@ def create_random_split(path): ) def create_non_uniform_split(args, idxs, client_number, is_train=True): + """ + Create a non-uniform split of data indices among clients based on the Dirichlet distribution. + + Args: + args: An object containing relevant parameters. + idxs (list): A list of data indices to be split. + client_number (int): The number of clients. + is_train (bool): A flag indicating whether the split is for training data. + + Returns: + list: A list of lists where each sublist contains data indices assigned to a client. + + Logging: + This function logs information about the data split process. + + Note: + This function relies on the `partition_class_samples_with_dirichlet_distribution` function. + + Raises: + None + """ logging.info("create_non_uniform_split------------------------------------------") N = len(idxs) alpha = args.partition_alpha @@ -249,6 +306,25 @@ def partition_data_by_sample_size( # For centralized training def get_dataloader(path, compact=True, normalize_features=False, normalize_adj=False): + """ + Get data loaders for training, validation, and testing sets. + + Args: + path (str): The path to the directory containing data files. + compact (bool, optional): Whether to use compact data format. Defaults to True. + normalize_features (bool, optional): Whether to normalize features. Defaults to False. + normalize_adj (bool, optional): Whether to normalize adjacency matrices. Defaults to False. + + Returns: + tuple: A tuple containing data loaders for training, validation, and testing sets. + + Note: + This function utilizes the `MoleculesDataset` class and data collators to create data loaders. + Each batch size is set to 1 to ensure that each batch represents an entire molecule. + + Raises: + None + """ ( train_adj_matrices, train_feature_matrices, @@ -318,6 +394,39 @@ def load_partition_data( normalize_features=False, normalize_adj=False, ): + """ + Load and partition data for federated learning among multiple clients. + + Args: + args: An object containing relevant parameters. + path (str): The path to the directory containing data files. + client_number (int): The number of clients. + uniform (bool, optional): Whether to use uniform data partitioning. Defaults to True. + global_test (bool, optional): Whether to use a global test dataset. Defaults to True. + compact (bool, optional): Whether to use compact data format. Defaults to True. + normalize_features (bool, optional): Whether to normalize features. Defaults to False. + normalize_adj (bool, optional): Whether to normalize adjacency matrices. Defaults to False. + + Returns: + tuple: A tuple containing information about the loaded data for federated learning. The tuple includes: + - train_data_num (int): Total number of training samples in the global dataset. + - val_data_num (int): Total number of validation samples in the global dataset. + - test_data_num (int): Total number of testing samples in the global dataset. + - train_data_global (data.DataLoader): DataLoader for the global training dataset. + - val_data_global (data.DataLoader): DataLoader for the global validation dataset. + - test_data_global (data.DataLoader): DataLoader for the global testing dataset. + - data_local_num_dict (dict): A dictionary mapping client IDs to the number of local samples. + - train_data_local_dict (dict): A dictionary mapping client IDs to their DataLoader for training. + - val_data_local_dict (dict): A dictionary mapping client IDs to their DataLoader for validation. + - test_data_local_dict (dict): A dictionary mapping client IDs to their DataLoader for testing. + + Note: + This function relies on data partitioning using the `partition_data_by_sample_size` function. + Each batch size is set to 1 to represent each molecule as a batch. + + Raises: + None + """ global_data_dict, partition_dicts = partition_data_by_sample_size( args, path, client_number, uniform, compact=compact ) @@ -415,6 +524,33 @@ def load_partition_data( def load_partition_data_distributed(process_id, path, client_number, uniform=True): + """ + Load and partition data for distributed federated learning. + + Args: + process_id (int): The ID of the current process. + path (str): The path to the directory containing data files. + client_number (int): The number of clients. + uniform (bool, optional): Whether to use uniform data partitioning. Defaults to True. + + Returns: + tuple: A tuple containing information about the loaded data for distributed federated learning. The tuple includes: + - train_data_num (int): Total number of training samples in the global dataset. + - train_data_global (data.DataLoader): DataLoader for the global training dataset (for process_id 0). + - val_data_global (data.DataLoader): DataLoader for the global validation dataset (for process_id 0). + - test_data_global (data.DataLoader): DataLoader for the global testing dataset (for process_id 0). + - local_data_num (int): Total number of local samples for the current process. + - train_data_local (data.DataLoader): DataLoader for the local training dataset (for process_id > 0). + - val_data_local (data.DataLoader): DataLoader for the local validation dataset (for process_id > 0). + - test_data_local (data.DataLoader): DataLoader for the local testing dataset (for process_id > 0). + + Note: + This function relies on data partitioning using the `partition_data_by_sample_size` function. + Each batch size is set to 1 to represent each molecule as a batch. + + Raises: + None + """ global_data_dict, partition_dicts = partition_data_by_sample_size( path, client_number, uniform ) @@ -490,6 +626,33 @@ def load_partition_data_distributed(process_id, path, client_number, uniform=Tru def load_moleculenet(args, dataset_name): + """ + Load and partition data for distributed federated learning. + + Args: + process_id (int): The ID of the current process. + path (str): The path to the directory containing data files. + client_number (int): The number of clients. + uniform (bool, optional): Whether to use uniform data partitioning. Defaults to True. + + Returns: + tuple: A tuple containing information about the loaded data for distributed federated learning. The tuple includes: + - train_data_num (int): Total number of training samples in the global dataset. + - train_data_global (data.DataLoader): DataLoader for the global training dataset (for process_id 0). + - val_data_global (data.DataLoader): DataLoader for the global validation dataset (for process_id 0). + - test_data_global (data.DataLoader): DataLoader for the global testing dataset (for process_id 0). + - local_data_num (int): Total number of local samples for the current process. + - train_data_local (data.DataLoader): DataLoader for the local training dataset (for process_id > 0). + - val_data_local (data.DataLoader): DataLoader for the local validation dataset (for process_id > 0). + - test_data_local (data.DataLoader): DataLoader for the local testing dataset (for process_id > 0). + + Note: + This function relies on data partitioning using the `partition_data_by_sample_size` function. + Each batch size is set to 1 to represent each molecule as a batch. + + Raises: + None + """ num_cats, feat_dim = 0, 0 if dataset_name not in ["sider", "tox21", "muv","qm8" ]: raise Exception("no such dataset!") diff --git a/research/SpreadGNN/data/datasets.py b/research/SpreadGNN/data/datasets.py index a76390403d..443cd76e09 100644 --- a/research/SpreadGNN/data/datasets.py +++ b/research/SpreadGNN/data/datasets.py @@ -13,12 +13,21 @@ # From GTTF, need to cite once paper is officially accepted to ICLR 2021 class CompactAdjacency: def __init__(self, adj, precomputed=None, subset=None): - """Constructs CompactAdjacency. + """ + Constructs a CompactAdjacency object. Args: - adj: scipy sparse matrix containing full adjacency. - precomputed: If given, must be a tuple (compact_adj, degrees). - In this case, adj must be None. If supplied, subset will be ignored. + adj: scipy sparse matrix containing the full adjacency. + precomputed: If given, must be a tuple (compact_adj, degrees). + In this case, adj must be None. If supplied, subset will be ignored. + subset: Optional set of node indices to consider in the adjacency matrix. + + Note: + This constructor initializes a CompactAdjacency object based on the provided arguments. + If 'precomputed' is provided, 'adj' and 'subset' will be ignored. + + Raises: + ValueError: If both 'adj' and 'precomputed' are set. """ if adj is None: return @@ -114,6 +123,25 @@ def __init__( fanouts=[2, 2], split="train", ): + """ + Constructs a dataset for molecules with adjacency matrices, feature matrices, and labels. + + Args: + adj_matrices (list): A list of adjacency matrices. + feature_matrices (list): A list of feature matrices. + labels (list): A list of labels. + path (str): The path to the directory containing data files. + compact (bool, optional): Whether to use compact adjacency matrices. Defaults to True. + fanouts (list, optional): A list of fanout values for each adjacency matrix. Defaults to [2, 2]. + split (str, optional): The dataset split ('train', 'val', or 'test'). Defaults to 'train'. + + Note: + This constructor initializes a MoleculesDataset object based on the provided arguments. + If 'compact' is set to True, it uses compact adjacency matrices. + + Raises: + None + """ if compact: # filename = path + '/train_comp_adjs.pkl' # if split == 'val': @@ -143,6 +171,19 @@ def __init__( self.fanouts = [fanouts] * len(adj_matrices) def __getitem__(self, index): + """ + Retrieves an item from the dataset. + + Args: + index (int): The index of the item to retrieve. + + Returns: + tuple: A tuple containing the following elements: + - adj_matrix: The adjacency matrix. + - feature_matrix: The feature matrix. + - label: The label. + - fanouts: The list of fanout values. + """ return ( self.adj_matrices[index], self.feature_matrices[index], @@ -151,4 +192,13 @@ def __getitem__(self, index): ) def __len__(self): + """ + Returns the total number of items in the dataset. + + Args: + None + + Returns: + int: The number of items in the dataset. + """ return len(self.adj_matrices) diff --git a/research/SpreadGNN/data/utils.py b/research/SpreadGNN/data/utils.py index 47a41b414a..2fd21df36c 100644 --- a/research/SpreadGNN/data/utils.py +++ b/research/SpreadGNN/data/utils.py @@ -5,6 +5,17 @@ def np_uniform_sample_next(compact_adj, tree, fanout): + """ + Uniformly sample next neighbors for a given compact adjacency matrix and traversal tree. + + Args: + compact_adj (CompactAdjacency): The compact adjacency matrix. + tree (list): The traversal tree. + fanout (int): The number of neighbors to sample for each node. + + Returns: + np.ndarray: An array containing the sampled neighbor indices. + """ last_level = tree[-1] # [batch, f^depth] batch_lengths = compact_adj.degrees[last_level] nodes = np.repeat(last_level, fanout, axis=1) @@ -27,6 +38,21 @@ def np_uniform_sample_next(compact_adj, tree, fanout): def np_traverse( compact_adj, seed_nodes, fanouts=(1,), sample_fn=np_uniform_sample_next ): + """ + Traverse a compact adjacency matrix. + + Args: + compact_adj (CompactAdjacency): The compact adjacency matrix. + seed_nodes (np.ndarray): An array of seed node indices. + fanouts (tuple): A tuple of fanout values. + sample_fn (function): A function for sampling neighbors. + + Returns: + list: A list containing the traversal tree. + + Raises: + ValueError: If the input seed_nodes format is incorrect. + """ if not isinstance(seed_nodes, np.ndarray): raise ValueError("Seed must a numpy array") @@ -53,6 +79,18 @@ def np_traverse( class WalkForestCollator(object): def __init__(self, normalize_features=False): + """ + Collate function for walking forest-based data. + + Args: + molecule (tuple): A tuple containing the molecular data. + + Returns: + tuple: A tuple containing collated data. + + Raises: + None + """ self.normalize_features = normalize_features def __call__(self, molecule): @@ -88,6 +126,18 @@ def __call__(self, molecule): class DefaultCollator(object): + """ + Default collate function for data. + + Args: + molecule (tuple): A tuple containing the molecular data + + Args: + molecule (tuple): A tuple containing the molecular data. + + Returns: + tuple: A tuple containing collated data. + """ def __init__(self, normalize_features=True, normalize_adj=True): self.normalize_features = normalize_features self.normalize_adj = normalize_adj diff --git a/research/SpreadGNN/model/gat_readout.py b/research/SpreadGNN/model/gat_readout.py index d1cd3a3d0e..51fadf9fab 100644 --- a/research/SpreadGNN/model/gat_readout.py +++ b/research/SpreadGNN/model/gat_readout.py @@ -7,6 +7,22 @@ class GraphAttentionLayer(nn.Module): """ Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 """ + """ + A single Graph Attention Layer (GAT) module. + + Args: + in_features (int): The number of input features. + out_features (int): The number of output features. + dropout (float): Dropout probability for attention coefficients. + alpha (float): LeakyReLU slope parameter. + concat (bool): Whether to concatenate the multi-head results or not. + + Attributes: + W (nn.Parameter): Learnable weight matrix. + a (nn.Parameter): Learnable attention parameter matrix. + leakyrelu (nn.LeakyReLU): LeakyReLU activation with slope alpha. + + """ def __init__(self, in_features, out_features, dropout, alpha, concat=True): super(GraphAttentionLayer, self).__init__() @@ -24,6 +40,17 @@ def __init__(self, in_features, out_features, dropout, alpha, concat=True): self.leakyrelu = nn.LeakyReLU(self.alpha) def forward(self, h, adj): + """ + Forward pass for the GAT layer. + + Args: + h (torch.Tensor): Input feature tensor. + adj (torch.Tensor): Adjacency matrix. + + Returns: + torch.Tensor: Output feature tensor. + + """ Wh = torch.mm( h, self.W ) # h.shape: (N, in_features), Wh.shape: (N, out_features) @@ -52,6 +79,13 @@ def _prepare_attentional_mechanism_input(self, Wh): return all_combinations_matrix.view(N, N, 2 * self.out_features) def __repr__(self): + """ + String representation of the GAT layer. + + Returns: + str: A string representing the layer. + + """ return ( self.__class__.__name__ + " (" @@ -63,6 +97,23 @@ def __repr__(self): class GAT(nn.Module): + """ + Graph Attention Network (GAT) model for node classification. + + Args: + feat_dim (int): Number of input features. + hidden_dim1 (int): Number of hidden units in the first GAT layer. + hidden_dim2 (int): Number of hidden units in the second GAT layer. + dropout (float): Dropout probability for attention coefficients. + alpha (float): LeakyReLU slope parameter. + nheads (int): Number of attention heads. + + Attributes: + dropout (float): Dropout probability. + attentions (nn.ModuleList): List of GAT layers with multiple heads. + out_att (GraphAttentionLayer): Final GAT layer. + + """ def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout, alpha, nheads): """Dense version of GAT.""" super(GAT, self).__init__() @@ -85,6 +136,17 @@ def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout, alpha, nheads): ) def forward(self, x, adj): + """ + Forward pass for the GAT model. + + Args: + x (torch.Tensor): Input feature tensor. + adj (torch.Tensor): Adjacency matrix. + + Returns: + torch.Tensor: Node embeddings. + + """ x = F.dropout(x, self.dropout, training=self.training) x = torch.cat([att(x, adj) for att in self.attentions], dim=1) x = F.dropout(x, self.dropout, training=self.training) @@ -95,7 +157,25 @@ def forward(self, x, adj): class Readout(nn.Module): """ - This module learns a single graph level representation for a molecule given GNN generated node embeddings + This module learns a single graph-level representation for a molecule given GNN-generated node embeddings. + + Args: + attr_dim (int): Dimension of node attributes. + embedding_dim (int): Dimension of node embeddings. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + + Attributes: + attr_dim (int): Dimension of node attributes. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + layer1 (nn.Linear): First linear layer. + layer2 (nn.Linear): Second linear layer. + output (nn.Linear): Output layer. + act (nn.ReLU): ReLU activation function. + """ def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): @@ -111,6 +191,17 @@ def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): self.act = nn.ReLU() def forward(self, node_features, node_embeddings): + """ + Forward pass for the Readout module. + + Args: + node_features (torch.Tensor): Node attributes. + node_embeddings (torch.Tensor): Node embeddings. + + Returns: + torch.Tensor: Logits for multilabel classification. + + """ combined_rep = torch.cat( (node_features, node_embeddings), dim=1 ) # Concat initial node attributed with embeddings from sage @@ -128,7 +219,23 @@ def forward(self, node_features, node_embeddings): class GatMoleculeNet(nn.Module): """ - Network that consolidates GAT + Readout into a single nn.Module + Neural network that combines GAT (Graph Attention Network) and Readout into a single module for molecular data. + + Args: + feat_dim (int): Dimension of input node features. + gat_hidden_dim1 (int): Dimension of the hidden layer in the GAT model. + node_embedding_dim (int): Dimension of node embeddings. + gat_dropout (float): Dropout probability for GAT layers. + gat_alpha (float): LeakyReLU slope parameter for GAT. + gat_nheads (int): Number of attention heads in GAT. + readout_hidden_dim (int): Dimension of the hidden layer in the Readout module. + graph_embedding_dim (int): Dimension of the graph-level embedding. + num_categories (int): Number of categories for classification. + + Attributes: + gat (GAT): GAT (Graph Attention Network) module. + readout (Readout): Readout module for graph-level representation. + """ def __init__( diff --git a/research/SpreadGNN/model/sage_readout.py b/research/SpreadGNN/model/sage_readout.py index e8df8f7235..659cfcdd41 100644 --- a/research/SpreadGNN/model/sage_readout.py +++ b/research/SpreadGNN/model/sage_readout.py @@ -7,7 +7,27 @@ class GraphSage(nn.Module): GraphSAGE model (https://arxiv.org/abs/1706.02216) to learn the role of atoms in the molecules inductively. Transforms input features into a fixed length embedding in a vector space. The embedding captures the role. """ - + """ + GraphSAGE model to learn the role of atoms in molecules inductively. + + GraphSAGE (Graph Sample and Aggregated) transforms input features into a fixed-length embedding in a vector space. + The resulting embedding captures the role of atoms in molecules. + + Args: + feat_dim (int): Dimension of input node features. + hidden_dim1 (int): Dimension of the first hidden layer. + hidden_dim2 (int): Dimension of the second hidden layer. + dropout (float): Dropout probability. + + Attributes: + feat_dim (int): Dimension of input node features. + hidden_dim1 (int): Dimension of the first hidden layer. + hidden_dim2 (int): Dimension of the second hidden layer. + layer1 (nn.Linear): First linear layer for feature transformation. + layer2 (nn.Linear): Second linear layer for feature transformation. + relu (nn.ReLU): ReLU activation function. + dropout (nn.Dropout): Dropout layer. + """ def __init__(self, feat_dim, hidden_dim1, hidden_dim2, dropout): super(GraphSage, self).__init__() @@ -61,7 +81,27 @@ def forward(self, forest, feature_matrix): class Readout(nn.Module): """ - This module learns a single graph level representation for a molecule given GraphSAGE generated embeddings + This module learns a single graph-level representation for a molecule using GraphSAGE-generated embeddings. + + The Readout module combines node-level features and GraphSAGE-generated embeddings to produce a graph-level representation of a molecule. This representation can be used for various downstream tasks, such as multi-label classification. + + Args: + attr_dim (int): Dimension of initial node attributes. + embedding_dim (int): Dimension of GraphSAGE-generated node embeddings. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + + Attributes: + attr_dim (int): Dimension of initial node attributes. + hidden_dim (int): Dimension of the hidden layer. + output_dim (int): Dimension of the output layer. + num_cats (int): Number of categories for classification. + layer1 (nn.Linear): First linear layer for feature transformation. + layer2 (nn.Linear): Second linear layer for feature transformation. + output (nn.Linear): Output linear layer. + act (nn.ReLU): ReLU activation function. + """ def __init__(self, attr_dim, embedding_dim, hidden_dim, output_dim, num_cats): @@ -94,7 +134,23 @@ def forward(self, node_features, node_embeddings): class SageMoleculeNet(nn.Module): """ - Network that consolidates Sage + Readout into a single nn.Module + Network that combines Sage (GraphSAGE) and Readout into a single neural network module. + + The SageMoleculeNet module integrates GraphSAGE for node-level embedding generation and a Readout module for graph-level representation of molecules. It is designed for tasks such as multi-label classification on molecular graphs. + + Args: + feat_dim (int): Dimension of node features. + sage_hidden_dim1 (int): Dimension of the first hidden layer in GraphSAGE. + node_embedding_dim (int): Dimension of node embeddings generated by GraphSAGE. + sage_dropout (float): Dropout rate for GraphSAGE. + readout_hidden_dim (int): Dimension of the hidden layer in the Readout module. + graph_embedding_dim (int): Dimension of the graph-level embeddings. + num_categories (int): Number of categories for classification. + + Attributes: + sage (GraphSage): GraphSAGE module for node-level embedding generation. + readout (Readout): Readout module for generating graph-level representations. + """ def __init__( diff --git a/research/SpreadGNN/trainer/gat_readout_trainer.py b/research/SpreadGNN/trainer/gat_readout_trainer.py index 989a102bff..74f8895e37 100755 --- a/research/SpreadGNN/trainer/gat_readout_trainer.py +++ b/research/SpreadGNN/trainer/gat_readout_trainer.py @@ -12,14 +12,46 @@ class GatMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the GatMoleculeNet model. + + This trainer is responsible for training and testing the GatMoleculeNet model on client devices. It implements methods for setting and retrieving model parameters, training the model, and evaluating its performance. + + Args: + model (GatMoleculeNet): The GatMoleculeNet model to be trained. + test_data (list of torch.Tensor): The test data for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list of torch.Tensor): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the maximum test score and the best model parameters. + """ model = self.model model.to(device) @@ -85,6 +117,17 @@ def train(self, train_data, device, args): return max_test_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list of torch.Tensor): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -138,6 +181,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server. + + Args: + train_data_local_dict (dict): A dictionary of training data for each client. + test_data_local_dict (dict): A dictionary of test data for each client. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if testing on the server is successful. + """ logging.info("----------test_on_the_server--------") model_list, score_list = [], [] diff --git a/research/SpreadGNN/trainer/gat_readout_trainer_regression.py b/research/SpreadGNN/trainer/gat_readout_trainer_regression.py index 88b66c5418..703c9bf22e 100755 --- a/research/SpreadGNN/trainer/gat_readout_trainer_regression.py +++ b/research/SpreadGNN/trainer/gat_readout_trainer_regression.py @@ -11,14 +11,46 @@ class GatMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the GatMoleculeNet model. + + This trainer is responsible for training and testing the GatMoleculeNet model on client devices. It implements methods for setting and retrieving model parameters, training the model, and evaluating its performance. + + Args: + model (GatMoleculeNet): The GatMoleculeNet model to be trained. + test_data (list of torch.Tensor): The test data for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list of torch.Tensor): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the minimum test score and the best model parameters. + """ model = self.model model.to(device) @@ -93,6 +125,17 @@ def train(self, train_data, device, args): return min_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list of torch.Tensor): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -129,6 +172,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server. + + Args: + train_data_local_dict (dict): A dictionary of training data for each client. + test_data_local_dict (dict): A dictionary of test data for each client. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if testing on the server is successful. + """ logging.info("----------test_on_the_server--------") # for client_idx in train_data_local_dict.keys(): # train_data = train_data_local_dict[client_idx] @@ -158,6 +213,16 @@ def test_on_the_server( return True def _compare_models(self, model_1, model_2): + """ + Compare two models to check if they match. + + Args: + model_1 (torch.nn.Module): The first model to compare. + model_2 (torch.nn.Module): The second model to compare. + + Raises: + Exception: If a mismatch is found between the two models. + """ models_differ = 0 for key_item_1, key_item_2 in zip( model_1.state_dict().items(), model_2.state_dict().items() diff --git a/research/SpreadGNN/trainer/sage_readout_trainer.py b/research/SpreadGNN/trainer/sage_readout_trainer.py index c5ec0a3f57..1e3bea9174 100755 --- a/research/SpreadGNN/trainer/sage_readout_trainer.py +++ b/research/SpreadGNN/trainer/sage_readout_trainer.py @@ -12,14 +12,45 @@ class SageMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the MoleculeNet model. This trainer handles training and testing the model on client devices. + + Args: + model (nn.Module): The MoleculeNet model to be trained. + test_data (list): The test data used for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ + return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the minimum test score and the best model parameters. + """ model = self.model model.to(device) @@ -87,6 +118,17 @@ def train(self, train_data, device, args): return max_test_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -138,6 +180,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server using data from client devices. + + Args: + train_data_local_dict (dict): A dictionary of training data from client devices. + test_data_local_dict (dict): A dictionary of test data from client devices. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if the testing is successful. + """ logging.info("----------test_on_the_server--------") model_list, score_list = [], [] diff --git a/research/SpreadGNN/trainer/sage_readout_trainer_regression.py b/research/SpreadGNN/trainer/sage_readout_trainer_regression.py index b0977b1086..4248427349 100755 --- a/research/SpreadGNN/trainer/sage_readout_trainer_regression.py +++ b/research/SpreadGNN/trainer/sage_readout_trainer_regression.py @@ -12,14 +12,44 @@ class SageMoleculeNetTrainer(ClientTrainer): + """ + Trainer for the MoleculeNet model. This trainer is responsible for training and testing the MoleculeNet model on client devices. + + Args: + model (nn.Module): The MoleculeNet model to be trained. + test_data (list): The test data for evaluating model performance. + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): The model parameters to be set. + """ logging.info("set_model_params") self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data (list): The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Additional training arguments. + + Returns: + Tuple[float, dict]: A tuple containing the minimum test score and the best model parameters. + """ model = self.model model.to(device) @@ -94,6 +124,17 @@ def train(self, train_data, device, args): return min_score, best_model_params def test(self, test_data, device, args): + """ + Test the model. + + Args: + test_data (list): The test data. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + Tuple[float, model]: A tuple containing the test score and the model used for testing. + """ logging.info("----------test--------") model = self.model model.eval() @@ -131,6 +172,18 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server. + + Args: + train_data_local_dict (dict): A dictionary of training data for each client. + test_data_local_dict (dict): A dictionary of test data for each client. + device (torch.device): The device (CPU or GPU) to use for testing. + args: Additional testing arguments. + + Returns: + bool: True if testing on the server is successful. + """ logging.info("----------test_on_the_server--------") # for client_idx in train_data_local_dict.keys(): # train_data = train_data_local_dict[client_idx] From a958d9e1232150afc5569ab2d33b00191aaf5632 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 13:42:55 +0530 Subject: [PATCH 38/70] additon --- python/fedml/arguments.py | 49 +++++++++++++++ python/fedml/launch_simulation.py | 6 +- python/fedml/runner.py | 92 ++++++++++++++++++++++++++++ python/fedml/simulation/simulator.py | 37 +++++++++++ 4 files changed, 183 insertions(+), 1 deletion(-) diff --git a/python/fedml/arguments.py b/python/fedml/arguments.py index 6459b6c2c4..16d4ff25da 100755 --- a/python/fedml/arguments.py +++ b/python/fedml/arguments.py @@ -34,6 +34,12 @@ def add_args(): + """ + Create and parse command line arguments for FedML. + + Returns: + argparse.Namespace: A namespace containing the parsed arguments. + """ parser = argparse.ArgumentParser(description="FedML") parser.add_argument( "--yaml_config_file", @@ -76,6 +82,15 @@ class Arguments: """Argument class which contains all arguments from yaml config and constructs additional arguments""" def __init__(self, cmd_args, training_type=None, comm_backend=None, override_cmd_args=True): + """ + Initialize the Arguments class. + + Args: + cmd_args (argparse.Namespace): Command line arguments. + training_type (str, optional): The training platform type. Defaults to None. + comm_backend (str, optional): The communication backend type. Defaults to None. + override_cmd_args (bool, optional): Whether to override command line arguments. Defaults to True. + """ # set the command line arguments cmd_args_dict = cmd_args.__dict__ for arg_key, arg_val in cmd_args_dict.items(): @@ -87,6 +102,16 @@ def __init__(self, cmd_args, training_type=None, comm_backend=None, override_cmd for arg_key, arg_val in cmd_args_dict.items(): setattr(self, arg_key, arg_val) def load_yaml_config(self, yaml_path): + """ + Load a YAML configuration file. + + Args: + yaml_path (str): Path to the YAML configuration file. + + Returns: + dict: Loaded configuration as a dictionary. + """ + try: with open(yaml_path, "r") as stream: try: @@ -97,6 +122,14 @@ def load_yaml_config(self, yaml_path): return None def get_default_yaml_config(self, cmd_args, training_type=None, comm_backend=None): + """ + Set default YAML configuration based on training type and communication backend. + + Args: + cmd_args (argparse.Namespace): Command line arguments. + training_type (str, optional): The training platform type. Defaults to None. + comm_backend (str, optional): The communication backend type. Defaults to None. + """ if cmd_args.yaml_config_file == "": path_current_file = path.abspath(path.dirname(__file__)) if ( @@ -191,12 +224,28 @@ def get_default_yaml_config(self, cmd_args, training_type=None, comm_backend=Non return configuration def set_attr_from_config(self, configuration): + """ + Set class attributes from a configuration dictionary. + + Args: + configuration (dict): Configuration dictionary. + """ for _, param_family in configuration.items(): for key, val in param_family.items(): setattr(self, key, val) def load_arguments(training_type=None, comm_backend=None): + """ + Load arguments from command line and YAML config file. + + Args: + training_type (str, optional): The training platform type. Defaults to None. + comm_backend (str, optional): The communication backend type. Defaults to None. + + Returns: + argparse.Namespace: Parsed arguments. + """ cmd_args = add_args() # Load all arguments from YAML config file args = Arguments(cmd_args, training_type, comm_backend) diff --git a/python/fedml/launch_simulation.py b/python/fedml/launch_simulation.py index 37335a2753..b8ca2cdfdf 100644 --- a/python/fedml/launch_simulation.py +++ b/python/fedml/launch_simulation.py @@ -7,8 +7,12 @@ def run_simulation(backend=FEDML_SIMULATION_TYPE_SP): + """ + Run a simulation of the FedML Parrot. - """FedML Parrot""" + Args: + backend (str): The communication backend to use for the simulation. Defaults to FEDML_SIMULATION_TYPE_SP. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_SIMULATION fedml._global_comm_backend = backend diff --git a/python/fedml/runner.py b/python/fedml/runner.py index d536bab3cc..214bd8d659 100644 --- a/python/fedml/runner.py +++ b/python/fedml/runner.py @@ -17,6 +17,18 @@ class FedMLRunner: + """ + The main runner for different Federated Learning scenarios. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + algorithm_flow (FedMLAlgorithmFlow, optional): The pre-defined algorithm flow. Defaults to None. + """ def __init__( self, args, @@ -55,6 +67,20 @@ def __init__( def _init_simulation_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the simulation-based Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized simulation-based runner. + """ if hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_SP: from .simulation.simulator import SimulatorSingleProcess @@ -81,6 +107,20 @@ def _init_simulation_runner( def _init_cross_silo_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the cross-silo Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized cross-silo runner. + """ if args.scenario == "horizontal": if args.role == "client": from .cross_silo import Client @@ -118,6 +158,20 @@ def _init_cross_silo_runner( def _init_cheetah_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the Cheetah Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized Cheetah runner. + """ if args.role == "client": from .cheetah import Client @@ -137,6 +191,20 @@ def _init_cheetah_runner( def _init_model_serving_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the model serving Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized model serving runner. + """ if args.role == "client": from .serving import Client @@ -156,6 +224,20 @@ def _init_model_serving_runner( def _init_cross_device_runner( self, args, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize the runner for the cross-device Federated Learning. + + Args: + args: The command line arguments. + device: The device (CPU or GPU) to use for training. + dataset: The dataset used for training. + model: The model to be trained. + client_trainer (ClientTrainer, optional): The client trainer for training clients. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator for aggregating client updates. Defaults to None. + + Returns: + runner: The initialized cross-device runner. + """ if args.role == "server": from .cross_device import ServerMNN @@ -170,6 +252,11 @@ def _init_cross_device_runner( @staticmethod def log_runner_result(): + """ + Log the result of the runner to a file. + + This method creates a log file containing the process ID and saves it to the "fedml_trace" directory. + """ log_runner_result_dir = os.path.join(expanduser("~"), "fedml_trace") if not os.path.exists(log_runner_result_dir): os.makedirs(log_runner_result_dir, exist_ok=True) @@ -179,6 +266,11 @@ def log_runner_result(): log_file_obj.close() def run(self): + """ + Run the initialized Federated Learning runner. + + This method executes the Federated Learning process using the selected runner. + """ self.runner.run() FedMLRunner.log_runner_result() diff --git a/python/fedml/simulation/simulator.py b/python/fedml/simulation/simulator.py index abf0394869..3740e60eba 100644 --- a/python/fedml/simulation/simulator.py +++ b/python/fedml/simulation/simulator.py @@ -26,6 +26,17 @@ class SimulatorSingleProcess: def __init__(self, args, device, dataset, model, client_trainer=None, server_aggregator=None): + """ + Initialize the SimulatorSingleProcess. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run simulations on. + dataset (object): The dataset used for training. + model (nn.Module): The machine learning model. + client_trainer (ClientTrainer, optional): The client trainer to use. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator to use. Defaults to None. + """ from .sp.classical_vertical_fl.vfl_api import VflFedAvgAPI from .sp.fedavg import FedAvgAPI from .sp.fedprox.fedprox_trainer import FedProxTrainer @@ -64,6 +75,9 @@ def __init__(self, args, device, dataset, model, client_trainer=None, server_agg raise Exception("Exception") def run(self): + """ + Run the federated training simulation. + """ self.fl_trainer.train() @@ -77,6 +91,18 @@ def __init__( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Initialize the SimulatorMPI. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run simulations on. + dataset (object): The dataset used for training. + model (nn.Module): The machine learning model. + client_trainer (ClientTrainer, optional): The client trainer to use. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator to use. Defaults to None. + """ + # Import various trainer classes based on the selected federated optimizer from .mpi.base_framework.algorithm_api import FedML_Base_distributed from .mpi.decentralized_framework.algorithm_api import FedML_Decentralized_Demo_distributed from .mpi.fedavg.FedAvgAPI import FedML_FedAvg_distributed @@ -217,6 +243,17 @@ def run(self): class SimulatorNCCL: def __init__(self, args, device, dataset, model, client_trainer=None, server_aggregator=None): + """ + Initialize the SimulatorNCCL. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run simulations on. + dataset (object): The dataset used for training. + model (nn.Module): The machine learning model. + client_trainer (ClientTrainer, optional): The client trainer to use. Defaults to None. + server_aggregator (ServerAggregator, optional): The server aggregator to use. Defaults to None. + """ from .nccl.fedavg.FedAvgAPI import FedML_FedAvg_NCCL if args.federated_optimizer == "FedAvg": From b85e43e900372de90024845c32d34cc099781d53 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 15:22:09 +0530 Subject: [PATCH 39/70] `fedml\simulation\sp\3 folder` update --- .../sp/classical_vertical_fl/client.py | 45 +++++++++ .../sp/classical_vertical_fl/party_models.py | 49 ++++++++++ .../sp/classical_vertical_fl/vfl.py | 69 +++++++++++++ .../sp/classical_vertical_fl/vfl_api.py | 66 ++++++++++++- .../sp/classical_vertical_fl/vfl_fixture.py | 33 +++++++ .../sp/decentralized/client_dsgd.py | 98 +++++++++++++++++++ .../sp/decentralized/client_pushsum.py | 53 ++++++++++ .../sp/decentralized/decentralized_fl_api.py | 25 +++++ .../sp/decentralized/topology_manager.py | 53 ++++++++++ python/fedml/simulation/sp/fedavg/client.py | 65 ++++++++++++ .../fedml/simulation/sp/fedavg/fedavg_api.py | 55 ++++++++++- 11 files changed, 607 insertions(+), 4 deletions(-) diff --git a/python/fedml/simulation/sp/classical_vertical_fl/client.py b/python/fedml/simulation/sp/classical_vertical_fl/client.py index 1f141efe98..9d096b08fb 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/client.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/client.py @@ -2,6 +2,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a federated learning client. + + Args: + client_idx (int): Index of the client. + local_training_data (dataset): Local training dataset for the client. + local_test_data (dataset): Local test dataset for the client. + local_sample_number (int): Number of samples in the local dataset. + args (argparse.Namespace): Parsed command-line arguments. + device (torch.device): The device to run training and inference on. + model_trainer (ModelTrainer): Trainer for the client's machine learning model. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -12,21 +24,54 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset and client index. + + Args: + client_idx (int): New index of the client. + local_training_data (dataset): New local training dataset for the client. + local_test_data (dataset): New local test dataset for the client. + local_sample_number (int): New number of samples in the local dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number def get_sample_number(self): + """ + Get the number of samples in the local dataset. + + Returns: + int: Number of samples in the local dataset. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's machine learning model using global model parameters. + + Args: + w_global (list): Global model parameters. + + Returns: + list: Updated model parameters after training. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's machine learning model. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset for testing. + + Returns: + dict: Metrics obtained from local testing. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/classical_vertical_fl/party_models.py b/python/fedml/simulation/sp/classical_vertical_fl/party_models.py index 5fa8237cfa..8846022829 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/party_models.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/party_models.py @@ -11,6 +11,12 @@ def sigmoid(x): class VFLGuestModel(object): def __init__(self, local_model): + """ + Initialize a VFL guest model. + + Args: + local_model (torch.nn.Module): Local machine learning model. + """ super(VFLGuestModel, self).__init__() self.localModel = local_model self.feature_dim = local_model.get_output_dim() @@ -24,14 +30,29 @@ def __init__(self, local_model): self.y = None def set_dense_model(self, dense_model): + """ + Set the dense model for the guest model. + + Args: + dense_model: New dense model to set. + """ self.dense_model = dense_model def set_batch(self, X, y, global_step): + """ + Set the batch data and global step for training. + + Args: + X: Input data for training. + y: Target labels for training. + global_step: Current global step in training. + """ self.X = X self.y = y self.current_global_step = global_step def _fit(self, X, y): + self.temp_K_Z = self.localModel.forward(X) self.K_U = self.dense_model.forward(self.temp_K_Z) @@ -39,6 +60,16 @@ def _fit(self, X, y): self._update_models(X, y) def predict(self, X, component_list): + """ + Predict using the guest model. + + Args: + X: Input data for prediction. + component_list: List of components to consider in the prediction. + + Returns: + Predicted values. + """ temp_K_Z = self.localModel.forward(X) U = self.dense_model.forward(temp_K_Z) for comp in component_list: @@ -46,6 +77,12 @@ def predict(self, X, component_list): return sigmoid(np.sum(U, axis=1)) def receive_components(self, component_list): + """ + Receive and store components from other parties. + + Args: + component_list: List of components to receive and store. + """ for party_component in component_list: self.parties_grad_component_list.append(party_component) @@ -67,6 +104,12 @@ def _compute_common_gradient_and_loss(self, y): self.loss = class_loss.item() def send_gradients(self): + """ + Send gradients to other parties. + + Returns: + Gradients to send. + """ return self.top_grads def _update_models(self, X, y): @@ -74,6 +117,12 @@ def _update_models(self, X, y): self.localModel.backward(X, back_grad) def get_loss(self): + """ + Get the loss value of the guest model. + + Returns: + Loss value. + """ return self.loss diff --git a/python/fedml/simulation/sp/classical_vertical_fl/vfl.py b/python/fedml/simulation/sp/classical_vertical_fl/vfl.py index dd421d32db..023377fdc2 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/vfl.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/vfl.py @@ -1,5 +1,33 @@ class VerticalMultiplePartyLogisticRegressionFederatedLearning(object): + """ + Federated Learning class for logistic regression with multiple parties. + + Args: + party_A (VFLGuestModel): The party with labels (party A). + main_party_id (str, optional): The ID of the main party. Defaults to "_main". + + Methods: + set_debug(is_debug): + Set the debug mode for the federated learning. + get_main_party_id(): + Get the ID of the main party. + add_party(id, party_model): + Add a party to the federated learning. + + Attributes: + main_party_id (str): The ID of the main party. + party_a (VFLGuestModel): The party with labels (party A). + party_dict (dict): A dictionary to store other parties without labels. + is_debug (bool): Flag to enable or disable debug mode. + """ def __init__(self, party_A, main_party_id="_main"): + """ + Initialize the VerticalMultiplePartyLogisticRegressionFederatedLearning. + + Args: + party_A (VFLGuestModel): The party with labels (party A). + main_party_id (str, optional): The ID of the main party. Defaults to "_main". + """ super(VerticalMultiplePartyLogisticRegressionFederatedLearning, self).__init__() self.main_party_id = main_party_id # party A is the parity with labels @@ -9,15 +37,46 @@ def __init__(self, party_A, main_party_id="_main"): self.is_debug = False def set_debug(self, is_debug): + """ + Set the debug mode for the federated learning. + + Args: + is_debug (bool): True to enable debug mode, False to disable. + """ self.is_debug = is_debug def get_main_party_id(self): + """ + Get the ID of the main party. + + Returns: + str: The ID of the main party. + """ return self.main_party_id def add_party(self, *, id, party_model): + """ + Add a party to the federated learning. + + Args: + id (str): The ID of the party. + party_model: The model associated with the party. + """ self.party_dict[id] = party_model def fit(self, X_A, y, party_X_dict, global_step): + """ + Perform the federated learning training. + + Args: + X_A: The batch data for party A (with labels). + y: The labels for party A. + party_X_dict (dict): A dictionary of batch data for other parties. + global_step: The global training step. + + Returns: + float: The loss after training. + """ if self.is_debug: print("==> start fit") @@ -54,6 +113,16 @@ def fit(self, X_A, y, party_X_dict, global_step): return loss def predict(self, X_A, party_X_dict): + """ + Perform predictions using the federated learning model. + + Args: + X_A: The input data for party A (with labels). + party_X_dict (dict): A dictionary of input data for other parties. + + Returns: + array: Predicted labels. + """ comp_list = [] for id, party_X in party_X_dict.items(): comp_list.append(self.party_dict[id].predict(party_X)) diff --git a/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py b/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py index 518612f5f8..9a910d46e1 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/vfl_api.py @@ -11,6 +11,15 @@ class VflFedAvgAPI(object): + """ + Federated Learning using the FedAvg algorithm. + + Args: + args (Namespace): Command-line arguments and settings. + device (str): The device (e.g., 'cpu', 'cuda') for model training. + dataset (tuple): A tuple containing dataset information. + model (torch.nn.Module): The machine learning model used for federated learning. + """ def __init__(self, args, device, dataset, model): self.device = device self.args = args @@ -46,6 +55,15 @@ def __init__(self, args, device, dataset, model): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for federated learning. + + Args: + train_data_local_num_dict (dict): A dictionary mapping client indexes to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indexes to local training data. + test_data_local_dict (dict): A dictionary mapping client indexes to local test data. + model_trainer (ModelTrainer): The model trainer used for local client training. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -61,6 +79,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Perform federated learning using the FedAvg algorithm. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() for round_idx in range(self.args.comm_round): @@ -109,6 +130,17 @@ def train(self): self._local_test_on_all_clients(round_idx) def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for the federated learning round. + + Args: + round_idx (int): The current round index. + client_num_in_total (int): The total number of clients in the dataset. + client_num_per_round (int): The number of clients to sample per round. + + Returns: + list: List of client indexes for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -119,6 +151,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset subset for testing. + + Args: + num_samples (int): The number of samples to include in the validation set. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -126,6 +164,15 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate model weights from all clients using Federated Averaging (FedAvg). + + Args: + w_locals (list): List of local model weights and sample numbers for each client. + + Returns: + dict: Averaged global model weights. + """ training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -144,10 +191,13 @@ def _aggregate(self, w_locals): def _aggregate_noniid_avg(self, w_locals): """ - The old aggregate method will impact the model performance when it comes to Non-IID setting + Aggregate model weights from all clients using non-IID averaging. + Args: - w_locals: + w_locals (list): List of local model weights for each client. + Returns: + dict: Averaged global model weights. """ (_, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -158,6 +208,12 @@ def _aggregate_noniid_avg(self, w_locals): return averaged_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx (int): The current round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -213,6 +269,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on a validation set and log the results. + + Args: + round_idx (int): The current round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py b/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py index 080a76ca52..b88c8c8b53 100644 --- a/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py +++ b/python/fedml/simulation/sp/classical_vertical_fl/vfl_fixture.py @@ -6,6 +6,20 @@ def compute_correct_prediction(*, y_targets, y_prob_preds, threshold=0.5): + """ + Compute correct predictions and counts based on probability predictions and threshold. + + Args: + y_targets (array-like): True labels. + y_prob_preds (array-like): Predicted probabilities. + threshold (float, optional): Threshold for binary classification. Defaults to 0.5. + + Returns: + Tuple: + - y_hat_lbls (numpy.ndarray): Predicted labels (0 or 1). + - [pred_pos_count, pred_neg_count, correct_count] (list): Counts of predicted positive, + predicted negative, and correct predictions. + """ y_hat_lbls = [] pred_pos_count = 0 pred_neg_count = 0 @@ -25,12 +39,31 @@ def compute_correct_prediction(*, y_targets, y_prob_preds, threshold=0.5): class FederatedLearningFixture(object): + """ + Fixture for performing federated learning with a specified model. + """ def __init__( self, federated_learning: VerticalMultiplePartyLogisticRegressionFederatedLearning, ): + """ + Initialize a Federated Learning Fixture. + + Args: + federated_learning (VerticalMultiplePartyLogisticRegressionFederatedLearning): + The federated learning instance to be used. + """ self.federated_learning = federated_learning def fit(self, train_data, test_data, epochs=50, batch_size=-1): + """ + Fit the federated learning model on the provided data. + + Args: + train_data (dict): Training data containing X and Y for each party. + test_data (dict): Testing data containing X and Y for each party. + epochs (int, optional): Number of training epochs. Defaults to 50. + batch_size (int, optional): Batch size for training. Defaults to -1 (no batching). + """ main_party_id = self.federated_learning.get_main_party_id() Xa_train = train_data[main_party_id]["X"] diff --git a/python/fedml/simulation/sp/decentralized/client_dsgd.py b/python/fedml/simulation/sp/decentralized/client_dsgd.py index f9891a9273..fd9cd28a1b 100644 --- a/python/fedml/simulation/sp/decentralized/client_dsgd.py +++ b/python/fedml/simulation/sp/decentralized/client_dsgd.py @@ -4,6 +4,53 @@ class ClientDSGD(object): + """ + Client for Distributed Stochastic Gradient Descent (DSGD). + + Args: + model: The machine learning model used by the client. + model_cache: The model cache used for temporary values. + client_id (int): The unique identifier of the client. + streaming_data (list): Streaming data for training. + topology_manager: The manager for defining communication topology. + iteration_number (int): The total number of iterations. + learning_rate (float): The learning rate for gradient descent. + batch_size (int): The batch size for training. + weight_decay (float): The weight decay for regularization. + latency (float): The communication latency. + b_symmetric (bool): Flag for symmetric or asymmetric communication topology. + + Methods: + train_local(iteration_id): + Train the client's model on local data for a specified iteration. + train(iteration_id): + Train the client's model on streaming data for a specified iteration. + get_regret(): + Get the regret (loss) for each iteration. + send_local_gradient_to_neighbor(client_list): + Send local gradients to neighboring clients. + receive_neighbor_gradients(client_id, model_x, topo_weight): + Receive gradients from a neighboring client. + update_local_parameters(): + Update local model parameters based on received gradients. + + Attributes: + model: The machine learning model used by the client. + b_symmetric (bool): Flag for symmetric or asymmetric communication topology. + topology_manager: The manager for defining communication topology. + id (int): The unique identifier of the client. + streaming_data (list): Streaming data for training. + optimizer: The optimizer for training the model. + criterion: The loss criterion used for training. + learning_rate (float): The learning rate for gradient descent. + iteration_number (int): The total number of iterations. + latency (float): The communication latency. + batch_size (int): The batch size for training. + loss_in_each_iteration (list): List to store loss for each iteration. + model_x: The model cache for temporary values. + neighbors_weight_dict (dict): Dictionary to store neighboring client weights. + neighbors_topo_weight_dict (dict): Dictionary to store neighboring client topology weights. + """ def __init__( self, model, @@ -18,6 +65,22 @@ def __init__( latency, b_symmetric, ): + """ + Initialize the ClientDSGD object. + + Args: + model: The machine learning model used by the client. + model_cache: The model cache used for temporary values. + client_id (int): The unique identifier of the client. + streaming_data (list): Streaming data for training. + topology_manager: The manager for defining communication topology. + iteration_number (int): The total number of iterations. + learning_rate (float): The learning rate for gradient descent. + batch_size (int): The batch size for training. + weight_decay (float): The weight decay for regularization. + latency (float): The communication latency. + b_symmetric (bool): Flag for symmetric or asymmetric communication topology. + """ # logging.info("streaming_data = %s" % streaming_data) # Since we use logistic regression, the model size is small. @@ -56,6 +119,12 @@ def __init__( self.neighbors_topo_weight_dict = dict() def train_local(self, iteration_id): + """ + Train the client's model on local data for a specified iteration. + + Args: + iteration_id (int): The current iteration. + """ self.optimizer.zero_grad() train_x = torch.from_numpy(self.streaming_data[iteration_id]["x"]) train_y = torch.FloatTensor([self.streaming_data[iteration_id]["y"]]) @@ -66,6 +135,12 @@ def train_local(self, iteration_id): self.loss_in_each_iteration.append(loss) def train(self, iteration_id): + """ + Train the client's model on streaming data for a specified iteration. + + Args: + iteration_id (int): The current iteration. + """ self.optimizer.zero_grad() if iteration_id >= self.iteration_number: @@ -86,10 +161,22 @@ def train(self, iteration_id): self.loss_in_each_iteration.append(loss) def get_regret(self): + """ + Get the regret (loss) for each iteration. + + Returns: + list: A list containing the loss for each iteration. + """ return self.loss_in_each_iteration # simulation def send_local_gradient_to_neighbor(self, client_list): + """ + Send local gradients to neighboring clients for simulation. + + Args: + client_list (list): List of client objects representing neighbors. + """ for index in range(len(self.topology)): if self.topology[index] != 0 and index != self.id: client = client_list[index] @@ -98,10 +185,21 @@ def send_local_gradient_to_neighbor(self, client_list): ) def receive_neighbor_gradients(self, client_id, model_x, topo_weight): + """ + Receive gradients from a neighboring client for simulation. + + Args: + client_id (int): The identifier of the neighboring client. + model_x: Model parameters from the neighboring client. + topo_weight (float): Topology weight associated with the neighboring client. + """ self.neighbors_weight_dict[client_id] = model_x self.neighbors_topo_weight_dict[client_id] = topo_weight def update_local_parameters(self): + """ + Update local model parameters based on received gradients. + """ # update x_{t+1/2} for x_paras in self.model_x.parameters(): x_paras.data.mul_(self.topology[self.id]) diff --git a/python/fedml/simulation/sp/decentralized/client_pushsum.py b/python/fedml/simulation/sp/decentralized/client_pushsum.py index 08da9bccf8..05e9dbc808 100644 --- a/python/fedml/simulation/sp/decentralized/client_pushsum.py +++ b/python/fedml/simulation/sp/decentralized/client_pushsum.py @@ -20,6 +20,23 @@ def __init__( b_symmetric, time_varying, ): + """ + Initialize a ClientPushsum instance. + + Args: + model: The client's model. + model_cache: Cache for the model parameters. + client_id (int): Identifier for the client. + streaming_data: Streaming data for training. + topology_manager: Topology manager for network topology. + iteration_number (int): Number of iterations. + learning_rate (float): Learning rate for optimization. + batch_size (int): Batch size for training. + weight_decay (float): Weight decay for optimization. + latency (float): Latency in communication. + b_symmetric (bool): Whether the topology is symmetric. + time_varying (bool): Whether the topology is time-varying. + """ # logging.info("streaming_data = %s" % streaming_data) # Since we use logistic regression, the model size is small. @@ -60,6 +77,12 @@ def __init__( self.neighbors_topo_weight_dict = dict() def train_local(self, iteration_id): + """ + Train the client's model using local data for a specific iteration. + + Args: + iteration_id (int): The iteration index. + """ self.optimizer.zero_grad() train_x = torch.from_numpy(self.streaming_data[iteration_id]["x"]) train_y = torch.FloatTensor([self.streaming_data[iteration_id]["y"]]) @@ -70,6 +93,12 @@ def train_local(self, iteration_id): self.loss_in_each_iteration.append(loss) def train(self, iteration_id): + """ + Train the client's model using data for a specific iteration. + + Args: + iteration_id (int): The iteration index. + """ self.optimizer.zero_grad() if iteration_id >= self.iteration_number: @@ -105,10 +134,22 @@ def train(self, iteration_id): self.loss_in_each_iteration.append(loss) def get_regret(self): + """ + Get the regret (loss) for each iteration. + + Returns: + list: A list containing the loss for each iteration. + """ return self.loss_in_each_iteration # simulation def send_local_gradient_to_neighbor(self, client_list): + """ + Send local gradients to neighboring clients for simulation. + + Args: + client_list (list): List of client objects representing neighbors. + """ for index in range(len(self.topology)): if self.topology[index] != 0 and index != self.id: client = client_list[index] @@ -120,11 +161,23 @@ def send_local_gradient_to_neighbor(self, client_list): ) def receive_neighbor_gradients(self, client_id, model_x, topo_weight, omega): + """ + Receive gradients from a neighboring client for simulation. + + Args: + client_id (int): The identifier of the neighboring client. + model_x: Model parameters from the neighboring client. + topo_weight (float): Topology weight associated with the neighboring client. + omega (float): Omega value for push-sum. + """ self.neighbors_weight_dict[client_id] = model_x self.neighbors_topo_weight_dict[client_id] = topo_weight self.neighbors_omega_dict[client_id] = omega def update_local_parameters(self): + """ + Update local model parameters and omega based on received gradients. + """ # update x_{t+1/2} for x_paras in self.model_x.parameters(): x_paras.data.mul_(self.topology[self.id]) diff --git a/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py b/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py index 9ced125465..90cdadb53c 100644 --- a/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py +++ b/python/fedml/simulation/sp/decentralized/decentralized_fl_api.py @@ -9,6 +9,17 @@ def cal_regret(client_list, client_number, t): + """ + Calculate the average regret across all clients. + + Args: + client_list (list): List of client objects. + client_number (int): Total number of clients. + t (int): Current iteration. + + Returns: + float: Average regret across all clients. + """ regret = 0 for client in client_list: regret += np.sum(client.get_regret()) @@ -20,6 +31,20 @@ def cal_regret(client_list, client_number, t): def FedML_decentralized_fl( client_number, client_id_list, streaming_data, model, model_cache, args ): + """ + Run decentralized federated learning with the specified configuration. + + Args: + client_number (int): Total number of clients. + client_id_list (list): List of client IDs. + streaming_data (list): List of streaming data for each client. + model: The federated learning model. + model_cache: Model cache for each client. + args: Additional arguments for configuration. + + Returns: + None + """ iteration_number_T = args.iteration_number lr_rate = args.learning_rate batch_size = args.batch_size diff --git a/python/fedml/simulation/sp/decentralized/topology_manager.py b/python/fedml/simulation/sp/decentralized/topology_manager.py index 906f6d77e7..7f4f3565ff 100644 --- a/python/fedml/simulation/sp/decentralized/topology_manager.py +++ b/python/fedml/simulation/sp/decentralized/topology_manager.py @@ -3,6 +3,29 @@ class TopologyManager: + """ + Manages the network topology for decentralized federated learning. + + Args: + n (int): Total number of clients. + b_symmetric (bool): Flag indicating symmetric or asymmetric topology. + undirected_neighbor_num (int): Number of undirected neighbors for symmetric topology. + out_directed_neighbor (int): Number of outgoing directed neighbors for asymmetric topology. + + Attributes: + n (int): Total number of clients. + b_symmetric (bool): Flag indicating symmetric or asymmetric topology. + undirected_neighbor_num (int): Number of undirected neighbors for symmetric topology. + out_directed_neighbor (int): Number of outgoing directed neighbors for asymmetric topology. + topology_symmetric (list): Symmetric topology information. + topology_asymmetric (list): Asymmetric topology information. + b_fully_connected (bool): Flag indicating if the topology is fully connected. + + Methods: + generate_topology(): Generates the network topology. + get_symmetric_neighbor_list(client_idx): Gets symmetric neighbors for a client. + get_asymmetric_neighbor_list(client_idx): Gets asymmetric neighbors for a client. + """ def __init__( self, n, b_symmetric, undirected_neighbor_num=5, out_directed_neighbor=5 ): @@ -17,6 +40,9 @@ def __init__( self.b_fully_connected = True def generate_topology(self): + """ + Generates the network topology based on configuration. + """ if self.b_fully_connected: self.__fully_connected() return @@ -27,16 +53,37 @@ def generate_topology(self): self.__randomly_pick_neighbors_asymmetric() def get_symmetric_neighbor_list(self, client_idx): + """ + Gets the symmetric neighbor list for a client. + + Args: + client_idx (int): Index of the client. + + Returns: + list: List of symmetric neighbors for the specified client. + """ if client_idx >= self.n: return [] return self.topology_symmetric[client_idx] def get_asymmetric_neighbor_list(self, client_idx): + """ + Gets the asymmetric neighbor list for a client. + + Args: + client_idx (int): Index of the client. + + Returns: + list: List of asymmetric neighbors for the specified client. + """ if client_idx >= self.n: return [] return self.topology_asymmetric[client_idx] def __randomly_pick_neighbors_symmetric(self): + """ + Generates symmetric topology with randomly added links for each node. + """ # first generate a ring topology topology_ring = np.array( nx.to_numpy_matrix(nx.watts_strogatz_graph(self.n, 2, 0)), dtype=np.float32 @@ -74,6 +121,9 @@ def __randomly_pick_neighbors_symmetric(self): self.topology_symmetric = topology_symmetric def __randomly_pick_neighbors_asymmetric(self): + """ + Generates asymmetric topology with randomly added links for each node. + """ # randomly add some links for each node (symmetric) k = self.undirected_neighbor_num # print("neighbors = " + str(k)) @@ -134,6 +184,9 @@ def __randomly_pick_neighbors_asymmetric(self): self.topology_asymmetric = topology_ring def __fully_connected(self): + """ + Generates fully connected symmetric topology. + """ topology_fully_connected = np.array( nx.to_numpy_matrix(nx.watts_strogatz_graph(self.n, self.n - 1, 0)), dtype=np.float32, diff --git a/python/fedml/simulation/sp/fedavg/client.py b/python/fedml/simulation/sp/fedavg/client.py index cc74a9d932..12df31b681 100644 --- a/python/fedml/simulation/sp/fedavg/client.py +++ b/python/fedml/simulation/sp/fedavg/client.py @@ -1,4 +1,36 @@ class Client: + """ + Represents a client in a federated learning system. + + Args: + client_idx (int): The index of the client. + local_training_data (list): Local training data. + local_test_data (list): Local test data. + local_sample_number (int): Number of local samples. + args (object): Arguments for configuration. + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + model_trainer (object): The model trainer object for training and testing. + + Attributes: + client_idx (int): The index of the client. + local_training_data (list): Local training data. + local_test_data (list): Local test data. + local_sample_number (int): Number of local samples. + args (object): Arguments for configuration. + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + model_trainer (object): The model trainer object for training and testing. + + Methods: + update_local_dataset(client_idx, local_training_data, local_test_data, local_sample_number): + Updates the local dataset for the client. + + get_sample_number(): Gets the number of local samples. + + train(w_global): Trains the client's model using the global model weights. + + local_test(b_use_test_dataset): Tests the client's model using local or test data. + + """ def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): @@ -12,6 +44,15 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Updates the local dataset for the client. + + Args: + client_idx (int): The index of the client. + local_training_data (list): Updated local training data. + local_test_data (list): Updated local test data. + local_sample_number (int): Updated number of local samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -19,15 +60,39 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Gets the number of local samples. + + Returns: + int: Number of local samples. + """ return self.local_sample_number def train(self, w_global): + """ + Trains the client's model using the global model weights. + + Args: + w_global (object): Global model weights. + + Returns: + object: Updated client model weights. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Tests the client's model using local or test data. + + Args: + b_use_test_dataset (bool): Flag to use test dataset for testing. + + Returns: + object: Model evaluation metrics. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/fedavg/fedavg_api.py b/python/fedml/simulation/sp/fedavg/fedavg_api.py index a748186fb3..ded0d92efa 100644 --- a/python/fedml/simulation/sp/fedavg/fedavg_api.py +++ b/python/fedml/simulation/sp/fedavg/fedavg_api.py @@ -12,6 +12,33 @@ class FedAvgAPI(object): + """ + Federated Averaging API for federated learning. + + Args: + args (object): Arguments for configuration. + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + dataset (tuple): A tuple containing dataset information. + + Attributes: + device (str): The device (e.g., 'cpu' or 'cuda') for model training. + args (object): Arguments for configuration. + train_global: Global training dataset. + test_global: Global test dataset. + val_global: Global validation dataset. + train_data_num_in_total (int): Total number of training samples. + test_data_num_in_total (int): Total number of test samples. + client_list (list): List of client instances. + train_data_local_num_dict (dict): Dictionary mapping client index to the number of local training samples. + train_data_local_dict (dict): Dictionary mapping client index to local training data. + test_data_local_dict (dict): Dictionary mapping client index to local test data. + model_trainer: Model trainer for federated learning. + model: The federated model. + + Methods: + train(): Train the federated model using federated averaging. + + """ def __init__(self, args, device, dataset, model): self.device = device self.args = args @@ -49,6 +76,7 @@ def __init__(self, args, device, dataset, model): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """Setup client instances.""" logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -125,6 +153,7 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """Sample clients for communication round.""" if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -135,6 +164,7 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """Generate a validation dataset.""" test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -142,6 +172,7 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """Aggregate local model weights.""" training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -160,10 +191,14 @@ def _aggregate(self, w_locals): def _aggregate_noniid_avg(self, w_locals): """ - The old aggregate method will impact the model performance when it comes to Non-IID setting + Aggregate local model weights using non-IID average method. + Args: - w_locals: + w_locals (list): List of tuples containing (sample_num, local_weights). + Returns: + dict: Averaged model parameters. + """ (_, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -174,6 +209,16 @@ def _aggregate_noniid_avg(self, w_locals): return averaged_params def _local_test_on_all_clients(self, round_idx): + """ + Aggregate local model weights using non-IID average method. + + Args: + w_locals (list): List of tuples containing (sample_num, local_weights). + + Returns: + dict: Averaged model parameters. + + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -235,7 +280,13 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on all clients and log the results. + Args: + round_idx (int): The current communication round index. + + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) if self.val_global is None: From 84784ff8000a794d2948723dbf5aad43761662ea Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 17:17:28 +0530 Subject: [PATCH 40/70] same as previous --- .../fedml/simulation/sp/feddyn/client copy.py | 54 +++++ .../simulation/sp/hierarchical_fl/client.py | 24 +++ .../simulation/sp/hierarchical_fl/group.py | 34 +++ .../simulation/sp/hierarchical_fl/trainer.py | 34 +++ python/fedml/simulation/sp/mime/client.py | 47 +++++ .../fedml/simulation/sp/mime/mime_trainer.py | 70 ++++++- python/fedml/simulation/sp/mime/opt_utils.py | 58 +++++- python/fedml/simulation/sp/scaffold/client.py | 46 +++++ .../sp/scaffold/scaffold_trainer.py | 59 ++++++ .../simulation/sp/turboaggregate/TA_client.py | 18 ++ .../sp/turboaggregate/TA_trainer.py | 65 ++++++ .../sp/turboaggregate/mpc_function.py | 194 +++++++++++++++++- 12 files changed, 697 insertions(+), 6 deletions(-) diff --git a/python/fedml/simulation/sp/feddyn/client copy.py b/python/fedml/simulation/sp/feddyn/client copy.py index 02d1b30333..1a7bc2bb50 100644 --- a/python/fedml/simulation/sp/feddyn/client copy.py +++ b/python/fedml/simulation/sp/feddyn/client copy.py @@ -4,6 +4,15 @@ def model_parameter_vector(model): + """ + Flatten the model's parameters into a single vector. + + Args: + model (torch.nn.Module): The neural network model. + + Returns: + torch.Tensor: A flattened vector containing all model parameters. + """ param = [p.view(-1) for p in model.parameters()] return torch.concat(param, dim=0) @@ -12,6 +21,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client for federated learning. + + Args: + client_idx (int): Index of the client. + local_training_data (torch.utils.data.DataLoader): Local training data. + local_test_data (torch.utils.data.DataLoader): Local test data. + local_sample_number (int): Number of samples in the local dataset. + args: Command-line arguments. + device (torch.device): Device for training. + model_trainer: Model trainer for training and testing. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -30,6 +51,15 @@ def __init__( def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): Index of the client. + local_training_data (torch.utils.data.DataLoader): Local training data. + local_test_data (torch.utils.data.DataLoader): Local test data. + local_sample_number (int): Number of samples in the local dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -37,15 +67,39 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of samples in the local dataset. + + Returns: + int: Number of samples. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's model using the global model parameters. + + Args: + w_global: Global model parameters. + + Returns: + tuple: A tuple containing the updated weights and gradients. + """ self.model_trainer.set_model_params(w_global) self.old_grad = self.model_trainer.train(self.local_training_data, self.device, self.args, self.old_grad) weights = self.model_trainer.get_model_params() return weights, self.old_grad def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's dataset. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset or training dataset. + + Returns: + dict: Metrics from the local test. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/hierarchical_fl/client.py b/python/fedml/simulation/sp/hierarchical_fl/client.py index 48efe5ba6c..5eed5e402e 100644 --- a/python/fedml/simulation/sp/hierarchical_fl/client.py +++ b/python/fedml/simulation/sp/hierarchical_fl/client.py @@ -7,6 +7,19 @@ class HFLClient(Client): + """ + Represents a High-Frequency Learning (HFL) client in a federated learning setting. + + Args: + client_idx (int): Index of the client. + local_training_data: Local training data for the client. + local_test_data: Local test data for the client. + local_sample_number: Number of local samples. + args: Arguments for client configuration. + device: Device (e.g., 'cuda' or 'cpu') to perform computations. + model: The client's model. + model_trainer: Trainer for the client's model. + """ def __init__(self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model, model_trainer): @@ -24,6 +37,17 @@ def __init__(self, client_idx, local_training_data, local_test_data, local_sampl self.criterion = nn.CrossEntropyLoss().to(device) def train(self, global_round_idx, group_round_idx, w): + """ + Train the client's model using High-Frequency Learning (HFL) approach. + + Args: + global_round_idx (int): Global round index. + group_round_idx (int): Group round index. + w: Model weights to initialize training. + + Returns: + list: A list of tuples containing global epoch and model state dictionaries. + """ self.model.load_state_dict(w) self.model.to(self.device) diff --git a/python/fedml/simulation/sp/hierarchical_fl/group.py b/python/fedml/simulation/sp/hierarchical_fl/group.py index adfa27d0bd..70c568fbc1 100644 --- a/python/fedml/simulation/sp/hierarchical_fl/group.py +++ b/python/fedml/simulation/sp/hierarchical_fl/group.py @@ -5,6 +5,20 @@ class Group(FedAvgAPI): + """ + Represents a group of clients in a federated learning setting. + + Args: + idx (int): Index of the group. + total_client_indexes (list): List of client indexes in the group. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + train_data_local_num_dict: Dictionary containing the number of local training samples for each client. + args: Arguments for group configuration. + device: Device (e.g., 'cuda' or 'cpu') to perform computations. + model: The shared model used by clients in the group. + model_trainer: Trainer for the shared model. + """ def __init__( self, idx, @@ -35,12 +49,32 @@ def __init__( ) def get_sample_number(self, sampled_client_indexes): + """ + Calculate the total number of training samples in the group. + + Args: + sampled_client_indexes (list): List of sampled client indexes. + + Returns: + int: Total number of training samples in the group. + """ self.group_sample_number = 0 for client_idx in sampled_client_indexes: self.group_sample_number += self.train_data_local_num_dict[client_idx] return self.group_sample_number def train(self, global_round_idx, w, sampled_client_indexes): + """ + Train the group of clients using federated learning. + + Args: + global_round_idx (int): Global round index. + w: Model weights to initialize training. + sampled_client_indexes (list): List of sampled client indexes. + + Returns: + list: A list of tuples containing global epoch and aggregated model weights. + """ sampled_client_list = [self.client_dict[client_idx] for client_idx in sampled_client_indexes] w_group = w w_group_list = [] diff --git a/python/fedml/simulation/sp/hierarchical_fl/trainer.py b/python/fedml/simulation/sp/hierarchical_fl/trainer.py index c0d1c05003..63085dd67e 100644 --- a/python/fedml/simulation/sp/hierarchical_fl/trainer.py +++ b/python/fedml/simulation/sp/hierarchical_fl/trainer.py @@ -8,6 +8,15 @@ class HierarchicalTrainer(FedAvgAPI): + """ + Represents a hierarchical federated learning trainer. + + Args: + train_data_local_num_dict: Dictionary containing the number of local training samples for each client. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + model_trainer: Trainer for the shared model. + """ def _setup_clients( self, train_data_local_num_dict, @@ -15,6 +24,15 @@ def _setup_clients( test_data_local_dict, model_trainer, ): + """ + Set up client groups and maintain a dummy client for testing. + + Args: + train_data_local_num_dict: Dictionary containing the number of local training samples for each client. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + model_trainer: Trainer for the shared model. + """ logging.info("############setup_clients (START)#############") if self.args.group_method == "random": self.group_indexes = np.random.randint( @@ -61,6 +79,17 @@ def _setup_clients( def _client_sampling( self, global_round_idx, client_num_in_total, client_num_per_round ): + """ + Sample clients for training in a hierarchical manner. + + Args: + global_round_idx (int): Global round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + dict: Dictionary mapping group indexes to sampled client indexes. + """ sampled_client_indexes = super()._client_sampling( global_round_idx, client_num_in_total, client_num_per_round ) @@ -76,6 +105,11 @@ def _client_sampling( return group_to_client_indexes def train(self): + """ + Train the hierarchical federated learning model. + + This method manages global communication rounds and client sampling. + """ w_global = self.model.state_dict() for global_round_idx in range(self.args.comm_round): logging.info( diff --git a/python/fedml/simulation/sp/mime/client.py b/python/fedml/simulation/sp/mime/client.py index 00df1b004e..e0b0553175 100644 --- a/python/fedml/simulation/sp/mime/client.py +++ b/python/fedml/simulation/sp/mime/client.py @@ -1,4 +1,16 @@ class Client: + """ + Represents a client in a federated learning setting. + + Args: + client_idx (int): Index of the client. + local_training_data: Local training data for the client. + local_test_data: Local test data for the client. + local_sample_number: Number of local samples. + args: Arguments for client configuration. + device: Device (e.g., 'cuda' or 'cpu') to perform computations. + model_trainer: Trainer for the client's model. + """ def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): @@ -12,6 +24,15 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): Index of the client. + local_training_data: New local training data for the client. + local_test_data: New local test data for the client. + local_sample_number: New number of local samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -19,15 +40,41 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local samples. + + Returns: + int: Number of local samples. + """ return self.local_sample_number def train(self, w_global, grad_global, global_named_states): + """ + Train the client's model. + + Args: + w_global: Global model parameters. + grad_global: Global gradient. + global_named_states: Named states of the global optimizer. + + Returns: + tuple: A tuple containing local model weights and local gradients. + """ self.model_trainer.set_model_params(w_global) local_grad = self.model_trainer.train(self.local_training_data, self.device, self.args, grad_global, global_named_states) weights = self.model_trainer.get_model_params() return weights, local_grad def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's dataset. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset. + + Returns: + dict: Metrics from the local test. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/mime/mime_trainer.py b/python/fedml/simulation/sp/mime/mime_trainer.py index ff7831a523..74eff35dfd 100644 --- a/python/fedml/simulation/sp/mime/mime_trainer.py +++ b/python/fedml/simulation/sp/mime/mime_trainer.py @@ -18,7 +18,20 @@ class MimeTrainer(object): + """ + Trainer for the Mime model on federated learning. + """ def __init__(self, dataset, model, device, args): + """ + Initialize the MimeTrainer. + + Args: + dataset: A list containing dataset information. + model: The Mime model. + device: The target device for training. + args: Training arguments. + """ + self.device = device self.args = args [ @@ -58,6 +71,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for federated learning. + + Args: + train_data_local_num_dict: Dictionary containing local training data numbers for each client. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local test data for each client. + model_trainer: Model trainer for client instances. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -73,11 +95,10 @@ def _setup_clients( logging.info("############setup_clients (END)#############") - - - - def train(self): + """ + Perform federated training using the Mime model. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -142,6 +163,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Perform client sampling for each communication round. + + Args: + round_idx: Index of the communication round. + client_num_in_total: Total number of clients. + client_num_per_round: Number of clients per round. + + Returns: + List: List of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -152,6 +184,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set by sampling a subset of the test data. + + Args: + num_samples (int): The number of samples to include in the validation set. Default is 10,000. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -160,6 +198,9 @@ def _generate_validation_set(self, num_samples=10000): def _instanciate_opt(self): + """ + Initialize the optimizer for the MimeTrainer. + """ self.opt = OptRepo.name2cls(self.args.server_optimizer)( # self.model_global.parameters(), lr=self.args.server_lr self.model_trainer.model.parameters(), @@ -173,11 +214,26 @@ def _instanciate_opt(self): def _aggregate(self, w_locals): + """ + Aggregate the local model weights to obtain global model weights. + + Args: + w_locals: List of local model weights. + + Returns: + avg_params: Aggregated global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients. + + Args: + round_idx: Index of the communication round. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -253,6 +309,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set. + + Args: + round_idx: Index of the communication round. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/mime/opt_utils.py b/python/fedml/simulation/sp/mime/opt_utils.py index 303b386db2..28ac32a944 100644 --- a/python/fedml/simulation/sp/mime/opt_utils.py +++ b/python/fedml/simulation/sp/mime/opt_utils.py @@ -4,6 +4,12 @@ def show_opt_state(optimizer): + """ + Display selected optimizer's state information. + + Args: + optimizer: The optimizer to display state information for. + """ i = 0 for p in optimizer.state.keys(): # print(list(optimizer.state[p].keys())) @@ -18,6 +24,12 @@ def show_opt_state(optimizer): print(key, torch.norm((optimizer.state[p][key]))) def show_named_state(named_states): + """ + Display state information for a dictionary of named states. + + Args: + named_states (dict): A dictionary containing named states to display. + """ i = 0 for name in named_states.keys(): # print(list(optimizer.state[p].keys())) @@ -34,8 +46,14 @@ def show_named_state(named_states): class OptimizerLoader(): - def __init__(self, model, optimizer): + """ + Initialize the OptimizerLoader. + + Args: + model: The model being optimized. + optimizer: The optimizer used for training. + """ self.optimizer = optimizer self.model = model self.named_states = {} @@ -50,9 +68,22 @@ def __init__(self, model, optimizer): # print(key, type(optimizer.state[p][key])) def get_opt_state(self): + """ + Get the optimizer's named states. + + Returns: + dict: A dictionary containing the optimizer's named states. + """ return self.named_states def set_opt_state(self, named_states, device="cpu"): + """ + Set the optimizer's named states. + + Args: + named_states (dict): A dictionary containing the named states to set. + device (str): The target device for the named states (default is "cpu"). + """ for p in self.optimizer.state.keys(): new_state = named_states[self.parameter_names[p]] # for key in self.optimizer.state[p].keys(): @@ -61,12 +92,25 @@ def set_opt_state(self, named_states, device="cpu"): # print(key, type(self.optimizer.state[p][key])) def get_grad(self): + """ + Get the gradients of the model's parameters. + + Returns: + dict: A dictionary containing the gradients of the model's parameters. + """ grad = {} for name, parameter in self.model.named_parameters(): grad[name] = parameter.grad return grad def set_grad(self, grad, device="cpu"): + """ + Set the gradients of the model's parameters. + + Args: + grad (dict): A dictionary containing the gradients to set. + device (str): The target device for the gradients (default is "cpu"). + """ for name, parameter in self.model.named_parameters(): # logging.info(f"parameter.grad: {type(parameter.grad)}, grad[name]: {type(grad[name])} ") # logging.info(f"parameter.grad.shape: {parameter.grad.shape}, grad[name].shape: {grad[name].shape} ") @@ -76,9 +120,21 @@ def set_grad(self, grad, device="cpu"): return grad def zero_grad(self): + """ + Zero out the gradients of the model's parameters. + """ self.optimizer.zero_grad() def update_opt_state(self, update_model=False): + """ + Update the optimizer's state after a step. + + Args: + update_model (bool): Whether to update the model's parameters as well (default is False). + + Returns: + dict: A dictionary containing the updated optimizer's named states. + """ if not update_model: origin_model_params = self.model.state_dict() self.optimizer.step() diff --git a/python/fedml/simulation/sp/scaffold/client.py b/python/fedml/simulation/sp/scaffold/client.py index eed328cc24..6b610122f6 100644 --- a/python/fedml/simulation/sp/scaffold/client.py +++ b/python/fedml/simulation/sp/scaffold/client.py @@ -6,6 +6,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client for federated learning. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The DataLoader for local training data. + local_test_data (torch.utils.data.DataLoader): The DataLoader for local test data. + local_sample_number (int): The number of local training samples. + args: The arguments for the client. + device (torch.device): The device to perform computations on. + model_trainer: The model trainer used for training. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -20,6 +32,15 @@ def __init__( def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The DataLoader for local training data. + local_test_data (torch.utils.data.DataLoader): The DataLoader for local test data. + local_sample_number (int): The number of local training samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -27,9 +48,25 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local training samples. + + Returns: + int: The number of local training samples. + """ return self.local_sample_number def train(self, w_global, c_model_global_param): + """ + Perform local training for the client. + + Args: + w_global: The global model parameters. + c_model_global_param: The global model parameters of the central model. + + Returns: + tuple: A tuple containing weights_delta and c_delta_para. + """ c_model_global_param = deepcopy(c_model_global_param) c_model_local_param = self.c_model_local.state_dict() @@ -56,6 +93,15 @@ def train(self, w_global, c_model_global_param): return weights_delta, c_delta_para def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's dataset. + + Args: + b_use_test_dataset (bool): If True, use the test dataset; if False, use the training dataset. + + Returns: + dict: A dictionary containing test metrics. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/scaffold/scaffold_trainer.py b/python/fedml/simulation/sp/scaffold/scaffold_trainer.py index 9a812b8a32..4caea49898 100644 --- a/python/fedml/simulation/sp/scaffold/scaffold_trainer.py +++ b/python/fedml/simulation/sp/scaffold/scaffold_trainer.py @@ -16,6 +16,15 @@ class ScaffoldTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the ScaffoldTrainer. + + Args: + dataset: A list of dataset components. + model: The model to be trained. + device: The computing device (e.g., 'cuda' or 'cpu'). + args: Training arguments. + """ self.device = device self.args = args [ @@ -58,6 +67,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for training. + + Args: + train_data_local_num_dict: A dictionary of local training dataset sizes. + train_data_local_dict: A dictionary of local training datasets. + test_data_local_dict: A dictionary of local test datasets. + model_trainer: The model trainer instance. + """ logging.info("############setup_clients (START)#############") if self.args.initialize_all_clients: num_initialized_clients = self.args.client_num_in_total @@ -77,6 +95,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Perform the training process. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -155,6 +176,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample clients for federated learning communication round. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select per round. + + Returns: + list: List of client indexes selected for communication. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -165,6 +197,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset subset for local validation. + + Args: + num_samples (int): Number of samples in the validation dataset subset. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -172,6 +210,15 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model updates using FedMLAggOperator. + + Args: + w_locals: List of local model updates. + + Returns: + tuple: Total aggregated model update and total client delta parameters. + """ # training_num = 0 # for idx in range(len(w_locals)): # (sample_num, averaged_params) = w_locals[idx] @@ -185,6 +232,12 @@ def _aggregate(self, w_locals): def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients for both training and test datasets. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -246,6 +299,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on a validation set for the specified round. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/turboaggregate/TA_client.py b/python/fedml/simulation/sp/turboaggregate/TA_client.py index 0d617c3b24..4f6356a941 100644 --- a/python/fedml/simulation/sp/turboaggregate/TA_client.py +++ b/python/fedml/simulation/sp/turboaggregate/TA_client.py @@ -4,6 +4,18 @@ class TA_Client(Client): + """ + A subclass of the Client class for a specific type of client. + + Args: + client_idx (int): The index of the client. + local_training_data: The local training data for the client. + local_test_data: The local test data for the client. + local_sample_number: The number of local samples. + args: Additional arguments. + device: The computing device (e.g., 'cuda' or 'cpu'). + model_trainer: The model trainer for this client. + """ def __init__( self, client_idx, @@ -30,4 +42,10 @@ def __init__( # self.buffer_out = np.zeros(dtype='int') def set_dropout(self, isdrop): + """ + Set the dropout flag for this client. + + Args: + isdrop (bool): Whether to enable dropout for this client. + """ self.isdrop = isdrop diff --git a/python/fedml/simulation/sp/turboaggregate/TA_trainer.py b/python/fedml/simulation/sp/turboaggregate/TA_trainer.py index 59283423be..90041e16df 100644 --- a/python/fedml/simulation/sp/turboaggregate/TA_trainer.py +++ b/python/fedml/simulation/sp/turboaggregate/TA_trainer.py @@ -10,6 +10,15 @@ class TurboAggregateTrainer(object): + """ + TurboAggregateTrainer for federated learning with Turbo-Aggregate protocol. + + Args: + dataset: A list containing dataset-related information. + model: The global model for training. + device: The computing device (e.g., 'cuda' or 'cpu'). + args: Additional training arguments. + """ def __init__(self, dataset, model, device, args): self.device = device self.args = args @@ -40,6 +49,14 @@ def __init__(self, dataset, model, device, args): self.setup_clients(data_local_num_dict, train_data_local_dict, test_data_local_dict) def setup_clients(self, data_local_num_dict, train_data_local_dict, test_data_local_dict): + """ + Set up the list of clients for federated learning. + + Args: + data_local_num_dict: A dictionary containing the number of local samples for each client. + train_data_local_dict: A dictionary containing local training data for each client. + test_data_local_dict: A dictionary containing local test data for each client. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_in_total): c = TA_Client( @@ -55,6 +72,9 @@ def setup_clients(self, data_local_num_dict, train_data_local_dict, test_data_lo logging.info("############setup_clients (END)#############") def train(self): + """ + Train the global model using the Turbo-Aggregate protocol. + """ for round_idx in range(self.args.comm_round): logging.info("Communication round : {}".format(round_idx)) w_global = self.model_trainer.get_model_params() @@ -94,6 +114,15 @@ def train(self): self.local_test(self.model_global, round_idx) def aggregate(self, w_locals): + """ + Aggregate the local model weights from clients using Turbo-Aggregate. + + Args: + w_locals: List of local model weights from clients. + + Returns: + Averaged global model weights. + """ logging.info("################aggregate: %d" % len(w_locals)) (sample_num, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -107,6 +136,7 @@ def aggregate(self, w_locals): return averaged_params def TA_topology_vanilla(self): + # logging.info("################aggregate: %d" % len(w_locals)) # N = self.args.client_number @@ -119,10 +149,24 @@ def TA_topology_vanilla(self): pass def local_test(self, model_global, round_idx): + """ + Perform local testing on clients. + + Args: + model_global: The global model to evaluate. + round_idx: The communication round index. + """ self.local_test_on_training_data(model_global, round_idx) self.local_test_on_test_data(model_global, round_idx) def local_test_on_training_data(self, model_global, round_idx): + """ + Perform local testing on training data for clients. + + Args: + model_global: The global model to evaluate. + round_idx: The communication round index. + """ num_samples = [] tot_corrects = [] losses = [] @@ -148,6 +192,13 @@ def local_test_on_training_data(self, model_global, round_idx): logging.info(stats) def local_test_on_test_data(self, model_global, round_idx): + """ + Perform local testing on test data for clients. + + Args: + model_global: The global model to evaluate. + round_idx: The communication round index. + """ num_samples = [] tot_corrects = [] losses = [] @@ -172,6 +223,9 @@ def local_test_on_test_data(self, model_global, round_idx): logging.info(stats) def global_test(self): + """ + Perform global testing using the global dataset and log the results. + """ logging.info("################global_test") acc_train, num_sample, loss_train = self.test_using_global_dataset( self.model_global, self.train_global, self.device @@ -190,6 +244,17 @@ def global_test(self): wandb.log({"Global Testing Accuracy": acc_test}) def test_using_global_dataset(self, model_global, global_test_data, device): + """ + Test the global model using the global test dataset. + + Args: + model_global: The global model to evaluate. + global_test_data: The global test dataset. + device: The computing device (e.g., 'cuda' or 'cpu'). + + Returns: + Tuple of testing accuracy, total samples, and testing loss. + """ model_global.eval() model_global.to(device) test_loss = test_acc = test_total = 0.0 diff --git a/python/fedml/simulation/sp/turboaggregate/mpc_function.py b/python/fedml/simulation/sp/turboaggregate/mpc_function.py index e2ab80b983..3c4e976761 100644 --- a/python/fedml/simulation/sp/turboaggregate/mpc_function.py +++ b/python/fedml/simulation/sp/turboaggregate/mpc_function.py @@ -2,6 +2,16 @@ def modular_inv(a, p): + """ + Compute the modular inverse of 'a' modulo 'p'. + + Args: + a (int): The number for which the modular inverse is calculated. + p (int): The prime modulo. + + Returns: + int: The modular inverse of 'a' modulo 'p'. + """ x, y, m = 1, 0, p while a > 1: q = a // m @@ -19,6 +29,17 @@ def modular_inv(a, p): def divmod(_num, _den, _p): + """ + Compute the result of `_num` divided by `_den` modulo prime `_p`. + + Args: + _num (int): The numerator. + _den (int): The denominator. + _p (int): The prime modulo. + + Returns: + int: The result of `_num / _den` modulo `_p`. + """ # compute num / den modulo prime p _num = np.mod(_num, _p) _den = np.mod(_den, _p) @@ -28,6 +49,16 @@ def divmod(_num, _den, _p): def PI(vals, p): # upper-case PI -- product of inputs + """ + Compute the product of a list of values modulo prime 'p'. + + Args: + vals (list): List of integers to be multiplied. + p (int): The prime modulo. + + Returns: + int: The product of the values in 'vals' modulo 'p'. + """ accum = 1 for v in vals: @@ -37,6 +68,18 @@ def PI(vals, p): # upper-case PI -- product of inputs def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): + """ + Generate Lagrange coefficients for polynomial interpolation. + + Args: + alpha_s (list): List of alpha values. + beta_s (list): List of beta values. + p (int): The prime modulo. + is_K1 (int): Flag indicating if it's K1. + + Returns: + numpy.ndarray: A matrix of Lagrange coefficients. + """ if is_K1 == 1: num_alpha = 1 else: @@ -60,6 +103,18 @@ def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): def BGW_encoding(X, N, T, p): + """ + Perform BGW encoding of input data X. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + N (int): The number of evaluation points. + T (int): The number of terms for encoding. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The BGW encoded data of shape (N, m, d). + """ m = len(X) d = len(X[0]) @@ -76,6 +131,16 @@ def BGW_encoding(X, N, T, p): def gen_BGW_lambda_s(alpha_s, p): + """ + Generate BGW lambda values for polynomial interpolation. + + Args: + alpha_s (numpy.ndarray): The alpha values. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The lambda values. + """ lambda_s = np.zeros((1, len(alpha_s)), dtype="int64") for i in range(len(alpha_s)): @@ -87,7 +152,19 @@ def gen_BGW_lambda_s(alpha_s, p): return lambda_s.astype("int64") -def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluation points +def BGW_decoding(f_eval, worker_idx, p): + """ + Decode the output from T+1 evaluation points using BGW decoding. + + Args: + f_eval (numpy.ndarray): The evaluation points of shape (RT, d). + worker_idx (numpy.ndarray): The worker indices of shape (1, RT). + p (int): The prime modulo. + + Returns: + numpy.ndarray: The decoded output of shape (1, d). + """ + # decode the output from T+1 evaluation points # f_eval : [RT X d ] # worker_idx : [ 1 X RT] # output : [ 1 X d ] @@ -109,6 +186,19 @@ def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluatio def LCC_encoding(X, N, K, T, p): + """ + Perform LCC encoding of input data X. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The LCC encoded data of shape (N, m//K, d). + """ m = len(X) d = len(X[0]) # print(m,d,m//K) @@ -135,6 +225,20 @@ def LCC_encoding(X, N, K, T, p): def LCC_encoding_w_Random(X, R_, N, K, T, p): + """ + Perform LCC encoding of input data X with random data R_. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + R_ (numpy.ndarray): Random data of shape (T, m // K, d). + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The LCC encoded data of shape (N, m//K, d). + """ m = len(X) d = len(X[0]) # print(m,d,m//K) @@ -165,6 +269,21 @@ def LCC_encoding_w_Random(X, R_, N, K, T, p): def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): + """ + Perform partial LCC encoding of input data X with random data R_ for specific workers. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + R_ (numpy.ndarray): Random data of shape (T, m // K, d). + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + p (int): The prime modulo. + worker_idx (numpy.ndarray): Worker indices for partial encoding. + + Returns: + numpy.ndarray: The partial LCC encoded data of shape (N_out, m//K, d). + """ m = len(X) d = len(X[0]) # print(m,d,m//K) @@ -193,6 +312,21 @@ def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): + """ + Perform LCC decoding of the given evaluation points and worker indices. + + Args: + f_eval (numpy.ndarray): The evaluation points of shape (RT, d). + f_deg (int): The degree of the polynomial. + N (int): The number of encoding points. + K (int): The number of systematic points. + T (int): The number of redundant points. + worker_idx (numpy.ndarray): Worker indices for decoding. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The decoded output of shape (1, d). + """ # RT_LCC = f_deg * (K + T - 1) + 1 n_beta = K # +T @@ -212,6 +346,17 @@ def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): def Gen_Additive_SS(d, n_out, p): + """ + Generate an additive secret sharing matrix. + + Args: + d (int): The dimension of the secret. + n_out (int): The number of output shares. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The additive secret sharing matrix. + """ # x_model should be one dimension temp = np.random.randint(0, p, size=(n_out - 1, d)) @@ -225,6 +370,18 @@ def Gen_Additive_SS(d, n_out, p): def LCC_encoding_with_points(X, alpha_s, beta_s, p): + """ + Perform LCC encoding of input data X using specific alpha and beta points. + + Args: + X (numpy.ndarray): The input data of shape (m, d). + alpha_s (numpy.ndarray): The alpha points. + beta_s (numpy.ndarray): The beta points. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The LCC encoded data of shape (N, d). + """ m, d = np.shape(X) # print alpha_s @@ -247,6 +404,18 @@ def LCC_encoding_with_points(X, alpha_s, beta_s, p): def LCC_decoding_with_points(f_eval, eval_points, target_points, p): + """ + Perform LCC decoding of the given evaluation points and target points. + + Args: + f_eval (numpy.ndarray): The evaluation points of shape (RT, d). + eval_points (numpy.ndarray): The evaluation points for decoding. + target_points (numpy.ndarray): The target points for decoding. + p (int): The prime modulo. + + Returns: + numpy.ndarray: The decoded output of shape (1, d). + """ alpha_s_eval = eval_points beta_s = target_points @@ -261,6 +430,17 @@ def LCC_decoding_with_points(f_eval, eval_points, target_points, p): def my_pk_gen(my_sk, p, g): + """ + Generate a public key using the private key and prime modulo. + + Args: + my_sk (int): The private key. + p (int): The prime modulo. + g (int): An optional generator. + + Returns: + int: The public key. + """ # print 'my_pk_gen option: g=',g if g == 0: return my_sk @@ -269,6 +449,18 @@ def my_pk_gen(my_sk, p, g): def my_key_agreement(my_sk, u_pk, p, g): + """ + Perform key agreement using private key, public key, prime modulo, and an optional generator. + + Args: + my_sk (int): The private key. + u_pk (int): The other party's public key. + p (int): The prime modulo. + g (int): An optional generator. + + Returns: + int: The shared secret key. + """ if g == 0: return np.mod(my_sk * u_pk, p) else: From 38815123ef27c77ef7643433d6fb8b5093cc4600 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 6 Sep 2023 22:42:44 +0530 Subject: [PATCH 41/70] same --- python/fedml/simulation/sp/feddyn/client.py | 54 +++++++++++++ .../sp/feddyn/feddyn_trainer copy.py | 80 +++++++++++++++++++ .../simulation/sp/feddyn/feddyn_trainer.py | 80 +++++++++++++++++++ python/fedml/simulation/sp/fednova/client.py | 73 +++++++++++++++++ .../simulation/sp/fednova/comm_helpers.py | 52 +++++++----- python/fedml/simulation/sp/fednova/fednova.py | 19 ++++- .../simulation/sp/fednova/fednova_api.py | 80 ++++++++++++++++++- .../simulation/sp/fednova/fednova_trainer.py | 59 ++++++++++++++ python/fedml/simulation/sp/fedopt/client.py | 45 +++++++++++ .../fedml/simulation/sp/fedopt/fedopt_api.py | 77 ++++++++++++++++++ python/fedml/simulation/sp/fedopt/optrepo.py | 6 +- python/fedml/simulation/sp/fedprox/client.py | 49 +++++++++++- .../simulation/sp/fedprox/fedprox_trainer.py | 65 +++++++++++++++ 13 files changed, 713 insertions(+), 26 deletions(-) diff --git a/python/fedml/simulation/sp/feddyn/client.py b/python/fedml/simulation/sp/feddyn/client.py index 3b37dd73b1..6fa77ef55a 100644 --- a/python/fedml/simulation/sp/feddyn/client.py +++ b/python/fedml/simulation/sp/feddyn/client.py @@ -4,6 +4,15 @@ def model_parameter_vector(model): + """ + Flatten and concatenate the model parameters into a single vector. + + Args: + model (nn.Module): The PyTorch model. + + Returns: + torch.Tensor: The concatenated parameter vector. + """ param = [p.view(-1) for p in model.parameters()] return torch.concat(param, dim=0) @@ -12,6 +21,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client for federated learning. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The local training dataset. + local_test_data (torch.utils.data.DataLoader): The local test dataset. + local_sample_number (int): The number of local samples. + args: The command-line arguments. + device (torch.device): The device (e.g., "cuda" or "cpu") for computation. + model_trainer: The model trainer responsible for training and testing. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -30,6 +51,15 @@ def __init__( def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the client's local dataset. + + Args: + client_idx (int): The index of the client. + local_training_data (torch.utils.data.DataLoader): The new local training dataset. + local_test_data (torch.utils.data.DataLoader): The new local test dataset. + local_sample_number (int): The number of local samples in the new dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -37,9 +67,24 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local samples. + + Returns: + int: The number of local samples. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's model using global weights. + + Args: + w_global: The global model weights. + + Returns: + dict: The updated client's model weights. + """ self.model_trainer.set_model_params(w_global) self.old_grad = self.model_trainer.train(self.local_training_data, self.device, self.args, self.old_grad) weights = self.model_trainer.get_model_params() @@ -47,6 +92,15 @@ def train(self, w_global): return weights def local_test(self, b_use_test_dataset): + """ + Perform local testing on the client's model. + + Args: + b_use_test_dataset (bool): Whether to use the test dataset for testing. + + Returns: + dict: Test metrics including correctness, loss, and more. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py b/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py index 60f50d2f3e..d04f6e07bb 100644 --- a/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py +++ b/python/fedml/simulation/sp/feddyn/feddyn_trainer copy.py @@ -17,6 +17,15 @@ class FedDynTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the FedDynTrainer. + + Args: + dataset: A tuple containing dataset information. + model: The model to be trained. + device: The device to run the training on (e.g., 'cpu' or 'cuda'). + args: Additional training configuration and hyperparameters. + """ self.device = device self.args = args [ @@ -59,6 +68,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for training. + + Args: + train_data_local_num_dict: A dictionary containing the number of samples for each local training dataset. + train_data_local_dict: A dictionary containing local training datasets. + test_data_local_dict: A dictionary containing local test datasets. + model_trainer: The model trainer instance. + """ logging.info("############setup_clients (START)#############") if self.args.initialize_all_clients: num_initialized_clients = self.args.client_num_in_total @@ -78,6 +96,11 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Train the federated dynamic model using FedDyn. + + This method performs the federated training loop, including client selection, training, aggregation, and testing. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -175,6 +198,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Select a subset of clients for communication in each round. + + Args: + round_idx: The current communication round index. + client_num_in_total: The total number of clients. + client_num_per_round: The number of clients to select in each round. + + Returns: + A list of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -185,6 +219,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset from the test dataset. + + Args: + num_samples: The number of samples to include in the validation dataset. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -192,11 +232,31 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model weights from all clients. + + Args: + w_locals: A list of tuples containing the number of samples and local model weights for each client. + + Returns: + The aggregated global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _test(self, test_data, device, args): + """ + Perform testing on the test dataset. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + + Returns: + A dictionary containing testing metrics (e.g., test accuracy, test loss). + """ model = self.model_trainer.model model.to(device) @@ -222,6 +282,14 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Perform testing on the test dataset and log testing metrics. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -253,6 +321,12 @@ def test(self, test_data, device, args): def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -314,6 +388,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set for all clients. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/feddyn/feddyn_trainer.py b/python/fedml/simulation/sp/feddyn/feddyn_trainer.py index 42b3eaa79c..c25251586f 100644 --- a/python/fedml/simulation/sp/feddyn/feddyn_trainer.py +++ b/python/fedml/simulation/sp/feddyn/feddyn_trainer.py @@ -17,6 +17,15 @@ class FedDynTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the FedDynTrainer. + + Args: + dataset: A tuple containing dataset information. + model: The model to be trained. + device: The device to run the training on (e.g., 'cpu' or 'cuda'). + args: Additional training configuration and hyperparameters. + """ self.device = device self.args = args [ @@ -59,6 +68,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up client instances for training. + + Args: + train_data_local_num_dict: A dictionary containing the number of samples for each local training dataset. + train_data_local_dict: A dictionary containing local training datasets. + test_data_local_dict: A dictionary containing local test datasets. + model_trainer: The model trainer instance. + """ logging.info("############setup_clients (START)#############") if self.args.initialize_all_clients: num_initialized_clients = self.args.client_num_in_total @@ -78,6 +96,11 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Train the federated dynamic model using FedDyn. + + This method performs the federated training loop, including client selection, training, aggregation, and testing. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -170,6 +193,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Select a subset of clients for communication in each round. + + Args: + round_idx: The current communication round index. + client_num_in_total: The total number of clients. + client_num_per_round: The number of clients to select in each round. + + Returns: + A list of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -180,6 +214,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset from the test dataset. + + Args: + num_samples: The number of samples to include in the validation dataset. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -187,11 +227,31 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model weights from all clients. + + Args: + w_locals: A list of tuples containing the number of samples and local model weights for each client. + + Returns: + The aggregated global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _test(self, test_data, device, args): + """ + Perform testing on the test dataset. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + + Returns: + A dictionary containing testing metrics (e.g., test accuracy, test loss). + """ model = self.model_trainer.model model.to(device) @@ -217,6 +277,14 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Perform testing on the test dataset and log testing metrics. + + Args: + test_data: The test dataset. + device: The device to run the testing on (e.g., 'cpu' or 'cuda'). + args: Additional testing configuration and hyperparameters. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -248,6 +316,12 @@ def test(self, test_data, device, args): def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -309,6 +383,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set for all clients. + + Args: + round_idx: The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/fednova/client.py b/python/fedml/simulation/sp/fednova/client.py index 32703acc87..0cbd2b2762 100644 --- a/python/fedml/simulation/sp/fednova/client.py +++ b/python/fedml/simulation/sp/fednova/client.py @@ -16,6 +16,17 @@ def __init__( args, device, ): + """ + Initialize a client instance. + + Args: + client_idx (int): The index of the client. + local_training_data: The local training data for this client. + local_test_data: The local test data for this client. + local_sample_number: The number of samples in the local training data. + args: Command-line arguments. + device: The device (e.g., "cpu" or "cuda") on which to perform computations. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -39,15 +50,42 @@ def __init__( def update_local_dataset( self, client_idx, local_training_data, local_test_data, local_sample_number ): + """ + Update the local datasets for the client. + + Args: + client_idx (int): The index of the client. + local_training_data: The new local training data. + local_test_data: The new local test data. + local_sample_number: The number of samples in the new local training data. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number def get_sample_number(self): + """ + Get the number of samples in the local training data. + + Returns: + int: The number of samples in the local training data. + """ return self.local_sample_number def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): + """ + Calculate the local normalized gradient. + + Args: + opt: The FedNova optimizer. + cur_params: The current parameters of the model. + init_params: The initial parameters of the model. + weight (float): Weight factor for the gradient calculation. + + Returns: + dict: A dictionary containing the local normalized gradients. + """ if weight == 0: weight = opt.ratio grad_dict = {} @@ -59,12 +97,27 @@ def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): return grad_dict def get_local_tau_eff(self, opt): + """ + Calculate the local effective tau. + + Args: + opt: The FedNova optimizer. + + Returns: + float: The local effective tau. + """ if opt.mu != 0: return opt.local_steps * opt.ratio else: return opt.local_normalizing_vec * opt.ratio def reset_fednova_optimizer(self, opt): + """ + Reset the FedNova optimizer state for the client. + + Args: + opt: The FedNova optimizer. + """ opt.local_counter = 0 opt.local_normalizing_vec = 0 opt.local_steps = 0 @@ -77,6 +130,16 @@ def reset_fednova_optimizer(self, opt): param_state["momentum_buffer"].zero_() def train(self, net, ratio): + """ + Train the model on the local training data. + + Args: + net: The neural network model. + ratio: The ratio used in training. + + Returns: + tuple: A tuple containing the loss, gradients, and effective tau. + """ net.train() # train and update init_params = copy.deepcopy(net.state_dict()) @@ -120,6 +183,16 @@ def train(self, net, ratio): return sum(epoch_loss) / len(epoch_loss), norm_grad, tau_eff def local_test(self, model_global, b_use_test_dataset=False): + """ + Evaluate the performance of the global model on the local test or training dataset. + + Args: + model_global: The global model to evaluate. + b_use_test_dataset (bool): Whether to use the local test dataset. If False, uses the local training dataset. + + Returns: + dict: A dictionary containing evaluation metrics, including accuracy, loss, precision, recall, and total samples. + """ model_global.eval() model_global.to(self.device) metrics = { diff --git a/python/fedml/simulation/sp/fednova/comm_helpers.py b/python/fedml/simulation/sp/fednova/comm_helpers.py index 94b2dfa1da..2ff37fcfd2 100644 --- a/python/fedml/simulation/sp/fednova/comm_helpers.py +++ b/python/fedml/simulation/sp/fednova/comm_helpers.py @@ -7,16 +7,21 @@ def flatten_tensors(tensors): """ + Flatten a list of dense tensors into a contiguous 1D buffer. + Reference: https://github.com/facebookresearch/stochastic_gradient_push - Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of - same dense type. + This function takes a list of dense tensors and flattens them into a single + contiguous 1D buffer. It assumes that all input tensors are of the same dense type. + Since inputs are dense, the resulting tensor will be a concatenated 1D buffer. Element-wise operation on this buffer will be equivalent to operating individually. - Arguments: - tensors (Iterable[Tensor]): dense tensors to flatten. + + Args: + tensors (Iterable[Tensor]): The list of dense tensors to flatten. + Returns: - A 1D buffer containing input tensors. + Tensor: A 1D buffer containing the flattened input tensors. """ if len(tensors) == 1: return tensors[0].view(-1).clone() @@ -27,15 +32,19 @@ def flatten_tensors(tensors): def unflatten_tensors(flat, tensors): """ Reference: https://github.com/facebookresearch/stochastic_gradient_push - View a flat buffer using the sizes of tensors. Assume that tensors are of - same dense type, and that flat is given by flatten_dense_tensors. - Arguments: - flat (Tensor): flattened dense tensors to unflatten. - tensors (Iterable[Tensor]): dense tensors whose sizes will be used to - unflatten flat. + Unflatten a flat buffer into a list of tensors using their original sizes. + + This function takes a flat buffer and unflattens it into a list of tensors using + the sizes of the original tensors. It assumes that all input tensors are of the + same dense type and that the flat buffer was generated using `flatten_tensors`. + + Args: + flat (Tensor): The flattened dense tensors to unflatten. + tensors (Iterable[Tensor]): The dense tensors whose sizes will be used to + unflatten the flat buffer. + Returns: - Unflattened dense tensors with sizes same as tensors and values from - flat. + tuple: Unflattened dense tensors with sizes same as `tensors` and values from `flat`. """ outputs = [] offset = 0 @@ -48,13 +57,18 @@ def unflatten_tensors(flat, tensors): def communicate(tensors, communication_op): """ + Communicate a list of tensors using a specified communication operation. + Reference: https://github.com/facebookresearch/stochastic_gradient_push - Communicate a list of tensors. - Arguments: - tensors (Iterable[Tensor]): list of tensors. - communication_op: a method or partial object which takes a tensor as - input and communicates it. It can be a partial object around - something like torch.distributed.all_reduce. + This function takes a list of tensors and communicates them using a specified + communication operation. It assumes that the communication_op can handle the + provided tensors appropriately, such as performing an all-reduce operation. + + Args: + tensors (Iterable[Tensor]): List of tensors to be communicated. + communication_op: A method or partial object which takes a tensor as input + and communicates it. It can be a partial object around something like + `torch.distributed.all_reduce`. """ flat_tensor = flatten_tensors(tensors) communication_op(tensor=flat_tensor) diff --git a/python/fedml/simulation/sp/fednova/fednova.py b/python/fedml/simulation/sp/fednova/fednova.py index 63e4662e60..a1590972c7 100644 --- a/python/fedml/simulation/sp/fednova/fednova.py +++ b/python/fedml/simulation/sp/fednova/fednova.py @@ -94,10 +94,15 @@ def __setstate__(self, state): group.setdefault("nesterov", False) def step(self, closure=None): - """Performs a single optimization step. - Arguments: + """ + Performs a single optimization step. + + Args: closure (callable, optional): A closure that reevaluates the model and returns the loss. + + Returns: + loss: The loss after the optimization step. """ loss = None @@ -169,6 +174,16 @@ def step(self, closure=None): return loss def average(self, weight=0, tau_eff=0): + """ + Averages accumulated local gradients across clients. + + Args: + weight (float, optional): Weight factor for averaging (default: 0). + tau_eff (float, optional): Effective tau value (default: 0). + + Returns: + None + """ if weight == 0: weight = self.ratio if tau_eff == 0: diff --git a/python/fedml/simulation/sp/fednova/fednova_api.py b/python/fedml/simulation/sp/fednova/fednova_api.py index 543370c9d6..214c057be0 100644 --- a/python/fedml/simulation/sp/fednova/fednova_api.py +++ b/python/fedml/simulation/sp/fednova/fednova_api.py @@ -12,6 +12,15 @@ class FedAvgAPI(object): def __init__(self, args, device, dataset, model): + """ + Initialize the FedAvgAPI. + + Args: + args (object): Arguments object containing configuration settings. + device (object): Device on which to perform computations. + dataset (list): List containing dataset information. + model (object): Machine learning model. + """ self.device = device self.args = args [ @@ -46,6 +55,15 @@ def __init__(self, args, device, dataset, model): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up the clients for federated training. + + Args: + train_data_local_num_dict (dict): Dictionary containing the number of local training samples for each client. + train_data_local_dict (dict): Dictionary containing local training data for each client. + test_data_local_dict (dict): Dictionary containing local test data for each client. + model_trainer (object): Model trainer object. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -60,6 +78,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Perform federated training using the FedAvg algorithm. + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() for round_idx in range(self.args.comm_round): @@ -108,6 +129,17 @@ def train(self): self._local_test_on_all_clients(round_idx) def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for federated training. + + Args: + round_idx (int): Index of the communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select for the current round. + + Returns: + list: List of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -118,6 +150,16 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set by randomly sampling from the global test dataset. + + Args: + num_samples (int, optional): Number of samples to include in the validation set. Default is 10,000. + + Note: + This function samples `num_samples` from the global test dataset and stores it as the validation set (`self.val_global`). + + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -125,6 +167,16 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model parameters weighted by the number of samples. + + Args: + w_locals (list): List of tuples, where each tuple contains the number of local samples and local model parameters. + + Returns: + dict: Averaged global model parameters. + + """ training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -143,10 +195,17 @@ def _aggregate(self, w_locals): def _aggregate_noniid_avg(self, w_locals): """ - The old aggregate method will impact the model performance when it comes to Non-IID setting + Aggregate local model parameters using a simple average, suitable for Non-IID settings. + Args: - w_locals: + w_locals (list): List of tuples, where each tuple contains the number of local samples and local model parameters. + Returns: + dict: Averaged global model parameters. + + Note: + In Non-IID settings, where the data distribution among clients is not identical, a simple average of local model parameters may be used for aggregation. This method averages the model parameters across clients for each parameter independently. + """ (_, averaged_params) = w_locals[0] for k in averaged_params.keys(): @@ -157,6 +216,16 @@ def _aggregate_noniid_avg(self, w_locals): return averaged_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients and log the results. + + Args: + round_idx (int): The current communication round index. + + Note: + This function iterates over all clients and performs testing on both training and test datasets. It then logs the training and test accuracy along with losses. + + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -212,6 +281,13 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set and log the results. + + Args: + round_idx (int): The current communication round index. + + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) diff --git a/python/fedml/simulation/sp/fednova/fednova_trainer.py b/python/fedml/simulation/sp/fednova/fednova_trainer.py index bbf72182ef..539a1e47d1 100644 --- a/python/fedml/simulation/sp/fednova/fednova_trainer.py +++ b/python/fedml/simulation/sp/fednova/fednova_trainer.py @@ -10,6 +10,15 @@ class FedNovaTrainer(object): def __init__(self, dataset, model, device, args): + """ + Initialize the FedNovaTrainer. + + Args: + dataset (tuple): A tuple containing dataset information. + model (torch.nn.Module): The global model to be trained. + device (torch.device): The target device for model training. + args (argparse.Namespace): Command-line arguments. + """ self.device = device self.args = args [ @@ -41,6 +50,17 @@ def __init__(self, dataset, model, device, args): def setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict ): + """ + Set up client instances for federated training. + + Args: + train_data_local_num_dict (dict): Dictionary containing local training data sizes. + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + + Returns: + None + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -55,6 +75,17 @@ def setup_clients( logging.info("############setup_clients (END)#############") def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Perform client sampling for federated training. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample in each round. + + Returns: + list: List of client indexes selected for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -71,6 +102,12 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def train(self): + """ + Perform federated training using FedNova optimizer. + + Returns: + None + """ for round_idx in range(self.args.comm_round): logging.info("################Communication round : {}".format(round_idx)) @@ -134,6 +171,18 @@ def train(self): self.local_test_on_all_clients(self.model_global, round_idx) def aggregate(self, params, norm_grads, tau_effs, tau_eff=0): + """ + Aggregate local gradients and update global model parameters. + + Args: + params (dict): Dictionary containing global model parameters. + norm_grads (list of dict): List of dictionaries containing normalized local gradients. + tau_effs (list): List of effective tau values for each client. + tau_eff (float): Effective tau value (optional). + + Returns: + dict: Updated global model parameters. + """ # get tau_eff if tau_eff == 0: tau_eff = sum(tau_effs) @@ -164,6 +213,16 @@ def aggregate(self, params, norm_grads, tau_effs, tau_eff=0): return params def local_test_on_all_clients(self, model_global, round_idx): + """ + Perform local testing on all clients and log results. + + Args: + model (torch.nn.Module): The global model for testing. + round_idx (int): The current communication round index. + + Returns: + None + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) train_metrics = { "num_samples": [], diff --git a/python/fedml/simulation/sp/fedopt/client.py b/python/fedml/simulation/sp/fedopt/client.py index 993634a74f..856749a9b2 100644 --- a/python/fedml/simulation/sp/fedopt/client.py +++ b/python/fedml/simulation/sp/fedopt/client.py @@ -5,6 +5,18 @@ class Client: def __init__( self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, ): + """ + Initialize a client in the federated learning system. + + Args: + client_idx (int): The unique identifier for this client. + local_training_data (torch.Dataset): The local training dataset for this client. + local_test_data (torch.Dataset): The local test dataset for this client. + local_sample_number (int): The number of samples in the local training dataset. + args: Additional arguments and settings. + device: The device (e.g., CPU or GPU) on which to perform computations. + model_trainer: The model trainer responsible for training and testing. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -16,21 +28,54 @@ def __init__( self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for this client. + + Args: + client_idx (int): The unique identifier for this client. + local_training_data (torch.Dataset): The new local training dataset. + local_test_data (torch.Dataset): The new local test dataset. + local_sample_number (int): The number of samples in the new local training dataset. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number def get_sample_number(self): + """ + Get the number of samples in the local training dataset. + + Returns: + int: The number of samples in the local training dataset. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's local model. + + Args: + w_global: The global model weights. + + Returns: + weights: The updated local model weights. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Perform local testing using either the local test dataset or local training dataset. + + Args: + b_use_test_dataset (bool): If True, use the local test dataset for testing. Otherwise, use the local training dataset. + + Returns: + metrics: The evaluation metrics obtained during testing. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/fedopt/fedopt_api.py b/python/fedml/simulation/sp/fedopt/fedopt_api.py index 8b0dd9e457..29cb6de340 100644 --- a/python/fedml/simulation/sp/fedopt/fedopt_api.py +++ b/python/fedml/simulation/sp/fedopt/fedopt_api.py @@ -12,6 +12,18 @@ class FedOptAPI(object): + """ + Base class for Federated Optimization. + + This class provides the foundation for federated optimization techniques. It sets up clients, + handles client sampling, and manages the global model and optimizer. + + Args: + args (object): Arguments containing configuration options. + device (str): Device (e.g., 'cpu' or 'cuda') to run computations on. + dataset (tuple): A tuple containing dataset information. + model (torch.nn.Module): The global model used for federated optimization. + """ def __init__(self, args, device, dataset, model): self.device = device self.args = args @@ -44,6 +56,14 @@ def __init__(self, args, device, dataset, model): self._setup_clients(train_data_local_num_dict, train_data_local_dict, test_data_local_dict) def _setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict): + """ + Set up client instances for federated optimization. + + Args: + train_data_local_num_dict (dict): A dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indices to their local training datasets. + test_data_local_dict (dict): A dictionary mapping client indices to their local test datasets. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -59,6 +79,17 @@ def _setup_clients(self, train_data_local_num_dict, train_data_local_dict, test_ logging.info("############setup_clients (END)#############") def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a set of clients for a communication round. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): Total number of clients in the system. + client_num_per_round (int): Number of clients to sample for the current round. + + Returns: + List[int]: A list of sampled client indices. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -69,6 +100,15 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set from the global test dataset. + + Args: + num_samples (int): Number of samples to include in the validation set. + + Notes: + This method updates the `val_global` attribute with the generated validation set. + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -76,6 +116,11 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample/home/chaoyanghe/zhtang_FedML/python/fedml/simulation/sp/fedopt/__pycache___testset def _instanciate_opt(self): + """ + Initialize the server optimizer. + + This method initializes the server optimizer based on the specified server optimizer type and learning rate. + """ self.opt = OptRepo.name2cls(self.args.server_optimizer)( # self.model_global.parameters(), lr=self.args.server_lr self.model_trainer.model.parameters(), @@ -85,6 +130,11 @@ def _instanciate_opt(self): ) def train(self): + """ + Train the global model using federated optimization. + + This method trains the global model using federated optimization over multiple communication rounds. + """ for round_idx in range(self.args.comm_round): w_global = self.model_trainer.get_model_params() logging.info("################ Communication round : {}".format(round_idx)) @@ -141,6 +191,15 @@ def train(self): self._local_test_on_all_clients(round_idx) def _aggregate(self, w_locals): + """ + Aggregate local model weights to compute global model weights. + + Args: + w_locals (list): A list of tuples containing local sample numbers and local model weights. + + Returns: + dict: A dictionary containing aggregated global model weights. + """ training_num = 0 for idx in range(len(w_locals)): (sample_num, averaged_params) = w_locals[idx] @@ -158,6 +217,12 @@ def _aggregate(self, w_locals): return averaged_params def _set_model_global_grads(self, new_state): + """ + Set the gradients of the global model based on the difference between new and current model states. + + Args: + new_state (dict): The new state of the global model. + """ new_model = copy.deepcopy(self.model_trainer.model) new_model.load_state_dict(new_state) with torch.no_grad(): @@ -171,6 +236,12 @@ def _set_model_global_grads(self, new_state): self.model_trainer.set_model_params(new_model_state_dict) def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) train_metrics = {"num_samples": [], "num_correct": [], "losses": []} @@ -231,6 +302,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on the validation set. + + Args: + round_idx (int): The current communication round index. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) if self.val_global is None: diff --git a/python/fedml/simulation/sp/fedopt/optrepo.py b/python/fedml/simulation/sp/fedopt/optrepo.py index 50615227d7..2a82c60e1e 100644 --- a/python/fedml/simulation/sp/fedopt/optrepo.py +++ b/python/fedml/simulation/sp/fedopt/optrepo.py @@ -5,8 +5,12 @@ class OptRepo: - """Collects and provides information about the subclasses of torch.optim.Optimizer.""" + """ + Collects and provides information about the subclasses of torch.optim.Optimizer. + This class allows you to retrieve optimizer classes by name and obtain information about supported optimizers. + """ + repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} @classmethod diff --git a/python/fedml/simulation/sp/fedprox/client.py b/python/fedml/simulation/sp/fedprox/client.py index cc74a9d932..ff669658bd 100644 --- a/python/fedml/simulation/sp/fedprox/client.py +++ b/python/fedml/simulation/sp/fedprox/client.py @@ -1,17 +1,38 @@ class Client: + """ + Represents a federated learning client. + + Args: + client_idx (int): Index of the client. + local_training_data (Dataset): Local training dataset for the client. + local_test_data (Dataset): Local test dataset for the client. + local_sample_number (int): Number of local training samples. + args (argparse.Namespace): Command-line arguments. + device (torch.device): Device for training (e.g., "cpu" or "cuda"). + model_trainer (ModelTrainer): Trainer for the client's model. + """ + def __init__( - self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer, + self, client_idx, local_training_data, local_test_data, local_sample_number, args, device, model_trainer ): self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data self.local_sample_number = local_sample_number - self.args = args self.device = device self.model_trainer = model_trainer def update_local_dataset(self, client_idx, local_training_data, local_test_data, local_sample_number): + """ + Update the local dataset for the client. + + Args: + client_idx (int): Index of the client. + local_training_data (Dataset): New local training dataset. + local_test_data (Dataset): New local test dataset. + local_sample_number (int): Number of local training samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_test_data = local_test_data @@ -19,15 +40,39 @@ def update_local_dataset(self, client_idx, local_training_data, local_test_data, self.model_trainer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local training samples. + + Returns: + int: Number of local training samples. + """ return self.local_sample_number def train(self, w_global): + """ + Train the client's model using the global model weights. + + Args: + w_global (dict): Global model weights. + + Returns: + dict: Updated client model weights. + """ self.model_trainer.set_model_params(w_global) self.model_trainer.train(self.local_training_data, self.device, self.args) weights = self.model_trainer.get_model_params() return weights def local_test(self, b_use_test_dataset): + """ + Test the client's model on the local test dataset. + + Args: + b_use_test_dataset (bool): Flag to indicate whether to use the test dataset. + + Returns: + dict: Evaluation metrics. + """ if b_use_test_dataset: test_data = self.local_test_data else: diff --git a/python/fedml/simulation/sp/fedprox/fedprox_trainer.py b/python/fedml/simulation/sp/fedprox/fedprox_trainer.py index df333b1f7a..0473e70ec7 100644 --- a/python/fedml/simulation/sp/fedprox/fedprox_trainer.py +++ b/python/fedml/simulation/sp/fedprox/fedprox_trainer.py @@ -15,6 +15,16 @@ class FedProxTrainer(object): + """ + Federated Proximal Trainer for a federated learning model. + + Args: + dataset (list): A list containing various dataset components. + model (nn.Module): The federated learning model. + device (torch.device): Device for training (e.g., "cpu" or "cuda"). + args (argparse.Namespace): Command-line arguments. + """ + def __init__(self, dataset, model, device, args): self.device = device self.args = args @@ -51,6 +61,15 @@ def __init__(self, dataset, model, device, args): def _setup_clients( self, train_data_local_num_dict, train_data_local_dict, test_data_local_dict, model_trainer, ): + """ + Set up federated clients. + + Args: + train_data_local_num_dict (dict): Number of local training samples for each client. + train_data_local_dict (dict): Local training datasets for clients. + test_data_local_dict (dict): Local test datasets for clients. + model_trainer (ModelTrainer): Trainer for the client's model. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -66,6 +85,14 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def train(self): + """ + Train the federated model using federated learning. + + This method performs federated learning by aggregating client updates. + + Returns: + None + """ logging.info("self.model_trainer = {}".format(self.model_trainer)) w_global = self.model_trainer.get_model_params() mlops.log_training_status(mlops.ClientConstants.MSG_MLOPS_CLIENT_STATUS_TRAINING) @@ -126,6 +153,17 @@ def train(self): mlops.log_aggregation_finished_status() def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for communication in a round. + + Args: + round_idx (int): Index of the communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of sampled client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -136,6 +174,12 @@ def _client_sampling(self, round_idx, client_num_in_total, client_num_per_round) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset from the test dataset. + + Args: + num_samples (int): Number of samples to include in the validation set (default is 10,000). + """ test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) @@ -143,11 +187,26 @@ def _generate_validation_set(self, num_samples=10000): self.val_global = sample_testset def _aggregate(self, w_locals): + """ + Aggregate local model weights from multiple clients. + + Args: + w_locals (list): List of local model weights. + + Returns: + dict: Averaged global model weights. + """ avg_params = FedMLAggOperator.agg(self.args, w_locals) return avg_params def _local_test_on_all_clients(self, round_idx): + """ + Perform local testing on all clients in the federation. + + Args: + round_idx (int): Index of the communication round. + """ logging.info("################local_test_on_all_clients : {}".format(round_idx)) @@ -209,6 +268,12 @@ def _local_test_on_all_clients(self, round_idx): logging.info(stats) def _local_test_on_validation_set(self, round_idx): + """ + Perform local testing on all clients on validation set. + + Args: + round_idx (int): Index of the communication round. + """ logging.info("################local_test_on_validation_set : {}".format(round_idx)) From 092d3a2660c2ce8596c527ca93af0550e527df09 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 11:29:08 +0530 Subject: [PATCH 42/70] `python\fedml\utils\ `update --- python/fedml/__init__.py | 124 ++++++- python/fedml/launch_cheeath.py | 12 +- python/fedml/launch_cross_device.py | 5 +- python/fedml/launch_cross_silo_hi.py | 12 +- python/fedml/launch_cross_silo_horizontal.py | 12 +- python/fedml/launch_serving.py | 12 +- python/fedml/utils/compression.py | 326 ++++++++++++++++++- python/fedml/utils/context.py | 50 +++ python/fedml/utils/logging.py | 13 +- python/fedml/utils/model_utils.py | 133 +++++++- 10 files changed, 654 insertions(+), 45 deletions(-) diff --git a/python/fedml/__init__.py b/python/fedml/__init__.py index 7ab441b9e8..6417657041 100644 --- a/python/fedml/__init__.py +++ b/python/fedml/__init__.py @@ -32,7 +32,17 @@ def init(args=None, check_env=True, should_init_logs=True): if args is None: args = load_arguments(fedml._global_training_type, fedml._global_comm_backend) - """Initialize FedML Engine.""" + """ + Initialize the FedML Engine. + + Args: + args (argparse.Namespace, optional): Command-line arguments. Defaults to None. + check_env (bool, optional): Whether to check the environment. Defaults to True. + should_init_logs (bool, optional): Whether to initialize logs. Defaults to True. + + Returns: + argparse.Namespace: Updated command-line arguments. + """ if check_env: collect_env(args) @@ -120,6 +130,12 @@ def init(args=None, check_env=True, should_init_logs=True): def print_args(args): + """ + Print the arguments to the log, excluding sensitive paths. + + Args: + args (argparse.Namespace): Command-line arguments. + """ mqtt_config_path = None s3_config_path = None args_copy = args @@ -138,7 +154,9 @@ def print_args(args): def update_client_specific_args(args): """ - data_silo_config is used for reading specific configuration for each client + Update client-specific arguments based on data_silo_config. + + data_silo_config is used for reading specific configuration for each client Example: In fedml_config.yaml, we have the following configuration client_specific_args: data_silo_config: @@ -149,6 +167,9 @@ def update_client_specific_args(args): fedml_config/data_silo_4_config.yaml, ] data_silo_1_config.yaml contains some client client speicifc arguments. + + Args: + args (argparse.Namespace): Command-line arguments. """ if ( hasattr(args, "data_silo_config") @@ -166,7 +187,17 @@ def update_client_specific_args(args): def init_simulation_mpi(args): + from mpi4py import MPI + """ + Initialize MPI-based simulation. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + argparse.Namespace: Updated command-line arguments. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() @@ -183,14 +214,35 @@ def init_simulation_mpi(args): def init_simulation_sp(args): + """ + Initialize single-process simulation. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + argparse.Namespace: Updated command-line arguments. + """ return args def init_simulation_nccl(args): + """ + Initialize NCCL-based simulation. + + Args: + args (argparse.Namespace): Command-line arguments. + """ return def manage_profiling_args(args): + """ + Manage profiling-related arguments and configurations. + + Args: + args (argparse.Namespace): Command-line arguments. + """ if not hasattr(args, "sys_perf_profiling"): args.sys_perf_profiling = True if not hasattr(args, "sys_perf_profiling"): @@ -236,6 +288,12 @@ def manage_profiling_args(args): def manage_cuda_rpc_args(args): + """ + Manage CUDA RPC-related arguments and configurations. + + Args: + args (argparse.Namespace): Command-line arguments. + """ if (not hasattr(args, "enable_cuda_rpc")) or (not args.using_gpu): args.enable_cuda_rpc = False @@ -264,6 +322,12 @@ def manage_cuda_rpc_args(args): def manage_mpi_args(args): + """ + Manage MPI-related arguments and configurations. + + Args: + args (argparse.Namespace): Command-line arguments. + """ if hasattr(args, "backend") and args.backend == "MPI": from mpi4py import MPI @@ -282,6 +346,15 @@ def manage_mpi_args(args): args.comm = None def init_cross_silo_horizontal(args): + """ + Initialize the cross-silo training for the horizontal scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.n_proc_in_silo = 1 args.proc_rank_in_silo = 0 manage_mpi_args(args) @@ -291,6 +364,15 @@ def init_cross_silo_horizontal(args): def init_cross_silo_hierarchical(args): + """ + Initialize the cross-silo training for the hierarchical scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ manage_mpi_args(args) manage_cuda_rpc_args(args) @@ -344,6 +426,15 @@ def init_cross_silo_hierarchical(args): def init_cheetah(args): + """ + Initialize the CheetaH training scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.n_proc_in_silo = 1 args.proc_rank_in_silo = 0 manage_mpi_args(args) @@ -353,6 +444,15 @@ def init_cheetah(args): def init_model_serving(args): + """ + Initialize the model serving scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.n_proc_in_silo = 1 args.proc_rank_in_silo = 0 manage_cuda_rpc_args(args) @@ -361,10 +461,12 @@ def init_model_serving(args): def update_client_id_list(args): - """ - generate args.client_id_list for CLI mode where args.client_id_list is set to None - In MLOps mode, args.client_id_list will be set to real-time client id list selected by UI (not starting from 1) + Generate args.client_id_list for the CLI mode where args.client_id_list is set to None. + In MLOps mode, args.client_id_list will be set to a real-time client id list selected by the UI (not starting from 1). + + Args: + args (argparse.Namespace): Command-line arguments. """ if not hasattr(args, "using_mlops") or (hasattr(args, "using_mlops") and not args.using_mlops): if not hasattr(args, "client_id_list") or args.client_id_list is None or args.client_id_list == "None" or args.client_id_list == "[]": @@ -396,12 +498,24 @@ def update_client_id_list(args): def init_cross_device(args): + """ + Initialize the cross-device training scenario. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + args (argparse.Namespace): Updated command-line arguments. + """ args.rank = 0 # only server runs on Python package args.role = "server" return args def run_distributed(): + """ + Placeholder function for running distributed training. + """ pass diff --git a/python/fedml/launch_cheeath.py b/python/fedml/launch_cheeath.py index d0c40f8a14..e323bf2d26 100644 --- a/python/fedml/launch_cheeath.py +++ b/python/fedml/launch_cheeath.py @@ -5,7 +5,11 @@ def run_cheetah_server(): - """FedML Cheetah""" + """ + Run the server for the FedML Cheetah platform. + + This function initializes the server, loads data, and starts training using the Cheetah server. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CHEETAH args = fedml.init() @@ -26,7 +30,11 @@ def run_cheetah_server(): def run_cheetah_client(): - """FedML Cheetah""" + """ + Run a client for the FedML Cheetah platform. + + This function initializes a client, loads data, and starts training using the Cheetah client. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CHEETAH args = fedml.init() diff --git a/python/fedml/launch_cross_device.py b/python/fedml/launch_cross_device.py index 23934bcabb..1b613f88a6 100644 --- a/python/fedml/launch_cross_device.py +++ b/python/fedml/launch_cross_device.py @@ -5,8 +5,11 @@ def run_mnn_server(): from .cross_device import ServerMNN + """ + Run the server for the FedML BeeHive platform. - """FedML BeeHive""" + This function initializes the server, loads data, and starts training using the MNN (Multi-device Neural Network) server. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_DEVICE args = fedml.init() diff --git a/python/fedml/launch_cross_silo_hi.py b/python/fedml/launch_cross_silo_hi.py index 140cb1718e..c3ca6499bf 100644 --- a/python/fedml/launch_cross_silo_hi.py +++ b/python/fedml/launch_cross_silo_hi.py @@ -5,7 +5,11 @@ def run_hierarchical_cross_silo_server(): - """FedML Octopus""" + """ + Run the server for the FedML Octopus platform. + + This function initializes the server, loads data, and starts training using the Cross-Silo Octopus server. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO args = fedml.init() @@ -26,7 +30,11 @@ def run_hierarchical_cross_silo_server(): def run_hierarchical_cross_silo_client(): - """FedML Octopus""" + """ + Run a client for the FedML Octopus platform. + + This function initializes a client, loads data, and starts training using the Cross-Silo Octopus client. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO args = fedml.init() diff --git a/python/fedml/launch_cross_silo_horizontal.py b/python/fedml/launch_cross_silo_horizontal.py index aebe72c06c..85484e18ef 100644 --- a/python/fedml/launch_cross_silo_horizontal.py +++ b/python/fedml/launch_cross_silo_horizontal.py @@ -5,7 +5,11 @@ def run_cross_silo_server(): - """FedML Octopus""" + """ + Run the server for the FedML Octopus platform using Cross-Silo training. + + This function initializes the server, loads data, and starts training for the Cross-Silo Octopus platform. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO args = fedml.init() @@ -26,7 +30,11 @@ def run_cross_silo_server(): def run_cross_silo_client(): - """FedML Octopus""" + """ + Run a client for the FedML Octopus platform using Cross-Silo training. + + This function initializes a client, loads data, and starts training for the Cross-Silo Octopus platform. + """ global _global_training_type _global_training_type = FEDML_TRAINING_PLATFORM_CROSS_SILO diff --git a/python/fedml/launch_serving.py b/python/fedml/launch_serving.py index 2d9c8bf5c4..719ce7f8f9 100644 --- a/python/fedml/launch_serving.py +++ b/python/fedml/launch_serving.py @@ -5,7 +5,11 @@ def run_model_serving_server(): - """FedML Model Serving""" + """ + Run the server for the FedML Model Serving platform. + + This function initializes the server, loads data, and starts serving the model for the Model Serving platform. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_SERVING args = fedml.init() @@ -26,7 +30,11 @@ def run_model_serving_server(): def run_model_serving_client(): - """FedML Model Serving""" + """ + Run a client for the FedML Model Serving platform. + + This function initializes a client, loads data, and starts serving the model for the Model Serving platform. + """ fedml._global_training_type = FEDML_TRAINING_PLATFORM_SERVING args = fedml.init() diff --git a/python/fedml/utils/compression.py b/python/fedml/utils/compression.py index 8038abfc36..064fabd29e 100644 --- a/python/fedml/utils/compression.py +++ b/python/fedml/utils/compression.py @@ -7,13 +7,38 @@ class NoneCompressor(): + """ + A compressor that does not perform any compression. + + This compressor simply returns the input tensor as-is when compressing and decompressing. + """ def __init__(self): self.name = 'none' def compress(self, tensor): + """ + Compresses the input tensor. + + Args: + tensor: The input tensor to be compressed. + + Returns: + compressed_tensor: The same input tensor. + dtype: The data type of the tensor. + """ return tensor, tensor.dtype def decompress(self, tensor, ctc): + """ + Decompresses the input tensor. + + Args: + tensor: The compressed tensor. + ctc: The data type of the tensor (ignored). + + Returns: + z: The decompressed tensor, which is the same as the input tensor. + """ z = tensor return z @@ -23,6 +48,9 @@ class TopKCompressor(): Sparse Communication for Distributed Gradient Descent, Alham Fikri Aji et al., 2017 """ def __init__(self): + """ + Initialize the TopKCompressor. + """ self.residuals = {} self.sparsities = [] self.zero_conditions = {} @@ -38,9 +66,23 @@ def __init__(self): def _process_data_before_selecting(self, name, data): + """ + Perform data processing before selecting the top-k values. + + Args: + name (str): The name of the data. + data (Tensor): The input data tensor. + """ pass def _process_data_after_residual(self, name, data): + """ + Perform data processing after applying residuals. + + Args: + name (str): The name of the data. + data (Tensor): The input data tensor. + """ if name not in self.zero_conditions: self.zero_conditions[name] = torch.ones(data.numel(), dtype=torch.float32, device=data.device) zero_condition = self.zero_conditions[name] @@ -49,6 +91,9 @@ def _process_data_after_residual(self, name, data): self.zc = zero_condition def clear(self): + """ + Clear the compressor's internal state. + """ self.residuals = {} self.sparsities = [] self.zero_conditions = {} @@ -57,6 +102,20 @@ def clear(self): def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): + """ + Compress the input tensor using top-k selection. + + Args: + tensor (Tensor): The input tensor to be compressed. + name (str): The name of the tensor (optional). + sigma_scale (float): Scaling factor for selecting top-k values (default: 2.5). + ratio (float): Ratio of values to be retained (default: 0.05). + + Returns: + tensor (Tensor): The compressed tensor. + indexes (Tensor): The indexes of the top-k values. + values (Tensor): The top-k values. + """ start = time.time() with torch.no_grad(): # top-k solution @@ -73,14 +132,32 @@ def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): return tensor, indexes, values def decompress(self, tensor, original_tensor_size): + """ + Decompress the input tensor. + + Args: + tensor (Tensor): The compressed tensor. + original_tensor_size: The size of the original tensor (ignored). + + Returns: + tensor (Tensor): The decompressed tensor, which is the same as the input tensor. + """ return tensor def decompress_new(self, tensor, indexes, name=None, shape=None): - ''' - Just decompress, without unflatten. - Remember to do unflatter after decompress - ''' + """ + Decompress the input tensor without unflattening. Remember to do unflatter after decompress + + Args: + tensor (Tensor): The compressed tensor. + indexes (Tensor): The indexes of the top-k values. + name (str): The name of the tensor (optional). + shape (tuple): The shape of the tensor (optional). + + Returns: + decompress_tensor (Tensor): The decompressed tensor, which may need to be unflattened. + """ if shape is None: decompress_tensor = torch.zeros( self.shapes[name], dtype=tensor.dtype, device=tensor.device).view(-1) @@ -97,30 +174,69 @@ def decompress_new(self, tensor, indexes, name=None, shape=None): return decompress_tensor def flatten(self, tensor, name=None): - ''' - flatten a tensor - ''' + """ + Flatten the input tensor. + + Args: + tensor (Tensor): The input tensor to be flattened. + name (str): The name of the tensor (optional). + + Returns: + flattened_tensor (Tensor): The flattened tensor. + """ self.shapes[name] = tensor.shape return tensor.view(-1) def unflatten(self, tensor, name=None, shape=None): - ''' - unflatten a tensor - ''' + """ + Unflatten the input tensor. + + Args: + tensor (Tensor): The input tensor to be unflattened. + name (str): The name of the tensor (optional). + shape (tuple): The desired shape for unflattening (optional). + + Returns: + unflattened_tensor (Tensor): The unflattened tensor. + """ if shape is None: return tensor.view(self.shapes[name]) else: return tensor.view(shape) def update_shapes_dict(self, tensor, name): + """ + Update the shapes dictionary with the shape of the tensor. + + Args: + tensor (Tensor): The input tensor. + name (str): The name of the tensor. + """ self.shapes[name] = tensor.shape def get_residuals(self, name, like_tensor): + """ + Get the residuals for a given tensor name. + + Args: + name (str): The name of the tensor. + like_tensor (Tensor): A tensor with the same shape and device as the residuals. + + Returns: + residuals (Tensor): The residuals tensor. + """ if name not in self.residuals: self.residuals[name] = torch.zeros_like(like_tensor.data) return self.residuals[name] def add_residuals(self, included_indexes, name): + """ + Add residuals to the tensor for specified indexes. + + Args: + included_indexes (Tensor or ndarray): The indexes to include in the residuals. + name (str): The name of the tensor. + """ with torch.no_grad(): residuals = self.residuals[name] if type(included_indexes) is np.ndarray: @@ -138,12 +254,39 @@ def add_residuals(self, included_indexes, name): class EFTopKCompressor(TopKCompressor): """ + EFTopKCompressor extends the TopKCompressor class to provide error-feedback top-k compression. + + Args: + None + + Attributes: + name (str): The name of the compressor. + + Methods: + __init__(): Initializes the EFTopKCompressor instance. + compress(tensor, name=None, sigma_scale=2.5, ratio=0.05): Compresses the input tensor using error-feedback top-k compression. + _process_data_before_selecting(name, data): Helper method to process data before selecting top-k values. """ def __init__(self): + """ + Initializes a new instance of EFTopKCompressor. + """ super().__init__() self.name = 'eftopk' def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): + """ + Compresses the input tensor using error-feedback top-k compression. + + Args: + tensor (torch.Tensor): The input tensor to be compressed. + name (str): The name associated with the compression operation (optional). + sigma_scale (float): The scale factor for sigma used in compression (default: 2.5). + ratio (float): The compression ratio (default: 0.05). + + Returns: + tuple: A tuple containing the compressed tensor, indexes of top-k values, and the top-k values themselves. + """ start = time.time() with torch.no_grad(): if name not in self.residuals: @@ -168,11 +311,38 @@ def compress(self, tensor, name=None, sigma_scale=2.5, ratio=0.05): return tensor, indexes, values def _process_data_before_selecting(self, name, data): + """ + Helper method to process data before selecting top-k values. + + Args: + name (str): The name associated with the compression operation. + data (torch.Tensor): The data tensor to be processed. + """ data.add_(self.residuals[name].data) class QuantizationCompressor(object): + """ + Quantization Compressor. + + This class represents a compressor that performs quantization on tensors. + + Attributes: + name (str): The name of the compressor. + residuals (dict): A dictionary to store residuals. + values (dict): A dictionary to store quantized values. + zc: Not specified in the code. + current_ratio (float): The current quantization ratio. + shapes (dict): A dictionary to store tensor shapes. + + Methods: + get_naive_quantize(x, s, is_biased=False): Calculate quantized values for the input tensor. + compress(tensor, name=None, quantize_level=32, is_biased=True): Compress a tensor. + decompress_new(tensor): Decompress a tensor. + update_shapes_dict(tensor, name): Update the shapes dictionary. + + """ def __init__(self): self.name = 'quant' self.residuals = {} @@ -183,6 +353,17 @@ def __init__(self): self.shapes = {} def get_naive_quantize(self, x, s, is_biased=False): + """ + Calculate quantized values for the input tensor. + + Args: + x: Input tensor. + s: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Quantized tensor. + """ norm = x.norm(p=2) # calculate the quantization value of tensor `x` at level `log_2 s`. level_float = s * x.abs() / norm @@ -191,6 +372,18 @@ def get_naive_quantize(self, x, s, is_biased=False): return torch.sign(x) * norm * previous_level / s def compress(self, tensor, name=None, quantize_level=32, is_biased=True): + """ + Compress a tensor. + + Args: + tensor: Input tensor. + name: Name for the tensor. + quantize_level: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Compressed tensor. + """ if quantize_level != 32: s = 2 ** quantize_level - 1 values = self.get_naive_quantize(tensor, s, is_biased) @@ -199,15 +392,53 @@ def compress(self, tensor, name=None, quantize_level=32, is_biased=True): return values def decompress_new(self, tensor): + """ + Decompress a tensor. + + Args: + tensor: Compressed tensor. + + Returns: + Tensor: Decompressed tensor. + """ return tensor def update_shapes_dict(self, tensor, name): + """ + Update the shapes dictionary with the shape of the given tensor. + + Args: + tensor: Input tensor. + name (str): Name for the tensor. + """ self.shapes[name] = tensor.shape +class QSGDCompressor(object): + """ + QSGD (Quantized Stochastic Gradient Descent) Compressor. + QSGD is a compression technique for gradient updates in distributed training. -class QSGDCompressor(object): + Args: + None + + Attributes: + name (str): The name of the compressor. + residuals (dict): Dictionary to store residuals. + values (dict): Dictionary to store quantized values. + zc: Not specified in the code. + current_ratio (float): Current quantization ratio. + shapes (dict): Dictionary to store tensor shapes. + + Methods: + get_qsgd(x, s, is_biased=False): Calculate quantized values for the input tensor. + qsgd_quantize_numpy(x, s, is_biased=False): Quantize a numpy array. + compress(tensor, name=None, quantize_level=32, is_biased=True): Compress a tensor. + decompress_new(tensor): Decompress a tensor. + update_shapes_dict(tensor, name): Update the shapes dictionary. + + """ def __init__(self): self.name = 'qsgd' self.residuals = {} @@ -218,6 +449,17 @@ def __init__(self): self.shapes = {} def get_qsgd(self, x, s, is_biased=False): + """ + Calculate quantized values for the input tensor. + + Args: + x: Input tensor. + s: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Quantized tensor. + """ norm = x.norm(p=2) # calculate the quantization value of tensor `x` at level `log_2 s`. level_float = s * x.abs() / norm @@ -235,7 +477,17 @@ def get_qsgd(self, x, s, is_biased=False): return scale * torch.sign(x) * norm * new_level / s def qsgd_quantize_numpy(self, x, s, is_biased=False): - """quantize the tensor x in d level on the absolute value coef wise""" + """ + Quantize a numpy array the tensor x in d level on the absolute value coef wise. + + Args: + x: Input numpy array. + s: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + ndarray: Quantized numpy array. + """ norm = np.sqrt(np.sum(np.square(x))) # calculate the quantization value of tensor `x` at level `log_2 s`. level_float = s * np.abs(x) / norm @@ -253,6 +505,18 @@ def qsgd_quantize_numpy(self, x, s, is_biased=False): return scale * np.sign(x) * norm * new_level / s def compress(self, tensor, name=None, quantize_level=32, is_biased=True): + """ + Compress a tensor. + + Args: + tensor: Input tensor. + name: Name for the tensor. + quantize_level: Quantization level. + is_biased (bool): Whether to use biased quantization. + + Returns: + Tensor: Compressed tensor. + """ if quantize_level != 32: s = 2 ** quantize_level - 1 values = self.get_qsgd(tensor, s, is_biased) @@ -261,13 +525,26 @@ def compress(self, tensor, name=None, quantize_level=32, is_biased=True): return values def decompress_new(self, tensor): - return tensor + """ + Decompress a tensor. - def update_shapes_dict(self, tensor, name): - self.shapes[name] = tensor.shape + Args: + tensor: Compressed tensor. + Returns: + Tensor: Decompressed tensor. + """ + return tensor + def update_shapes_dict(self, tensor, name): + """ + Update the shapes dictionary. + Args: + tensor: Input tensor. + name: Name for the tensor. + """ + self.shapes[name] = tensor.shape compressors = { @@ -282,11 +559,30 @@ def update_shapes_dict(self, tensor, name): def gen_threshold_from_normal_distribution(p_value, mu, sigma): r"""PPF.""" + """ + Generate threshold from a normal distribution. + + Args: + p_value (float): The p-value. + mu (float): The mean of the distribution. + sigma (float): The standard deviation of the distribution. + + Returns: + left_thres (float): The left threshold value. + right_thres (float): The right threshold value. + """ zvalue = stats.norm.ppf((1-p_value)/2) return mu+zvalue*sigma, mu-zvalue*sigma def test_gaussion_thres(): + """ + Test threshold calculation for a Gaussian distribution. + + This function generates random data from a Gaussian distribution and computes various statistics + including p-value, mean, and standard deviation. It then calculates a threshold and compares it + with the threshold generated from the Gaussian distribution. + """ set_mean = 0.0; set_std = 0.5 d = np.random.normal(set_mean, set_std, 10000) k2, p = stats.normaltest(d) diff --git a/python/fedml/utils/context.py b/python/fedml/utils/context.py index 1303a3fbc1..40a6c25f79 100644 --- a/python/fedml/utils/context.py +++ b/python/fedml/utils/context.py @@ -7,6 +7,21 @@ @contextmanager def raise_MPI_error(): + """ + Context manager to catch and handle MPI-related errors. + + This context manager is used to catch exceptions and errors that may occur + during MPI (Message Passing Interface) operations and handle them gracefully. + + Usage: + ```python + with raise_MPI_error(): + # Code that may raise MPI-related errors + ``` + + Returns: + None + """ import logging logging.debug("Debugging, Enter the MPI catch error") @@ -20,6 +35,21 @@ def raise_MPI_error(): @contextmanager def raise_error_without_process(): + """ + Context manager to catch and handle errors without aborting the MPI process. + + This context manager is used to catch exceptions and errors without aborting + the MPI (Message Passing Interface) process, allowing it to continue running. + + Usage: + ```python + with raise_error_without_process(): + # Code that may raise errors + ``` + + Returns: + None + """ import logging logging.debug("Debugging, Enter the MPI catch error") @@ -32,6 +62,26 @@ def raise_error_without_process(): @contextmanager def get_lock(lock: threading.Lock()): + """ + Context manager to acquire and release a threading lock. + + This context manager is used to acquire and release a threading lock in a controlled + manner. It ensures that the lock is always released, even in the presence of exceptions. + + Args: + lock (threading.Lock): The threading lock to acquire and release. + + Usage: + ```python + my_lock = threading.Lock() + with get_lock(my_lock): + # Code that requires the lock + # The lock is automatically released after the code block + ``` + + Returns: + None + """ lock.acquire() yield if lock.locked(): diff --git a/python/fedml/utils/logging.py b/python/fedml/utils/logging.py index 8aa089d5f9..5fa886b027 100644 --- a/python/fedml/utils/logging.py +++ b/python/fedml/utils/logging.py @@ -1,5 +1,6 @@ import logging +#define log levels log_levels = { "debug": logging.DEBUG, "info": logging.INFO, @@ -12,16 +13,16 @@ class LoggerCreator: @staticmethod def create_logger(name=None, level=logging.INFO, args=None): - """create a logger + """ + Create and configure a logger. Args: - name (str): name of the logger - level: level of logger + name (str): The name of the logger. + level: The logging level for the logger. - Raises: - ValueError is name is None + Returns: + logger: An instance of the logger. """ - if name is None: raise ValueError("name for logger cannot be None") diff --git a/python/fedml/utils/model_utils.py b/python/fedml/utils/model_utils.py index 0c4b58421c..6961473e49 100644 --- a/python/fedml/utils/model_utils.py +++ b/python/fedml/utils/model_utils.py @@ -9,7 +9,13 @@ def get_weights(state): """ - Returns list of weights from state_dict + Returns a list of weights from a state_dict. + + Args: + state (dict or None): A PyTorch state_dict or None. + + Returns: + list or None: A list of tensor weights or None if the state is None. """ if state is not None: return list(state.values()) @@ -18,6 +24,12 @@ def get_weights(state): def clear_optim_buffer(optimizer): + """ + Clears the optimizer's momentum buffers for each parameter. + + Args: + optimizer: A PyTorch optimizer. + """ for group in optimizer.param_groups: for p in group["params"]: param_state = optimizer.state[p] @@ -30,6 +42,13 @@ def clear_optim_buffer(optimizer): def optimizer_to(optim, device): + """ + Moves the optimizer's state and associated tensors to the specified device. + + Args: + optim (torch.optim.Optimizer): A PyTorch optimizer. + device (torch.device): The target device (e.g., 'cuda' or 'cpu'). + """ for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): @@ -45,6 +64,16 @@ def optimizer_to(optim, device): def move_to_cpu(model, optimizer): + """ + Moves a PyTorch model and its associated optimizer to the CPU device. + + Args: + model (torch.nn.Module): The PyTorch model. + optimizer (torch.optim.Optimizer): The optimizer associated with the model. + + Returns: + torch.nn.Module: The model after moving it to the CPU. + """ if str(next(model.parameters()).device) == "cpu": pass else: @@ -56,6 +85,17 @@ def move_to_cpu(model, optimizer): def move_to_gpu(model, optimizer, device): + """ + Moves a PyTorch model and its associated optimizer to the specified GPU device. + + Args: + model (torch.nn.Module): The PyTorch model. + optimizer (torch.optim.Optimizer): The optimizer associated with the model. + device (str or torch.device): The target GPU device, e.g., 'cuda:0'. + + Returns: + torch.nn.Module: The model after moving it to the GPU. + """ if str(next(model.parameters()).device) == "cpu": model = model.to(device) else: @@ -72,9 +112,15 @@ def move_to_gpu(model, optimizer, device): def get_named_data(model, mode="MODEL", use_cuda=True): """ - getting the whole model and getting the gradients can be conducted - by using different methods for reducing the communication. - `model` choices: ['MODEL', 'GRAD', 'MODEL+GRAD'] + Get various components of a PyTorch model based on the specified mode. + + Args: + model (torch.nn.Module): The PyTorch model. + mode (str): Mode for extracting components ('MODEL', 'GRAD', or 'MODEL+GRAD'). + use_cuda (bool): Whether to use CUDA (GPU) for extraction. + + Returns: + dict: A dictionary containing the requested components. """ if mode == "MODEL": own_state = model.cpu().state_dict() @@ -113,6 +159,17 @@ def get_named_data(model, mode="MODEL", use_cuda=True): def get_bn_params(prefix, module, use_cuda=True): + """ + Get batch normalization parameters with the specified prefix. + + Args: + prefix (str): Prefix for parameter names. + module (nn.BatchNorm2d): Batch normalization module. + use_cuda (bool): Whether to use CUDA (GPU) for extraction. + + Returns: + dict: A dictionary containing batch normalization parameters. + """ bn_params = {} if use_cuda: bn_params[f"{prefix}.weight"] = module.weight @@ -130,6 +187,16 @@ def get_bn_params(prefix, module, use_cuda=True): def get_all_bn_params(model, use_cuda=True): + """ + Get all batch normalization parameters from a PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model. + use_cuda (bool): Whether to use CUDA (GPU) for extraction. + + Returns: + dict: A dictionary containing all batch normalization parameters. + """ all_bn_params = {} for module_name, module in model.named_modules(): # print(f"key:{key}, module, {module}") @@ -146,6 +213,12 @@ def get_all_bn_params(model, use_cuda=True): def check_bn_status(bn_module): + """ + Print and log batch normalization parameters and status. + + Args: + bn_module (nn.BatchNorm2d): Batch normalization module. + """ logging.info(f"weight: {bn_module.weight[:10].mean()}") logging.info(f"bias: {bn_module.bias[:10].mean()}") logging.info(f"running_mean: {bn_module.running_mean[:10].mean()}") @@ -159,10 +232,15 @@ def check_bn_status(bn_module): def average_named_params(named_params_list, average_weights_dict_list, inplace=True): """ - This is a weighted average operation. - average_weights_dict_list: includes weights with respect to clients. Same for each param. - inplace: Whether change the first client's model inplace. - Note: This function also can be used to average gradients. + Average named parameters based on a list of parameters and their associated weights. + + Args: + named_params_list (list): List of named parameters to be averaged. + average_weights_dict_list (list): List of weights for each set of named parameters. + inplace (bool): Whether to modify the first set of parameters in-place. + + Returns: + dict: Averaged named parameters. """ # logging.info("################aggregate: %d" % len(named_params_list)) @@ -219,6 +297,15 @@ def average_named_params(named_params_list, average_weights_dict_list, inplace=T def get_average_weight(sample_num_list): + """ + Calculate average weights based on a list of sample numbers. + + Args: + sample_num_list (list): List of sample numbers. + + Returns: + list: List of average weights. + """ # balance_sample_number_list = [] average_weights_dict_list = [] sum = 0 @@ -239,6 +326,16 @@ def get_average_weight(sample_num_list): def check_device(data_src, device=None): + """ + Ensure data is on the specified device. + + Args: + data_src: Data to be moved to the device. + device (str): Device to move the data to (e.g., 'cpu' or 'cuda'). + + Returns: + Data on the specified device. + """ if device is not None: if data_src.device is not device: return data_src.to(device) @@ -252,7 +349,16 @@ def check_device(data_src, device=None): def get_diff_weights(weights1, weights2): - """ Produce a direction from 'weights1' to 'weights2'.""" + """ + Calculate the difference between two sets of weights. + + Args: + weights1: First set of weights. + weights2: Second set of weights. + + Returns: + Difference between the two sets of weights. + """ if isinstance(weights1, list) and isinstance(weights2, list): return [w2 - w1 for (w1, w2) in zip(weights1, weights2)] elif isinstance(weights1, torch.Tensor) and isinstance(weights2, torch.Tensor): @@ -263,7 +369,14 @@ def get_diff_weights(weights1, weights2): def get_name_params_difference(named_parameters1, named_parameters2): """ - return named_parameters2 - named_parameters1 + Calculate the difference between two sets of named parameters. + + Args: + named_parameters1 (dict): First set of named parameters. + named_parameters2 (dict): Second set of named parameters. + + Returns: + dict: Dictionary containing the differences between common named parameters. """ common_names = list(set(named_parameters1.keys()).intersection(set(named_parameters2.keys()))) named_diff_parameters = {} From d100a0b2dce2f9a35c2adfcd23508bb13d040983 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 13:28:39 +0530 Subject: [PATCH 43/70] udpate `python\fedml\simulation\nccl\base_framework` server.py remaning --- python/fedml/__init__.py | 100 +++++----- .../nccl/base_framework/LocalAggregator.py | 131 ++++++++++++- .../nccl/base_framework/algorithm_api.py | 16 ++ .../simulation/nccl/base_framework/common.py | 183 +++++++++++++++++- .../simulation/nccl/base_framework/params.py | 167 ++++++++++++++-- 5 files changed, 524 insertions(+), 73 deletions(-) diff --git a/python/fedml/__init__.py b/python/fedml/__init__.py index 6417657041..b3274d1843 100644 --- a/python/fedml/__init__.py +++ b/python/fedml/__init__.py @@ -22,6 +22,56 @@ ) from .core.common.ml_engine_backend import MLEngineBackend +from fedml import device +from fedml import data +from fedml import model +from fedml import mlops + +from .arguments import load_arguments + +from .launch_simulation import run_simulation + +from .launch_cross_silo_horizontal import run_cross_silo_server +from .launch_cross_silo_horizontal import run_cross_silo_client + +from .launch_cross_silo_hi import run_hierarchical_cross_silo_server +from .launch_cross_silo_hi import run_hierarchical_cross_silo_client + +from .launch_cheeath import run_cheetah_server +from .launch_cheeath import run_cheetah_client + +from .launch_serving import run_model_serving_client +from .launch_serving import run_model_serving_server + +from .launch_cross_device import run_mnn_server + +from .core.common.ml_engine_backend import MLEngineBackend + +from .runner import FedMLRunner + +from fedml import api + +__all__ = [ + "MLEngineBackend", + "device", + "data", + "model", + "mlops", + "FedMLRunner", + "run_simulation", + "run_cross_silo_server", + "run_cross_silo_client", + "run_hierarchical_cross_silo_server", + "run_hierarchical_cross_silo_client", + "run_cheetah_server", + "run_cheetah_client", + "run_model_serving_client", + "run_model_serving_server", + "run_mnn_server", + "api" +] + + _global_training_type = None _global_comm_backend = None @@ -517,53 +567,3 @@ def run_distributed(): Placeholder function for running distributed training. """ pass - - -from fedml import device -from fedml import data -from fedml import model -from fedml import mlops - -from .arguments import load_arguments - -from .launch_simulation import run_simulation - -from .launch_cross_silo_horizontal import run_cross_silo_server -from .launch_cross_silo_horizontal import run_cross_silo_client - -from .launch_cross_silo_hi import run_hierarchical_cross_silo_server -from .launch_cross_silo_hi import run_hierarchical_cross_silo_client - -from .launch_cheeath import run_cheetah_server -from .launch_cheeath import run_cheetah_client - -from .launch_serving import run_model_serving_client -from .launch_serving import run_model_serving_server - -from .launch_cross_device import run_mnn_server - -from .core.common.ml_engine_backend import MLEngineBackend - -from .runner import FedMLRunner - -from fedml import api - -__all__ = [ - "MLEngineBackend", - "device", - "data", - "model", - "mlops", - "FedMLRunner", - "run_simulation", - "run_cross_silo_server", - "run_cross_silo_client", - "run_hierarchical_cross_silo_server", - "run_hierarchical_cross_silo_client", - "run_cheetah_server", - "run_cheetah_client", - "run_model_serving_client", - "run_model_serving_server", - "run_mnn_server", - "api" -] diff --git a/python/fedml/simulation/nccl/base_framework/LocalAggregator.py b/python/fedml/simulation/nccl/base_framework/LocalAggregator.py index e6bb9d7cd5..36057a0b6c 100644 --- a/python/fedml/simulation/nccl/base_framework/LocalAggregator.py +++ b/python/fedml/simulation/nccl/base_framework/LocalAggregator.py @@ -16,11 +16,53 @@ class BaseLocalAggregator(object): """ Used to manage and aggregate results from local trainers (clients). It needs to know all datasets. - device: indicates the device of this local aggregator. + device: indicates the device of this local aggregator + + Args: + args: The command-line arguments for the aggregator. + rank (int): The rank of this local aggregator. + worker_number (int): The total number of workers, including the server and clients. + comm: The communication state. + device: The device where the aggregator is located. + dataset: The dataset used for training and testing. + model: The model used for training. + trainer: The trainer responsible for training the model. + + Attributes: + device: Indicates the device of this local aggregator. + args: The command-line arguments for the aggregator. + trainer: The trainer responsible for training the model. + train_global: The global training dataset. + test_global: The global testing dataset. + val_global: The global validation dataset (if available). + train_data_num_in_total: The total number of training data points across all clients. + test_data_num_in_total: The total number of testing data points across all clients. + train_data_local_num_dict: A dictionary mapping client indices to the number of training data points for each client. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + comm: The communication state. + rank: The rank of this local aggregator. + device_rank: The rank of this local aggregator as a device (GPU). + worker_number: The total number of workers, including the server and clients. + device_number: The total number of devices (GPUs) used for training. + groups: A dictionary of communication groups, where each group is associated with a specific device. + + Methods: + measure_client_runtime(): Measures the runtime of client operations. + simulate_client(server_params, client_index, average_weight): Simulates a client's training process. + add_client_result(localAggregatorToServerParams, client_params): Adds client results to be aggregated and sent to the server. """ # def __init__(self, args, trainer, device, dataset, comm=None, rank=0, size=0, backend="NCCL"): def __init__(self, args, rank, worker_number, comm, device, dataset, model, trainer): + """ + Measure the runtime of client operations. + + This method measures the runtime of client operations and can be used for performance analysis. + + Returns: + None + """ self.device = device self.args = args self.trainer = trainer @@ -64,9 +106,29 @@ def __init__(self, args, rank, worker_number, comm, device, dataset, model, trai logging.info("self.trainer = {}".format(self.trainer)) def measure_client_runtime(self): + """ + Measure the runtime of client operations. + + This method measures the runtime of client operations and can be used for performance analysis. + + Returns: + None + """ pass def simulate_client(self, server_params, client_index, average_weight): + """ + Simulate a client's training process. + + Args: + server_params: Parameters received from the server. + client_index (int): The index of the simulated client. + average_weight: The average weight used in the simulation. + + Returns: + client_params: Parameters to be sent back to the server. + """ + # server_model_parameters = server_params.get("model_params") # self.trainer.set_model_params(server_model_parameters) self.trainer.id = client_index @@ -83,6 +145,16 @@ def simulate_client(self, server_params, client_index, average_weight): return client_params def add_client_result(self, localAggregatorToServerParams, client_params): + """ + Add client results to be aggregated and sent to the server. + + Args: + localAggregatorToServerParams: Parameters to be sent to the server. + client_params: Parameters received from a client. + + Returns: + None + """ # Add params that needed to be reduces from clients mean_sum_param_names = client_params.get_sum_reduce_param_names() for name in mean_sum_param_names: @@ -96,6 +168,15 @@ def add_client_result(self, localAggregatorToServerParams, client_params): ) def simulate_all_tasks(self, server_params): + """ + Simulate all tasks for this local aggregator. + + Args: + server_params: Parameters received from the server. + + Returns: + localAggregatorToServerParams: Parameters to be sent back to the server. + """ average_weight_dict = self.decode_average_weight_dict(server_params) client_indexes = server_params.get(f"client_schedule{self.device_rank}").numpy() simulated_client_indexes = [] @@ -124,7 +205,16 @@ def simulate_all_tasks(self, server_params): def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, server_params): """ - This is used for receiving server schedule client indexes. + Receive server's schedule of client indexes for this local aggregator. + + Args: + round_idx: The current round index. + client_num_in_total: The total number of clients. + client_num_per_round: The number of clients to be scheduled for this round. + server_params: Parameters received from the server. + + Returns: + None """ # scheduler(workloads, constraints, memory) for i in range(self.device_number): @@ -133,12 +223,31 @@ def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, return None, None def get_average_weight(self, client_indexes): + """ + Get average weight for a list of client indexes. + + Args: + client_indexes: A list of client indexes. + + Returns: + average_weight_dict: A dictionary mapping client indexes to their average weights. + """ average_weight_dict = {} for client_index in client_indexes: average_weight_dict[client_index] = 0.0 return average_weight_dict def encode_average_weight_dict(self, server_params, average_weight_dict): + """ + Encode and add the average weight dictionary to server parameters. + + Args: + server_params: Parameters to be sent to the server. + average_weight_dict: A dictionary mapping client indexes to their average weights. + + Returns: + None + """ server_params.add_broadcast_param( name="average_weight_dict_keys", param=torch.tensor(list(average_weight_dict.keys())) ) @@ -147,6 +256,15 @@ def encode_average_weight_dict(self, server_params, average_weight_dict): ) def decode_average_weight_dict(self, server_params): + """ + Decode the average weight dictionary from server parameters. + + Args: + server_params: Parameters received from the server. + + Returns: + average_weight_dict: A dictionary mapping client indexes to their average weights. + """ average_weight_dict_keys = server_params.get("average_weight_dict_keys").numpy() average_weight_dict_values = server_params.get("average_weight_dict_values").numpy() average_weight_dict = {} @@ -154,6 +272,15 @@ def decode_average_weight_dict(self, server_params): return average_weight_dict def train(self): + """ + Train the federated learning model. + + This method handles the federated learning training process, including communication with the server, + scheduling clients, and aggregating local client results. + + Returns: + None + """ server_params = ServerToClientParams() server_params.add_broadcast_param(name="broadcastTest", param=torch.tensor([0, 0, 0])) server_params.broadcast() diff --git a/python/fedml/simulation/nccl/base_framework/algorithm_api.py b/python/fedml/simulation/nccl/base_framework/algorithm_api.py index e656d220e1..127efb9459 100644 --- a/python/fedml/simulation/nccl/base_framework/algorithm_api.py +++ b/python/fedml/simulation/nccl/base_framework/algorithm_api.py @@ -3,6 +3,22 @@ def FedML_Base_NCCL(args, process_id, worker_number, comm, device, dataset, model, model_trainer=None): + """ + Create an instance of either the BaseServer or BaseLocalAggregator based on the process ID. + + Args: + args: The arguments for configuring the FedML engine. + process_id (int): The ID of the current process. + worker_number (int): The total number of workers in the simulation. + comm: The communication backend (e.g., MPI communicator). + device: The device on which the model should be placed. + dataset: The dataset used for training. + model: The model to be trained. + model_trainer: An optional trainer for the model. + + Returns: + BaseServer or BaseLocalAggregator: An instance of either the server or local aggregator based on the process ID. + """ if process_id == 0: return BaseServer(args, process_id, worker_number, comm, device, dataset, model, model_trainer) diff --git a/python/fedml/simulation/nccl/base_framework/common.py b/python/fedml/simulation/nccl/base_framework/common.py index 7011ebdda4..1b64f8a63a 100644 --- a/python/fedml/simulation/nccl/base_framework/common.py +++ b/python/fedml/simulation/nccl/base_framework/common.py @@ -11,8 +11,14 @@ def get_weights(state): """ - Returns list of weights from state_dict - """ + Returns a list of weights from the state dictionary. + + Args: + state (dict): The state dictionary containing model parameters. + + Returns: + list or None: A list of model weights or None if the state is None. + """" if state is not None: return list(state.values()) else: @@ -20,12 +26,25 @@ def get_weights(state): def set_model_params_with_list(model, new_model_params): + """ + Set the model parameters with a list of new parameters. + + Args: + model: The model whose parameters will be updated. + new_model_params (list): A list of new model parameters. + """ for model_param, model_update_param in zip(model.parameters(), new_model_params): print(f"model_param.shape: {model_param.shape}, model_update_param.shape: {model_update_param.shape}") # model_param.data = model_update_param def clear_optim_buffer(optimizer): + """ + Clear the optimization buffer for momentum. + + Args: + optimizer: The optimizer whose buffer will be cleared. + """ for group in optimizer.param_groups: for p in group["params"]: param_state = optimizer.state[p] @@ -38,6 +57,13 @@ def clear_optim_buffer(optimizer): def optimizer_to(optim, device): + """ + Move optimizer parameters to the specified device. + + Args: + optim: The optimizer whose parameters will be moved. + device (str): The target device (e.g., 'cpu' or 'cuda'). + """ for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): @@ -53,6 +79,16 @@ def optimizer_to(optim, device): def move_to_cpu(model, optimizer): + """ + Move the model and optimizer to the CPU. + + Args: + model: The model to be moved. + optimizer: The optimizer to be moved. + + Returns: + model: The model on the CPU. + """ if str(next(model.parameters()).device) == "cpu": pass else: @@ -64,6 +100,17 @@ def move_to_cpu(model, optimizer): def move_to_gpu(model, optimizer, device): + """ + Move the model and optimizer to the specified GPU device. + + Args: + model: The model to be moved. + optimizer: The optimizer to be moved. + device (str): The target GPU device (e.g., 'cuda'). + + Returns: + model: The model on the specified GPU device. + """ if str(next(model.parameters()).device) == "cpu": model = model.to(device) else: @@ -104,6 +151,16 @@ class CommState: def init_ddp(args): + """ + Initialize Distributed Data Parallel (DDP) for training. + + Args: + args: The arguments containing DDP configuration. + + Returns: + global_rank (int): The global rank of the current process. + world_size (int): The total number of processes in the world. + """ # use InfiniBand os.environ["NCCL_DEBUG"] = "INFO" os.environ["NCCL_SOCKET_IFNAME"] = "lo" @@ -127,6 +184,15 @@ def init_ddp(args): def FedML_NCCL_Similulation_init(args): + """ + Initialize NCCL-based simulation environment. + + Args: + args: The arguments containing simulation configuration. + + Returns: + args (object): The updated arguments. + """ # dist.init_process_group( # init_method='tcp://10.1.1.20:23456', # rank=args.rank, @@ -157,27 +223,74 @@ def FedML_NCCL_Similulation_init(args): def get_rank(): - return dist.get_rank() + """ + Get the rank of the current process in the distributed environment. + Returns: + int: The rank of the current process. + """ + return dist.get_rank() def get_server_rank(): - return CommState.server_rank + """ + Get the rank of the server process in the distributed environment. + Returns: + int: The rank of the server process. + """ + return CommState.server_rank def get_world_size(): - return dist.get_world_size() + """ + Get the total number of processes in the distributed environment. + Returns: + int: The total number of processes. + """ + return dist.get_world_size() def get_worker_number(): + """ + Get the number of worker processes (excluding the server) in the distributed environment. + + Returns: + int: The number of worker processes. + """ return CommState.device_size def new_group(ranks): + """ + Create a new process group with the specified ranks. + + Args: + ranks (list): A list of ranks to include in the new group. + + Returns: + dist.ProcessGroup: The new process group. + """ return dist.new_group(ranks=ranks) # dist.new_group(ranks=None, timeout=datetime.timedelta(seconds=1800), backend=None, pg_options=None) def fedml_nccl_send_to_server(tensor, src=0, group=None): + """ + Send a tensor from a device (GPU) to the server process. + + Args: + tensor (torch.Tensor): The tensor to send. + src (int): The source rank of the sending process. + group (dist.ProcessGroup, optional): The process group to use for communication. + + Note: + This function is used to send tensors from a device (GPU) to the server during communication. + + Example: + ```python + fedml_nccl_send_to_server(my_tensor, src=1, group=my_group) + ``` + + """ is_cuda = tensor.is_cuda # if not is_cuda: # logging.info("Warning: Tensor is not on GPU!!!") @@ -186,6 +299,22 @@ def fedml_nccl_send_to_server(tensor, src=0, group=None): def fedml_nccl_broadcast(tensor, src): + """ + Broadcast a tensor from the server process to all devices (GPUs). + + Args: + tensor (torch.Tensor): The tensor to broadcast. + src (int): The source rank of the broadcasting process. + + Note: + This function is used to broadcast tensors from the server to all devices during communication. + + Example: + ```python + fedml_nccl_broadcast(my_tensor, src=0) + ``` + + """ is_cuda = tensor.is_cuda # if not is_cuda: # logging.info("Warning: Tensor is not on GPU!!!") @@ -195,7 +324,21 @@ def fedml_nccl_broadcast(tensor, src): def fedml_nccl_reduce(tensor, dst, op: ReduceOp = ReduceOp.SUM): """ - :param op: Currently only supports SUM and MEAN reduction ops + Reduce a tensor across processes with the specified reduction operation. + + Args: + tensor (torch.Tensor): The tensor to reduce. + dst (int): The destination rank for the reduced tensor. + op (ReduceOp): The reduction operation (SUM or MEAN). Currently only supports SUM and MEAN reduction ops + + Note: + This function is used to perform reduction operations (SUM or MEAN) on tensors across processes. + + Example: + ```python + fedml_nccl_reduce(my_tensor, dst=0, op=ReduceOp.SUM) + ``` + """ is_cuda = tensor.is_cuda # if not is_cuda: @@ -216,10 +359,38 @@ def fedml_nccl_reduce(tensor, dst, op: ReduceOp = ReduceOp.SUM): def fedml_nccl_barrier(): + """ + Synchronize all processes in the distributed environment. + + Note: + This function is used to ensure that all processes reach a barrier and synchronize their execution. + + Example: + ```python + fedml_nccl_barrier() + ``` + + """ dist.barrier() def broadcast_model_state(state_dict, src): + """ + Broadcast the model's state dictionary from the server process to all devices (GPUs). + + Args: + state_dict (dict): The model's state dictionary to broadcast. + src (int): The source rank of the broadcasting process. + + Note: + This function is used to broadcast the model's state dictionary from the server to all devices during communication. + + Example: + ```python + broadcast_model_state(my_state_dict, src=0) + ``` + + """ # for name, param in state_dict.items(): # logging.info(f"name:{name}, param.shape: {param.shape}") for param in state_dict.values(): diff --git a/python/fedml/simulation/nccl/base_framework/params.py b/python/fedml/simulation/nccl/base_framework/params.py index 1e1c39777f..9b790a8ebf 100644 --- a/python/fedml/simulation/nccl/base_framework/params.py +++ b/python/fedml/simulation/nccl/base_framework/params.py @@ -10,15 +10,20 @@ class Params(Params): """ - Unified Parameter Object for passing arguments among APIs - from the algorithm frame (e.g., client_trainer.py and server aggregator.py). + Unified Parameter Object for passing arguments among APIs. - Usage:: + This class is used for passing arguments among different parts of the algorithm framework. + You can add parameters and retrieve them using attribute access. + + Example: >> my_params = Params() - >> # add parameter + >> # Add a parameter >> my_params.add(name="w", param=model_weights) - >> # get parameter - >> my_params.w + >> # Get a parameter + >> weight = my_params.w + + Attributes: + _params (dict): A dictionary to store parameter names and values. """ def __init__(self, **kwargs): @@ -27,7 +32,20 @@ def __init__(self, **kwargs): class ServerToClientParams(Params): """ - Normally, ServerToClient only broadcast parameters, hence all devices will receive same data from server. + Parameters sent from server to clients for broadcasting. + + This class represents parameters that are broadcasted from the server to all clients. + It allows adding broadcast parameters and performing the broadcasting operation. + + Example: + >> server_params = ServerToClientParams() + >> # Add a broadcast parameter + >> server_params.add_broadcast_param(name="w", param=model_weights) + >> # Broadcast the added parameters to all clients + >> server_params.broadcast() + + Attributes: + _broadcast_params (list): A list of parameter names to be broadcasted. """ def __init__(self, **kwargs): @@ -36,14 +54,26 @@ def __init__(self, **kwargs): # self._broadcast_params = {} def add_broadcast_param(self, name, param): + """ + Add a parameter to be broadcasted to all clients. + + Args: + name (str): The name of the parameter. + param (torch.Tensor or list of torch.Tensor): The parameter to be broadcasted. + + Returns: + None + """ self.__dict__.update({name: param}) self._broadcast_params.append(name) # self._broadcast_params.update({name: param}) - def broadcast(self): + def broadcast(self): """ - Perform communication of the added parameters. - Note that this is a collective operation and all processes (server and devices) must call this function. + Perform broadcasting of the added parameters to all clients. + + Note: + This is a collective operation, and all processes (server and devices) must call this function. """ for param_name in self._broadcast_params: @@ -56,13 +86,26 @@ def broadcast(self): class LocalAggregatorToServerParams(Params): + """ + Parameters sent from local aggregator to the server for aggregation. + + This class represents parameters that are sent from local aggregators to the server + for aggregation and communication between clients and the server. + + Attributes: + _reduce_params (dict): A dictionary containing lists of parameters to be reduced using different operations. + _gather_params (list): A list of parameter names to be gathered from clients. + client_indexes (list): List of client indexes for which this local aggregator has data. + """" # def __init__(self, client_indexes, rank, group, **kwargs): def __init__(self, client_indexes, **kwargs): """ - client_indexes and group are used to indicate client_indexes that are - simulated by currernt LocalAggregator, - This will be used for gathering data. + Initialize the LocalAggregatorToServerParams object. + + Args: + client_indexes (list): List of client indexes that are simulated by this LocalAggregator. """ + super().__init__(**kwargs) self._reduce_params = dict([(ReduceOp.SUM, []),]) self._gather_params = [] @@ -71,6 +114,17 @@ def __init__(self, client_indexes, **kwargs): # self.group = group def add_reduce_param(self, name, param, op=ReduceOp.SUM): + """ + Add a parameter to be reduced. + + Args: + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be reduced. + op (ReduceOp, optional): The reduction operation (default is ReduceOp.SUM). + + Returns: + None + """ if name in self.__dict__: if isinstance(self.__dict__[name], list): for i, tensor in enumerate(param): @@ -83,8 +137,15 @@ def add_reduce_param(self, name, param, op=ReduceOp.SUM): def add_gather_params(self, client_index, name, param): """ - Server needs to add all gather param of all clients, - Then the collective communication can work. + Add parameters to be gathered from clients. + + Args: + client_index (int): The client index for which the parameter is added. + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be gathered. + + Returns: + None """ # new_name = f"client{client_index}_name" # self.__dict__.update({new_name: param}) @@ -96,6 +157,17 @@ def add_gather_params(self, client_index, name, param): self.__dict__[name][client_index] = param def communicate(self, rank, groups, client_schedule=None): + """ + Perform communication between local aggregator and server. + + Args: + rank (int): The rank of the local aggregator. + groups (dict): Dictionary of communication groups. + client_schedule (list, optional): Schedule of client indexes (default is None). + + Returns: + None + """ for param_name in self._reduce_params[ReduceOp.SUM]: param = getattr(self, param_name) if isinstance(param, list): @@ -128,29 +200,94 @@ def communicate(self, rank, groups, client_schedule=None): class ClientToLocalAggregatorParams(Params): + """ + Parameters sent from a client to a local aggregator for aggregation. + + This class represents parameters that are sent from a client to a local aggregator + for aggregation and communication within a local group. + + Attributes: + client_index (int): The client index. + _reduce_params (dict): A dictionary containing lists of parameters to be reduced using different operations. + _gather_params (list): A list of parameter names to be gathered by the local aggregator. + """ def __init__(self, client_index, **kwargs): + """ + Initialize the ClientToLocalAggregatorParams object. + + Args: + client_index (int): The client index for which the parameters are intended. + """ super().__init__(**kwargs) self.client_index = client_index self._reduce_params = dict([(ReduceOp.MEAN, []), (ReduceOp.SUM, []),]) self._gather_params = [] def add_reduce_param(self, name, param, op=ReduceOp.SUM): + """ + Add a parameter to be reduced. + + Args: + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be reduced. + op (ReduceOp, optional): The reduction operation (default is ReduceOp.SUM). + + Returns: + None + """ self.__dict__.update({name: param}) self._reduce_params[op].append(name) def add_gather_params(self, name, param): + """ + Add parameters to be gathered by the local aggregator. + + Args: + name (str): The name of the parameter. + param (torch.Tensor): The parameter to be gathered. + + Returns: + None + """ self.__dict__.update({name: param}) self._gather_params.append(name) def get_mean_reduce_param_names(self): + """ + Get the names of parameters to be reduced with the MEAN operation. + + Returns: + list: A list of parameter names. + """ return self._reduce_params[ReduceOp.MEAN] def get_sum_reduce_param_names(self): + """ + Get the names of parameters to be reduced with the SUM operation. + + Returns: + list: A list of parameter names. + """ return self._reduce_params[ReduceOp.SUM] def get_gather_param_names(self): + """ + Get the names of parameters to be gathered by the local aggregator. + + Returns: + list: A list of parameter names. + """ return self._gather_params def local_gather(local_gather_params): + """ + Perform local gathering of parameters. + + Args: + local_gather_params (ClientToLocalAggregatorParams): Parameters to be gathered. + + Returns: + None + """ pass From d4471bfc5803016aa9d5b72d12b09029fa71054f Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 18:09:36 +0530 Subject: [PATCH 44/70] `python\fedml\simulation\nccl` python\fedml\simulation\nccl done --- .../simulation/nccl/base_framework/Server.py | 176 ++++++++++++++++++ .../fedml/simulation/nccl/fedavg/FedAvgAPI.py | 20 ++ 2 files changed, 196 insertions(+) diff --git a/python/fedml/simulation/nccl/base_framework/Server.py b/python/fedml/simulation/nccl/base_framework/Server.py index aaae29ad63..82c8fee9c4 100644 --- a/python/fedml/simulation/nccl/base_framework/Server.py +++ b/python/fedml/simulation/nccl/base_framework/Server.py @@ -17,9 +17,57 @@ class BaseServer: Used to manage and aggregate results from local aggregators. We hope users does not need to modify this code. """ + """ + Used to manage and aggregate results from local aggregators. + + Attributes: + device (str): The device associated with this server. + args: Command-line arguments. + trainer: The trainer used for training. + train_global: Global training data. + test_global: Global test data. + val_global: Global validation data. + train_data_num_in_total (int): The total number of training data points. + test_data_num_in_total (int): The total number of test data points. + train_data_local_num_dict: A dictionary containing local training data counts. + train_data_local_dict: A dictionary containing local training data. + test_data_local_dict: A dictionary containing local test data. + comm: Communication object. + rank (int): The rank of this server. + worker_number (int): The total number of workers (devices). + device_number (int): The total number of devices excluding the server. + groups (dict): A dictionary of communication groups. + client_runtime_history (dict): A history of client runtimes. + + Methods: + client_sampling(round_idx, client_num_in_total, client_num_per_round): + Randomly sample clients for communication in a federated round. + + simulate_all_tasks(server_params, client_indexes): + Simulate tasks for all selected clients and create localAggregatorToServerParams. + + workload_estimate(client_indexes, mode="simulate"): + Estimate the workload of clients in a federated round. + + memory_estimate(client_indexes, mode="simulate"): + Estimate the memory usage of clients in a federated round. + """ # def __init__(self, args, trainer, device, dataset, comm=None, rank=0, size=0, backend="NCCL"): def __init__(self, args, rank, worker_number, comm, device, dataset, model, trainer): + """ + Initialize the BaseServer object. + + Args: + args: Command-line arguments. + rank (int): The rank of this server. + worker_number (int): The total number of workers (devices). + comm: Communication object. + device (str): The device associated with this server. + dataset: Dataset information. + model: The model used for federated learning. + trainer: The trainer used for training. + """ self.device = device self.args = args self.trainer = trainer @@ -56,6 +104,17 @@ def __init__(self, args, rank, worker_number, comm, device, dataset, model, trai self.client_runtime_history = {} def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample clients for communication in a federated round. + + Args: + round_idx (int): The index of the federated round. + client_num_in_total (int): The total number of clients in the dataset. + client_num_per_round (int): The number of clients to be sampled in each round. + + Returns: + list: A list of client indexes sampled for communication. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -66,6 +125,16 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def simulate_all_tasks(self, server_params, client_indexes): + """ + Simulate tasks for all selected clients and create localAggregatorToServerParams. + + Args: + server_params: Server parameters. + client_indexes (list): List of client indexes selected for communication. + + Returns: + LocalAggregatorToServerParams: Parameters to be communicated to the local aggregators. + """ localAggregatorToServerParams = LocalAggregatorToServerParams(None) # model_update = [torch.zeros_like(v) for v in get_weights(self.trainer.get_model_params())] # localAggregatorToServerParams.add_reduce_param(name="model_params", @@ -81,6 +150,16 @@ def simulate_all_tasks(self, server_params, client_indexes): return localAggregatorToServerParams def workload_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the workload of clients in a federated round. + + Args: + client_indexes (list): List of client indexes. + mode (str, optional): The mode for workload estimation (default is "simulate"). + + Returns: + list: A list of estimated client workloads. + """ if mode == "simulate": client_samples = [self.train_data_local_num_dict[client_index] for client_index in client_indexes] workload = client_samples @@ -91,6 +170,16 @@ def workload_estimate(self, client_indexes, mode="simulate"): return workload def memory_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the memory usage of clients in a federated round. + + Args: + client_indexes (list): List of client indexes. + mode (str, optional): The mode for memory estimation (default is "simulate"). + + Returns: + np.ndarray: An array representing the estimated memory usage for each client. + """ if mode == "simulate": memory = np.ones(self.device_number) elif mode == "real": @@ -100,6 +189,15 @@ def memory_estimate(self, client_indexes, mode="simulate"): return memory def resource_estimate(self, mode="simulate"): + """ + Estimate the resource usage of clients in a federated round. + + Args: + mode (str, optional): The mode for resource estimation (default is "simulate"). + + Returns: + np.ndarray: An array representing the estimated resource usage for each client. + """ if mode == "simulate": resource = np.ones(self.device_number) elif mode == "real": @@ -109,6 +207,19 @@ def resource_estimate(self, mode="simulate"): return resource def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, server_params, mode="simulate"): + """ + Schedule clients for communication in a federated round. + + Args: + round_idx (int): The index of the federated round. + client_num_in_total (int): The total number of clients in the dataset. + client_num_per_round (int): The number of clients to be scheduled in each round. + server_params: Server parameters. + mode (str, optional): The mode for scheduling (default is "simulate"). + + Returns: + tuple: A tuple containing the selected client indexes and their schedule for communication. + """ # scheduler(workloads, constraints, memory) client_indexes = self.client_sampling(round_idx, client_num_in_total, client_num_per_round) # workload = self.workload_estimate(client_indexes, mode) @@ -129,6 +240,15 @@ def client_schedule(self, round_idx, client_num_in_total, client_num_per_round, return client_indexes, client_schedule def get_average_weight(self, client_indexes): + """ + Calculate the average weight for each client based on their training data size. + + Args: + client_indexes (list): List of client indexes. + + Returns: + dict: A dictionary mapping client indexes to their average weights. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -139,6 +259,13 @@ def get_average_weight(self, client_indexes): return average_weight_dict def encode_average_weight_dict(self, server_params, average_weight_dict): + """ + Encode the average weight dictionary into server parameters. + + Args: + server_params: Server parameters. + average_weight_dict (dict): A dictionary mapping client indexes to their average weights. + """ server_params.add_broadcast_param( name="average_weight_dict_keys", param=torch.tensor(list(average_weight_dict.keys())) ) @@ -147,12 +274,49 @@ def encode_average_weight_dict(self, server_params, average_weight_dict): ) def decode_average_weight_dict(self, server_params): + """ + Decode the average weight dictionary received from the server. + + This method is used to decode the average weight dictionary that was previously encoded and broadcasted + by the server. The average weight dictionary represents the weights assigned to each client based on + their training data size. + + Args: + server_params (ServerToClientParams): The server parameters containing the average weight dictionary. + + Returns: + dict: The decoded average weight dictionary. + """ pass def record_client_runtime(self, client_runtimes): + """ + Record the runtime of each client during a training round. + + This method is used to record the runtime of each client during a training round. The client runtimes are + typically collected and communicated by the local aggregators. + + Args: + client_runtimes (list): A list of client runtimes for each client. + + Returns: + None + """ pass def train(self): + """ + Train the federated learning model using the server-client communication protocol. + + This method implements the federated learning training process by coordinating communication + between the server and clients for multiple rounds of training. + + Args: + None + + Returns: + None + """ server_params = ServerToClientParams() server_params.add_broadcast_param(name="broadcastTest", param=torch.tensor([1, 2, 3])) server_params.broadcast() @@ -198,6 +362,18 @@ def train(self): self.test_on_server_for_all_clients(round) def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients after a certain number of rounds. + + This method tests the federated learning model on both the training and test datasets + for all clients on the server side. + + Args: + round_idx (int): The current round index. + + Returns: + None + """ if self.trainer.test_on_the_server( self.train_data_local_dict, self.test_data_local_dict, self.device, self.args, ): diff --git a/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py b/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py index 094be26ee1..2e459ad72f 100644 --- a/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py +++ b/python/fedml/simulation/nccl/fedavg/FedAvgAPI.py @@ -4,6 +4,26 @@ def FedML_FedAvg_NCCL(args, process_id, worker_number, comm, device, dataset, model, model_trainer=None): + """ + Create a FedAvgServer or FedAvgLocalAggregator object based on the process ID. + + This function is a factory function for creating either a FedAvgServer or a FedAvgLocalAggregator object + based on the value of the process ID. If the process ID is 0, it creates a FedAvgServer object; otherwise, + it creates a FedAvgLocalAggregator object. + + Args: + args (object): Arguments for the federated learning setup. + process_id (int): The process ID. + worker_number (int): The total number of worker processes. + comm (object): The communication backend. + device (object): The device on which the model is trained. + dataset (tuple): A tuple containing dataset-related information. + model (object): The machine learning model. + model_trainer (object, optional): The model trainer. If not provided, it will be created. + + Returns: + object: A FedAvgServer or FedAvgLocalAggregator object. + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) if process_id == 0: From 14c75cb14978ed67273c2d117c7c2a95cf49bec0 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 19:44:57 +0530 Subject: [PATCH 45/70] `python\fedml\simulation\mpi python\fedml\simulation\mpi fedseg spilt_nn --- .../fedml/simulation/mpi/fedseg/FedSegAPI.py | 59 ++++ .../simulation/mpi/fedseg/FedSegAggregator.py | 96 +++++++ .../mpi/fedseg/FedSegClientManager.py | 92 +++++++ .../mpi/fedseg/FedSegServerManager.py | 86 ++++++ .../simulation/mpi/fedseg/FedSegTrainer.py | 78 ++++++ .../simulation/mpi/fedseg/MyModelTrainer.py | 49 ++++ python/fedml/simulation/mpi/fedseg/utils.py | 260 +++++++++++++++++- .../simulation/mpi/split_nn/SplitNNAPI.py | 47 ++++ .../fedml/simulation/mpi/split_nn/client.py | 30 ++ .../simulation/mpi/split_nn/client_manager.py | 86 ++++++ .../fedml/simulation/mpi/split_nn/server.py | 46 +++- .../simulation/mpi/split_nn/server_manager.py | 46 ++++ 12 files changed, 963 insertions(+), 12 deletions(-) diff --git a/python/fedml/simulation/mpi/fedseg/FedSegAPI.py b/python/fedml/simulation/mpi/fedseg/FedSegAPI.py index 4c07213060..fead0eaa4b 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegAPI.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegAPI.py @@ -10,6 +10,12 @@ def FedML_init(): + """ + Initialize the federated learning environment. + + Returns: + tuple: A tuple containing the MPI communicator (`comm`), process ID (`process_id`), and worker number (`worker_number`). + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -29,6 +35,25 @@ def FedML_FedSeg_distributed( args, model_trainer=None, ): + """ + Initialize and run the federated Segmentation training process. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of workers (including the server). + device: The device on which the model is trained. + comm: The MPI communicator. + model: The neural network model. + train_data_num: The number of training data samples. + train_data_local_num_dict: A dictionary containing the number of local training data samples for each worker. + train_data_local_dict: A dictionary containing the local training data for each worker. + test_data_local_dict: A dictionary containing the local testing data for each worker. + args: Additional arguments for the federated learning setup. + model_trainer: The model trainer for training the model (optional). + + Notes: + - If `process_id` is 0, it initializes the server. Otherwise, it initializes a client. + """ if process_id == 0: init_server(args, device, comm, process_id, worker_number, model, model_trainer) @@ -49,6 +74,21 @@ def FedML_FedSeg_distributed( def init_server(args, device, comm, rank, size, model, model_trainer): + """ + Initialize the federated learning server. + + Args: + args: Additional arguments for the server initialization. + device: The device on which the model is trained. + comm: The MPI communicator. + rank (int): The rank of the current process. + size (int): The total number of processes. + model: The neural network model. + model_trainer: The model trainer for training the model (optional). + + Notes: + This function initializes the server for federated Segmentation training. + """ logging.info("Initializing Server") if model_trainer is None: @@ -78,6 +118,25 @@ def init_client( test_data_local_dict, model_trainer, ): + """ + Initialize and run a federated learning client. + + Args: + args: Additional arguments for the client initialization. + device: The device on which the model is trained. + comm: The MPI communicator. + process_id (int): The ID of the current client process. + size (int): The total number of processes. + model: The neural network model. + train_data_num: The number of training data samples. + train_data_local_num_dict: A dictionary containing the number of local training data samples for each client. + train_data_local_dict: A dictionary containing the local training data for each client. + test_data_local_dict: A dictionary containing the local testing data for each client. + model_trainer: The model trainer for training the model (optional). + + Notes: + This function initializes and runs a federated learning client. + """ client_index = process_id - 1 logging.info("Initializing Client: {0}".format(client_index)) diff --git a/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py b/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py index bb126cae8e..a0e6e5f3d6 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegAggregator.py @@ -8,6 +8,44 @@ class FedSegAggregator(object): + """ + Federated Segmentation Aggregator for collecting and managing model updates and statistics from clients. + + Args: + worker_num (int): Number of worker (client) nodes. + device: The computing device (e.g., GPU) for training. + model: The segmentation model used in federated learning. + args: Additional configuration arguments. + model_trainer: Trainer for the segmentation model. + + Attributes: + trainer: The model trainer for training and evaluation. + worker_num (int): Number of worker (client) nodes. + device: The computing device for training. + args: Additional configuration arguments. + model_dict (dict): Dictionary to store model parameters received from clients. + sample_num_dict (dict): Dictionary to store the number of training samples from clients. + flag_client_model_uploaded_dict (dict): Dictionary to track whether each client has uploaded its model. + train_acc_client_dict (dict): Dictionary to store training accuracy for each client. + train_acc_class_client_dict (dict): Dictionary to store training class-wise accuracy for each client. + train_mIoU_client_dict (dict): Dictionary to store training mean Intersection over Union (mIoU) for each client. + train_FWIoU_client_dict (dict): Dictionary to store training frequency-weighted IoU (FWIoU) for each client. + train_loss_client_dict (dict): Dictionary to store training loss for each client. + test_acc_client_dict (dict): Dictionary to store test accuracy for each client. + test_acc_class_client_dict (dict): Dictionary to store test class-wise accuracy for each client. + test_mIoU_client_dict (dict): Dictionary to store test mean Intersection over Union (mIoU) for each client. + test_FWIoU_client_dict (dict): Dictionary to store test frequency-weighted IoU (FWIoU) for each client. + test_loss_client_dict (dict): Dictionary to store test loss for each client. + best_mIoU (float): Best mIoU value among all clients. + best_mIoU_clients (dict): Dictionary to store the clients with the best mIoU. + saver: Saver for saving experiment configurations and results. + + Methods: + get_global_model_params: Get the global model parameters. + set_global_model_params: Set the global model parameters. + add_local_trained_result: Add model parameters and sample count from a client. + check_whether_all_receive: Check if all clients have uploaded their models. + """ def __init__(self, worker_num, device, model, args, model_trainer): self.trainer = model_trainer self.worker_num = worker_num @@ -43,18 +81,44 @@ def __init__(self, worker_num, device, model, args, model_trainer): ) def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + Global model parameters. + """ return self.trainer.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters: Global model parameters to set. + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the model parameters and sample count from a client. + + Args: + index (int): Index or identifier of the client. + model_params: Model parameters trained by the client. + sample_num (int): Number of training samples used by the client. + """ logging.info("Add model index: {}".format(index)) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all clients have uploaded their models. + + Returns: + True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -63,6 +127,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate model updates from multiple clients. + + Returns: + Averaged model parameters after aggregation. + """ start_time = time.time() model_list = [] training_num = 0 @@ -93,6 +163,17 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select a subset of clients for federated learning. + + Args: + round_idx (int): Current federated learning round index. + client_num_in_total (int): Total number of available clients. + client_num_per_round (int): Number of clients to select for the current round. + + Returns: + List of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -115,6 +196,15 @@ def add_client_test_result( train_eval_metrics: EvaluationMetricsKeeper, test_eval_metrics: EvaluationMetricsKeeper, ): + """ + Add evaluation metrics and results from a client. + + Args: + round_idx (int): Current federated learning round index. + client_idx (int): Index or identifier of the client. + train_eval_metrics (EvaluationMetricsKeeper): Evaluation metrics for training data. + test_eval_metrics (EvaluationMetricsKeeper): Evaluation metrics for testing data. + """ logging.info("Adding client test result : {}".format(client_idx)) # Populating Training Dictionary @@ -176,6 +266,12 @@ def add_client_test_result( self.saver.save_checkpoint(saver_state, is_best, filename) def output_global_acc_and_loss(self, round_idx): + """ + Output global accuracy and loss statistics for the current federated learning round. + + Args: + round_idx (int): Current federated learning round index. + """ logging.info( "################## Output global accuracy and loss for round {} :".format( round_idx diff --git a/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py b/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py index de6ce9f19b..5370c557d4 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegClientManager.py @@ -7,16 +7,65 @@ class FedSegClientManager(FedMLCommManager): + """ + Client manager for federated segmentation. + + This class manages the client-side communication and training in a federated segmentation system. + + Args: + args: Additional configuration arguments. + trainer: Model trainer for federated segmentation. + comm: MPI communicator for distributed communication. + rank (int): Rank of the client. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + + Attributes: + args: Additional configuration arguments. + trainer: Model trainer for federated segmentation. + num_rounds (int): Number of communication rounds. + + Methods: + run(): Start the client manager. + register_message_receive_handlers(): Register message handlers for receiving initialization and model synchronization messages. + handle_message_init(msg_params): Handle the initialization message from the central server. + start_training(): Start the training process. + handle_message_receive_model_from_server(msg_params): Handle received model updates from the central server. + send_model_to_server(receive_id, weights, local_sample_num, train_evaluation_metrics, test_evaluation_metrics): Send trained model updates to the central server. + """ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the FedSegClientManager. + + Args: + args: Additional configuration arguments. + trainer: Model trainer for federated segmentation. + comm: MPI communicator for distributed communication. + rank (int): Rank of the client. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 def run(self): + """ + Start the client manager. + + Notes: + This function starts the client manager to handle communication and training. + """ super().run() def register_message_receive_handlers(self): + """ + Register message handlers for receiving initialization and model synchronization messages. + + Notes: + This function registers message handlers to process incoming messages from the central server. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +75,15 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the central server. + + Args: + msg_params (dict): Parameters included in the received message. + + Notes: + This function processes the initialization message from the central server, updates the model and dataset, and starts training. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) logging.info( @@ -39,10 +97,25 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the training process. + + Notes: + This function initiates the training process on the client side. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle received model updates from the central server. + + Args: + msg_params (dict): Parameters included in the received message. + + Notes: + This function processes received model updates from the central server, updates the model and dataset, and continues training. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -62,6 +135,19 @@ def send_model_to_server( train_evaluation_metrics, test_evaluation_metrics, ): + """ + Send trained model updates to the central server. + + Args: + receive_id (int): Receiver's ID. + weights: Trained model parameters. + local_sample_num (int): Number of local training samples. + train_evaluation_metrics: Evaluation metrics for training. + test_evaluation_metrics: Evaluation metrics for testing. + + Notes: + This function sends the trained model updates and evaluation metrics to the central server. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -78,6 +164,12 @@ def send_model_to_server( self.send_message(message) def __train(self): + """ + Perform training on the client side. + + Notes: + This method initiates the training process on the client side, including testing the global parameters, training the local model, and sending updates to the central server. + """ train_evaluation_metrics = test_evaluation_metrics = None logging.info( diff --git a/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py b/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py index d1382c6d77..677d17f757 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegServerManager.py @@ -7,7 +7,44 @@ class FedSegServerManager(FedMLCommManager): + """ + Server manager for federated segmentation. + + This class manages the server-side communication and aggregation of model updates in a federated segmentation system. + + Args: + args: Additional configuration arguments. + aggregator: Aggregator for federated segmentation models. + comm: MPI communicator for distributed communication. + rank (int): Rank of the server. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + + Attributes: + args: Additional configuration arguments. + aggregator: Aggregator for federated segmentation models. + round_num (int): Number of communication rounds. + + Methods: + run(): Start the server manager. + send_init_msg(): Send initial configuration messages to clients. + register_message_receive_handlers(): Register message handlers for receiving model updates from clients. + handle_message_receive_model_from_client(msg_params): Handle received model updates from clients. + send_message_init_config(receive_id, global_model_params, client_index): Send initial configuration messages to clients. + send_message_sync_model_to_client(receive_id, global_model_params, client_index): Send model synchronization messages to clients. + """ def __init__(self, args, aggregator, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the FedSegServerManager. + + Args: + args: Additional configuration arguments. + aggregator: Aggregator for federated segmentation models. + comm: MPI communicator for distributed communication. + rank (int): Rank of the server. + size (int): Total number of processes. + backend (str): Communication backend (default: "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.args = args self.aggregator = aggregator @@ -16,9 +53,21 @@ def __init__(self, args, aggregator, comm=None, rank=0, size=0, backend="MPI"): logging.info("Initializing Server Manager") def run(self): + """ + Start the server manager. + + Notes: + This function starts the server manager to handle communication and aggregation. + """ super().run() def send_init_msg(self): + """ + Send initial configuration messages to clients. + + Notes: + This function sends initial configuration messages to clients, including global model parameters and client indexes. + """ # sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, @@ -32,12 +81,27 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message handlers for receiving model updates from clients. + + Notes: + This function registers message handlers to process incoming messages from clients. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model updates from clients. + + Args: + msg_params (dict): Parameters included in the received message. + + Notes: + This function processes received model updates from clients, aggregates them, and initiates the next round of communication. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -82,6 +146,17 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initial configuration messages to clients. + + Args: + receive_id (int): Receiver's ID. + global_model_params: Global model parameters. + client_index (int): Index of the client. + + Notes: + This function sends initial configuration messages to clients, including global model parameters and client indexes. + """ logging.info("Initial Configurations sent to client {0}".format(client_index)) message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id @@ -93,6 +168,17 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send model synchronization messages to clients. + + Args: + receive_id (int): Receiver's ID. + global_model_params: Global model parameters. + client_index (int): Index of the client. + + Notes: + This function sends model synchronization messages to clients, updating their models with the global parameters. + """ logging.info( "send_message_sync_model_to_client. receive_id {0}".format(receive_id) ) diff --git a/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py b/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py index f0bda08b47..d9b057f030 100644 --- a/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py +++ b/python/fedml/simulation/mpi/fedseg/FedSegTrainer.py @@ -2,6 +2,40 @@ class FedSegTrainer(object): + """ + Trainer for federated segmentation models on a client. + + This class manages the training process of a federated segmentation model on a client. + + Args: + client_index (int): The index of the client within the federated system. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + train_data_num (int): Total number of training samples across all clients. + test_data_local_dict (dict): A dictionary containing local test data for each client. + device (torch.device): The device on which to perform training and evaluation. + model (nn.Module): The segmentation model to be trained. + args: Additional configuration arguments. + model_trainer: Trainer for the segmentation model. + + Attributes: + args: Additional configuration arguments. + trainer: Trainer for the segmentation model. + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data for the client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + test_data_local_dict (dict): A dictionary containing local test data for the client. + all_train_data_num (int): Total number of training samples across all clients. + train_local: Local training data for the client. + local_sample_number (int): The number of local training samples for the client. + test_local: Local test data for the client. + + Methods: + update_model(weights): Update the model with the provided weights. + update_dataset(client_index): Update the dataset for the client with the given index. + train(): Perform training on the local dataset and return trained weights and the number of local samples. + test(): Perform testing on the local test dataset and return evaluation metrics. + """ def __init__( self, client_index, @@ -14,6 +48,20 @@ def __init__( args, model_trainer, ): + """ + Initialize the FedSegTrainer for a client. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + train_data_num (int): Total number of training samples across all clients. + test_data_local_dict (dict): A dictionary containing local test data for each client. + device (torch.device): The device on which to perform training and evaluation. + model: The segmentation model to be trained. + args: Additional configuration arguments. + model_trainer: Trainer for the segmentation model. + """ self.args = args self.trainer = model_trainer @@ -30,15 +78,39 @@ def __init__( self.device = device def update_model(self, weights): + """ + Update the model with the provided weights. + + Args: + weights: Model weights to be set. + + Notes: + This function updates the model with the provided weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the dataset for the client with the given index. + + Args: + client_index (int): The index of the client. + + Notes: + This function updates the dataset and client-related attributes for the specified client index. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def train(self): + """ + Perform training on the local dataset and return trained weights and the number of local samples. + + Returns: + tuple: A tuple containing trained model weights and the number of local training samples. + """ self.trainer.train(self.train_local, self.device) weights = self.trainer.get_model_params() @@ -46,6 +118,12 @@ def train(self): return weights, self.local_sample_number def test(self): + """ + Perform testing on the local test dataset and return evaluation metrics. + + Returns: + tuple: A tuple containing evaluation metrics on the local test dataset. + """ train_evaluation_metrics = None if self.args.round_idx and self.args.round_idx % self.args.evaluation_frequency == 0: diff --git a/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py b/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py index a4230e4e29..abec3eb5a7 100644 --- a/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py +++ b/python/fedml/simulation/mpi/fedseg/MyModelTrainer.py @@ -9,7 +9,28 @@ class MyModelTrainer(ClientTrainer): + """ + A custom model trainer for federated learning clients. + + This trainer is designed for training and evaluating a segmentation model in a federated learning setting. + + Attributes: + model (nn.Module): The segmentation model to be trained and evaluated. + args: Additional configuration arguments for training and evaluation. + + Methods: + get_model_params(): Get the model parameters for the current trainer. + set_model_params(model_parameters): Set the model parameters for the current trainer. + train(train_data, device, args): Train the model on the provided training data. + test(test_data, device, args): Evaluate the model on the provided test data. + """ def get_model_params(self): + """ + Get the model parameters for the current trainer. + + Returns: + dict: A dictionary containing the model parameters. + """ if self.args.backbone_freezed: logging.info("Initializing model; Backbone Freezed") return self.model.encoder_decoder.cpu().state_dict() @@ -18,6 +39,12 @@ def get_model_params(self): return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters for the current trainer. + + Args: + model_parameters (dict): A dictionary containing the model parameters to be set. + """ if self.args.backbone_freezed: logging.info("Updating Global model; Backbone Freezed") self.model.encoder_decoder.load_state_dict(model_parameters) @@ -26,6 +53,17 @@ def set_model_params(self, model_parameters): self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the provided training data. + + Args: + train_data (DataLoader): DataLoader containing the training data. + device (torch.device): The device on which to perform training. + args: Additional arguments for training. + + Notes: + This function trains the model using the provided data and updates its parameters. + """ model = self.model args = self.args @@ -100,6 +138,17 @@ def train(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Evaluate the model on the provided test data. + + Args: + test_data (DataLoader): DataLoader containing the test data. + device (torch.device): The device on which to perform evaluation. + args: Additional arguments for evaluation. + + Returns: + EvaluationMetricsKeeper: An object containing various evaluation metrics. + """ logging.info("Evaluation on trainer ID:{}".format(self.id)) model = self.model args = self.args diff --git a/python/fedml/simulation/mpi/fedseg/utils.py b/python/fedml/simulation/mpi/fedseg/utils.py index 74cff9d830..815d4ab674 100644 --- a/python/fedml/simulation/mpi/fedseg/utils.py +++ b/python/fedml/simulation/mpi/fedseg/utils.py @@ -16,6 +16,15 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list of NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary containing model parameters as NumPy arrays. + + Returns: + dict: A dictionary containing model parameters as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -24,27 +33,73 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to a list of NumPy arrays. + + Args: + model_params (dict): A dictionary containing model parameters as PyTorch tensors. + + Returns: + dict: A dictionary containing model parameters as NumPy arrays. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def save_as_pickle_file(path, data): + """ + Save data to a pickle file. + + Args: + path (str): The file path where the data will be saved. + data (any): The data to be saved. + """ with open(path, "wb") as f: pickle.dump(data, f) f.close() def load_from_pickle_file(path): + """ + Load data from a pickle file. + + Args: + path (str): The file path from which the data will be loaded. + + Returns: + any: The loaded data. + """ return pickle.load(open(path, "rb")) def count_parameters(model): + """ + Count the number of trainable parameters in a PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model. + + Returns: + float: The number of trainable parameters in millions (M). + """ params = sum(p.numel() for p in model.parameters() if p.requires_grad) return params / 1000000 def str_to_bool(s): + """ + Convert a string to a boolean value. + + Args: + s (str): The input string. + + Returns: + bool: The boolean value corresponding to the string ("True" or "False"). + + Raises: + ValueError: If the input string is neither "True" nor "False". + """ if s == "True": return True elif s == "False": @@ -54,6 +109,23 @@ def str_to_bool(s): class EvaluationMetricsKeeper: + """ + A class to store and manage evaluation metrics. + + Args: + accuracy (float): Accuracy metric. + accuracy_class (float): Accuracy per class metric. + mIoU (float): Mean Intersection over Union (mIoU) metric. + FWIoU (float): Frequency-Weighted Intersection over Union (FWIoU) metric. + loss (float): Loss metric. + + Attributes: + acc (float): Accuracy metric. + acc_class (float): Accuracy per class metric. + mIoU (float): Mean Intersection over Union (mIoU) metric. + FWIoU (float): Frequency-Weighted Intersection over Union (FWIoU) metric. + loss (float): Loss metric. + """ def __init__(self, accuracy, accuracy_class, mIoU, FWIoU, loss): self.acc = accuracy self.acc_class = accuracy_class @@ -64,13 +136,37 @@ def __init__(self, accuracy, accuracy_class, mIoU, FWIoU, loss): # Segmentation Loss class SegmentationLosses(object): + """ + A class for managing segmentation loss functions. + + Args: + size_average (bool): Whether to compute the size-average loss. + batch_average (bool): Whether to compute the batch-average loss. + ignore_index (int): The index to ignore in the loss computation. + + Attributes: + ignore_index (int): The index to ignore in the loss computation. + size_average (bool): Whether to compute the size-average loss. + batch_average (bool): Whether to compute the batch-average loss. + """ def __init__(self, size_average=True, batch_average=True, ignore_index=255): self.ignore_index = ignore_index self.size_average = size_average self.batch_average = batch_average def build_loss(self, mode="ce"): - """Choices: ['ce' or 'focal']""" + """ + Build a segmentation loss function based on the specified mode. + + Args: + mode (str): The mode of the loss function. Choices: ['ce' or 'focal'] + + Returns: + function: The selected segmentation loss function. + + Raises: + NotImplementedError: If an unsupported mode is specified. + """ if mode == "ce": return self.CrossEntropyLoss elif mode == "focal": @@ -79,6 +175,19 @@ def build_loss(self, mode="ce"): raise NotImplementedError def CrossEntropyLoss(self, logit, target): + """ + Compute the Cross Entropy loss. + + Args: + logit (torch.Tensor): The predicted logit tensor. + target (torch.Tensor): The target tensor. + + Returns: + torch.Tensor: The computed loss. + + Note: + This function uses the specified ignore_index and handles size and batch averaging. + """ n, c, h, w = logit.size() criterion = nn.CrossEntropyLoss( ignore_index=self.ignore_index, size_average=self.size_average @@ -91,6 +200,21 @@ def CrossEntropyLoss(self, logit, target): return loss def FocalLoss(self, logit, target, gamma=2, alpha=0.5): + """ + Compute the Focal loss. + + Args: + logit (torch.Tensor): The predicted logit tensor. + target (torch.Tensor): The target tensor. + gamma (float): The Focal loss gamma parameter. + alpha (float): The Focal loss alpha parameter. + + Returns: + torch.Tensor: The computed loss. + + Note: + This function uses the specified ignore_index and handles size and batch averaging. + """ n, c, h, w = logit.size() criterion = nn.CrossEntropyLoss( ignore_index=self.ignore_index, size_average=self.size_average @@ -109,16 +233,33 @@ def FocalLoss(self, logit, target, gamma=2, alpha=0.5): # LR Scheduler class LR_Scheduler(object): - """Learning Rate Scheduler + """ + Learning Rate Scheduler for adjusting the learning rate during training. + Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` + Args: - args: - :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`,`step`), - :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, - :attr:`args.lr_step` - iters_per_epoch: number of iterations per epoch + mode (str): The mode of the learning rate scheduler. + Choices: ['cos', 'poly', 'step'] + - 'cos': Cosine mode. + - 'poly': Polynomial mode. + - 'step': Step mode. + base_lr (float): The base learning rate. + num_epochs (int): The total number of training epochs. + iters_per_epoch (int): The number of iterations per epoch. + lr_step (int): The step size for the 'step' mode. + warmup_epochs (int): The number of warm-up epochs. + + Attributes: + mode (str): The mode of the learning rate scheduler. + lr (float): The current learning rate. + lr_step (int): The step size for the 'step' mode. + iters_per_epoch (int): The number of iterations per epoch. + N (int): The total number of iterations over all epochs. + epoch (int): The current epoch. + warmup_iters (int): The number of warm-up iterations. """ def __init__( @@ -136,6 +277,14 @@ def __init__( self.warmup_iters = warmup_epochs * iters_per_epoch def __call__(self, optimizer, i, epoch): + """ + Adjusts the learning rate based on the specified mode. + + Args: + optimizer: The optimizer whose learning rate will be adjusted. + i (int): The current iteration. + epoch (int): The current epoch. + """ T = epoch * self.iters_per_epoch + i if self.mode == "cos": lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) @@ -154,6 +303,13 @@ def __call__(self, optimizer, i, epoch): self._adjust_learning_rate(optimizer, lr) def _adjust_learning_rate(self, optimizer, lr): + """ + Adjusts the learning rate of the optimizer. + + Args: + optimizer: The optimizer whose learning rate will be adjusted. + lr (float): The new learning rate. + """ if len(optimizer.param_groups) == 1: optimizer.param_groups[0]["lr"] = lr else: @@ -165,7 +321,25 @@ def _adjust_learning_rate(self, optimizer, lr): # save model checkpoints (centralized) class Saver(object): + """ + Utility class for saving checkpoints and experiment configuration. + + Args: + args (argparse.Namespace): The command-line arguments. + + Attributes: + args (argparse.Namespace): The command-line arguments. + directory (str): The directory where experiments are stored. + runs (list): A list of existing experiment directories. + experiment_dir (str): The directory for the current experiment. + """ def __init__(self, args): + """ + Initializes a new Saver object for saving checkpoints and experiment configuration. + + Args: + args (argparse.Namespace): The command-line arguments. + """ self.args = args self.directory = os.path.join("run", args.dataset, args.model, args.checkname) self.runs = sorted(glob.glob(os.path.join(self.directory, "experiment_*"))) @@ -178,7 +352,14 @@ def __init__(self, args): os.makedirs(self.experiment_dir) def save_checkpoint(self, state, is_best, filename="checkpoint.pth.tar"): - """Saves checkpoint to disk""" + """ + Saves a checkpoint to disk. + + Args: + state (dict): The state to be saved. + is_best (bool): True if this is the best checkpoint, False otherwise. + filename (str, optional): The filename for the checkpoint. Defaults to "checkpoint.pth.tar". + """ filename = os.path.join(self.experiment_dir, filename) torch.save(state, filename) if is_best: @@ -211,6 +392,9 @@ def save_checkpoint(self, state, is_best, filename="checkpoint.pth.tar"): ) def save_experiment_config(self): + """ + Saves the experiment configuration to a text file. + """ logfile = os.path.join(self.experiment_dir, "parameters.txt") log_file = open(logfile, "w") @@ -251,20 +435,54 @@ def save_experiment_config(self): # Evaluation Metrics class Evaluator(object): + """ + Class for evaluating segmentation results. + + Args: + num_class (int): The number of classes in the segmentation task. + + Attributes: + num_class (int): The number of classes in the segmentation task. + confusion_matrix (numpy.ndarray): The confusion matrix for evaluating segmentation results. + """ def __init__(self, num_class): + """ + Initializes an Evaluator object for evaluating segmentation results. + + Args: + num_class (int): The number of classes in the segmentation task. + """ self.num_class = num_class self.confusion_matrix = np.zeros((self.num_class,) * 2) def Pixel_Accuracy(self): + """ + Computes the Pixel Accuracy for segmentation evaluation. + + Returns: + float: The Pixel Accuracy. + """ Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() return Acc def Pixel_Accuracy_Class(self): + """ + Computes the Pixel Accuracy per class for segmentation evaluation. + + Returns: + float: The mean Pixel Accuracy per class. + """ Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) Acc = np.nanmean(Acc) return Acc def Mean_Intersection_over_Union(self): + """ + Computes the Mean Intersection over Union (IoU) for segmentation evaluation. + + Returns: + float: The Mean IoU. + """ MIoU = np.diag(self.confusion_matrix) / ( np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) @@ -274,6 +492,12 @@ def Mean_Intersection_over_Union(self): return MIoU def Frequency_Weighted_Intersection_over_Union(self): + """ + Computes the Frequency Weighted Intersection over Union (IoU) for segmentation evaluation. + + Returns: + float: The Frequency Weighted IoU. + """ freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) iu = np.diag(self.confusion_matrix) / ( np.sum(self.confusion_matrix, axis=1) @@ -285,6 +509,16 @@ def Frequency_Weighted_Intersection_over_Union(self): return FWIoU def _generate_matrix(self, gt_image, pre_image): + """ + Generates a confusion matrix for segmentation evaluation. + + Args: + gt_image (numpy.ndarray): Ground truth segmentation image. + pre_image (numpy.ndarray): Predicted segmentation image. + + Returns: + numpy.ndarray: The confusion matrix. + """ mask = (gt_image >= 0) & (gt_image < self.num_class) label = self.num_class * gt_image[mask].astype("int") + pre_image[mask] count = np.bincount(label, minlength=self.num_class**2) @@ -292,8 +526,18 @@ def _generate_matrix(self, gt_image, pre_image): return confusion_matrix def add_batch(self, gt_image, pre_image): + """ + Adds a batch of ground truth and predicted images for evaluation. + + Args: + gt_image (numpy.ndarray): Batch of ground truth segmentation images. + pre_image (numpy.ndarray): Batch of predicted segmentation images. + """ assert gt_image.shape == pre_image.shape self.confusion_matrix += self._generate_matrix(gt_image, pre_image) def reset(self): + """ + Resets the confusion matrix to zero. + """ self.confusion_matrix = np.zeros((self.num_class,) * 2) diff --git a/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py b/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py index 76ab846d65..26834eee89 100644 --- a/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py +++ b/python/fedml/simulation/mpi/split_nn/SplitNNAPI.py @@ -10,6 +10,21 @@ def SplitNN_distributed( process_id, worker_number, device, comm, model, dataset, args, ): + """ + Initialize and distribute a Split Neural Network for training. + + Args: + process_id (int): The ID of the current process. + worker_number (int): Total number of worker processes. + device: The computing device (e.g., GPU) for training. + comm: Communication backend for distributed training. + model: The neural network model to be trained. + dataset: Dataset information including data splits. + args: Additional training configuration arguments. + + Returns: + None + """ [ train_data_num, local_data_num, @@ -47,6 +62,20 @@ def SplitNN_distributed( def init_server(comm, server_model, process_id, worker_number, device, args): + """ + Initialize and run the server-side component of Split Neural Network training. + + Args: + comm: Communication backend for distributed training. + server_model: The server-side portion of the neural network model. + process_id (int): The ID of the current process. + worker_number (int): Total number of worker processes. + device: The computing device (e.g., GPU) for training. + args: Additional training configuration arguments. + + Returns: + none + """ arg_dict = { "comm": comm, "model": server_model, @@ -63,6 +92,24 @@ def init_server(comm, server_model, process_id, worker_number, device, args): def init_client( comm, client_model, worker_number, train_data_local, test_data_local, process_id, server_rank, epochs, device, args, ): + """ + Initialize and run the client-side component of Split Neural Network training. + + Args: + comm: Communication backend for distributed training. + client_model: The client-side portion of the neural network model. + worker_number (int): Total number of worker processes. + train_data_local: Local training data for the client. + test_data_local: Local testing data for the client. + process_id (int): The ID of the current process. + server_rank (int): The rank of the server process. + epochs: Number of training epochs for the client. + device: The computing device (e.g., GPU) for training. + args: Additional training configuration arguments. + + Returns: + None + """ client_ID = process_id - 1 arg_dict = { "client_index": client_ID, diff --git a/python/fedml/simulation/mpi/split_nn/client.py b/python/fedml/simulation/mpi/split_nn/client.py index 9fc8a816de..9a16af3a22 100644 --- a/python/fedml/simulation/mpi/split_nn/client.py +++ b/python/fedml/simulation/mpi/split_nn/client.py @@ -4,7 +4,19 @@ class SplitNN_client: + """ + SplitNNClient class represents a client in a Split Learning setup. + + Args: + args (dict): Dictionary containing client-specific configuration. + """ def __init__(self, args): + """ + Initialize a SplitNNClient instance. + + Args: + args (dict): Dictionary containing client-specific configuration. + """ self.client_idx = args['client_index'] self.comm = args["comm"] self.model = args["model"] @@ -26,6 +38,12 @@ def __init__(self, args): self.device = args["device"] def forward_pass(self): + """ + Perform a forward pass through the model. + + Returns: + tuple: Tuple containing model activations (outputs) and labels. + """ logging.info("forward_pass") inputs, labels = next(self.dataloader) inputs, labels = inputs.to(self.device), labels.to(self.device) @@ -40,16 +58,28 @@ def forward_pass(self): return self.acts, labels def backward_pass(self, grads): + """ + Perform a backward pass and update model parameters. + + Args: + grads: Gradients used for the backward pass. + """ logging.info("backward_pass") self.acts.backward(grads) self.optimizer.step() def eval_mode(self): + """ + Switch the model to evaluation mode and prepare the test data loader. + """ logging.info("eval_mode") self.dataloader = iter(self.testloader) self.model.eval() def train_mode(self): + """ + Switch the model to training mode and prepare the training data loader. + """ logging.info("train_mode") self.dataloader = iter(self.trainloader) self.model.train() diff --git a/python/fedml/simulation/mpi/split_nn/client_manager.py b/python/fedml/simulation/mpi/split_nn/client_manager.py index 1bb1e84580..73639f3fcb 100644 --- a/python/fedml/simulation/mpi/split_nn/client_manager.py +++ b/python/fedml/simulation/mpi/split_nn/client_manager.py @@ -6,7 +6,27 @@ class SplitNNClientManager(FedMLCommManager): + """ + Manages the client-side operations for Split Learning in a Federated Learning setting. + + Args: + arg_dict (dict): A dictionary containing necessary arguments. + trainer (Trainer): The trainer responsible for the client's model. + backend (str): The communication backend (e.g., "MPI"). + + Attributes: + trainer (Trainer): The trainer responsible for the client's model. + args (args): Arguments for the client manager. + """ def __init__(self, arg_dict, trainer, backend="MPI"): + """ + Initialize a SplitNNClientManager. + + Args: + arg_dict (dict): A dictionary containing necessary arguments. + trainer (Trainer): The trainer responsible for the client's model. + backend (str): The communication backend (e.g., "MPI"). + """ super().__init__( arg_dict["args"], arg_dict["comm"], @@ -19,12 +39,20 @@ def __init__(self, arg_dict, trainer, backend="MPI"): self.args.round_idx = 0 def run(self): + """ + Start the client manager. + + If the trainer's rank is 1, it starts the protocol by running the forward pass. + """ if self.trainer.rank == 1: logging.info("Starting protocol from rank 1 process") self.run_forward_pass() super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2C_SEMAPHORE, self.handle_message_semaphore ) @@ -33,12 +61,23 @@ def register_message_receive_handlers(self): ) def handle_message_semaphore(self, msg_params): + """ + Handle the semaphore message and start the training process. + + Args: + msg_params: Parameters of the semaphore message. + """ # no point in checking the semaphore message logging.info("Starting training at node {}".format(self.trainer.rank)) self.trainer.train_mode() self.run_forward_pass() def run_forward_pass(self): + """ + Run the forward pass of the trainer. + + Sends activations and labels to the server afterward. + """ acts, labels = self.trainer.forward_pass() self.send_activations_and_labels_to_server( acts, labels, self.trainer.SERVER_RANK @@ -46,6 +85,15 @@ def run_forward_pass(self): self.trainer.batch_idx += 1 def run_eval(self): + """ + Run the evaluation process for the client. + + This method sends a validation signal to the server, switches the trainer to evaluation mode, + and performs the forward pass for validation data. After validation, it sends a validation + completion signal to the server and updates the round index. If the maximum number of + epochs per node is reached, it sends a finish signal to the server. + + """ self.send_validation_signal_to_server(self.trainer.SERVER_RANK) self.trainer.eval_mode() for i in range(len(self.trainer.testloader)): @@ -69,6 +117,12 @@ def run_eval(self): self.finish() def handle_message_gradients(self, msg_params): + """ + Handle received gradients and initiate backward pass. + + Args: + msg_params: Parameters of the received gradients message. + """ grads = msg_params.get(MyMessage.MSG_ARG_KEY_GRADS) self.trainer.backward_pass(grads) if self.trainer.batch_idx == len(self.trainer.trainloader): @@ -79,6 +133,14 @@ def handle_message_gradients(self, msg_params): self.run_forward_pass() def send_activations_and_labels_to_server(self, acts, labels, receive_id): + """ + Send activations and labels to the server. + + Args: + acts: Activations to be sent. + labels: Labels corresponding to the activations. + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_ACTS, self.get_sender_id(), receive_id ) @@ -86,24 +148,48 @@ def send_activations_and_labels_to_server(self, acts, labels, receive_id): self.send_message(message) def send_semaphore_to_client(self, receive_id): + """ + Send a semaphore message to a client. + + Args: + receive_id: ID of the receiving client. + """ message = Message( MyMessage.MSG_TYPE_C2C_SEMAPHORE, self.get_sender_id(), receive_id ) self.send_message(message) def send_validation_signal_to_server(self, receive_id): + """ + Send a validation signal message to the server. + + Args: + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_VALIDATION_MODE, self.get_sender_id(), receive_id ) self.send_message(message) def send_validation_over_to_server(self, receive_id): + """ + Send a validation completion signal message to the server. + + Args: + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_VALIDATION_OVER, self.get_sender_id(), receive_id ) self.send_message(message) def send_finish_to_server(self, receive_id): + """ + Send a finish signal message to the server. + + Args: + receive_id: ID of the receiving entity (typically, the server). + """ message = Message( MyMessage.MSG_TYPE_C2S_PROTOCOL_FINISHED, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/split_nn/server.py b/python/fedml/simulation/mpi/split_nn/server.py index 1187d16fc4..cf9cff9c17 100644 --- a/python/fedml/simulation/mpi/split_nn/server.py +++ b/python/fedml/simulation/mpi/split_nn/server.py @@ -5,21 +5,39 @@ class SplitNN_server: + """ + SplitNN Server for managing communication and training. + """ + def __init__(self, args): + """ + Initialize the SplitNN Server. + + Args: + args (dict): A dictionary containing configuration arguments. + """ self.comm = args["comm"] self.model = args["model"] self.MAX_RANK = args["max_rank"] self.init_params() def init_params(self): + """ + Initialize training parameters and optimizer. + """ self.epoch = 0 self.log_step = 50 self.active_node = 1 self.train_mode() - self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) + self.optimizer = optim.SGD( + self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4 + ) self.criterion = nn.CrossEntropyLoss() def reset_local_params(self): + """ + Reset local training parameters. + """ logging.info("reset_local_params") self.total = 0 self.correct = 0 @@ -28,25 +46,38 @@ def reset_local_params(self): self.batch_idx = 0 def train_mode(self): + """ + Switch to training mode. + """ logging.info("train_mode") self.model.train() self.phase = "train" self.reset_local_params() def eval_mode(self): + """ + Switch to evaluation mode. + """ logging.info("eval_mode") self.model.eval() self.phase = "validation" self.reset_local_params() def forward_pass(self, acts, labels): + """ + Perform a forward pass of the model. + + Args: + acts: Activations. + labels: Ground truth labels. + """ logging.info("forward_pass") self.acts = acts self.optimizer.zero_grad() self.acts.retain_grad() logits = self.model(acts) _, predictions = logits.max(1) - self.loss = self.criterion(logits, labels) # pylint: disable=E1102 + self.loss = self.criterion(logits, labels) self.total += labels.size(0) self.correct += predictions.eq(labels).sum().item() if self.step % self.log_step == 0 and self.phase == "train": @@ -61,18 +92,25 @@ def forward_pass(self, acts, labels): self.step += 1 def backward_pass(self): + """ + Perform a backward pass and update model weights. + """ logging.info("backward_pass") self.loss.backward() self.optimizer.step() return self.acts.grad def validation_over(self): + """ + Handle the end of validation and switch to the next node. + """ logging.info("validation_over") - # not precise estimation of validation loss self.val_loss /= self.step acc = self.correct / self.total logging.info( - "phase={} acc={} loss={} epoch={} and step={}".format(self.phase, acc, self.val_loss, self.epoch, self.step) + "phase={} acc={} loss={} epoch={} and step={}".format( + self.phase, acc, self.val_loss, self.epoch, self.step + ) ) self.epoch += 1 diff --git a/python/fedml/simulation/mpi/split_nn/server_manager.py b/python/fedml/simulation/mpi/split_nn/server_manager.py index cd7aa3ad52..683bb1e50f 100644 --- a/python/fedml/simulation/mpi/split_nn/server_manager.py +++ b/python/fedml/simulation/mpi/split_nn/server_manager.py @@ -4,7 +4,18 @@ class SplitNNServerManager(FedMLCommManager): + """ + Manager for the SplitNN server that handles communication. + """ def __init__(self, arg_dict, trainer, backend="MPI"): + """ + Initialize the SplitNNServerManager. + + Args: + arg_dict (dict): A dictionary containing configuration arguments. + trainer: The trainer instance for the server. + backend (str): The communication backend to use (default is "MPI"). + """ super().__init__( arg_dict["args"], arg_dict["comm"], @@ -19,6 +30,9 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_ACTS, self.handle_message_acts ) @@ -34,6 +48,12 @@ def register_message_receive_handlers(self): ) def send_grads_to_client(self, receive_id, grads): + """ + Handle a message containing activations. + + Args: + msg_params (dict): Parameters of the received message. + """ message = Message( MyMessage.MSG_TYPE_S2C_GRADS, self.get_sender_id(), receive_id ) @@ -41,6 +61,12 @@ def send_grads_to_client(self, receive_id, grads): self.send_message(message) def handle_message_acts(self, msg_params): + """ + Handle a message containing activations. + + Args: + msg_params (dict): Parameters of the received message. + """ acts, labels = msg_params.get(MyMessage.MSG_ARG_KEY_ACTS) self.trainer.forward_pass(acts, labels) if self.trainer.phase == "train": @@ -48,10 +74,30 @@ def handle_message_acts(self, msg_params): self.send_grads_to_client(self.trainer.active_node, grads) def handle_message_validation_mode(self, msg_params): + """ + Handle a message indicating validation mode. + + Args: + msg_params (dict): Parameters of the received message. + """ + self.trainer.eval_mode() def handle_message_validation_over(self, msg_params): + """ + Handle a message indicating the end of validation. + + Args: + msg_params (dict): Parameters of the received message. + """ + self.trainer.validation_over() def handle_message_finish_protocol(self): + """ + Handle a message indicating the protocol has finished. + + Args: + msg_params (dict): Parameters of the received message. + """ self.finish() From a8d0246475db4fa960290e458fe9124fcf0729ca Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 7 Sep 2023 21:28:52 +0530 Subject: [PATCH 46/70] python\fedml\simulation\mpi\async_fedavg python\fedml\simulation\mpi\async_fedavg --- .../mpi/async_fedavg/AsyncFedAVGAggregator.py | 138 ++++++++++++++++++ .../mpi/async_fedavg/AsyncFedAVGTrainer.py | 74 ++++++++++ .../async_fedavg/AsyncFedAvgClientManager.py | 108 ++++++++++++-- .../mpi/async_fedavg/AsyncFedAvgSeqAPI.py | 58 ++++++++ .../async_fedavg/AsyncFedAvgServerManager.py | 93 ++++++++++++ .../mpi/async_fedavg/MyModelTrainer.py | 60 ++++++++ .../mpi/async_fedavg/my_model_trainer.py | 60 ++++++++ .../my_model_trainer_classification.py | 60 ++++++++ .../mpi/async_fedavg/my_model_trainer_nwp.py | 60 ++++++++ .../my_model_trainer_tag_prediction.py | 60 ++++++++ .../simulation/mpi/async_fedavg/utils.py | 29 ++++ 11 files changed, 789 insertions(+), 11 deletions(-) diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py index ce87629513..50a14adb20 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGAggregator.py @@ -12,6 +12,67 @@ from ....core.schedule.runtime_estimate import t_sample_fit class AsyncFedAVGAggregator(object): + """ + Aggregator for the asynchronous Federated Averaging server in a federated learning system. + + Args: + train_global: Global training data. + test_global: Global testing data. + all_train_data_num: Total number of training data samples. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local testing data for each client. + train_data_local_num_dict: Dictionary containing the number of local training data samples for each client. + worker_num: Number of worker processes. + device: The computing device (e.g., CPU or GPU). + args: Command-line arguments and configurations. + model_trainer: Trainer for the federated learning model. + + Attributes: + trainer: Trainer for the federated learning model. + args: Command-line arguments and configurations. + train_global: Global training data. + test_global: Global testing data. + val_global: Global validation data generated from the global training data. + all_train_data_num: Total number of training data samples. + train_data_local_dict: Dictionary containing local training data for each client. + test_data_local_dict: Dictionary containing local testing data for each client. + train_data_local_num_dict: Dictionary containing the number of local training data samples for each client. + worker_num: Number of worker processes. + device: The computing device (e.g., CPU or GPU). + model_dict: Dictionary containing client models indexed by client ID. + sample_num_dict: Dictionary containing the number of samples trained by each client. + flag_client_model_uploaded_dict: Dictionary tracking whether client models have been uploaded. + runtime_history: Dictionary containing runtime information for clients. + model_weights: Global model weights updated during aggregation. + client_running_status: Array tracking the status of running clients. + + Methods: + get_global_model_params(): + Get the global model parameters. + + set_global_model_params(model_parameters): + Set the global model parameters. + + add_local_trained_result(index, model_params, local_sample_number, + current_round, client_round): + Add the locally trained model results to the aggregator and update the global model. + + client_schedule(round_idx, client_indexes, mode="simulate"): + Generate a schedule for clients based on runtime information. + + get_average_weight(client_indexes): + Calculate the average weight assigned to each client based on the number of training samples. + + client_sampling(round_idx, client_num_in_total, client_num_per_round): + Sample clients for communication in a round. + + _generate_validation_set(num_samples=10000): + Generate a validation set from the global testing data. + + test_on_server_for_all_clients(round_idx): + Perform testing on the server for all clients and log the results. + """ + def __init__( self, train_global, @@ -54,14 +115,42 @@ def __init__( def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ # return self.trainer.get_model_params() return self.model_weights def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters to be set. + + Returns: + None + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, local_sample_number, current_round, client_round): + """ + Add the locally trained model results to the aggregator and update the global model. + + Args: + index (int): Index of the client. + model_params (dict): Model parameters trained by the client. + local_sample_number (int): Number of local training data samples used by the client. + current_round (int): Current communication round. + client_round (int): Round index for the client. + + Returns: + None + """ logging.info("add_model. index = %d" % index) self.client_running_status = np.setdiff1d(self.client_running_status, @@ -76,6 +165,17 @@ def add_local_trained_result(self, index, model_params, local_sample_number, def client_schedule(self, round_idx, client_indexes, mode="simulate"): + """ + Generate a schedule for clients based on runtime information. + + Args: + round_idx (int): Current communication round. + client_indexes (list): List of client indexes. + mode (str): The scheduling mode ("simulate" or "release"). + + Returns: + list: List of client schedules. + """ self.runtime_history = {} for i in range(self.worker_num): self.runtime_history[i] = {} @@ -91,6 +191,15 @@ def client_schedule(self, round_idx, client_indexes, mode="simulate"): def get_average_weight(self, client_indexes): + """ + Calculate the average weight assigned to each client based on the number of training samples. + + Args: + client_indexes (list): List of client indexes. + + Returns: + dict: A dictionary mapping client indexes to their respective average weights. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -102,6 +211,17 @@ def get_average_weight(self, client_indexes): def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample clients for communication in a round. + + Args: + round_idx (int): Current communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of client indexes selected for communication in the current round. + """ num_clients = min(client_num_per_round, client_num_in_total) np.random.seed( round_idx @@ -116,6 +236,15 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set from the global testing data. + + Args: + num_samples (int): Number of samples to include in the validation set. + + Returns: + torch.utils.data.DataLoader: DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -130,6 +259,15 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients and log the results. + + Args: + round_idx (int): Current communication round. + + Returns: + None + """ if self.trainer.test_on_the_server( self.train_data_local_dict, self.test_data_local_dict, diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py index 1265fb298a..589480abe7 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAVGTrainer.py @@ -2,6 +2,46 @@ class AsyncFedAVGTrainer(object): + """ + An asynchronous Federated Averaging trainer for client nodes in a federated learning system. + + Args: + client_index (int): The index of the client node. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + test_data_local_dict (dict): A dictionary containing local testing data for each client. + train_data_num (int): The total number of training samples across all clients. + device (torch.device): The device (e.g., CPU or GPU) to perform training and testing on. + args (argparse.Namespace): Command-line arguments and configurations for training. + model_trainer (ClientTrainer): An instance of a client-side model trainer. + + Attributes: + trainer (ClientTrainer): The model trainer used for training and testing. + client_index (int): The index of the client node. + train_data_local_dict (dict): A dictionary containing local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples for each client. + test_data_local_dict (dict): A dictionary containing local testing data for each client. + all_train_data_num (int): The total number of training samples across all clients. + train_local (Dataset): The local training dataset for the current client. + local_sample_number (int): The number of local training samples for the current client. + test_local (Dataset): The local testing dataset for the current client. + device (torch.device): The device used for training and testing. + args (argparse.Namespace): Command-line arguments and configurations for training. + + Methods: + update_model(weights): + Update the model's weights with the provided weights. + + update_dataset(client_index): + Update the local training and testing datasets for the current client. + + train(round_idx=None): + Train the model on the local training dataset. + + test(): + Test the model on both the local training and testing datasets. + + """ def __init__( self, client_index, @@ -28,15 +68,42 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model's weights with the provided weights. + + Args: + weights (dict): The model parameters as a dictionary of tensors. + + Returns: + None + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local training and testing datasets for the current client. + + Args: + client_index (int): The index of the current client. + + Returns: + None + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def train(self, round_idx=None): + """ + Train the model on the local training dataset. + + Args: + round_idx (int, optional): The current round index. Defaults to None. + + Returns: + tuple: A tuple containing the trained model's weights and the number of local training samples. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) @@ -45,6 +112,13 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """ + Test the model on both the local training and testing datasets. + + Returns: + tuple: A tuple containing various metrics, including training accuracy, training loss, the number + of training samples, testing accuracy, testing loss, and the number of testing samples. + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py index 3e8fb9af07..d97c437f21 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgClientManager.py @@ -11,6 +11,45 @@ class AsyncFedAVGClientManager(FedMLCommManager): + """ + Client manager for asynchronous Federated Averaging in a federated learning system. + + Args: + args: Command-line arguments and configurations. + trainer: Trainer for the federated learning model. + comm: Communication backend. + rank: Rank of the client manager. + size: Total number of client managers. + backend: Communication backend type (default: "MPI"). + + Attributes: + trainer: Trainer for the federated learning model. + num_rounds: Total number of communication rounds. + round_idx: Current communication round index. + worker_id: Unique identifier for the client manager. + + Methods: + run(): + Run the client manager. + + register_message_receive_handlers(): + Register message receive handlers for communication. + + handle_message_init(msg_params): + Handle the initialization message from the server. + + start_training(): + Start the training process. + + handle_message_receive_model_from_server(msg_params): + Handle the received model from the server. + + send_result_to_server(receive_id, weights, local_sample_num, client_runtime_info): + Send training results to the server. + + __train(global_model_params, client_index): + Perform model training for a client. + """ def __init__( self, args, @@ -27,9 +66,21 @@ def __init__( self.worker_id = self.rank - 1 def run(self): + """ + Run the communication manager. + + Returns: + None + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -39,6 +90,15 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params (dict): Dictionary of message parameters. + + Returns: + None + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -46,10 +106,25 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_index) def start_training(self): + """ + Start the training process. + + Returns: + None + """ self.round_idx = 0 # self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Dictionary of message parameters. + + Returns: + None + """ logging.info("handle_message_receive_model_from_server.") global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -62,6 +137,18 @@ def handle_message_receive_model_from_server(self, msg_params): def send_result_to_server(self, receive_id, weights, local_sample_num, client_runtime_info): + """ + Send the training result to the server. + + Args: + receive_id (int): ID of the message receiver. + weights (dict): Model parameters. + local_sample_num (int): Number of local training samples. + client_runtime_info (dict): Dictionary of client runtime information. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -74,6 +161,16 @@ def send_result_to_server(self, receive_id, weights, local_sample_num, client_ru def __train(self, global_model_params, client_index): + """ + Perform the training process for a client. + + Args: + global_model_params (dict): Global model parameters. + client_index (int): Index of the client. + + Returns: + None + """ logging.info("#######training########### round_id = %d" % self.round_idx) local_agg_model_params = {} @@ -92,14 +189,3 @@ def __train(self, global_model_params, client_index): # diff_weights = get_name_params_difference(global_model_params, weights) # weights - global_model_params self.send_result_to_server(0, weights, local_sample_num, client_runtime_info) - - - - - - - - - - - diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py index 25c0f9bfd9..bb3424f773 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgSeqAPI.py @@ -11,6 +11,23 @@ def FedML_Async_distributed( args, process_id, worker_number, comm, device, dataset, model, model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Run the asynchronous federated learning process. + + Args: + args (object): An object containing the configuration parameters. + process_id (int): The unique ID of the current process. + worker_number (int): The total number of worker processes. + comm (object): The communication object. + device (object): The device to run the training on (e.g., GPU). + dataset (list): A list containing dataset-related information. + model (object): The federated learning model. + model_trainer (object, optional): The model trainer object. Defaults to None. + preprocessed_sampling_lists (list, optional): Preprocessed sampling lists for clients. Defaults to None. + + Returns: + None + """ [ train_data_num, test_data_num, @@ -75,6 +92,28 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for asynchronous federated learning. + + Args: + args (object): An object containing the configuration parameters. + device (object): The device to run the training on (e.g., GPU). + comm (object): The communication object. + rank (int): The rank of the current process. + size (int): The total number of processes. + model (object): The federated learning model. + train_data_num (int): The number of training data samples. + train_data_global (object): The global training dataset. + test_data_global (object): The global test dataset. + train_data_local_dict (dict): A dictionary containing local training data for clients. + test_data_local_dict (dict): A dictionary containing local test data for clients. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for clients. + model_trainer (object): The model trainer object. + preprocessed_sampling_lists (list, optional): Preprocessed sampling lists for clients. Defaults to None. + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(-1) @@ -126,6 +165,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for asynchronous federated learning. + + Args: + args (object): An object containing the configuration parameters. + device (object): The device to run the training on (e.g., GPU). + comm (object): The communication object. + process_id (int): The unique ID of the current process. + size (int): The total number of processes. + model (object): The federated learning model. + train_data_num (int): The number of training data samples. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for clients. + train_data_local_dict (dict): A dictionary containing local training data for clients. + test_data_local_dict (dict): A dictionary containing local test data for clients. + model_trainer (object, optional): The model trainer object. Defaults to None. + + Returns: + None + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py index 2a456df3e5..d7f89afdab 100644 --- a/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py +++ b/python/fedml/simulation/mpi/async_fedavg/AsyncFedAvgServerManager.py @@ -8,6 +8,49 @@ class AsyncFedAVGServerManager(FedMLCommManager): + """ + Manager for the asynchronous Federated Averaging server in a federated learning system. + + Args: + args (argparse.Namespace): Command-line arguments and configurations for the server. + aggregator: An instance of the aggregator responsible for aggregating client updates. + comm: The communication object for inter-process communication. + rank (int): The rank of the server process. + size (int): The total number of processes. + backend (str): The communication backend (e.g., "MPI"). + is_preprocessed (bool): Indicates whether the data is preprocessed. + preprocessed_client_lists (list): A list of preprocessed client data. + + Attributes: + args (argparse.Namespace): Command-line arguments and configurations for the server. + aggregator: An instance of the aggregator responsible for aggregating client updates. + round_num (int): The total number of communication rounds. + round_idx (int): The current round index. + is_preprocessed (bool): Indicates whether the data is preprocessed. + preprocessed_client_lists (list): A list of preprocessed client data. + client_round_dict (dict): A dictionary to track the round index for each client. + + Methods: + run(): + Start the server and begin the federated learning process. + + send_init_msg(): + Send initialization messages to client processes to start communication. + + register_message_receive_handlers(): + Register message handlers for receiving client updates. + + handle_message_receive_model_from_client(msg_params): + Handle the received client update message, record client runtime information, + aggregate the updates, and perform testing. + + send_message_init_config(receive_id, global_model_params, client_index): + Send initialization configuration messages to clients. + + send_message_sync_model_to_client(receive_id, global_model_params, client_index): + Send synchronized model updates to clients. + + """ def __init__( self, args, @@ -32,10 +75,22 @@ def __init__( def run(self): + """ + Start the server and begin the federated learning process. + + Returns: + None + """ super().run() def send_init_msg(self): + """ + Send initialization messages to client processes to start communication. + + Returns: + None + """ # sampling clients # client_indexes = self.aggregator.client_sampling( # self.round_idx, @@ -54,12 +109,28 @@ def send_init_msg(self): def register_message_receive_handlers(self): + """ + Register message handlers for receiving client updates. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received client update message, record client runtime information, + aggregate the updates, and perform testing. + + Args: + msg_params (dict): Message parameters containing client update information. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -107,6 +178,17 @@ def handle_message_receive_model_from_client(self, msg_params): def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initialization configuration messages to clients. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters to be sent to clients. + client_index (list): List of client indexes for the current communication round. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -117,6 +199,17 @@ def send_message_init_config(self, receive_id, global_model_params, def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index): + """ + Send synchronized model updates to clients. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters to be sent to clients. + client_index (list): List of client indexes for the current communication round. + + Returns: + None + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py b/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py index 008582c6d3..b0cd08e362 100644 --- a/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py +++ b/python/fedml/simulation/mpi/async_fedavg/MyModelTrainer.py @@ -7,13 +7,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -50,6 +87,17 @@ def train(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.eval() @@ -91,4 +139,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py index 7ed01c3703..5941065ed7 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer.py @@ -6,13 +6,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -61,6 +98,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -115,4 +163,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py index 8aa72effea..ec9aeca3c8 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_classification.py @@ -6,13 +6,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -65,6 +102,17 @@ def train(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -92,4 +140,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py index 08d9b13f65..faf2f1690f 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_nwp.py @@ -5,13 +5,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -56,6 +93,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -83,4 +131,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py index 50539b3ea5..08fd44d0ce 100644 --- a/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py +++ b/python/fedml/simulation/mpi/async_fedavg/my_model_trainer_tag_prediction.py @@ -5,13 +5,50 @@ class MyModelTrainer(ClientTrainer): + """ + Custom client model trainer for federated learning. + + Args: + None + + Methods: + get_model_params(): Get the model parameters. + set_model_params(model_parameters): Set the model parameters. + train(train_data, device, args): Train the model on the client. + test(test_data, device, args): Test the model on the client. + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Test the model on the server (not implemented in this class). + """ def get_model_params(self): + """ + Get the model parameters. + + Returns: + dict: The model parameters as a dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (dict): A dictionary containing the model parameters to set. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the client. + + Args: + train_data (torch.utils.data.DataLoader): DataLoader containing training data. + device (torch.device): The device (CPU or GPU) to train on. + args: Additional training arguments. + + Returns: + None + """ model = self.model model.to(device) @@ -56,6 +93,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Test the model on the client. + + Args: + test_data (torch.utils.data.DataLoader): DataLoader containing test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -100,4 +148,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Test the model on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): Dictionary containing local training data. + test_data_local_dict (dict): Dictionary containing local test data. + device (torch.device): The device (CPU or GPU) to test on. + args: Additional testing arguments. + + Returns: + bool: Always returns False in this implementation. + """ return False diff --git a/python/fedml/simulation/mpi/async_fedavg/utils.py b/python/fedml/simulation/mpi/async_fedavg/utils.py index aea2449590..63aa625d5f 100644 --- a/python/fedml/simulation/mpi/async_fedavg/utils.py +++ b/python/fedml/simulation/mpi/async_fedavg/utils.py @@ -5,6 +5,16 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary of model parameters with keys as parameter names + and values as NumPy arrays. + + Returns: + dict: A dictionary of model parameters with the same keys, but values as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -13,12 +23,31 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to NumPy arrays. + + Args: + model_params (dict): A dictionary of model parameters with keys as parameter names + and values as PyTorch tensors. + + Returns: + dict: A dictionary of model parameters with the same keys, but values as NumPy arrays. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a sweep process using a named pipe. + + Args: + args: Additional arguments or information to include in the completion message. + + Returns: + None + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From c755ec9729e6501ce3aea2b7d68ca1dbccfd21c6 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 8 Sep 2023 12:16:34 +0530 Subject: [PATCH 47/70] python\fedml\simulation\mpi base framework classical_vertical_fl decentralized framework fedavg --- .../mpi/base_framework/algorithm_api.py | 44 +++++++ .../mpi/base_framework/central_manager.py | 60 ++++++++- .../mpi/base_framework/central_worker.py | 51 ++++++++ .../mpi/base_framework/client_manager.py | 93 ++++++++++++++ .../mpi/base_framework/client_worker.py | 44 ++++++- .../mpi/base_framework/message_define.py | 2 +- .../classical_vertical_fl/guest_manager.py | 80 ++++++++++++ .../classical_vertical_fl/guest_trainer.py | 120 ++++++++++++++++++ .../mpi/classical_vertical_fl/host_manager.py | 34 ++++- .../mpi/classical_vertical_fl/host_trainer.py | 54 +++++++- .../decentralized_framework/algorithm_api.py | 18 ++- .../decentralized_worker.py | 29 +++++ .../decentralized_worker_manager.py | 57 ++++++++- .../simulation/mpi/fedavg/FedAVGAggregator.py | 85 ++++++++++++- .../simulation/mpi/fedavg/FedAVGTrainer.py | 66 ++++++++++ .../fedml/simulation/mpi/fedavg/FedAvgAPI.py | 49 +++++++ .../mpi/fedavg/FedAvgClientManager.py | 49 ++++++- .../mpi/fedavg/FedAvgServerManager.py | 87 ++++++++++++- python/fedml/simulation/mpi/fedavg/utils.py | 30 ++++- 19 files changed, 1023 insertions(+), 29 deletions(-) diff --git a/python/fedml/simulation/mpi/base_framework/algorithm_api.py b/python/fedml/simulation/mpi/base_framework/algorithm_api.py index 5df90e0187..65eb2023dc 100644 --- a/python/fedml/simulation/mpi/base_framework/algorithm_api.py +++ b/python/fedml/simulation/mpi/base_framework/algorithm_api.py @@ -7,6 +7,14 @@ def FedML_init(): + """ + Initialize the MPI communication and retrieve process information. + + Returns: + comm (object): MPI communication object. + process_id (int): Unique ID of the current process. + worker_number (int): Total number of worker processes. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -14,6 +22,18 @@ def FedML_init(): def FedML_Base_distributed(args, process_id, worker_number, comm): + """ + Run the base distributed federated learning process. + + Args: + args (object): An object containing the configuration parameters. + process_id (int): Unique ID of the current process. + worker_number (int): Total number of worker processes. + comm (object): MPI communication object. + + Returns: + None + """ if process_id == 0: init_central_worker(args, comm, process_id, worker_number) else: @@ -21,6 +41,18 @@ def FedML_Base_distributed(args, process_id, worker_number, comm): def init_central_worker(args, comm, process_id, size): + """ + Initialize the central worker for distributed federated learning. + + Args: + args (object): An object containing the configuration parameters. + comm (object): MPI communication object. + process_id (int): Unique ID of the current process. + size (int): Total number of processes. + + Returns: + None + """ # aggregator client_num = size - 1 aggregator = BaseCentralWorker(client_num, args) @@ -31,6 +63,18 @@ def init_central_worker(args, comm, process_id, size): def init_client_worker(args, comm, process_id, size): + """ + Initialize a client worker for distributed federated learning. + + Args: + args (object): An object containing the configuration parameters. + comm (object): MPI communication object. + process_id (int): Unique ID of the current process. + size (int): Total number of processes. + + Returns: + None + """ # trainer client_ID = process_id - 1 trainer = BaseClientWorker(client_ID) diff --git a/python/fedml/simulation/mpi/base_framework/central_manager.py b/python/fedml/simulation/mpi/base_framework/central_manager.py index dc192a187c..4ccaaaafe1 100644 --- a/python/fedml/simulation/mpi/base_framework/central_manager.py +++ b/python/fedml/simulation/mpi/base_framework/central_manager.py @@ -7,6 +7,19 @@ class BaseCentralManager(FedMLCommManager): def __init__(self, args, comm, rank, size, aggregator): + """ + Initialize the BaseCentralManager. + + Args: + args (object): An object containing configuration parameters. + comm (object): MPI communication object. + rank (int): The rank of the current process. + size (int): The total number of processes. + aggregator (object): The aggregator for aggregating results. + + Returns: + None + """ super().__init__(args, comm, rank, size) self.aggregator = aggregator @@ -14,17 +27,40 @@ def __init__(self, args, comm, rank, size, aggregator): self.args.round_idx = 0 def run(self): + """ + Run the central manager. + + This method initiates the communication with client processes and aggregates their results. + + Returns: + None + """ for process_id in range(1, self.size): self.send_message_init_config(process_id) super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for the central manager. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_INFORMATION, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle messages received from client processes. + + Args: + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) client_local_result = msg_params.get(MyMessage.MSG_ARG_KEY_INFORMATION) @@ -34,11 +70,12 @@ def handle_message_receive_model_from_client(self, msg_params): logging.info("b_all_received = " + str(b_all_received)) if b_all_received: logging.info( - "**********************************ROUND INDEX = " + str(self.args.round_idx) + "**********************************ROUND INDEX = " + + str(self.args.round_idx) ) global_result = self.aggregator.aggregate() - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: self.finish() @@ -48,12 +85,31 @@ def handle_message_receive_model_from_client(self, msg_params): self.send_message_to_client(receiver_id, global_result) def send_message_init_config(self, receive_id): + """ + Send initialization configuration message to a client process. + + Args: + receive_id (int): The ID of the receiving client process. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) self.send_message(message) def send_message_to_client(self, receive_id, global_result): + """ + Send a message to a client process containing global results. + + Args: + receive_id (int): The ID of the receiving client process. + global_result (object): The global result to be sent. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INFORMATION, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/base_framework/central_worker.py b/python/fedml/simulation/mpi/base_framework/central_worker.py index c9a862251b..982d90932a 100644 --- a/python/fedml/simulation/mpi/base_framework/central_worker.py +++ b/python/fedml/simulation/mpi/base_framework/central_worker.py @@ -2,7 +2,36 @@ class BaseCentralWorker(object): + """ + Base class representing a central worker in a distributed system. + + This class is responsible for managing client local results and aggregating them. + + Attributes: + client_num (int): The number of client processes. + args (object): An object containing configuration parameters. + client_local_result_list (dict): A dictionary to store client local results. + flag_client_model_uploaded_dict (dict): A dictionary to track whether each client has uploaded results. + + Methods: + add_client_local_result(index, client_local_result): + Add client's local result to the worker. + check_whether_all_receive(): + Check if all clients have uploaded their local results. + aggregate(): + Aggregate client local results. + """ def __init__(self, client_num, args): + """ + Initialize the BaseCentralWorker. + + Args: + client_num (int): The number of client processes. + args (object): An object containing configuration parameters. + + Returns: + None + """ self.client_num = client_num self.args = args @@ -13,11 +42,27 @@ def __init__(self, client_num, args): self.flag_client_model_uploaded_dict[idx] = False def add_client_local_result(self, index, client_local_result): + """ + Add client's local result to the worker. + + Args: + index (int): The index of the client. + client_local_result (object): The local result from the client. + + Returns: + None + """ logging.info("add_client_local_result. index = %d" % index) self.client_local_result_list[index] = client_local_result self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local results. + + Returns: + bool: True if all clients have uploaded their results, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -26,6 +71,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate client local results. + + Returns: + object: The aggregated global result. + """ global_result = 0 for k in self.client_local_result_list.keys(): global_result += self.client_local_result_list[k] diff --git a/python/fedml/simulation/mpi/base_framework/client_manager.py b/python/fedml/simulation/mpi/base_framework/client_manager.py index 00e5fc12ca..2ff3436cdb 100644 --- a/python/fedml/simulation/mpi/base_framework/client_manager.py +++ b/python/fedml/simulation/mpi/base_framework/client_manager.py @@ -4,16 +4,72 @@ class BaseClientManager(FedMLCommManager): + """ + Base class representing a client manager in a distributed system. + + This class handles the communication between clients and the central server. + + Attributes: + args (object): An object containing configuration parameters. + comm (object): A communication object for MPI communication. + rank (int): The rank of the current process. + size (int): The total number of processes. + trainer (object): An object responsible for client-side training. + num_rounds (int): The total number of communication rounds. + + Methods: + run(): + Start the client manager. + handle_message_init(msg_params): + Handle initialization message from the server. + handle_message_receive_model_from_server(msg_params): + Handle receiving model update from the server. + send_model_to_server(receive_id, client_gradient): + Send client-side model updates to the server. + __train(): + Perform training and send updates to the server. + """ def __init__(self, args, comm, rank, size, trainer): + """ + Initialize the BaseClientManager. + + Args: + args (object): An object containing configuration parameters. + comm (object): A communication object for MPI communication. + rank (int): The rank of the current process. + size (int): The total number of processes. + trainer (object): An object responsible for client-side training. + + Returns: + None + """ super().__init__(args, comm, rank, size) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 def run(self): + """ + Start the client manager. + + Args: + None + + Returns: + None + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + + Args: + None + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -23,11 +79,29 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle initialization message from the server. + + Args: + msg_params (dict): Parameters included in the message. + + Returns: + None + """ self.trainer.update(0) self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle receiving model update from the server. + + Args: + msg_params (dict): Parameters included in the message. + + Returns: + None + """ global_result = msg_params.get(MyMessage.MSG_ARG_KEY_INFORMATION) self.trainer.update(global_result) self.args.round_idx += 1 @@ -36,6 +110,16 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, client_gradient): + """ + Send client-side model updates to the server. + + Args: + receive_id (int): The ID of the recipient (server). + client_gradient (object): The client-side model update. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_INFORMATION, self.get_sender_id(), receive_id ) @@ -43,6 +127,15 @@ def send_model_to_server(self, receive_id, client_gradient): self.send_message(message) def __train(self): + """ + Perform training and send updates to the server. + + Args: + None + + Returns: + None + """ # do something here (e.g., training) training_interation_result = self.trainer.train() diff --git a/python/fedml/simulation/mpi/base_framework/client_worker.py b/python/fedml/simulation/mpi/base_framework/client_worker.py index bf607d8beb..f8673b7178 100644 --- a/python/fedml/simulation/mpi/base_framework/client_worker.py +++ b/python/fedml/simulation/mpi/base_framework/client_worker.py @@ -1,12 +1,54 @@ class BaseClientWorker(object): + """ + Base class representing a client worker in a distributed system. + + This class is responsible for client-side operations, such as training and updating information. + + Attributes: + client_index (int): The index of the client worker. + updated_information (int): Information that can be updated during training. + + Methods: + update(updated_information): + Update the information associated with the client. + train(): + Perform client-specific training or operation. + + """ + def __init__(self, client_index): + """ + Initialize the BaseClientWorker. + + Args: + client_index (int): The index of the client worker. + + Returns: + None + """ self.client_index = client_index self.updated_information = 0 def update(self, updated_information): + """ + Update the information associated with the client. + + Args: + updated_information (int): The new information to be associated with the client. + + Returns: + None + """ self.updated_information = updated_information print(self.updated_information) def train(self): - # complete your own algorithm operation here, as am example, we return the client_index + """ + Perform client-specific training or operation. + + Returns: + int: An example result (client_index in this case). + """ + # Complete your own algorithm operation here. + # As an example, we return the client_index. return self.client_index diff --git a/python/fedml/simulation/mpi/base_framework/message_define.py b/python/fedml/simulation/mpi/base_framework/message_define.py index 27ba9f14d4..b8b52f79d5 100644 --- a/python/fedml/simulation/mpi/base_framework/message_define.py +++ b/python/fedml/simulation/mpi/base_framework/message_define.py @@ -15,6 +15,6 @@ class MyMessage(object): MSG_ARG_KEY_RECEIVER = "receiver" """ - message payload keywords definition + message payload keywords definition """ MSG_ARG_KEY_INFORMATION = "information" diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py b/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py index fbcfafef43..b3f1c03469 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/guest_manager.py @@ -4,7 +4,47 @@ class GuestManager(FedMLCommManager): + """ + Class representing the manager for a guest in a distributed system. + + This class is responsible for handling communication between the guest and other participants, + as well as coordinating training rounds. + + Attributes: + args: Arguments for the manager. + comm: The communication interface. + rank: The rank of the guest in the communication group. + size: The total number of participants in the communication group. + guest_trainer: The trainer responsible for guest-specific training. + + Methods: + run(): + Start the guest manager and run communication. + register_message_receive_handlers(): + Register message receive handlers for handling incoming messages. + handle_message_receive_logits_from_client(msg_params): + Handle the reception of logits and trigger training when all data is received. + send_message_init_config(receive_id): + Send an initialization message to a client. + send_message_to_client(receive_id, global_result): + Send a message containing global training results to a client. + + """ + def __init__(self, args, comm, rank, size, guest_trainer): + """ + Initialize the GuestManager. + + Args: + args: Arguments for the manager. + comm: The communication interface. + rank: The rank of the guest in the communication group. + size: The total number of participants in the communication group. + guest_trainer: The trainer responsible for guest-specific training. + + Returns: + None + """ super().__init__(args, comm, rank, size) self.guest_trainer = guest_trainer @@ -12,17 +52,38 @@ def __init__(self, args, comm, rank, size, guest_trainer): self.args.round_idx = 0 def run(self): + """ + Start the guest manager and run communication. + + Returns: + None + """ for process_id in range(1, self.size): self.send_message_init_config(process_id) super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_LOGITS, self.handle_message_receive_logits_from_client, ) def handle_message_receive_logits_from_client(self, msg_params): + """ + Handle the reception of logits and trigger training when all data is received. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) host_train_logits = msg_params.get(MyMessage.MSG_ARG_KEY_TRAIN_LOGITS) host_test_logits = msg_params.get(MyMessage.MSG_ARG_KEY_TEST_LOGITS) @@ -44,12 +105,31 @@ def handle_message_receive_logits_from_client(self, msg_params): self.finish() def send_message_init_config(self, receive_id): + """ + Send an initialization message to a client. + + Args: + receive_id: The ID of the client to receive the message. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) self.send_message(message) def send_message_to_client(self, receive_id, global_result): + """ + Send a message containing global training results to a client. + + Args: + receive_id: The ID of the client to receive the message. + global_result: The global training result to send. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_GRADIENT, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py b/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py index 9be9ee9bc2..a8610b46ce 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/guest_trainer.py @@ -8,6 +8,43 @@ class GuestTrainer(object): + """ + Class representing the trainer for a guest in a distributed system. + + This class handles training and gradient aggregation for the guest. + + Attributes: + client_num: The number of clients in the system. + device: The device (e.g., CPU or GPU) used for training. + X_train: The training data features. + y_train: The training data labels. + X_test: The test data features. + y_test: The test data labels. + model_feature_extractor: The feature extractor model. + model_classifier: The classifier model. + args: Arguments for the trainer. + + Methods: + get_batch_num(): + Get the number of batches for training. + add_client_local_result(index, host_train_logits, host_test_logits): + Add client local results to the trainer. + check_whether_all_receive(): + Check if all client local results have been received. + train(round_idx): + Perform training for a round and return gradients to hosts. + _bp_classifier(x, grads): + Backpropagate gradients through the classifier. + _bp_feature_extractor(x, grads): + Backpropagate gradients through the feature extractor. + _test(round_idx): + Perform testing and calculate evaluation metrics. + _sigmoid(x): + Compute the sigmoid function. + _compute_correct_prediction(y_targets, y_prob_preds, threshold): + Compute correct predictions and evaluation statistics. + + """ def __init__( self, client_num, @@ -72,15 +109,38 @@ def __init__( self.loss_list = list() def get_batch_num(self): + """ + Get the number of batches for training. + + Returns: + int: The number of batches. + """ return self.n_batches def add_client_local_result(self, index, host_train_logits, host_test_logits): + """ + Add client local results to the trainer. + + Args: + index: The index of the client. + host_train_logits: Logits from the client's local training data. + host_test_logits: Logits from the client's local test data. + + Returns: + None + """ # logging.info("add_client_local_result. index = %d" % index) self.host_local_train_logits_list[index] = host_train_logits self.host_local_test_logits_list[index] = host_test_logits self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all client local results have been received. + + Returns: + bool: True if all results have been received, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -89,6 +149,15 @@ def check_whether_all_receive(self): return True def train(self, round_idx): + """ + Perform training for a round and return gradients to hosts. + + Args: + round_idx: The index of the training round. + + Returns: + ndarray: Gradients to hosts. + """ batch_x = self.X_train[ self.batch_idx * self.batch_size : self.batch_idx * self.batch_size + self.batch_size @@ -137,6 +206,17 @@ def train(self, round_idx): return gradients_to_hosts def _bp_classifier(self, x, grads): + """ + Backpropagate gradients through the classifier. + + Args: + x: Input data. + grads: Gradients to be backpropagated. + + Returns: + ndarray: Gradients of the input data. + """ + x = x.clone().detach().requires_grad_(True) output = self.model_classifier(x) output.backward(gradient=grads) @@ -146,12 +226,31 @@ def _bp_classifier(self, x, grads): return x_grad def _bp_feature_extractor(self, x, grads): + """ + Backpropagate gradients through the feature extractor. + + Args: + x: Input data. + grads: Gradients to be backpropagated. + + Returns: + None + """ output = self.model_feature_extractor(x) output.backward(gradient=grads) self.optimizer_fe.step() self.optimizer_fe.zero_grad() def _test(self, round_idx): + """ + Perform testing and calculate evaluation metrics. + + Args: + round_idx: The index of the training round. + + Returns: + None + """ X_test = torch.tensor(self.X_test).float().to(self.device) y_test = self.y_test @@ -183,9 +282,30 @@ def _test(self, round_idx): ) def _sigmoid(self, x): + """ + Compute the sigmoid function. + + Args: + x: Input data. + + Returns: + ndarray: Sigmoid values. + """ return 1.0 / (1.0 + np.exp(-x)) def _compute_correct_prediction(self, y_targets, y_prob_preds, threshold=0.5): + """ + Compute correct predictions and evaluation statistics. + + Args: + y_targets: True labels. + y_prob_preds: Predicted probabilities. + threshold: Threshold for classification. + + Returns: + ndarray: Predicted labels. + list: Statistics (positive predictions, negative predictions, correct predictions). + """ y_hat_lbls = [] pred_pos_count = 0 pred_neg_count = 0 diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py b/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py index 82d31aa70d..cb37de2a54 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/host_manager.py @@ -2,18 +2,29 @@ from ....core.distributed.fedml_comm_manager import FedMLCommManager from ....core.distributed.communication.message import Message - class HostManager(FedMLCommManager): def __init__(self, args, comm, rank, size, trainer): + """ + Initialize a HostManager instance. + + Args: + args: Configuration arguments. + comm: MPI communication object. + rank: Rank of the process. + size: Number of processes in the communicator. + trainer: Trainer for host-specific tasks. + """ super().__init__(args, comm, rank, size) self.trainer = trainer self.num_rounds = args.comm_round self.round_idx = 0 def run(self): + """Start the HostManager.""" super().run() def register_message_receive_handlers(self): + """Register message receive handlers.""" self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -23,10 +34,22 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params: Parameters from the initialization message. + """ self.round_idx = 0 self.__train() def handle_message_receive_gradient_from_server(self, msg_params): + """ + Handle the gradient message received from the server. + + Args: + msg_params: Parameters from the gradient message. + """ gradient = msg_params.get(MyMessage.MSG_ARG_KEY_GRADIENT) self.trainer.update_model(gradient) self.round_idx += 1 @@ -35,6 +58,14 @@ def handle_message_receive_gradient_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, host_train_logits, host_test_logits): + """ + Send host training and test logits to the server. + + Args: + receive_id: ID of the receiver. + host_train_logits: Host's training logits. + host_test_logits: Host's test logits. + """ message = Message( MyMessage.MSG_TYPE_C2S_LOGITS, self.get_sender_id(), receive_id ) @@ -43,6 +74,7 @@ def send_model_to_server(self, receive_id, host_train_logits, host_test_logits): self.send_message(message) def __train(self): + """Perform host training and send logits to the server.""" host_train_logits, host_test_logits = self.trainer.computer_logits( self.round_idx ) diff --git a/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py b/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py index 06af3bcd27..6d18e386e1 100644 --- a/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py +++ b/python/fedml/simulation/mpi/classical_vertical_fl/host_trainer.py @@ -3,6 +3,20 @@ class HostTrainer(object): + """ + Trainer for host-specific tasks in a federated learning environment. + + This class manages the training and gradient update process for host-specific tasks in a federated learning system. + + Args: + client_index: Index of the host client. + device: Computing device (e.g., CPU or GPU) to perform training. + X_train: Training data for the host. + X_test: Test data for the host. + model_feature_extractor: Feature extractor model. + model_classifier: Classifier model. + args: Configuration arguments. + """ def __init__( self, client_index, @@ -13,6 +27,9 @@ def __init__( model_classifier, args, ): + """ + Initialize a HostTrainer instance. + """ # device information self.client_index = client_index self.device = device @@ -55,9 +72,19 @@ def __init__( self.cached_extracted_features = None def get_batch_num(self): + """Get the number of training batches.""" return self.n_batches def computer_logits(self, round_idx): + """ + Compute logits for host-specific tasks. + + Args: + round_idx: Current round index. + + Returns: + tuple: A tuple containing host training logits and host test logits. + """ batch_x = self.X_train[ self.batch_idx * self.batch_size : self.batch_idx * self.batch_size + self.batch_size @@ -65,13 +92,12 @@ def computer_logits(self, round_idx): self.batch_x = torch.tensor(batch_x).float().to(self.device) self.extracted_feature = self.model_feature_extractor.forward(self.batch_x) logits = self.model_classifier.forward(self.extracted_feature) - # copy to CPU host memory logits_train = logits.cpu().detach().numpy() self.batch_idx += 1 if self.batch_idx == self.n_batches: self.batch_idx = 0 - # for test + # For test if (round_idx + 1) % self.args.frequency_of_the_test == 0: X_test = torch.tensor(self.X_test).float().to(self.device) extracted_feature = self.model_feature_extractor.forward(X_test) @@ -83,12 +109,27 @@ def computer_logits(self, round_idx): return logits_train, logits_test def update_model(self, gradient): - # logging.info("#######################gradient = " + str(gradient)) + """ + Update the model using the received gradient. + + Args: + gradient: Gradient received from the server. + """ gradient = torch.tensor(gradient).float().to(self.device) back_grad = self._bp_classifier(self.extracted_feature, gradient) self._bp_feature_extractor(self.batch_x, back_grad) def _bp_classifier(self, x, grads): + """ + Backpropagate gradients through the classifier model. + + Args: + x: Input data. + grads: Gradients to backpropagate. + + Returns: + x_grad: Gradients of the input data. + """ x = x.clone().detach().requires_grad_(True) output = self.model_classifier(x) output.backward(gradient=grads) @@ -98,6 +139,13 @@ def _bp_classifier(self, x, grads): return x_grad def _bp_feature_extractor(self, x, grads): + """ + Backpropagate gradients through the feature extractor model. + + Args: + x: Input data. + grads: Gradients to backpropagate. + """ output = self.model_feature_extractor(x) output.backward(gradient=grads) self.optimizer_fe.step() diff --git a/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py b/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py index c72161fad4..c4b4c7bd62 100644 --- a/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py +++ b/python/fedml/simulation/mpi/decentralized_framework/algorithm_api.py @@ -4,16 +4,28 @@ from .decentralized_worker_manager import DecentralizedWorkerManager from ....core.distributed.topology.symmetric_topology_manager import SymmetricTopologyManager - def FedML_Decentralized_Demo_distributed(args, process_id, worker_number, comm): - # initialize the topology (ring) + """ + Run the decentralized federated learning demo on a distributed system. + + This function initializes the topology (ring) for decentralized federated learning, + initializes the decentralized worker (trainer), and runs the decentralized worker manager. + + Args: + args: Configuration arguments. + process_id: The unique ID of the current process. + worker_number: The total number of workers in the distributed system. + comm: MPI communication object for distributed communication. + """ + # Initialize the topology (ring) tpmgr = SymmetricTopologyManager(worker_number, 2) tpmgr.generate_topology() logging.info(tpmgr.topology) - # initialize the decentralized trainer (worker) + # Initialize the decentralized trainer (worker) worker_index = process_id trainer = DecentralizedWorker(worker_index, tpmgr) + # Initialize the decentralized worker manager client_manager = DecentralizedWorkerManager(args, comm, process_id, worker_number, trainer, tpmgr) client_manager.run() diff --git a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py index 03d64e0525..e09994ee78 100644 --- a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py +++ b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker.py @@ -1,5 +1,15 @@ class DecentralizedWorker(object): + """ + Represents a decentralized federated learning worker. + """ def __init__(self, worker_index, topology_manager): + """ + Represents a decentralized federated learning worker. + + Args: + worker_index: The index or ID of the worker. + topology_manager: The topology manager for communication with neighboring workers. + """ self.worker_index = worker_index self.in_neighbor_idx_list = topology_manager.get_in_neighbor_idx_list( self.worker_index @@ -11,10 +21,23 @@ def __init__(self, worker_index, topology_manager): self.flag_neighbor_result_received_dict[neighbor_idx] = False def add_result(self, worker_index, updated_information): + """ + Add the result received from a neighboring worker. + + Args: + worker_index: The index or ID of the neighboring worker. + updated_information: The updated information received from the neighboring worker. + """ self.worker_result_dict[worker_index] = updated_information self.flag_neighbor_result_received_dict[worker_index] = True def check_whether_all_receive(self): + """ + Check if results have been received from all neighboring workers. + + Returns: + bool: True if results have been received from all neighbors, False otherwise. + """ for neighbor_idx in self.in_neighbor_idx_list: if not self.flag_neighbor_result_received_dict[neighbor_idx]: return False @@ -23,5 +46,11 @@ def check_whether_all_receive(self): return True def train(self): + """ + Perform the training process for the decentralized worker. + + Returns: + int: A placeholder value (0 in this case) representing the result of the training iteration. + """ self.add_result(self.worker_index, 0) return 0 diff --git a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py index a782a5ea3b..97bac8ca67 100644 --- a/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py +++ b/python/fedml/simulation/mpi/decentralized_framework/decentralized_worker_manager.py @@ -4,9 +4,22 @@ from ....core.distributed.communication.message import Message from ....core.distributed.fedml_comm_manager import FedMLCommManager - class DecentralizedWorkerManager(FedMLCommManager): + """ + Class representing a decentralized federated learning worker in a distributed system. + """ def __init__(self, args, comm, rank, size, trainer, topology_manager): + """ + Manages decentralized federated learning workers in a distributed system. + + Args: + args: Configuration arguments. + comm: MPI communication object for distributed communication. + rank: The rank (ID) of the current worker. + size: The total number of workers in the distributed system. + trainer: The decentralized worker/trainer. + topology_manager: The topology manager for communication between workers. + """ super().__init__(args, comm, rank, size) self.worker_index = rank self.trainer = trainer @@ -15,21 +28,36 @@ def __init__(self, args, comm, rank, size, trainer, topology_manager): self.round_idx = 0 def run(self): + """ + Start the training process for decentralized federated learning workers. + """ self.start_training() super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + """ self.register_message_receive_handler(MyMessage.MSG_TYPE_SEND_MSG_TO_NEIGHBOR, self.handle_msg_from_neighbor) def start_training(self): + """ + Initialize and start the training process. + """ self.round_idx = 0 self.__train() def handle_msg_from_neighbor(self, msg_params): + """ + Handle messages received from neighboring workers. + + Args: + msg_params: Parameters included in the received message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - training_interation_result = msg_params.get(MyMessage.MSG_ARG_KEY_PARAMS_1) + training_iteration_result = msg_params.get(MyMessage.MSG_ARG_KEY_PARAMS_1) logging.info("handle_msg_from_neighbor. sender_id = " + str(sender_id)) - self.trainer.add_result(sender_id, training_interation_result) + self.trainer.add_result(sender_id, training_iteration_result) if self.trainer.check_whether_all_receive(): logging.info(">>>>>>>>>>>>>>>WORKER %d, ROUND %d finished!<<<<<<<<" % (self.worker_index, self.round_idx)) self.round_idx += 1 @@ -38,17 +66,34 @@ def handle_msg_from_neighbor(self, msg_params): self.__train() def __train(self): - # do something here (e.g., training) - training_interation_result = self.trainer.train() + """ + Perform the training process and communicate with neighboring workers. + """ + # Perform the training process here (e.g., training iteration) + training_iteration_result = self.trainer.train() + # Send the training iteration result to neighboring workers for neighbor_idx in self.topology_manager.get_out_neighbor_idx_list(self.worker_index): - self.send_result_to_neighbors(neighbor_idx, training_interation_result) + self.send_result_to_neighbors(neighbor_idx, training_iteration_result) def send_message_init_config(self, receive_id): + """ + Send an initialization message to a specified worker. + + Args: + receive_id: The ID of the receiving worker. + """ message = Message(MyMessage.MSG_TYPE_INIT, self.get_sender_id(), receive_id) self.send_message(message) def send_result_to_neighbors(self, receive_id, client_params1): + """ + Send training iteration results to neighboring workers. + + Args: + receive_id: The ID of the receiving worker. + client_params1: Parameters to be sent in the message. + """ logging.info("send_result_to_neighbors. receive_id = " + str(receive_id)) message = Message(MyMessage.MSG_TYPE_SEND_MSG_TO_NEIGHBOR, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_PARAMS_1, client_params1) diff --git a/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py b/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py index 0c99809001..893b2c4b03 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py +++ b/python/fedml/simulation/mpi/fedavg/FedAVGAggregator.py @@ -13,6 +13,21 @@ from ....core.security.fedml_defender import FedMLDefender class FedAVGAggregator(object): + """ + Represents a Federated Averaging (FedAVG) aggregator for federated learning. + + Args: + train_global: The global training dataset. + test_global: The global testing dataset. + all_train_data_num: The total number of training data samples. + train_data_local_dict: A dictionary mapping worker indices to their local training datasets. + test_data_local_dict: A dictionary mapping worker indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping worker indices to the number of local training samples. + worker_num: The number of worker nodes participating in the federated learning. + device: The device (e.g., 'cuda' or 'cpu') used for computations. + args: Additional configuration arguments. + server_aggregator: The server-side aggregator used for communication with workers. + """ def __init__( self, train_global, @@ -47,18 +62,44 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters from the aggregator. + + Returns: + dict: The global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters in the aggregator. + + Args: + model_parameters (dict): The global model parameters to set. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the locally trained model result from a worker. + + Args: + index: The index or ID of the worker. + model_params (dict): The model parameters trained by the worker. + sample_num (int): The number of training samples used by the worker. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if model results have been received from all workers. + + Returns: + bool: True if results have been received from all workers, False otherwise. + """ logging.debug("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -68,6 +109,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate the model updates from worker nodes using Federated Averaging (FedAVG). + + Returns: + dict: The averaged model parameters. + """ start_time = time.time() model_list = [] @@ -97,6 +144,15 @@ def aggregate(self): return averaged_params def _fedavg_aggregation_(self, model_list): + """ + Perform the FedAVG aggregation on a list of local model updates. + + Args: + model_list (list): A list of tuples containing local sample numbers and model parameters. + + Returns: + dict: The aggregated model parameters. + """ training_num = 0 for i in range(0, len(model_list)): local_sample_number, local_model_params = model_list[i] @@ -116,6 +172,17 @@ def _fedavg_aggregation_(self, model_list): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for a federated learning round. + + Args: + round_idx (int): The index of the current federated learning round. + client_num_in_total (int): The total number of clients available. + client_num_per_round (int): The number of clients to sample for the current round. + + Returns: + list: A list of client indexes selected for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -131,7 +198,17 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes + def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing purposes. + + Args: + num_samples (int): The number of samples to include in the validation set. + + Returns: + DataLoader: A DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -146,6 +223,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients in a federated learning round. + + Args: + round_idx (int): The index of the current federated learning round. + """ if self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -170,4 +253,4 @@ def test_on_server_for_all_clients(self, round_idx): metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) else: - mlops.log({"round_idx": round_idx}) \ No newline at end of file + mlops.log({"round_idx": round_idx}) diff --git a/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py b/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py index 6b1e271d09..d0cfbd3d3a 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py +++ b/python/fedml/simulation/mpi/fedavg/FedAVGTrainer.py @@ -2,6 +2,42 @@ class FedAVGTrainer(object): + """ + A class that handles training and testing on a local client in the FedAVG framework. + + This class is responsible for training and testing a local model using client-specific data in a federated learning setting. + + Args: + client_index: The index or ID of the client. + train_data_local_dict: A dictionary containing local training data. + train_data_local_num_dict: A dictionary containing the number of training samples for each client. + test_data_local_dict: A dictionary containing local testing data. + train_data_num: The total number of training samples. + device: The computing device (e.g., "cuda" or "cpu") to perform training and testing. + args: An object containing configuration parameters. + model_trainer: A model trainer object responsible for training and testing. + + Attributes: + trainer: A model trainer object responsible for training and testing. + client_index: The index or ID of the client. + train_data_local_dict: A dictionary containing local training data. + train_data_local_num_dict: A dictionary containing the number of training samples for each client. + test_data_local_dict: A dictionary containing local testing data. + all_train_data_num: The total number of training samples. + train_local: Local training data for the current client. + local_sample_number: The number of training samples for the current client. + test_local: Local testing data for the current client. + device: The computing device (e.g., "cuda" or "cpu") to perform training and testing. + args: An object containing configuration parameters. + + Methods: + update_model(weights): Update the model with new weights. + update_dataset(client_index): Update the local datasets and client index. + train(round_idx=None): Train the local model using the current client's data. + test(): Test the local model on both training and testing data. + + """ + def __init__( self, client_index, @@ -28,9 +64,19 @@ def __init__( self.args = args def update_model(self, weights): + """Update the model with new weights. + + Args: + weights: The new model weights to set. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """Update the local datasets and client index. + + Args: + client_index: The index or ID of the client. + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -49,6 +95,15 @@ def update_dataset(self, client_index): self.test_local = None def train(self, round_idx=None): + """Train the local model using the current client's data. + + Args: + round_idx: The current communication round index (optional). + + Returns: + weights: The trained model weights. + local_sample_number: The number of training samples used for training. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) @@ -57,6 +112,17 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """Test the local model on both training and testing data. + + Returns: + A tuple containing the following metrics: + - train_tot_correct: The total number of correct predictions on the training data. + - train_loss: The loss on the training data. + - train_num_sample: The total number of training samples. + - test_tot_correct: The total number of correct predictions on the testing data. + - test_loss: The loss on the testing data. + - test_num_sample: The total number of testing samples. + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py b/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py index bc5069ab6d..d88ce8e3d1 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py +++ b/python/fedml/simulation/mpi/fedavg/FedAvgAPI.py @@ -21,6 +21,20 @@ def FedML_FedAvg_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Run Federated Averaging (FedAvg) in a distributed setting. + + Args: + args: The command-line arguments and configuration for the FedAvg process. + process_id (int): The unique identifier for the current process. + worker_number (int): The total number of worker processes. + comm: The communication backend for inter-process communication. + device: The target device (e.g., CPU or GPU) for training. + dataset: The dataset for training and testing. + model: The machine learning model to be trained. + client_trainer (ClientTrainer, optional): The client trainer responsible for local training. + server_aggregator (ServerAggregator, optional): The server aggregator for model aggregation. + """ [ train_data_num, test_data_num, @@ -83,6 +97,24 @@ def init_server( train_data_local_num_dict, server_aggregator ): + """ + Initialize the server for FedAvg. + + Args: + args: The command-line arguments and configuration for the FedAvg process. + device: The target device (e.g., CPU or GPU) for training. + comm: The communication backend for inter-process communication. + rank (int): The rank or identifier of the server process. + size (int): The total number of processes. + model: The machine learning model to be trained. + train_data_num (int): The number of training samples. + train_data_global: The global training dataset. + test_data_global: The global testing dataset. + train_data_local_dict: A dictionary mapping client IDs to their local training datasets. + test_data_local_dict: A dictionary mapping client IDs to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training samples. + server_aggregator: The server aggregator for model aggregation. + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -109,6 +141,7 @@ def init_server( server_manager.run() + def init_client( args, device, @@ -122,6 +155,22 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for FedAvg. + + Args: + args: The command-line arguments and configuration for the FedAvg process. + device: The target device (e.g., CPU or GPU) for training. + comm: The communication backend for inter-process communication. + process_id (int): The unique identifier for the client process. + size (int): The total number of processes. + model: The machine learning model to be trained. + train_data_num (int): The number of training samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training samples. + train_data_local_dict: A dictionary mapping client IDs to their local training datasets. + test_data_local_dict: A dictionary mapping client IDs to their local testing datasets. + model_trainer (ModelTrainer, optional): The model trainer responsible for local training. + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py b/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py index 2cc81658b9..8eac4ff5a3 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py +++ b/python/fedml/simulation/mpi/fedavg/FedAvgClientManager.py @@ -7,6 +7,9 @@ class FedAVGClientManager(FedMLCommManager): + """ + Class representing the client manager in the FedAVG federated learning process. + """ def __init__( self, args, @@ -16,16 +19,32 @@ def __init__( size=0, backend="MPI", ): + """ + Initialize the client manager for the FedAVG federated learning process. + + Args: + args (Namespace): Command-line arguments and configuration for the FedAVG process. + trainer: The federated learning trainer responsible for local training. + comm: The communication backend for inter-process communication. + rank (int): The rank or identifier of the current client. + size (int): The total number of clients. + backend (str): The backend for distributed computing (e.g., "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 - def run(self): + """ + Start the client manager to handle federated learning tasks. + """ super().run() def register_message_receive_handlers(self): + """ + Register message handlers for processing incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -35,6 +54,12 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params (dict): Parameters received in the message. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -44,10 +69,19 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the federated training process. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the model update message received from the server. + + Args: + msg_params (dict): Parameters received in the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -56,11 +90,19 @@ def handle_message_receive_model_from_server(self, msg_params): self.trainer.update_dataset(int(client_index)) self.args.round_idx += 1 self.__train() + if self.args.round_idx == self.num_rounds - 1: - # post_complete_message_to_sweep_process(self.args) self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the locally trained model to the server. + + Args: + receive_id (int): The ID of the server to receive the model. + weights: The model parameters. + local_sample_num (int): The number of local training samples. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -71,6 +113,9 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """ + Perform federated training for a round. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py b/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py index 631db27f08..e0cc5cb096 100644 --- a/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py +++ b/python/fedml/simulation/mpi/fedavg/FedAvgServerManager.py @@ -1,3 +1,4 @@ + import logging from .message_define import MyMessage @@ -7,6 +8,38 @@ class FedAVGServerManager(FedMLCommManager): + """ + A class that manages the server-side operations in a Federated Averaging (FedAVG) framework. + + This class handles the synchronization of model parameters and training progress across multiple clients + in a federated learning setting using the FedAVG algorithm. + + Args: + args: An object containing configuration parameters. + aggregator: An aggregator object responsible for aggregating client updates. + comm: A communication object for inter-process communication. + rank: The rank or ID of this process in the communication group. + size: The total number of processes in the communication group. + backend: The backend used for communication (e.g., "MPI" or "gloo"). + is_preprocessed: A flag indicating whether the client data is preprocessed. + preprocessed_client_lists: A list of preprocessed client data. + + Attributes: + args: An object containing configuration parameters. + aggregator: An aggregator object responsible for aggregating client updates. + round_num: The total number of communication rounds. + is_preprocessed: A flag indicating whether the client data is preprocessed. + preprocessed_client_lists: A list of preprocessed client data. + + Methods: + run(): Start the server manager and enter the main execution loop. + send_init_msg(): Send an initialization message to clients to start the federated learning process. + register_message_receive_handlers(): Register message handlers for message types. + handle_message_receive_model_from_client(msg_params): Handle a message received from a client containing model updates. + send_message_init_config(receive_id, global_model_params, client_index): Send an initialization message to a specific client. + send_message_sync_model_to_client(receive_id, global_model_params, client_index): Send a model synchronization message to a client. + """ + def __init__( self, args, @@ -18,6 +51,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the server manager for the FedAVG federated learning process. + + Args: + args (Namespace): Command-line arguments and configuration for the FedAVG process. + aggregator: The federated learning aggregator responsible for model aggregation. + comm: The communication backend for inter-process communication. + rank (int): The rank or identifier of the current server. + size (int): The total number of clients and servers. + backend (str): The backend for distributed computing (e.g., "MPI"). + is_preprocessed (bool): Whether client sampling has been preprocessed. + preprocessed_client_lists (list): Preprocessed client sampling lists. + """ super().__init__(args, comm, rank, size, backend) self.args = args self.aggregator = aggregator @@ -27,10 +73,15 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Start the server manager to handle federated learning tasks. + """ super().run() def send_init_msg(self): - # sampling clients + """ + Send initialization messages to clients, including global model parameters and client indexes. + """ client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -43,12 +94,21 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message handlers for processing incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the model update message received from a client. + + Args: + msg_params (dict): Parameters received in the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -62,20 +122,19 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: - # post_complete_message_to_sweep_process(self.args) self.finish() return if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -91,6 +150,14 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send an initialization message to a client. + + Args: + receive_id (int): The ID of the client to receive the message. + global_model_params: The global model parameters. + client_index: The index of the client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -101,6 +168,14 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send a model synchronization message to a client. + + Args: + receive_id (int): The ID of the client to receive the message. + global_model_params: The global model parameters. + client_index: The index of the client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, @@ -110,3 +185,5 @@ def send_message_sync_model_to_client( message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) self.send_message(message) + + \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fedavg/utils.py b/python/fedml/simulation/mpi/fedavg/utils.py index aea2449590..7d58689867 100644 --- a/python/fedml/simulation/mpi/fedavg/utils.py +++ b/python/fedml/simulation/mpi/fedavg/utils.py @@ -1,24 +1,46 @@ import os - import numpy as np import torch - def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list format to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary containing model parameters in list format. + + Returns: + dict: A dictionary containing model parameters as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to a list format. + + Args: + model_params (dict): A dictionary containing model parameters as PyTorch tensors. + + Returns: + dict: A dictionary containing model parameters in list format. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a sweep process. + + This function creates a named pipe and writes a completion message to it, along with the provided arguments. + + Args: + args: An object containing configuration parameters. + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 0a6b8d505a11082e1107c9c6c8a761042db10fe8 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 8 Sep 2023 13:30:38 +0530 Subject: [PATCH 48/70] python\fedml\simulation\mpi \fedavg_seq --- .../mpi/fedavg_seq/FedAVGAggregator.py | 160 +++++++++++++++--- .../mpi/fedavg_seq/FedAVGTrainer.py | 84 +++++++++ .../mpi/fedavg_seq/FedAvgClientManager.py | 90 ++++++++-- .../simulation/mpi/fedavg_seq/FedAvgSeqAPI.py | 51 ++++++ .../mpi/fedavg_seq/FedAvgServerManager.py | 58 +++++++ .../my_model_trainer_classification.py | 77 +++++++++ .../fedml/simulation/mpi/fedavg_seq/utils.py | 33 +++- 7 files changed, 511 insertions(+), 42 deletions(-) diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py b/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py index 5adb0e0208..c0b4eaa458 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAVGAggregator.py @@ -15,6 +15,23 @@ class FedAVGAggregator(object): + """ + Federated Averaging Aggregator. + + This class handles the aggregation of local model updates from clients in a federated learning setup using Federated Averaging. + + Args: + train_global: The global training dataset. + test_global: The global test dataset. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary containing local training datasets for each client. + test_data_local_dict: A dictionary containing local test datasets for each client. + train_data_local_num_dict: A dictionary containing the number of local training samples for each client. + worker_num: The number of worker nodes (clients). + device: The device (e.g., 'cpu' or 'cuda') on which the model and data should be placed. + args: An object containing configuration parameters. + server_aggregator: An optional server aggregator object. + """ def __init__( self, train_global, @@ -57,18 +74,42 @@ def __init__( self.runtime_avg[i][j] = None def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: A dictionary containing the global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): A dictionary containing the global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params): + """ + Add the local model update from a client. + + Args: + index (int): The index of the client. + model_params (dict): A dictionary containing the local model parameters. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params - # self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local model updates. + + Returns: + bool: True if all clients have uploaded their updates, False otherwise. + """ logging.debug("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -78,6 +119,16 @@ def check_whether_all_receive(self): return True def workload_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the workload of clients. + + Args: + client_indexes (list): A list of client indexes. + mode (str): The estimation mode, either "simulate" or "real". + + Returns: + list: A list of estimated workloads. + """ if mode == "simulate": client_samples = [ self.train_data_local_num_dict[client_index] @@ -91,6 +142,16 @@ def workload_estimate(self, client_indexes, mode="simulate"): return workload def memory_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the memory usage of clients. + + Args: + client_indexes (list): A list of client indexes. + mode (str): The estimation mode, either "simulate" or "real". + + Returns: + list: A list of estimated memory usages. + """ if mode == "simulate": memory = np.ones(self.worker_num) elif mode == "real": @@ -100,6 +161,15 @@ def memory_estimate(self, client_indexes, mode="simulate"): return memory def resource_estimate(self, mode="simulate"): + """ + Estimate the resource usage of clients. + + Args: + mode (str): The estimation mode, either "simulate" or "real". + + Returns: + list: A list of estimated resource usages. + """ if mode == "simulate": resource = np.ones(self.worker_num) elif mode == "real": @@ -109,6 +179,13 @@ def resource_estimate(self, mode="simulate"): return resource def record_client_runtime(self, worker_id, client_runtimes): + """ + Record the runtime of client training. + + Args: + worker_id (int): The ID of the worker. + client_runtimes (dict): A dictionary containing client runtime information. + """ for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id].append(runtime) if hasattr(self.args, "runtime_est_mode"): @@ -117,21 +194,27 @@ def record_client_runtime(self, worker_id, client_runtimes): if self.runtime_avg[worker_id][client_id] is None: self.runtime_avg[worker_id][client_id] = runtime else: - self.runtime_avg[worker_id][client_id] += self.runtime_avg[worker_id][client_id]/2 + runtime/2 + self.runtime_avg[worker_id][client_id] += self.runtime_avg[worker_id][client_id] / 2 + runtime / 2 elif self.args.runtime_est_mode == 'time_window': for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id] = self.runtime_history[worker_id][client_id][-3:] + def generate_client_schedule(self, round_idx, client_indexes): - # self.runtime_history = {} - # for i in range(self.worker_num): - # self.runtime_history[i] = {} - # for j in range(self.args.client_num_in_total): - # self.runtime_history[i][j] = [] + """ + Generate the schedule of clients for a given round. + + Args: + round_idx (int): The index of the round. + client_indexes (list): A list of client indexes. + + Returns: + list: A list of client schedules. + """ previous_time = time.time() if hasattr(self.args, "simulation_schedule") and round_idx > 5: - # Need some rounds to record some information. + # Need some rounds to record some information. simulation_schedule = self.args.simulation_schedule if hasattr(self.args, "runtime_est_mode"): if self.args.runtime_est_mode == 'EMA': @@ -144,7 +227,7 @@ def generate_client_schedule(self, round_idx, client_indexes): runtime_to_fit = self.runtime_history fit_params, fit_funcs, fit_errors = t_sample_fit( - self.worker_num, self.args.client_num_in_total, runtime_to_fit, + self.worker_num, self.args.client_num_in_total, runtime_to_fit, self.train_data_local_num_dict, uniform_client=True, uniform_gpu=False) if self.args.enable_wandb: @@ -187,6 +270,15 @@ def generate_client_schedule(self, round_idx, client_indexes): return client_schedule def get_average_weight(self, client_indexes): + """ + Calculate the average weight of clients based on their data sizes. + + Args: + client_indexes (list): A list of client indexes. + + Returns: + dict: A dictionary containing the average weight for each client. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -199,36 +291,33 @@ def get_average_weight(self, client_indexes): return average_weight_dict def aggregate(self): + """ + Aggregate the local model updates from clients and compute the global model parameters. + + Returns: + dict: A dictionary containing the global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 for idx in range(self.worker_num): - # added for attack & defense; enable multiple defenses - # if FedMLDefender.get_instance().is_defense_enabled(): - # self.model_dict[idx] = FedMLDefender.get_instance().defend( - # self.model_dict[idx], self.get_global_model_params() - # ) - if len(self.model_dict[idx]) > 0: - # some workers may not have parameters + # Some workers may not have parameters model_list.append(self.model_dict[idx]) # training_num += self.sample_num_dict[idx] logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) - # (num0, averaged_params) = model_list[0] averaged_params = model_list[0] for k in averaged_params.keys(): for i in range(0, len(model_list)): local_model_params = model_list[i] - # w = local_sample_number / training_num if i == 0: averaged_params[k] = local_model_params[k] else: averaged_params[k] += local_model_params[k] - # update the global model which is cached at the server side + # Update the global model which is cached at the server side self.set_global_model_params(averaged_params) end_time = time.time() @@ -236,6 +325,17 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select a subset of clients for training in a round. + + Args: + round_idx (int): The index of the round. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to select per round. + + Returns: + list: A list of selected client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -244,7 +344,7 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): num_clients = min(client_num_per_round, client_num_in_total) np.random.seed( round_idx - ) # make sure for each comparison, we are selecting the same clients each round + ) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice( range(client_num_in_total), num_clients, replace=False ) @@ -252,6 +352,15 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing. + + Args: + num_samples (int, optional): The number of samples in the validation set. Defaults to 10000. + + Returns: + torch.utils.data.DataLoader: A DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -266,6 +375,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Test the global model on all clients. + + Args: + round_idx (int): The index of the current round. + """ if ( round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1 @@ -278,6 +393,8 @@ def test_on_server_for_all_clients(self, round_idx): train_num_samples = [] train_tot_corrects = [] train_losses = [] + + # Note: The following code is commented out, so it doesn't affect the execution. # for client_idx in range(self.args.client_num_in_total): # # train data # metrics = self.trainer.test( @@ -312,6 +429,7 @@ def test_on_server_for_all_clients(self, round_idx): else: metrics = self.aggregator.test(self.val_global, self.device, self.args) + # Note: The following code is commented out, so it doesn't affect the execution. # test_tot_correct, test_num_sample, test_loss = ( # metrics["test_correct"], # metrics["test_total"], diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py b/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py index 148eb9b7c3..cbbf31181d 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAVGTrainer.py @@ -2,6 +2,49 @@ class FedAVGTrainer(object): + """ + Trainer class for federated learning clients using the FedAVG algorithm. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training datasets. + train_data_local_num_dict (dict): A dictionary containing the number of samples for each local dataset. + test_data_local_dict (dict): A dictionary containing local testing datasets. + train_data_num (int): The total number of training samples. + device (str): The device (e.g., "cpu" or "cuda") for training. + args (Namespace): Command-line arguments and configuration. + model_trainer (object): An instance of the model trainer used for training. + + Attributes: + trainer (object): The model trainer instance. + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training datasets. + train_data_local_num_dict (dict): A dictionary containing the number of samples for each local dataset. + test_data_local_dict (dict): A dictionary containing local testing datasets. + all_train_data_num (int): The total number of training samples. + train_local (Dataset): The local training dataset. + local_sample_number (int): The number of local training samples. + test_local (Dataset): The local testing dataset. + device (str): The device for training (e.g., "cpu" or "cuda"). + args (Namespace): Command-line arguments and configuration. + + Methods: + update_model(weights): + Update the model with given weights. + + update_dataset(client_index): + Update the current dataset for training and testing. + + get_lr(progress): + Calculate the learning rate based on the training progress. + + train(round_idx=None): + Train the model on the local dataset for a given round. + + test(): + Evaluate the trained model on both local training and testing datasets. + + """ def __init__( self, client_index, @@ -28,15 +71,39 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model with the provided weights. + + Args: + weights (dict): The model parameters to set. + + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the current dataset for training and testing. + + Args: + client_index (int): The index of the client representing the dataset to be used. + + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def get_lr(self, progress): + """ + Calculate the learning rate based on the training progress. + + Args: + progress (int): The training progress, typically the round index. + + Returns: + float: The calculated learning rate. + + """ # This aims to make a float step_size work. if self.args.lr_schedule == "StepLR": exp_num = progress / self.args.lr_step_size @@ -56,6 +123,16 @@ def get_lr(self, progress): return lr def train(self, round_idx=None): + """ + Train the model on the local dataset for a given round. + + Args: + round_idx (int, optional): The current round index. Defaults to None. + + Returns: + tuple: A tuple containing the trained model weights and the number of local samples used. + + """ self.args.round_idx = round_idx # lr = self.get_lr(round_idx) # self.trainer.train(self.train_local, self.device, self.args, lr=lr) @@ -65,6 +142,13 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """ + Evaluate the trained model on both local training and testing datasets. + + Returns: + tuple: A tuple containing training and testing metrics, including correct predictions, loss, and sample counts. + + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py b/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py index 7cfaa43056..edaca7d713 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAvgClientManager.py @@ -8,10 +8,37 @@ from ....core.distributed.fedml_comm_manager import FedMLCommManager + + class FedAVGClientManager(FedMLCommManager): + """ + Manager for federated learning clients using the Federated Averaging (FedAvg) algorithm. + + This class handles communication between the server and clients, as well as the training + process on each client. + + Args: + args (Namespace): Command-line arguments and configuration. + trainer (object): An instance of the model trainer used for local training on clients. + comm (object, optional): The communication backend (e.g., MPI). Defaults to None. + rank (int, optional): The rank of the client. Defaults to 0. + size (int, optional): The total number of clients. Defaults to 0. + backend (str, optional): The communication backend type (e.g., MPI). Defaults to "MPI". + """ def __init__( self, args, trainer, comm=None, rank=0, size=0, backend="MPI", ): + """ + Initialize the FedAVGClientManager. + + Args: + args: The command-line arguments. + trainer: The trainer for client-side training. + comm: The communication backend. + rank: The rank of the client. + size: The total number of clients. + backend: The communication backend (e.g., "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round @@ -19,18 +46,28 @@ def __init__( self.worker_id = self.rank - 1 def run(self): + """ + Run the FedAVGClientManager. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers. + """ self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.handle_message_receive_model_from_server, ) def handle_message_init(self, msg_params): - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) + """ + Handle initialization message from the server. + Args: + msg_params: The message parameters. + """ + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -39,14 +76,20 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_indexes, average_weight_dict) def start_training(self): + """ + Start the training process. + """ self.round_idx = 0 - # self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params: The message parameters. + """ logging.info("handle_message_receive_model_from_server.") global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) - average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -54,18 +97,31 @@ def handle_message_receive_model_from_server(self, msg_params): self.round_idx += 1 self.__train(global_model_params, client_indexes, average_weight_dict) if self.round_idx == self.num_rounds - 1: - # post_complete_message_to_sweep_process(self.args) self.finish() def send_result_to_server(self, receive_id, weights, client_runtime_info): + """ + Send the training results to the server. + + Args: + receive_id: The ID of the recipient (server). + weights: The model weights. + client_runtime_info: Information about client runtime. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - # message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info) self.send_message(message) def add_client_model(self, local_agg_model_params, model_params, weight=1.0): - # Add params that needed to be reduces from clients + """ + Add the client model parameters to the local aggregation. + + Args: + local_agg_model_params: The local aggregation of model parameters. + model_params: The model parameters. + weight: The weight for averaging. + """ for name, param in model_params.items(): if name not in local_agg_model_params: local_agg_model_params[name] = param * weight @@ -73,32 +129,34 @@ def add_client_model(self, local_agg_model_params, model_params, weight=1.0): local_agg_model_params[name] += param * weight def __train(self, global_model_params, client_indexes, average_weight_dict): + """ + Train the client model. + + Args: + global_model_params: The global model parameters. + client_indexes: The indexes of clients. + average_weight_dict: The dictionary of average weights. + """ logging.info("#######training########### round_id = %d" % self.round_idx) if hasattr(self.args, "simulation_gpu_hetero"): - # runtime_speed_ratio - # runtime_speed_ratio * t_train - t_train - # time.sleep(runtime_speed_ratio * t_train - t_train) simulation_gpu_hetero = self.args.simulation_gpu_hetero runtime_speed_ratio = self.args.gpu_hetero_ratio * self.worker_id / self.args.worker_num if hasattr(self.args, "simulation_environment_hetero"): - # runtime_speed_ratio - # runtime_speed_ratio * t_train - t_train - # time.sleep(runtime_speed_ratio * t_train - t_train) if self.args.simulation_environment_hetero == "cos": runtime_speed_ratio = self.args.environment_hetero_ratio * \ (1 + cos(self.round_idx / self.num_rounds*3.1415926 + self.worker_id)) else: raise NotImplementedError - local_agg_model_params = {} client_runtime_info = {} for client_index in client_indexes: logging.info( "#######training########### Simulating client_index = %d, average weight: %f " - % (client_index, average_weight_dict[client_index]) + % (client_index, + average_weight_dict[client_index]) ) start_time = time.time() self.trainer.update_model(global_model_params) diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py b/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py index c9aa2dbfb1..a0afb8e868 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAvgSeqAPI.py @@ -12,6 +12,21 @@ def FedML_FedAvgSeq_distributed( args, process_id, worker_number, comm, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Function to initialize and run federated learning in a distributed environment using the FedAvg algorithm. + + Args: + args (Namespace): Command-line arguments and configuration. + process_id (int): The unique identifier for the current process. + worker_number (int): The total number of worker processes. + comm (object): The communication backend (e.g., MPI). + device (str): The device (e.g., "cpu" or "cuda") for training. + dataset (list): List containing dataset information. + model (nn.Module): The federated learning model. + client_trainer (object, optional): An instance of the client model trainer. Defaults to None. + server_aggregator (object, optional): An instance of the server aggregator. Defaults to None. + """ + [ train_data_num, test_data_num, @@ -58,6 +73,7 @@ def FedML_FedAvgSeq_distributed( client_trainer, ) +# Rest of the code... def init_server( args, @@ -74,6 +90,25 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize and run the federated learning server. + + Args: + args (Namespace): Command-line arguments and configuration. + device (str): The device (e.g., "cpu" or "cuda") for training. + comm (object): The communication backend (e.g., MPI). + rank (int): The rank of the server process. + size (int): The total number of processes. + model (nn.Module): The federated learning model. + train_data_num (int): The total number of training samples. + train_data_global (Dataset): The global training dataset. + test_data_global (Dataset): The global test dataset. + train_data_local_dict (dict): A dictionary of local training datasets. + test_data_local_dict (dict): A dictionary of local test datasets. + train_data_local_num_dict (dict): A dictionary of the number of samples in each local training dataset. + server_aggregator (object): An instance of the server aggregator. + """ + if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -113,6 +148,22 @@ def init_client( test_data_local_dict, client_trainer=None, ): + """ + Initialize and run a federated learning client. + + Args: + args (Namespace): Command-line arguments and configuration. + device (str): The device (e.g., "cpu" or "cuda") for training. + comm (object): The communication backend (e.g., MPI). + process_id (int): The unique identifier for the client process. + size (int): The total number of processes. + model (nn.Module): The federated learning model. + train_data_num (int): The total number of training samples. + train_data_local_num_dict (dict): A dictionary of the number of samples in each local training dataset. + train_data_local_dict (dict): A dictionary of local training datasets. + test_data_local_dict (dict): A dictionary of local test datasets. + client_trainer (object, optional): An instance of the client model trainer. Defaults to None. + """ client_index = process_id - 1 if client_trainer is None: client_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py b/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py index f9269a50b9..b59bb6b64e 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py +++ b/python/fedml/simulation/mpi/fedavg_seq/FedAvgServerManager.py @@ -10,6 +10,19 @@ class FedAVGServerManager(FedMLCommManager): + """ + Class responsible for managing the server in the FedAVG federated learning system. + + Args: + args (Namespace): Command-line arguments and configuration. + aggregator (object): An instance of the aggregator used for federated learning. + comm (object, optional): The communication backend (e.g., MPI). Defaults to None. + rank (int, optional): The rank of the server process. Defaults to 0. + size (int, optional): The total number of processes. Defaults to 0. + backend (str, optional): The backend used for communication. Defaults to "MPI". + is_preprocessed (bool, optional): Indicates whether client lists are preprocessed. Defaults to False. + preprocessed_client_lists (list, optional): Preprocessed client lists. Defaults to None. + """ def __init__( self, args, @@ -30,9 +43,19 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Run the server manager to coordinate federated learning. + + This method runs the server manager to coordinate the federated learning process. + """ super().run() def send_init_msg(self): + """ + Send the initialization message to clients. + + This method sends an initialization message to client processes to begin federated learning. + """ # sampling clients self.previous_time = time.time() client_indexes = self.aggregator.client_sampling( @@ -48,11 +71,24 @@ def send_init_msg(self): self.send_message_init_config(process_id, global_model_params, average_weight_dict, client_schedule) def register_message_receive_handlers(self): + """ + Register message receive handlers. + + This method registers message receive handlers for processing incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model update from a client. + + Args: + msg_params (dict): The parameters of the received message. + + This method handles the model update received from a client during federated learning. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -112,6 +148,17 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send the initialization configuration message to a client. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Average weight dictionary for clients. + client_schedule (list): The schedule of clients for the current round. + + This method sends an initialization configuration message to a client process. + """ message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) # message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) @@ -120,6 +167,17 @@ def send_message_init_config(self, receive_id, global_model_params, average_weig self.send_message(message) def send_message_sync_model_to_client(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send the model synchronization message to a client. + + Args: + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Average weight dictionary for clients. + client_schedule (list): The schedule of clients for the current round. + + This method sends a model synchronization message to a client process. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) diff --git a/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py b/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py index 20ce167511..19da2853aa 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py +++ b/python/fedml/simulation/mpi/fedavg_seq/my_model_trainer_classification.py @@ -6,13 +6,65 @@ class MyModelTrainer(ClientTrainer): + """ + Custom model trainer for federated learning clients. + + Args: + model (nn.Module): The PyTorch model to be trained. + id (int): The identifier of the client. + + Attributes: + model (nn.Module): The PyTorch model being trained. + id (int): The identifier of the client. + + Methods: + get_model_params(): + Get the model parameters as a dictionary. + + set_model_params(model_parameters): + Set the model parameters using a dictionary. + + train(train_data, device, args, lr=None): + Train the model on the provided training data. + + test(test_data, device, args): + Evaluate the model on the provided test data. + + test_on_the_server(train_data_local_dict, test_data_local_dict, device, args=None): + Perform testing on the server (not implemented in this class). + + """ def get_model_params(self): + """ + Get the model parameters as a dictionary. + + Returns: + dict: A dictionary containing the model's state dictionary. + + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters using a dictionary. + + Args: + model_parameters (dict): A dictionary containing the model's state dictionary. + + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, lr=None): + """ + Train the model on the provided training data. + + Args: + train_data (DataLoader): The DataLoader containing the training data. + device (str): The device (e.g., "cpu" or "cuda") for training. + args (Namespace): Command-line arguments and configuration. + lr (float, optional): The learning rate. Defaults to None. + + """ model = self.model model.to(device) @@ -66,6 +118,18 @@ def train(self, train_data, device, args, lr=None): ) def test(self, test_data, device, args): + """ + Evaluate the model on the provided test data. + + Args: + test_data (DataLoader): The DataLoader containing the test data. + device (str): The device (e.g., "cpu" or "cuda") for testing. + args (Namespace): Command-line arguments and configuration. + + Returns: + dict: A dictionary containing test metrics, including correct predictions, loss, and total samples. + + """ model = self.model model.to(device) @@ -93,4 +157,17 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Perform testing on the server (not implemented in this class). + + Args: + train_data_local_dict (dict): A dictionary containing local training datasets. + test_data_local_dict (dict): A dictionary containing local testing datasets. + device (str): The device (e.g., "cpu" or "cuda") for testing. + args (Namespace, optional): Command-line arguments and configuration. Defaults to None. + + Returns: + bool: Always returns False as this method is not implemented in this class. + + """ return False diff --git a/python/fedml/simulation/mpi/fedavg_seq/utils.py b/python/fedml/simulation/mpi/fedavg_seq/utils.py index aea2449590..479e91c857 100644 --- a/python/fedml/simulation/mpi/fedavg_seq/utils.py +++ b/python/fedml/simulation/mpi/fedavg_seq/utils.py @@ -1,24 +1,47 @@ +import torch +import numpy as np import os -import numpy as np -import torch +def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from NumPy arrays to PyTorch tensors. + Args: + model_params_list (dict): A dictionary of model parameters as NumPy arrays. -def transform_list_to_tensor(model_params_list): + Returns: + dict: A dictionary of model parameters as PyTorch tensors. + + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to NumPy arrays. + + Args: + model_params (dict): A dictionary of model parameters as PyTorch tensors. + + Returns: + dict: A dictionary of model parameters as NumPy arrays. + + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe for communication. + + Args: + args: Additional information or arguments (usually configuration). + + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 0a9649183990b77a35e129e04875b56d81ef8845 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 8 Sep 2023 19:33:23 +0530 Subject: [PATCH 49/70] python\fedml\simulation\mpi fedgan fedgkt --- .../simulation/mpi/fedgan/FedGANAggregator.py | 119 +++++++++++- .../simulation/mpi/fedgan/FedGANTrainer.py | 70 ++++++- .../fedml/simulation/mpi/fedgan/FedGanAPI.py | 58 ++++++ .../mpi/fedgan/FedGanClientManager.py | 49 +++++ .../mpi/fedgan/FedGanServerManager.py | 61 +++++- .../simulation/mpi/fedgan/gan_trainer.py | 62 +++++- python/fedml/simulation/mpi/fedgan/utils.py | 29 ++- .../fedml/simulation/mpi/fedgkt/FedGKTAPI.py | 58 +++++- .../simulation/mpi/fedgkt/GKTClientManager.py | 107 +++++++++- .../simulation/mpi/fedgkt/GKTClientTrainer.py | 74 +++++++ .../simulation/mpi/fedgkt/GKTServerManager.py | 73 ++++++- .../simulation/mpi/fedgkt/GKTServerTrainer.py | 113 ++++++++++- python/fedml/simulation/mpi/fedgkt/utils.py | 182 +++++++++++++++--- 13 files changed, 998 insertions(+), 57 deletions(-) diff --git a/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py b/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py index 826b2da7ec..745bcc1882 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py +++ b/python/fedml/simulation/mpi/fedgan/FedGANAggregator.py @@ -13,6 +13,35 @@ class FedGANAggregator(object): + """ + A class for aggregating and managing local models in a Federated Generative Adversarial Network (FedGAN) setup. + + Attributes: + trainer: Model trainer object for training and testing. + args: Configuration arguments. + train_global: Global training dataset. + test_global: Global testing dataset. + val_global: Validation dataset for testing. + all_train_data_num: Total number of training samples. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + worker_num: Number of worker nodes. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + model_dict: Dictionary to store local models from client workers. + sample_num_dict: Dictionary to store the number of training samples from client workers. + flag_client_model_uploaded_dict: Dictionary to track whether client models have been uploaded. + + Methods: + get_global_model_params(): Get the global model parameters. + set_global_model_params(model_parameters): Set the global model parameters. + add_local_trained_result(index, model_params, sample_num): Add local trained model results to the aggregator. + check_whether_all_receive(): Check if all client workers have uploaded their local models. + aggregate(): Aggregate local models from client workers. + client_sampling(round_idx, client_num_in_total, client_num_per_round): Randomly sample a subset of clients for communication in a round. + _generate_validation_set(num_samples): Generate a validation dataset for testing. + test_on_server_for_all_clients(round_idx): Perform testing on the server side for all clients. + """ def __init__( self, train_global, @@ -26,6 +55,22 @@ def __init__( args, model_trainer, ): + """ + Initialize the FedGANAggregator. + + Args: + train_global: Global training dataset. + test_global: Global testing dataset. + all_train_data_num: Total number of training samples. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + worker_num: Number of worker nodes. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + args: Configuration arguments. + model_trainer: Model trainer object for training and testing. + + """ self.trainer = model_trainer self.args = args @@ -47,18 +92,48 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + + """ return self.trainer.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters to set. + + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add local trained model results to the aggregator. + + Args: + index: Index of the client worker. + model_params (dict): Local model parameters. + sample_num (int): Number of local training samples. + + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all client workers have uploaded their local models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -67,6 +142,13 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models from client workers. + + Returns: + dict: Averaged global model parameters. + + """ start_time = time.time() model_list = [] training_num = 0 @@ -77,7 +159,6 @@ def aggregate(self): logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) (num0, averaged_params) = model_list[0] for net in averaged_params.keys(): for k in averaged_params[net].keys(): @@ -89,7 +170,7 @@ def aggregate(self): else: averaged_params[net][k] += local_model_params[net][k] * w - # update the global model which is cached at the server side + # Update the global model which is cached at the server side self.set_global_model_params(averaged_params) end_time = time.time() @@ -97,6 +178,18 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for communication in a round. + + Args: + round_idx (int): Current communication round. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of client indexes selected for communication in the current round. + + """ if client_num_in_total == client_num_per_round: client_indexes = [ client_index for client_index in range(client_num_in_total) @@ -105,7 +198,7 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): num_clients = min(client_num_per_round, client_num_in_total) np.random.seed( round_idx - ) # make sure for each comparison, we are selecting the same clients each round + ) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice( range(client_num_in_total), num_clients, replace=False ) @@ -113,6 +206,16 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset for testing. + + Args: + num_samples (int): Number of samples to include in the validation set. + + Returns: + DataLoader: Validation dataset loader. + + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample( @@ -127,6 +230,16 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server side for all clients. + + Args: + round_idx (int): Current communication round. + + Returns: + bool: True if testing on the server side is performed, False otherwise. + + """ if self.trainer.test_on_the_server( self.train_data_local_dict, self.test_data_local_dict, diff --git a/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py b/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py index d9caae6958..69f7d9324d 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py +++ b/python/fedml/simulation/mpi/fedgan/FedGANTrainer.py @@ -2,6 +2,33 @@ class FedGANTrainer(object): + """ + Trainer for a federated GAN client. + + Args: + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training datasets. + train_data_local_num_dict (dict): Dictionary of local training dataset sizes. + test_data_local_dict (dict): Dictionary of local test datasets. + train_data_num (int): Number of samples in the global training dataset. + device: Device for training (e.g., 'cuda' or 'cpu'). + args: Configuration arguments. + model_trainer: Trainer for the GAN model. + + Attributes: + trainer: Trainer for the GAN model. + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training datasets. + train_data_local_num_dict (dict): Dictionary of local training dataset sizes. + test_data_local_dict (dict): Dictionary of local test datasets. + all_train_data_num (int): Number of samples in the global training dataset. + train_local: Local training dataset. + local_sample_number: Number of samples in the local training dataset. + test_local: Local test dataset. + device: Device for training (e.g., 'cuda' or 'cpu'). + args: Configuration arguments. + """ + def __init__( self, client_index, @@ -14,7 +41,6 @@ def __init__( model_trainer, ): self.trainer = model_trainer - self.client_index = client_index self.train_data_local_dict = train_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict @@ -23,28 +49,59 @@ def __init__( self.train_local = None self.local_sample_number = None self.test_local = None - self.device = device self.args = args def update_model(self, weights): + """ + Update the model with new weights. + + Args: + weights: New model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the client's dataset. + + Args: + client_index (int): Index of the client. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] - # self.test_local = self.test_data_local_dict[client_index] def train(self, round_idx=None): + """ + Train the client's GAN model. + + Args: + round_idx: Index of the training round (optional). + + Returns: + weights: Updated model weights. + local_sample_number: Number of samples in the local dataset. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) - weights = self.trainer.get_model_params() return weights, self.local_sample_number def test(self): - # train data + """ + Test the client's GAN model on both training and test datasets. + + Returns: + Tuple containing: + - train_tot_correct: Total correct predictions on the training dataset. + - train_loss: Loss on the training dataset. + - train_num_sample: Number of samples in the training dataset. + - test_tot_correct: Total correct predictions on the test dataset. + - test_loss: Loss on the test dataset. + - test_num_sample: Number of samples in the test dataset. + """ + # Train data metrics train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( train_metrics["test_correct"], @@ -52,7 +109,7 @@ def test(self): train_metrics["test_loss"], ) - # test data + # Test data metrics test_metrics = self.trainer.test(self.test_local, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( test_metrics["test_correct"], @@ -68,3 +125,4 @@ def test(self): test_loss, test_num_sample, ) + diff --git a/python/fedml/simulation/mpi/fedgan/FedGanAPI.py b/python/fedml/simulation/mpi/fedgan/FedGanAPI.py index eda4b5ff75..8f498adb41 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGanAPI.py +++ b/python/fedml/simulation/mpi/fedgan/FedGanAPI.py @@ -8,6 +8,12 @@ def FedML_init(): + """ + Initialize the MPI communication and return necessary information. + + Returns: + tuple: A tuple containing the MPI communication object, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -27,6 +33,21 @@ def FedML_FedGan_distributed( model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Initialize and run the Federated GAN distributed training. + + Args: + args: Configuration arguments. + process_id (int): The process ID of the current worker. + worker_number (int): Total number of workers. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + comm: MPI communication object. + model: GAN model to be trained. + dataset: Dataset information including training and testing data. + model_trainer: Model trainer object for training and testing. + preprocessed_sampling_lists: Preprocessed client sampling lists. + + """ [ train_data_num, test_data_num, @@ -92,6 +113,26 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for Federated GAN training. + + Args: + args: Configuration arguments. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + comm: MPI communication object. + rank (int): Rank of the current process. + size (int): Total number of processes. + model: GAN model to be trained. + train_data_num: Total number of training samples. + train_data_global: Global training dataset. + test_data_global: Global testing dataset. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + model_trainer: Model trainer object for training and testing. + preprocessed_sampling_lists: Preprocessed client sampling lists. + + """ if model_trainer is None: pass @@ -148,6 +189,23 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Federated GAN training. + + Args: + args: Configuration arguments. + device: Torch device for computation (e.g., 'cuda' or 'cpu'). + comm: MPI communication object. + process_id (int): The process ID of the current client. + size (int): Total number of processes. + model: GAN model to be trained. + train_data_num: Total number of training samples. + train_data_local_num_dict: Dictionary of the number of local training samples for each worker. + train_data_local_dict: Dictionary of local training datasets for each worker. + test_data_local_dict: Dictionary of local testing datasets for each worker. + model_trainer: Model trainer object for training and testing. + + """ client_index = process_id - 1 model_trainer.set_id(client_index) diff --git a/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py b/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py index df8dcc55bd..ad5385b561 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py +++ b/python/fedml/simulation/mpi/fedgan/FedGanClientManager.py @@ -7,6 +7,23 @@ class FedGANClientManager(FedMLCommManager): + """ + Manager for Federated GAN client-side operations. + + Args: + args: Configuration arguments. + trainer: Model trainer for local training. + comm: MPI communication object. + rank (int): Rank of the current process. + size (int): Total number of processes. + backend (str): Backend for communication (e.g., 'MPI'). + + Attributes: + trainer: Model trainer for local training. + num_rounds: Number of communication rounds. + args.round_idx: Current communication round index. + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -14,9 +31,15 @@ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): self.args.round_idx = 0 def run(self): + """ + Start the client manager's execution. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for initialization and model updates. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +49,12 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params (dict): Message parameters containing model parameters and client index. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -35,10 +64,19 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the client training. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model update message from the server. + + Args: + msg_params (dict): Message parameters containing model parameters and client index. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -52,6 +90,14 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the local model to the server. + + Args: + receive_id (int): ID of the server receiving the model. + weights: Model weights to be sent. + local_sample_num: Number of local samples used for training. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -62,6 +108,9 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """ + Perform the local training and send the updated model to the server. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py b/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py index 15b8dc7390..5088f24cf1 100644 --- a/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py +++ b/python/fedml/simulation/mpi/fedgan/FedGanServerManager.py @@ -7,6 +7,28 @@ class FedGANServerManager(FedMLCommManager): + """ + Manager for Federated GAN server-side operations. + + Args: + args: Configuration arguments. + aggregator: Aggregator for model updates. + comm: MPI communication object. + rank (int): Rank of the current process. + size (int): Total number of processes. + backend (str): Backend for communication (e.g., 'MPI'). + is_preprocessed (bool): Indicates if client sampling is preprocessed. + preprocessed_client_lists (list): Preprocessed client sampling lists. + + Attributes: + args: Configuration arguments. + aggregator: Aggregator for model updates. + round_num: Number of communication rounds. + args.round_idx: Current communication round index. + is_preprocessed: Indicates if client sampling is preprocessed. + preprocessed_client_lists: Preprocessed client sampling lists. + """ + def __init__( self, args, @@ -27,10 +49,16 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Start the server manager's execution. + """ super().run() def send_init_msg(self): - # sampling clients + """ + Send initialization message to clients, including global model parameters and client indexes. + """ + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -43,12 +71,21 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message receive handlers for receiving model updates from clients. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model update message from a client. + + Args: + msg_params (dict): Message parameters containing sender ID, model parameters, and local sample count. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -62,7 +99,7 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() # self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: # post_complete_message_to_sweep_process(self.args) @@ -71,12 +108,12 @@ def handle_message_receive_model_from_client(self, msg_params): return if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -92,6 +129,14 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initialization configuration message to a client. + + Args: + receive_id (int): ID of the client receiving the configuration. + global_model_params: Global model parameters. + client_index: Index of the client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -102,6 +147,14 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send a model synchronization message to a client. + + Args: + receive_id (int): ID of the client receiving the model. + global_model_params: Global model parameters. + client_index: Index of the client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fedgan/gan_trainer.py b/python/fedml/simulation/mpi/fedgan/gan_trainer.py index 3ecb788a26..bad160fcb0 100644 --- a/python/fedml/simulation/mpi/fedgan/gan_trainer.py +++ b/python/fedml/simulation/mpi/fedgan/gan_trainer.py @@ -8,22 +8,57 @@ class GANTrainer(ClientTrainer): + """ + Trainer for a Generative Adversarial Network (GAN) client. + + Args: + netd: Discriminator network. + netg: Generator network. + + Attributes: + netg: Generator network. + netd: Discriminator network. + """ + def __init__(self, netd, netg): self.netg = netg self.netd = netd super(GANTrainer, self).__init__(model=None, args=None) def get_model_params(self): + """ + Get the parameters of the generator and discriminator networks. + + Returns: + dict: Dictionary containing the state dictionaries of the generator and discriminator networks. + """ weights_d = self.netd.cpu().state_dict() weights_g = self.netg.cpu().state_dict() weights = {"netg": weights_g, "netd": weights_d} return weights def set_model_params(self, model_parameters): + """ + Set the parameters of the generator and discriminator networks. + + Args: + model_parameters (dict): Dictionary containing the state dictionaries of the generator and discriminator networks. + """ self.netg.load_state_dict(model_parameters["netg"]) self.netd.load_state_dict(model_parameters["netd"]) def train(self, train_data, device, args): + """ + Train the generator and discriminator networks of the GAN. + + Args: + train_data: Training data for the GAN. + device: Device for training (e.g., 'cuda' or 'cpu'). + args: Configuration arguments for training. + + Returns: + None + """ netg = self.netg netd = self.netd @@ -32,7 +67,7 @@ def train(self, train_data, device, args): netd.to(device) netd.train() - criterion = nn.BCELoss() # pylint: disable=E1102 + criterion = nn.BCELoss() # Binary Cross-Entropy Loss optimizer_g = torch.optim.Adam(netg.parameters(), lr=args.lr) optimizer_d = torch.optim.Adam(netd.parameters(), lr=args.lr) @@ -43,29 +78,28 @@ def train(self, train_data, device, args): batch_d_loss = [] batch_g_loss = [] for batch_idx, (x, _) in enumerate(train_data): - # logging.info(batch_idx) - # logging.info(x.shape) if len(x) < 2: continue x = x.to(device) real_labels = torch.ones(x.size(0), 1).to(device) fake_labels = torch.zeros(x.size(0), 1).to(device) optimizer_d.zero_grad() - d_real_loss = criterion(netd(x), real_labels) # pylint: disable=E1102 + d_real_loss = criterion(netd(x), real_labels) noise = torch.randn(x.size(0), 100).to(device) - d_fake_loss = criterion(netd(netg(noise)), fake_labels) # pylint: disable=E1102 + d_fake_loss = criterion(netd(netg(noise)), fake_labels) d_loss = d_real_loss + d_fake_loss d_loss.backward() optimizer_d.step() noise = torch.randn(x.size(0), 100).to(device) optimizer_g.zero_grad() - g_loss = criterion(netd(netg(noise)), real_labels) # pylint: disable=E1102 + g_loss = criterion(netd(netg(noise)), real_labels) g_loss.backward() optimizer_g.step() batch_d_loss.append(d_loss.item()) batch_g_loss.append(g_loss.item()) + if len(batch_g_loss) > 0: epoch_g_loss.append(sum(batch_g_loss) / len(batch_g_loss)) epoch_d_loss.append(sum(batch_d_loss) / len(batch_d_loss)) @@ -81,7 +115,7 @@ def train(self, train_data, device, args): ) netg.eval() z = torch.randn(100, 100).to(device) - y_hat = netg(z).view(100, 28, 28) # (100, 28, 28) + y_hat = netg(z).view(100, 28, 28) result = y_hat.cpu().data.numpy() img = np.zeros([280, 280]) for j in range(10): @@ -89,8 +123,20 @@ def train(self, train_data, device, args): [x for x in result[j * 10: (j + 1) * 10]], axis=-1 ) + # Save generated images if needed # imsave("samples/{}_{}.jpg".format(self.id, epoch), img, cmap="gray") netg.train() def test(self, test_data, device, args): - pass + """ + Test the GAN model. + + Args: + test_data: Test data for the GAN. + device: Device for testing (e.g., 'cuda' or 'cpu'). + args: Configuration arguments for testing. + + Returns: + None + """ + pass # Testing is not implemented in this trainer diff --git a/python/fedml/simulation/mpi/fedgan/utils.py b/python/fedml/simulation/mpi/fedgan/utils.py index 195d130aea..e5edfa8ed9 100644 --- a/python/fedml/simulation/mpi/fedgan/utils.py +++ b/python/fedml/simulation/mpi/fedgan/utils.py @@ -5,6 +5,15 @@ def transform_list_to_tensor(model_params_list): + """ + Convert a dictionary of model parameters from NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): A dictionary containing model parameters. + + Returns: + dict: A dictionary with model parameters converted to PyTorch tensors. + """ for net in model_params_list.keys(): for k in model_params_list[net].keys(): model_params_list[net][k] = torch.from_numpy( @@ -14,6 +23,15 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Convert a dictionary of model parameters from PyTorch tensors to NumPy arrays. + + Args: + model_params (dict): A dictionary containing model parameters as PyTorch tensors. + + Returns: + dict: A dictionary with model parameters converted to NumPy arrays. + """ for net in model_params.keys(): for k in model_params[net].keys(): model_params[net][k] = model_params[net][k].detach().numpy().tolist() @@ -21,10 +39,19 @@ def transform_tensor_to_list(model_params): def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe. + + Args: + args: Information or data to be included in the completion message. + + Returns: + None + """ pipe_path = "./tmp/fedml" if not os.path.exists(pipe_path): os.mkfifo(pipe_path) pipe_fd = os.open(pipe_path, os.O_WRONLY) with os.fdopen(pipe_fd, "w") as pipe: - pipe.write("training is finished! \n%s\n" % (str(args))) + pipe.write("Training is finished! \n%s\n" % (str(args))) diff --git a/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py b/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py index 9c4916a337..8d1fb71caa 100644 --- a/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py +++ b/python/fedml/simulation/mpi/fedgkt/FedGKTAPI.py @@ -7,6 +7,14 @@ def FedML_init(): + """ + Initialize the Federated Learning environment. + + Returns: + comm: The MPI communication object. + process_id: The ID of the current process. + worker_number: The total number of worker processes. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -22,6 +30,21 @@ def FedML_FedGKT_distributed( dataset, args, ): + """ + Perform Federated Knowledge Transfer (FedGKT) in a distributed setting. + + Args: + process_id: The ID of the current process. + worker_number: The total number of worker processes. + device: The device (e.g., CPU or GPU) for training. + comm: The MPI communication object. + model: A tuple containing client and server models. + dataset: A list containing dataset-related information. + args: Additional arguments and settings. + + Returns: + None + """ [ train_data_num, test_data_num, @@ -50,12 +73,26 @@ def FedML_FedGKT_distributed( def init_server(args, device, comm, rank, size, model): + """ + Initialize the Federated Knowledge Transfer (FedGKT) server. + + Args: + args: Additional arguments and settings. + device: The device (e.g., CPU or GPU) for training. + comm: The MPI communication object. + rank: The rank of the current process. + size: The total number of processes. + model: The server model for FedGKT. + + Returns: + None + """ # aggregator client_num = size - 1 server_trainer = GKTServerTrainer(client_num, device, model, args) # start the distributed training - server_manager = GKTServerMananger(args, server_trainer, comm, rank, size) + server_manager = GKTServerManager(args, server_trainer, comm, rank, size) server_manager.run() @@ -70,6 +107,23 @@ def init_client( test_data_local_dict, train_data_local_num_dict, ): + """ + Initialize a FedGKT client. + + Args: + args: Additional arguments and settings. + device: The device (e.g., CPU or GPU) for training. + comm: The MPI communication object. + process_id: The ID of the current process. + size: The total number of processes. + model: The client model for FedGKT. + train_data_local_dict: A dictionary of local training data. + test_data_local_dict: A dictionary of local testing data. + train_data_local_num_dict: A dictionary of the number of local training samples. + + Returns: + None + """ client_ID = process_id - 1 # 2. initialize the trainer @@ -84,5 +138,5 @@ def init_client( ) # 3. start the distributed training - client_manager = GKTClientMananger(args, trainer, comm, process_id, size) + client_manager = GKTClientManager(args, trainer, comm, process_id, size) client_manager.run() diff --git a/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py b/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py index befc3a5618..77c76b706d 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTClientManager.py @@ -5,18 +5,80 @@ from ....core.distributed.communication.message import Message -class GKTClientMananger(FedMLCommManager): +class GKTClientManager(FedMLCommManager): + """ + A class representing the client-side manager for Global Knowledge Transfer (GKT). + + This manager is responsible for coordinating communication between the client and the server + during the GKT training process. + + Args: + args (argparse.Namespace): Additional arguments and settings. + trainer (GKTClientTrainer): The client-side trainer responsible for training the client model. + comm (MPI.Comm): MPI communication object. + rank (int): The rank or identifier of the client process. + size (int): The total number of processes in the communication group. + backend (str): The MPI backend for communication (default is "MPI"). + + Attributes: + args (argparse.Namespace): Additional arguments and settings. + trainer (GKTClientTrainer): The client-side trainer responsible for training the client model. + num_rounds (int): The total number of communication rounds. + device (torch.device): The device (e.g., GPU) used for training. + args.round_idx (int): The current round index. + + Methods: + run(): Start the client manager to initiate communication and training. + register_message_receive_handlers(): Register message receive handlers for communication. + handle_message_init(msg_params): Handle the initialization message from the server. + handle_message_receive_logits_from_server(msg_params): Handle logits received from the server. + send_model_to_server(extracted_feature_dict, logits_dict, labels_dict, extracted_feature_dict_test, labels_dict_test): + Send extracted features, logits, and labels to the server for knowledge transfer. + __train(): Start the client model training process. + + """ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): - super().__init__(args, comm, rank, size, backend) + """ + Initialize the GKT (Global Knowledge Transfer) client manager. + + Args: + args: Additional arguments and settings. + trainer: The GKT client trainer instance. + comm: The MPI communication object. + rank: The rank of the current process. + size: The total number of processes. + backend: The communication backend (default: "MPI"). + Returns: + None + """ + super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round self.args.round_idx = 0 def run(self): + """ + Start the GKT client manager. + + Args: + None + + Returns: + None + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for the GKT client manager. + + Args: + None + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,11 +88,29 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle the initialization message from the server. + + Args: + msg_params: Parameters from the received message. + + Returns: + None + """ logging.info("handle_message_init. Rank = " + str(self.rank)) self.args.round_idx = 0 self.__train() def handle_message_receive_logits_from_server(self, msg_params): + """ + Handle the message containing logits from the server. + + Args: + msg_params: Parameters from the received message. + + Returns: + None + """ logging.info( "handle_message_receive_logits_from_server. Rank = " + str(self.rank) ) @@ -50,6 +130,20 @@ def send_model_to_server( extracted_feature_dict_test, labels_dict_test, ): + """ + Send extracted features, logits, and labels to the server. + + Args: + receive_id: The ID of the recipient (usually the server). + extracted_feature_dict: A dictionary of extracted features. + logits_dict: A dictionary of logits. + labels_dict: A dictionary of labels. + extracted_feature_dict_test: A dictionary of extracted features for testing. + labels_dict_test: A dictionary of labels for testing. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_FEATURE_AND_LOGITS, self.get_sender_id(), @@ -65,6 +159,15 @@ def send_model_to_server( self.send_message(message) def __train(self): + """ + Perform the training process for the GKT client. + + Args: + None + + Returns: + None + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) ( extracted_feature_dict, diff --git a/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py b/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py index 927cb0b63f..fc969119f5 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTClientTrainer.py @@ -7,6 +7,40 @@ class GKTClientTrainer(object): + """ + A class representing the client-side trainer for Global Knowledge Transfer (GKT). + + This trainer is responsible for training a client model and exchanging knowledge with the server. + + Args: + client_index (int): The index of the client. + local_training_data (list): Local training data for the client. + local_test_data (list): Local test data for the client. + local_sample_number (int): The number of local training samples. + device (torch.device): The device (e.g., GPU) on which the client model is located. + client_model (torch.nn.Module): The client model. + args (argparse.Namespace): Additional arguments and settings. + + Attributes: + client_index (int): The index of the client. + local_training_data (list): Local training data for the client. + local_test_data (list): Local test data for the client. + local_sample_number (int): The number of local training samples. + args (argparse.Namespace): Additional arguments and settings. + device (torch.device): The device (e.g., GPU) on which the client model is located. + client_model (torch.nn.Module): The client model. + model_params (iterable): The parameters of the client model. + master_params (iterable): The master parameters of the client model. + optimizer (torch.optim.Optimizer): The optimizer used for training. + criterion_CE (torch.nn.CrossEntropyLoss): The cross-entropy loss criterion. + criterion_KL (KL_Loss): The KL divergence loss criterion for knowledge distillation. + server_logits_dict (dict): A dictionary to store logits received from the server. + + Methods: + get_sample_number(): Get the number of local training samples. + update_large_model_logits(logits): Update the logits received from the server. + train(): Train the client model and return extracted features, logits, and labels for training and test data. + """ def __init__( self, client_index, @@ -17,6 +51,21 @@ def __init__( client_model, args, ): + """ + Initialize the GKT (Global Knowledge Transfer) client trainer. + + Args: + client_index (int): The index of the client. + local_training_data (list): Local training data for the client. + local_test_data (list): Local test data for the client. + local_sample_number (int): The number of local training samples. + device (torch.device): The device (e.g., GPU) on which the client model is located. + client_model (torch.nn.Module): The client model. + args (argparse.Namespace): Additional arguments and settings. + + Returns: + None + """ self.client_index = client_index self.local_training_data = local_training_data[client_index] self.local_test_data = local_test_data[client_index] @@ -60,12 +109,37 @@ def __init__( self.server_logits_dict = dict() def get_sample_number(self): + """ + Get the number of local training samples. + + Returns: + int: The number of local training samples. + """ return self.local_sample_number def update_large_model_logits(self, logits): + """ + Update the logits received from the server. + + Args: + logits (dict): Logits received from the server. + + Returns: + None + """ self.server_logits_dict = logits def train(self): + """ + Train the client model. + + Returns: + dict: Extracted features for training data. + dict: Logits for training data. + dict: Labels for training data. + dict: Extracted features for test data. + dict: Labels for test data. + """ # key: batch_index; value: extracted_feature_map extracted_feature_dict = dict() diff --git a/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py b/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py index bad4f0be1c..309771d803 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTServerManager.py @@ -5,13 +5,51 @@ class GKTServerMananger(FedMLCommManager): + """ + Manager class for the server in the Global Knowledge Transfer (GKT) framework. + + This class handles communication and coordination between the server and clients in the GKT framework. + + Args: + args: Additional arguments and settings. + server_trainer: The server trainer responsible for aggregating client updates. + comm: MPI communication object. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Backend used for communication. + + Attributes: + server_trainer: The server trainer instance. + round_num: The total number of communication rounds. + args: Additional arguments and settings. + count: A counter used for tracking communication rounds. + + Methods: + run(): Start the server manager to handle communication with clients. + register_message_receive_handlers(): Register message handlers for message types. + handle_message_receive_feature_and_logits_from_client(msg_params): Handle client messages containing feature maps, logits, and labels. + send_message_init_config(receive_id, global_model_params): Send an initialization message to a client. + send_message_sync_model_to_client(receive_id, global_logits): Send a synchronization message with global logits to a client. + """ def __init__(self, args, server_trainer, comm=None, rank=0, size=0, backend="MPI"): - super().__init__(args, comm, rank, size, backend) + """ + Initialize the GKT (Global Knowledge Transfer) server manager. + + Args: + args: Additional arguments and settings. + server_trainer: The server trainer. + comm: MPI communication object. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Backend used for communication. + Returns: + None + """ + super().__init__(args, comm, rank, size, backend) self.server_trainer = server_trainer self.round_num = args.comm_round self.args.round_idx = 0 - self.count = 0 def run(self): @@ -27,6 +65,15 @@ def register_message_receive_handlers(self): ) def handle_message_receive_feature_and_logits_from_client(self, msg_params): + """ + Handle the message received from a client containing feature maps, logits, and labels. + + Args: + msg_params: Parameters received in the message. + + Returns: + None + """ logging.info("handle_message_receive_feature_and_logits_from_client") sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) extracted_feature_dict = msg_params.get(MyMessage.MSG_ARG_KEY_FEATURE) @@ -48,7 +95,7 @@ def handle_message_receive_feature_and_logits_from_client(self, msg_params): if b_all_received: self.server_trainer.train(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: self.finish() @@ -59,6 +106,16 @@ def handle_message_receive_feature_and_logits_from_client(self, msg_params): self.send_message_sync_model_to_client(receiver_id, global_logits) def send_message_init_config(self, receive_id, global_model_params): + """ + Send an initialization message to a client. + + Args: + receive_id: ID of the client to receive the message. + global_model_params: Global model parameters. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -66,6 +123,16 @@ def send_message_init_config(self, receive_id, global_model_params): logging.info("send_message_init_config. Receive_id: " + str(receive_id)) def send_message_sync_model_to_client(self, receive_id, global_logits): + """ + Send a synchronization message with global logits to a client. + + Args: + receive_id: ID of the client to receive the message. + global_logits: Global logits. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_S2C_SYNC_TO_CLIENT, self.get_sender_id(), receive_id ) diff --git a/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py b/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py index 5bca610f5a..310cb648ba 100644 --- a/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py +++ b/python/fedml/simulation/mpi/fedgkt/GKTServerTrainer.py @@ -10,6 +10,15 @@ class GKTServerTrainer(object): + """ + Server-side trainer for Global Knowledge Transfer (GKT) in federated learning. + + Args: + client_num (int): Number of client devices. + device (str): The device on which to perform training (e.g., 'cuda' or 'cpu'). + server_model (nn.Module): The global server model. + args (argparse.Namespace): Command-line arguments and configurations. + """ def __init__(self, client_num, device, server_model, args): self.client_num = client_num self.device = device @@ -97,6 +106,17 @@ def add_local_trained_result( extracted_feature_dict_test, labels_dict_test, ): + """ + Add local training results from a client. + + Args: + index (int): Index of the client. + extracted_feature_dict (dict): Extracted feature maps from the client model. + logits_dict (dict): Logits from the client model. + labels_dict (dict): Labels from the client model. + extracted_feature_dict_test (dict): Extracted feature maps from the client model for testing. + labels_dict_test (dict): Labels from the client model for testing. + """ logging.info("add_model. index = %d" % index) self.client_extracted_feauture_dict[index] = extracted_feature_dict self.client_logits_dict[index] = logits_dict @@ -107,6 +127,12 @@ def add_local_trained_result( self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all client models have uploaded updates. + + Returns: + bool: True if all clients have uploaded updates, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -115,9 +141,24 @@ def check_whether_all_receive(self): return True def get_global_logits(self, client_index): + """ + Get global logits for a specific client. + + Args: + client_index (int): Index of the client. + + Returns: + dict: Global logits for the client. + """ return self.server_logits_dict[client_index] def train(self, round_idx): + """ + Train the server model using client updates. + + Args: + round_idx (int): Current communication round index. + """ if self.args.sweep == 1: self.sweep(round_idx) else: @@ -127,6 +168,12 @@ def train(self, round_idx): self.do_not_train_on_client(round_idx) def train_and_distill_on_client(self, round_idx): + """ + Train the server model on the client using distillation from client logits. + + Args: + round_idx (int): Current communication round index. + """ if self.args.test: epochs_server, whether_distill_back = self.get_server_epoch_strategy_test() else: @@ -146,21 +193,48 @@ def train_and_distill_on_client(self, round_idx): self.scheduler.step(self.best_acc, epoch=round_idx) def do_not_train_on_client(self, round_idx): + """ + Perform no training on the client model, only evaluation. + + Args: + round_idx (int): Current communication round index. + """ self.train_and_eval(round_idx, 1) self.scheduler.step(self.best_acc, epoch=round_idx) def sweep(self, round_idx): + """ + Perform sweeping training on the client model. + + Args: + round_idx (int): Current communication round index. + """ # train according to the logits from the client self.train_and_eval(round_idx, self.args.epochs_server) self.scheduler.step(self.best_acc, epoch=round_idx) def get_server_epoch_strategy_test(self): + """ + Get the training strategy for server epoch in the test mode. + + Returns: + tuple: Tuple containing the number of epochs (1) and whether to distill back (True). + """ return 1, True # ResNet56 def get_server_epoch_strategy_reset56(self, round_idx): + """ + Get the training strategy for server epoch in the ResNet56 client model. + + Args: + round_idx (int): Current communication round index. + + Returns: + tuple: Tuple containing the number of epochs and whether to distill back (True/False). + """ whether_distill_back = True - # set the training strategy + # set the training strategy based on round index if round_idx < 20: epochs = 20 elif 20 <= round_idx < 30: @@ -183,6 +257,15 @@ def get_server_epoch_strategy_reset56(self, round_idx): # ResNet56-2 def get_server_epoch_strategy_reset56_2(self, round_idx): + """ + Get the training strategy for server epoch in the ResNet56-2 client model. + + Args: + round_idx (int): Current communication round index. + + Returns: + tuple: Tuple containing the number of epochs and whether to distill back (True/False). + """ whether_distill_back = True # set the training strategy epochs = self.args.epochs_server @@ -190,6 +273,15 @@ def get_server_epoch_strategy_reset56_2(self, round_idx): # not increase after 40 epochs def get_server_epoch_strategy2(self, round_idx): + """ + Determine the training strategy (number of epochs and distillation) for the server model. + + Args: + round_idx (int): Current communication round index. + + Returns: + tuple: Tuple containing the number of epochs and whether to distill back (True/False). + """ whether_distill_back = True # set the training strategy if round_idx < 20: @@ -213,6 +305,13 @@ def get_server_epoch_strategy2(self, round_idx): return epochs, whether_distill_back def train_and_eval(self, round_idx, epochs): + """ + Train and evaluate the server model for a specified number of epochs. + + Args: + round_idx (int): Current communication round index. + epochs (int): Number of epochs to train for. + """ for epoch in range(epochs): logging.info( "train_and_eval. round_idx = %d, epoch = %d" % (round_idx, epoch) @@ -295,6 +394,12 @@ def train_and_eval(self, round_idx, epochs): ) def train_large_model_on_the_server(self): + """ + Train the server model using client features and logits. + + Returns: + dict: Dictionary containing training metrics (loss, accuracy). + """ # clear the server side logits for key in self.server_logits_dict.keys(): @@ -371,6 +476,12 @@ def train_large_model_on_the_server(self): return train_metrics def eval_large_model_on_the_server(self): + """ + Evaluate the server model on the test dataset provided by clients. + + Returns: + dict: Dictionary containing test metrics (loss, accuracy). + """ # set model to evaluation mode self.model_global.eval() diff --git a/python/fedml/simulation/mpi/fedgkt/utils.py b/python/fedml/simulation/mpi/fedgkt/utils.py index fe2ae83878..5f7feb88de 100644 --- a/python/fedml/simulation/mpi/fedgkt/utils.py +++ b/python/fedml/simulation/mpi/fedgkt/utils.py @@ -7,6 +7,15 @@ def get_state_dict(file): + """ + Load a PyTorch state dictionary from a file. + + Args: + file (str): The path to the file containing the state dictionary. + + Returns: + dict: The loaded state dictionary. + """ try: pretrain_state_dict = torch.load(file) except AssertionError: @@ -15,8 +24,16 @@ def get_state_dict(file): ) return pretrain_state_dict - def get_flat_params_from(model): + """ + Get a flat tensor containing all the parameters of a PyTorch model. + + Args: + model (nn.Module): The PyTorch model. + + Returns: + torch.Tensor: A 1D tensor containing the flattened parameters. + """ params = [] for param in model.parameters(): params.append(param.data.view(-1)) @@ -24,8 +41,14 @@ def get_flat_params_from(model): flat_params = torch.cat(params) return flat_params - def set_flat_params_to(model, flat_params): + """ + Set the parameters of a PyTorch model using a flat tensor of parameters. + + Args: + model (nn.Module): The PyTorch model. + flat_params (torch.Tensor): A 1D tensor containing the flattened parameters. + """ prev_ind = 0 for param in model.parameters(): flat_size = int(np.prod(list(param.size()))) @@ -35,32 +58,59 @@ def set_flat_params_to(model, flat_params): prev_ind += flat_size + class RunningAverage: - """A simple class that maintains the running average of a quantity + """ + A simple class that maintains the running average of a quantity Example: - ``` - loss_avg = RunningAverage() - loss_avg.update(2) - loss_avg.update(4) - loss_avg() = 3 - ``` + ``` + loss_avg = RunningAverage() + loss_avg.update(2) + loss_avg.update(4) + loss_avg() = 3.0 + ``` + + Attributes: + steps (int): The number of updates made to the running average. + total (float): The cumulative sum of values for the running average. """ def __init__(self): + """ + Initialize a RunningAverage object. + """ self.steps = 0 self.total = 0 def update(self, val): + """Update the running average with a new value. + + Args: + val (float): The new value to update the running average. + """ self.total += val self.steps += 1 def value(self): - return self.total / float(self.steps) + """Get the current value of the running average. + Returns: + float: The current running average value. + """ + return self.total / float(self.steps) def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" + """Computes the precision@k for the specified values of k. + + Args: + output (torch.Tensor): The model's output tensor. + target (torch.Tensor): The target tensor. + topk (tuple): A tuple of integers specifying the top-k values to compute. + + Returns: + list: A list of accuracy values for each k in topk. + """ maxk = max(topk) batch_size = target.size(0) @@ -76,17 +126,44 @@ def accuracy(output, target, topk=(1,)): class KL_Loss(nn.Module): + """ + Kullback-Leibler (KL) Divergence Loss with Temperature Scaling. + + This class represents the KL divergence loss with an optional temperature + scaling parameter for softening the logits. It is commonly used in knowledge + distillation between a student and a teacher model. + + Args: + temperature (float, optional): The temperature parameter for softening + the logits (default is 1). + + Attributes: + T (float): The temperature parameter for temperature scaling. + + """ + def __init__(self, temperature=1): + """ + Initialize the KL Divergence Loss. + + Args: + temperature (float, optional): The temperature parameter for softening + the logits (default is 1). + + """ super(KL_Loss, self).__init__() self.T = temperature def forward(self, output_batch, teacher_outputs): - # output_batch -> B X num_classes - # teacher_outputs -> B X num_classes + """Compute the KL divergence loss between output_batch and teacher_outputs. - # loss_2 = -torch.sum(torch.sum(torch.mul(F.log_softmax(teacher_outputs,dim=1), F.softmax(teacher_outputs,dim=1)+10**(-7))))/teacher_outputs.size(0) - # print('loss H:',loss_2) + Args: + output_batch (torch.Tensor): The output tensor from the student model. + teacher_outputs (torch.Tensor): The output tensor from the teacher model. + Returns: + torch.Tensor: The computed KL divergence loss. + """ output_batch = F.log_softmax(output_batch / self.T, dim=1) teacher_outputs = F.softmax(teacher_outputs / self.T, dim=1) + 10 ** (-7) @@ -96,20 +173,48 @@ def forward(self, output_batch, teacher_outputs): * nn.KLDivLoss(reduction="batchmean")(output_batch, teacher_outputs) ) - # Same result KL-loss implementation - # loss = T * T * torch.sum(torch.sum(torch.mul(teacher_outputs, torch.log(teacher_outputs) - output_batch)))/teacher_outputs.size(0) return loss + class CE_Loss(nn.Module): + """ + Cross-Entropy Loss with Temperature Scaling. + + This class represents the cross-entropy loss with an optional temperature + scaling parameter for softening the logits. It is commonly used in knowledge + distillation between a student and a teacher model. + + Args: + temperature (float, optional): The temperature parameter for softening + the logits (default is 1). + + Attributes: + T (float): The temperature parameter for temperature scaling. + + """ + def __init__(self, temperature=1): + """ + Initialize the Cross-Entropy (CE) Loss. + + Args: + temperature (float): The temperature parameter for softening the logits (default is 1). + + """ super(CE_Loss, self).__init__() self.T = temperature def forward(self, output_batch, teacher_outputs): - # output_batch -> B X num_classes - # teacher_outputs -> B X num_classes + """Compute the cross-entropy loss between output_batch and teacher_outputs. + + Args: + output_batch (torch.Tensor): The output tensor from the student model. + teacher_outputs (torch.Tensor): The output tensor from the teacher model. + Returns: + torch.Tensor: The computed cross-entropy loss. + """ output_batch = F.log_softmax(output_batch / self.T, dim=1) teacher_outputs = F.softmax(teacher_outputs / self.T, dim=1) @@ -123,28 +228,51 @@ def forward(self, output_batch, teacher_outputs): return loss - def save_dict_to_json(d, json_path): - """Saves dict of floats in json file + """Saves a dictionary of floats in a JSON file. Args: - d: (dict) of float-castable values (np.float, int, float, etc.) - json_path: (string) path to json file + d (dict): A dictionary of float-castable values (np.float, int, float, etc.). + json_path (str): Path to the JSON file where the dictionary will be saved. """ with open(json_path, "w") as f: - # We need to convert the values to float for json (it doesn't accept np.array, np.float, ) - d = {k: v for k, v in d.items()} + # We need to convert the values to float for JSON (it doesn't accept np.array, np.float, etc.) + d = {k: float(v) for k, v in d.items()} json.dump(d, f, indent=4) - # Filter out batch norm parameters and remove them from weight decay - gets us higher accuracy 93.2 -> 93.48 # https://arxiv.org/pdf/1807.11205.pdf def bnwd_optim_params(model, model_params, master_params): + """Split model parameters into two groups for optimization. + + This function separates model parameters into two groups: batch normalization parameters + and remaining parameters. It sets the weight decay for batch normalization parameters to 0. + + Args: + model (nn.Module): The neural network model. + model_params (list): List of model parameters. + master_params (list): List of master parameters. + + Returns: + list: List of dictionaries specifying parameter groups for optimization. + """ bn_params, remaining_params = split_bn_params(model, model_params, master_params) return [{"params": bn_params, "weight_decay": 0}, {"params": remaining_params}] - def split_bn_params(model, model_params, master_params): + """Split model parameters into batch normalization and remaining parameters. + + This function separates model parameters into two groups: batch normalization parameters + and remaining parameters. + + Args: + model (nn.Module): The neural network model. + model_params (list): List of model parameters. + master_params (list): List of master parameters. + + Returns: + tuple: Two lists containing batch normalization parameters and remaining parameters. + """ def get_bn_params(module): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): return module.parameters() From f93ca88fbbe2b024a0e9e1688f696cb701df2dc5 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 13:02:02 +0530 Subject: [PATCH 50/70] python\fedml\simulation\mpi\ fednas fednova fedopt --- .../fedml/simulation/mpi/fednas/FedNASAPI.py | 52 +++++- .../simulation/mpi/fednas/FedNASAggregator.py | 112 ++++++++++++ .../mpi/fednas/FedNASClientManager.py | 50 +++++- .../mpi/fednas/FedNASServerManager.py | 45 ++++- .../simulation/mpi/fednas/FedNASTrainer.py | 123 +++++++++++++ .../simulation/mpi/fednova/FedNovaAPI.py | 49 +++++ .../mpi/fednova/FedNovaAggregator.py | 167 ++++++++++++++++++ .../mpi/fednova/FedNovaClientManager.py | 113 ++++++++---- .../mpi/fednova/FedNovaServerManager.py | 79 ++++++++- .../simulation/mpi/fednova/FedNovaTrainer.py | 78 +++++++- .../my_model_trainer_classification.py | 66 ++++++- python/fedml/simulation/mpi/fednova/utils.py | 35 +++- .../fedml/simulation/mpi/fedopt/FedOptAPI.py | 63 ++++++- .../simulation/mpi/fedopt/FedOptAggregator.py | 138 +++++++++++++-- .../mpi/fedopt/FedOptClientManager.py | 31 ++++ .../mpi/fedopt/FedOptServerManager.py | 44 ++++- .../simulation/mpi/fedopt/FedOptTrainer.py | 48 ++++- python/fedml/simulation/mpi/fedopt/optrepo.py | 25 ++- python/fedml/simulation/mpi/fedopt/utils.py | 29 ++- .../mpi/fedopt_seq/FedOptAggregator.py | 167 ++++++++++++++++++ 20 files changed, 1419 insertions(+), 95 deletions(-) diff --git a/python/fedml/simulation/mpi/fednas/FedNASAPI.py b/python/fedml/simulation/mpi/fednas/FedNASAPI.py index d213473b75..1615f5e43f 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASAPI.py +++ b/python/fedml/simulation/mpi/fednas/FedNASAPI.py @@ -8,12 +8,17 @@ def FedML_init(): + """ + Initialize the Federated Machine Learning environment using MPI (Message Passing Interface). + + Returns: + Tuple: A tuple containing the MPI communicator (`comm`), process ID (`process_id`), and worker number (`worker_number`). + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() return comm, process_id, worker_number - def FedML_FedNAS_distributed( args, process_id, @@ -25,6 +30,20 @@ def FedML_FedNAS_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Initialize and run the Federated NAS (Neural Architecture Search) distributed training process. + + Args: + args: Command-line arguments and configurations. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + comm: The MPI communicator. + device: The device (e.g., GPU) to run the training on. + dataset: A list containing dataset information. + model: The neural network model. + client_trainer (ClientTrainer, optional): The client trainer instance. + server_aggregator (ServerAggregator, optional): The server aggregator instance. + """ [ train_data_num, test_data_num, @@ -53,10 +72,23 @@ def FedML_FedNAS_distributed( test_data_local_dict, ) - def init_server( args, device, comm, process_id, worker_number, model, train_data_num, train_data_global, test_data_global, ): + """ + Initialize and run the server component of the Federated NAS distributed training. + + Args: + args: Command-line arguments and configurations. + device: The device (e.g., GPU) to run the training on. + comm: The MPI communicator. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + model: The neural network model. + train_data_num: The number of training data samples. + train_data_global: The global training data. + test_data_global: The global testing data. + """ # aggregator client_num = worker_number - 1 aggregator = FedNASAggregator(train_data_global, test_data_global, train_data_num, client_num, model, device, args,) @@ -65,7 +97,6 @@ def init_server( server_manager = FedNASServerManager(args, comm, process_id, worker_number, aggregator) server_manager.run() - def init_client( args, device, @@ -78,6 +109,21 @@ def init_client( train_data_local, test_data_local, ): + """ + Initialize and run the client component of the Federated NAS distributed training. + + Args: + args: Command-line arguments and configurations. + device: The device (e.g., GPU) to run the training on. + comm: The MPI communicator. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + model: The neural network model. + train_data_num: The number of training data samples. + local_data_num: The number of local training data samples. + train_data_local: The local training data. + test_data_local: The local testing data. + """ # trainer client_ID = process_id - 1 trainer = FedNASTrainer( diff --git a/python/fedml/simulation/mpi/fednas/FedNASAggregator.py b/python/fedml/simulation/mpi/fednas/FedNASAggregator.py index 988ae59e23..301bb8b7e8 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASAggregator.py +++ b/python/fedml/simulation/mpi/fednas/FedNASAggregator.py @@ -7,6 +7,38 @@ class FedNASAggregator(object): + """ + A class responsible for aggregating model parameters and architectures from multiple clients. + + Args: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global testing dataset. + all_train_data_num (int): The total number of training data samples. + client_num (int): The number of clients participating in federated learning. + model (nn.Module): The neural network model to be aggregated. + device (str): The device (e.g., 'cuda' or 'cpu') on which the model is trained. + args (argparse.Namespace): Command-line arguments and configurations. + + Attributes: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global testing dataset. + all_train_data_num (int): The total number of training data samples. + client_num (int): The number of clients participating in federated learning. + device (str): The device (e.g., 'cuda' or 'cpu') on which the model is trained. + args (argparse.Namespace): Command-line arguments and configurations. + model (nn.Module): The neural network model to be aggregated. + model_dict (dict): A dictionary to store client model parameters. + arch_dict (dict): A dictionary to store client model architectures. + sample_num_dict (dict): A dictionary to store the number of samples from each client. + train_acc_dict (dict): A dictionary to store training accuracy from each client. + train_loss_dict (dict): A dictionary to store training loss from each client. + train_acc_avg (float): The average training accuracy. + test_acc_avg (float): The average testing accuracy. + test_loss_avg (float): The average testing loss. + flag_client_model_uploaded_dict (dict): A dictionary to track whether client models have been uploaded. + best_accuracy (float): The best accuracy achieved during aggregation. + best_accuracy_different_cnn_counts (dict): A dictionary to store the best accuracy with different CNN counts. + """ def __init__( self, train_global, @@ -17,6 +49,19 @@ def __init__( device, args, ): + """ + Initialize a FedNASAggregator object. + + Args: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global testing dataset. + all_train_data_num (int): The total number of training data samples. + client_num (int): The number of clients participating in federated learning. + model (nn.Module): The neural network model to be aggregated. + device (str): The device (e.g., 'cuda' or 'cpu') on which the model is trained. + args (argparse.Namespace): Command-line arguments and configurations. + """ + self.train_global = train_global self.test_global = test_global self.all_train_data_num = all_train_data_num @@ -43,11 +88,29 @@ def __init__( self.wandb_table = wandb.Table(columns=["Epoch", "Searched Architecture"]) def get_model(self): + """ + Get the aggregated model. + + Returns: + nn.Module: The aggregated neural network model. + """ return self.model def add_local_trained_result( self, index, model_params, arch_params, sample_num, train_acc, train_loss ): + """ + Add the results from a locally trained model to the aggregator. + + Args: + index (int): The index of the client. + model_params (dict): The model parameters from the client. + arch_params (dict): The model architecture parameters from the client. + sample_num (int): The number of samples used for training by the client. + train_acc (float): The training accuracy achieved by the client. + train_loss (torch.Tensor): The training loss from the client. + """ + logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.arch_dict[index] = arch_params @@ -57,6 +120,12 @@ def add_local_trained_result( self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all client models have been received by the aggregator. + + Returns: + bool: True if all client models have been received, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -65,6 +134,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate model parameters and architectures from multiple clients. + + Returns: + dict: The aggregated model parameters and architectures. + """ averaged_weights = self.__aggregate_weight() self.model.load_state_dict(averaged_weights) if self.args.stage == "search": @@ -75,11 +150,23 @@ def aggregate(self): return averaged_weights def __update_arch(self, alphas): + """ + Update the architecture parameters of the aggregator's model. + + Args: + alphas (list): A list of architecture parameters. + """ logging.info("update_arch. server.") for a_g, model_arch in zip(alphas, self.model.arch_parameters()): model_arch.data.copy_(a_g.data) def __aggregate_weight(self): + """ + Aggregate model weights from multiple clients. + + Returns: + dict: The aggregated model weights. + """ logging.info("################aggregate weights############") start_time = time.time() model_list = [] @@ -104,6 +191,12 @@ def __aggregate_weight(self): return averaged_params def __aggregate_alpha(self): + """ + Calculate and log statistics including training accuracy, training loss, validation accuracy, and validation loss. + + Args: + round_idx (int): The current round index. + """ logging.info("################aggregate alphas############") start_time = time.time() alpha_list = [] @@ -124,6 +217,12 @@ def __aggregate_alpha(self): return averaged_alphas def statistics(self, round_idx): + """ + Calculate and log statistics including training accuracy, training loss, validation accuracy, and validation loss. + + Args: + round_idx (int): The current round index. + """ # train acc train_acc_list = self.train_acc_dict.values() self.train_acc_avg = sum(train_acc_list) / len(train_acc_list) @@ -175,6 +274,12 @@ def statistics(self, round_idx): ) def infer(self, round_idx): + """ + Perform model inference and calculate test accuracy and loss. + + Args: + round_idx (int): The current round index. + """ self.model.eval() self.model.to(self.device) if ( @@ -217,6 +322,13 @@ def infer(self, round_idx): logging.info("server_infer time cost: %d" % (end_time - start_time)) def record_model_global_architecture(self, round_idx): + """ + Record and log the architecture information of the global model, including genotype, CNN count, + and best accuracy for different CNN structures. + + Args: + round_idx (int): The current round index. + """ # save the structure genotype, normal_cnn_count, reduce_cnn_count = self.model.genotype() cnn_count = normal_cnn_count + reduce_cnn_count diff --git a/python/fedml/simulation/mpi/fednas/FedNASClientManager.py b/python/fedml/simulation/mpi/fednas/FedNASClientManager.py index 369e7677fa..11716a4826 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASClientManager.py +++ b/python/fedml/simulation/mpi/fednas/FedNASClientManager.py @@ -7,6 +7,17 @@ class FedNASClientManager(FedMLCommManager): + """ + Manager class for the client in the Federated NAS (Neural Architecture Search) distributed training. + + Args: + args: Command-line arguments and configurations. + comm: The MPI communicator. + rank: The process rank of the current worker. + size: The total number of workers. + trainer: The client trainer instance. + """ + def __init__(self, args, comm, rank, size, trainer): super().__init__(args, comm, rank, size) @@ -15,9 +26,15 @@ def __init__(self, args, comm, rank, size, trainer): self.args.round_idx = 0 def run(self): + """ + Start the client manager. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.__handle_msg_client_receive_config ) @@ -27,6 +44,12 @@ def register_message_receive_handlers(self): ) def __handle_msg_client_receive_config(self, msg_params): + """ + Handle the received configuration message from the server. + + Args: + msg_params (dict): The message parameters containing model and architecture information. + """ logging.info("__handle_msg_client_receive_config") global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS) @@ -35,10 +58,16 @@ def __handle_msg_client_receive_config(self, msg_params): self.trainer.update_arch(arch_params) self.args.round_idx = 0 - # start to train + # Start training self.__train() def __handle_msg_client_receive_model_from_server(self, msg_params): + """ + Handle the received model message from the server. + + Args: + msg_params (dict): The message parameters containing model and architecture information. + """ process_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS) @@ -54,6 +83,9 @@ def __handle_msg_client_receive_model_from_server(self, msg_params): self.finish() def __train(self): + """ + Perform the local training for the client. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) start_time = time.time() if self.args.stage == "search": @@ -68,7 +100,7 @@ def __train(self): weights, local_sample_num, train_acc, train_loss = self.trainer.train() alphas = [] train_finished_time = time.time() - # for one epoch, the local searching time cost is: 75s (based on RTX2080Ti) + # For one epoch, the local searching time cost is approximately 75s (based on RTX2080Ti) logging.info( "local searching time cost: %d" % (train_finished_time - start_time) ) @@ -77,7 +109,7 @@ def __train(self): weights, alphas, local_sample_num, train_acc, train_loss ) communication_finished_time = time.time() - # for one epoch, the local communication time cost is: < 1s (based o n RTX2080Ti) + # For one epoch, the local communication time cost is less than 1s (based on RTX2080Ti) logging.info( "local communication time cost: %d" % (communication_finished_time - train_finished_time) @@ -86,10 +118,20 @@ def __train(self): def __send_msg_fedavg_send_model_to_server( self, weights, alphas, local_sample_num, valid_acc, valid_loss ): + """ + Send the model updates and training results to the server. + + Args: + weights: The updated model weights. + alphas: The updated architecture parameters (only in the search stage). + local_sample_num: The number of local training samples. + valid_acc: The local training accuracy. + valid_loss: The local training loss. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.rank, 0) message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, alphas) message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_ACC, valid_acc) message.add_params(MyMessage.MSG_ARG_KEY_LOCAL_TRAINING_LOSS, valid_loss) - self.send_message(message) + self.send_message(message) \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fednas/FedNASServerManager.py b/python/fedml/simulation/mpi/fednas/FedNASServerManager.py index 9f5b1c94c4..921bf28bae 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASServerManager.py +++ b/python/fedml/simulation/mpi/fednas/FedNASServerManager.py @@ -8,6 +8,17 @@ class FedNASServerManager(FedMLCommManager): + """ + Manager class for the server in the Federated NAS (Neural Architecture Search) distributed training. + + Args: + args: Command-line arguments and configurations. + comm: The MPI communicator. + rank: The process rank of the current worker. + size: The total number of workers. + aggregator: The aggregator for collecting client updates. + """ + def __init__(self, args, comm, rank, size, aggregator): super().__init__(args, comm, rank, size) @@ -17,6 +28,9 @@ def __init__(self, args, comm, rank, size, aggregator): self.aggregator = aggregator def run(self): + """ + Start the server manager. + """ global_model = self.aggregator.get_model() global_model_params = global_model.state_dict() global_arch_params = None @@ -29,6 +43,9 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.__handle_msg_server_receive_model_from_client_opt_send, @@ -37,6 +54,14 @@ def register_message_receive_handlers(self): def __send_initial_config_to_client( self, process_id, global_model_params, global_arch_params ): + """ + Send the initial configuration to a client. + + Args: + process_id: The ID of the target client. + global_model_params: The global model parameters. + global_arch_params: The global architecture parameters (only in the search stage). + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), process_id ) @@ -46,6 +71,12 @@ def __send_initial_config_to_client( self.send_message(message) def __handle_msg_server_receive_model_from_client_opt_send(self, msg_params): + """ + Handle the received model message from a client and optionally send updated models to clients. + + Args: + msg_params (dict): The message parameters containing model and architecture information. + """ process_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) arch_params = msg_params.get(MyMessage.MSG_ARG_KEY_ARCH_PARAMS) @@ -69,15 +100,15 @@ def __handle_msg_server_receive_model_from_client_opt_send(self, msg_params): else: global_model_params = self.aggregator.aggregate() global_arch_params = [] - self.aggregator.infer(self.args.round_idx) # for NAS, it cost 151 seconds + self.aggregator.infer(self.args.round_idx) # For NAS, it takes approximately 151 seconds self.aggregator.statistics(self.args.round_idx) if self.args.stage == "search": self.aggregator.record_model_global_architecture(self.args.round_idx) - # free all teh GPU memory cache + # Free all GPU memory cache torch.cuda.empty_cache() - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: self.finish() @@ -91,6 +122,14 @@ def __handle_msg_server_receive_model_from_client_opt_send(self, msg_params): def __send_model_to_client_message( self, process_id, global_model_params, global_arch_params ): + """ + Send the updated model to a client. + + Args: + process_id: The ID of the target client. + global_model_params: The updated global model parameters. + global_arch_params: The updated global architecture parameters (only in the search stage). + """ message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, 0, process_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) message.add_params(MyMessage.MSG_ARG_KEY_ARCH_PARAMS, global_arch_params) diff --git a/python/fedml/simulation/mpi/fednas/FedNASTrainer.py b/python/fedml/simulation/mpi/fednas/FedNASTrainer.py index 837162f24b..296744b2bd 100644 --- a/python/fedml/simulation/mpi/fednas/FedNASTrainer.py +++ b/python/fedml/simulation/mpi/fednas/FedNASTrainer.py @@ -8,6 +8,44 @@ class FedNASTrainer(object): + """ + Federated NAS Trainer for local model training and inference. + + This class is responsible for performing local training and inference on client devices during federated NAS. + + Args: + client_index (int): Index of the client within the federated system. + train_data_local_dict (dict): Dictionary containing local training datasets for each client. + test_data_local_dict (dict): Dictionary containing local test/validation datasets for each client. + train_data_local_num (int): Number of training samples on the local client. + train_data_num (int): Total number of training samples across all clients. + model (nn.Module): The neural network model to be trained. + device: The computing device (e.g., GPU) to perform training and inference. + args: Additional configuration and hyperparameters for training and inference. + + Methods: + update_model(weights): + Update the model's weights with global model weights. + + update_arch(alphas): + Update the model's architecture with global architecture parameters. + + search(): + Perform local architecture search and training. + + train(): + Perform local training. + + local_train(train_queue, valid_queue, model, criterion, optimizer): + Perform local training on a batch of data. + + local_infer(valid_queue, model, criterion): + Perform local inference on a batch of data. + + infer(): + Perform inference using the trained model. + + """ def __init__( self, client_index, @@ -33,16 +71,39 @@ def __init__( self.test_local = test_data_local_dict[client_index] def update_model(self, weights): + """ + Update the model with new weights. + + Args: + weights (dict): The model weights to update. + """ logging.info("update_model. client_index = %d" % self.client_index) self.model.load_state_dict(weights) def update_arch(self, alphas): + """ + Update the model architecture parameters (only used in the search stage). + + Args: + alphas (list): The architecture parameters to update. + """ logging.info("update_arch. client_index = %d" % self.client_index) for a_g, model_arch in zip(alphas, self.model.arch_parameters()): model_arch.data.copy_(a_g.data) # local search def search(self): + """ + Perform local neural architecture search. + + Returns: + tuple: A tuple containing the following elements: + - weights (dict): The updated model weights. + - alphas (list): The updated architecture parameters (only in the search stage). + - local_sample_number (int): The number of local training samples. + - local_avg_train_acc (float): The average training accuracy. + - local_avg_train_loss (float): The average training loss. + """ self.model.to(self.device) self.model.train() @@ -108,6 +169,22 @@ def search(self): def local_search( self, train_queue, valid_queue, model, architect, criterion, optimizer ): + """ + Perform local neural architecture search. + + Args: + train_queue (DataLoader): DataLoader for the training dataset. + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model. + architect (Architect): The architect responsible for architecture search. + criterion: The loss criterion for optimization. + optimizer: The optimizer for weight updates. + + Returns: + tuple: A tuple containing the following elements: + - top1_accuracy (float): Top-1 accuracy achieved during local search. + - loss (float): Average loss during local search. + """ objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() @@ -168,6 +245,16 @@ def local_search( return top1.avg / 100.0, objs.avg / 100.0, loss def train(self): + """ + Perform local training. + + Returns: + tuple: A tuple containing the following elements: + - weights (dict): The updated model weights. + - local_sample_number (int): The number of local training samples. + - local_avg_train_acc (float): The average training accuracy. + - local_avg_train_loss (float): The average training loss. + """ self.model.to(self.device) self.model.train() @@ -213,6 +300,21 @@ def train(self): ) def local_train(self, train_queue, valid_queue, model, criterion, optimizer): + """ + Perform local training. + + Args: + train_queue (DataLoader): DataLoader for the training dataset. + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model. + criterion: The loss criterion for optimization. + optimizer: The optimizer for weight updates. + + Returns: + tuple: A tuple containing the following elements: + - top1_accuracy (float): Top-1 accuracy achieved during local training. + - loss (float): Average loss during local training. + """ objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() @@ -249,6 +351,19 @@ def local_train(self, train_queue, valid_queue, model, criterion, optimizer): return top1.avg, objs.avg, loss def local_infer(self, valid_queue, model, criterion): + """ + Perform local inference. + + Args: + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model. + criterion: The loss criterion for evaluation. + + Returns: + tuple: A tuple containing the following elements: + - top1_accuracy (float): Top-1 accuracy achieved during local inference. + - loss (float): Average loss during local inference. + """ objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() top5 = utils.AvgrageMeter() @@ -281,6 +396,14 @@ def local_infer(self, valid_queue, model, criterion): # after searching, infer() function is used to infer the searched architecture def infer(self): + """ + Perform inference using the trained model. + + Returns: + tuple: A tuple containing the following elements: + - test_accuracy (float): Test accuracy achieved using the trained model. + - test_loss (float): Test loss using the trained model. + """ self.model.to(self.device) self.model.eval() diff --git a/python/fedml/simulation/mpi/fednova/FedNovaAPI.py b/python/fedml/simulation/mpi/fednova/FedNovaAPI.py index 16b48a5e17..2dd230e1ef 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaAPI.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaAPI.py @@ -11,6 +11,20 @@ def FedML_FedNova_distributed( args, process_id, worker_number, comm, device, dataset, model, client_trainer=None, server_aggregator=None ): + """ + Initialize and run the FedNova distributed training process. + + Args: + args: Command-line arguments. + process_id (int): ID of the current process. + worker_number (int): Total number of worker processes. + comm: Communication backend for distributed training. + device: PyTorch device (CPU or GPU) to run computations. + dataset: Dataset information including data loaders and other data-related details. + model: The model used for training. + client_trainer: Client-specific trainer (if applicable). + server_aggregator: Server aggregator for model updates (if provided). + """ [ train_data_num, test_data_num, @@ -72,6 +86,25 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for FedNova federated learning. + + Args: + args: Command-line arguments. + device: PyTorch device (CPU or GPU) to run computations. + comm: Communication backend for distributed training. + rank (int): Rank of the current process. + size (int): Total number of processes. + model: The model used for training. + train_data_num: Total number of training samples. + train_data_global: Global training dataset. + test_data_global: Global test dataset. + train_data_local_dict: Dictionary of local training datasets for clients. + test_data_local_dict: Dictionary of local test datasets for clients. + train_data_local_num_dict: Dictionary of the number of local training samples for clients. + server_aggregator: Server aggregator for model updates. + """ + if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -111,6 +144,22 @@ def init_client( test_data_local_dict, client_trainer=None, ): + """ + Initialize a client for FedNova federated learning. + + Args: + args: Command-line arguments. + device: PyTorch device (CPU or GPU) to run computations. + comm: Communication backend for distributed training. + process_id (int): ID of the current client process. + size (int): Total number of processes. + model: The model used for training. + train_data_num: Total number of training samples. + train_data_local_num_dict: Dictionary of the number of local training samples for clients. + train_data_local_dict: Dictionary of local training datasets for clients. + test_data_local_dict: Dictionary of local test datasets for clients. + client_trainer: Client-specific trainer (if applicable). + """ client_index = process_id - 1 if client_trainer is None: # client_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py b/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py index 71fb4743c0..4748409d09 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaAggregator.py @@ -13,6 +13,63 @@ class FedNovaAggregator(object): + """ + Federated Nova Aggregator for aggregating local model updates in a federated learning setup. + + This class manages the aggregation of local model updates from multiple clients in a federated learning system + using the Federated Nova (FedNova) approach. + + Args: + train_global (Dataset): Global training dataset. + test_global (Dataset): Global test/validation dataset. + all_train_data_num (int): Total number of training samples across all clients. + train_data_local_dict (dict): Dictionary containing local training datasets for each client. + test_data_local_dict (dict): Dictionary containing local test/validation datasets for each client. + train_data_local_num_dict (dict): Dictionary containing the number of local training samples for each client. + worker_num (int): Number of worker nodes/clients. + device: The computing device (e.g., GPU) to perform aggregation and computations. + args: Additional configuration and hyperparameters. + server_aggregator: Server-side aggregator for aggregation methods. + + Methods: + get_global_model_params(): + Get the global model parameters. + + set_global_model_params(model_parameters): + Set the global model parameters. + + add_local_trained_result(index, local_result): + Add the local training results from a client. + + check_whether_all_receive(): + Check if all clients have uploaded their local models. + + record_client_runtime(worker_id, client_runtimes): + Record client runtime information for scheduling. + + generate_client_schedule(round_idx, client_indexes): + Generate a schedule for client training in the federated round. + + get_average_weight(client_indexes): + Calculate the average weight for client selection. + + fednova_aggregate(params, norm_grads, tau_effs, tau_eff=0): + Perform FedNova aggregation of local model updates. + + aggregate(): + Aggregate local model updates using the FedNova aggregation method. + + client_sampling(round_idx, client_num_in_total, client_num_per_round): + Perform client sampling for a federated round. + + _generate_validation_set(num_samples=10000): + Generate a validation dataset for testing. + + test_on_server_for_all_clients(round_idx): + Test the global model on all clients' datasets. + + """ + def __init__( self, train_global, @@ -26,6 +83,21 @@ def __init__( args, server_aggregator, ): + """ + Initialize the FedNova manager. + + Args: + train_global: Global training dataset. + test_global: Global test dataset. + all_train_data_num: Total number of training samples. + train_data_local_dict: Dictionary containing local training datasets for clients. + test_data_local_dict: Dictionary containing local test datasets for clients. + train_data_local_num_dict: Dictionary containing the number of local training samples for clients. + worker_num: Number of worker nodes (clients). + device: PyTorch device (CPU or GPU) to run computations. + args: Command-line arguments. + server_aggregator: Aggregator for model updates. + """ self.aggregator = server_aggregator self.args = args @@ -55,18 +127,43 @@ def __init__( self.global_momentum_buffer = dict() def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters to be set. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, local_result): + """ + Add the local training result for a client. + + Args: + index (int): Index of the client. + local_result (dict): Local training result. + """ logging.info("add_model. index = %d" % index) self.result_dict[index] = local_result # self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local model updates. + + Returns: + bool: True if all clients have uploaded, False otherwise. + """ logging.debug("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -76,10 +173,27 @@ def check_whether_all_receive(self): return True def record_client_runtime(self, worker_id, client_runtimes): + """ + Record client runtime information. + + Args: + worker_id (int): Index of the worker. + client_runtimes (dict): Dictionary containing client runtime information. + """ for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id].append(runtime) def generate_client_schedule(self, round_idx, client_indexes): + """ + Generate a schedule for selecting clients in the current round. + + Args: + round_idx (int): Current round index. + client_indexes (list): List of client indexes. + + Returns: + list: List of client schedules for each worker. + """ # self.runtime_history = {} # for i in range(self.worker_num): # self.runtime_history[i] = {} @@ -128,6 +242,15 @@ def generate_client_schedule(self, round_idx, client_indexes): return client_schedule def get_average_weight(self, client_indexes): + """ + Get the average weight for clients based on the number of local samples. + + Args: + client_indexes (list): List of client indexes. + + Returns: + dict: Dictionary mapping client index to average weight. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -138,6 +261,18 @@ def get_average_weight(self, client_indexes): return average_weight_dict def fednova_aggregate(self, params, norm_grads, tau_effs, tau_eff=0): + """ + Perform FedNova aggregation. + + Args: + params (dict): Model parameters to be aggregated. + norm_grads (list): List of normalized gradients from clients. + tau_effs (list): List of effective tau values. + tau_eff (int): Effective tau for aggregation (optional). + + Returns: + dict: Aggregated model parameters. + """ # get tau_eff if tau_eff == 0: tau_eff = sum(tau_effs) @@ -166,6 +301,12 @@ def fednova_aggregate(self, params, norm_grads, tau_effs, tau_eff=0): return params def aggregate(self): + """ + Aggregate model updates from clients. + + Returns: + dict: Aggregated model parameters. + """ start_time = time.time() grad_results = [] t_eff_results = [] @@ -191,6 +332,17 @@ def aggregate(self): return w_global def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample clients for the current round. + + Args: + round_idx (int): Current round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + list: List of sampled client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -201,6 +353,15 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing. + + Args: + num_samples (int): Number of samples to include in the validation set (optional). + + Returns: + DataLoader: DataLoader for the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -211,6 +372,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients. + + Args: + round_idx (int): Current round index. + """ if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) train_num_samples = [] diff --git a/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py b/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py index ad2a99b2ed..8ea484c125 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaClientManager.py @@ -9,6 +9,29 @@ class FedNovaClientManager(FedMLCommManager): + """ + Manager for the client-side of the FedNova federated learning process. + + Parameters: + args: Command-line arguments. + trainer: Client trainer responsible for local training. + comm: Communication backend for distributed training. + rank (int): Rank of the client process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + + Methods: + __init__: Initialize the FedNovaClientManager. + run: Start the client manager. + register_message_receive_handlers: Register message receive handlers for handling incoming messages. + handle_message_init: Handle the initialization message received from the server. + start_training: Start the training process. + handle_message_receive_model_from_server: Handle the received model from the server. + send_result_to_server: Send training results to the server. + add_client_model: Add client model parameters to the aggregation. + __train: Perform the training process for the specified clients. + """ + def __init__( self, args, @@ -18,6 +41,17 @@ def __init__( size=0, backend="MPI", ): + """ + Initialize the FedNovaClientManager. + + Args: + args: Command-line arguments. + trainer: Client trainer responsible for local training. + comm: Communication backend for distributed training. + rank (int): Rank of the client process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer = trainer self.num_rounds = args.comm_round @@ -25,9 +59,15 @@ def __init__( self.worker_id = self.rank - 1 def run(self): + """ + Start the client manager. + """ super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -37,9 +77,13 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) + """ + Handle the initialization message received from the server. + Args: + msg_params: Parameters included in the received message. + """ + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -48,14 +92,19 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_indexes, average_weight_dict) def start_training(self): + """ + Start the training process. + """ self.round_idx = 0 - # self.__train() def handle_message_receive_model_from_server(self, msg_params): - logging.info("handle_message_receive_model_from_server.") - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) - # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) + """ + Handle the received model from the server. + Args: + msg_params: Parameters included in the received message. + """ + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) average_weight_dict = msg_params.get(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS) client_schedule = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE) client_indexes = client_schedule[self.worker_id] @@ -63,40 +112,52 @@ def handle_message_receive_model_from_server(self, msg_params): self.round_idx += 1 self.__train(global_model_params, client_indexes, average_weight_dict) if self.round_idx == self.num_rounds - 1: - # post_complete_message_to_sweep_process(self.args) self.finish() - def send_result_to_server(self, receive_id, weights, client_runtime_info): + """ + Send training results to the server. + + Args: + receive_id: ID of the recipient (e.g., the server). + weights: Model weights or parameters. + client_runtime_info: Information about client runtime. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id, ) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - # message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info) self.send_message(message) - def add_client_model(self, local_agg_model_params, client_index, grad, t_eff, weight=1.0): - # Add params that needed to be reduces from clients - # for name, param in model_params.items(): - # if name not in local_agg_model_params: - # local_agg_model_params[name] = param * weight - # else: - # local_agg_model_params[name] += param * weight - # local_agg_model_params[client_index]["grad"] = grad - # local_agg_model_params[client_index]["t_eff"] = t_eff + """ + Add client model parameters to the aggregation. + + Args: + local_agg_model_params: Local aggregation of model parameters. + client_index: Index or ID of the client. + grad: Gradients computed during training. + t_eff: Efficiency factor. + weight: Weight assigned to the client's contribution. + """ local_agg_model_params.append({ "grad": grad, "t_eff": t_eff, }) - def __train(self, global_model_params, client_indexes, average_weight_dict): + """ + Perform the training process for the specified clients. + + Args: + global_model_params: Global model parameters. + client_indexes: Indexes of the clients to train. + average_weight_dict: Dictionary of average weights for clients. + """ logging.info("#######training########### round_id = %d" % self.round_idx) - # local_agg_model_params = {} local_agg_model_params = [] client_runtime_info = {} for client_index in client_indexes: @@ -105,7 +166,6 @@ def __train(self, global_model_params, client_indexes, average_weight_dict): start_time = time.time() self.trainer.update_model(global_model_params) self.trainer.update_dataset(int(client_index)) - # weights, local_sample_num = self.trainer.train(self.round_idx) loss, grad, t_eff = self.trainer.train(self.round_idx) self.add_client_model(local_agg_model_params, client_index, grad, t_eff, weight=average_weight_dict[client_index]) @@ -116,14 +176,3 @@ def __train(self, global_model_params, client_indexes, average_weight_dict): logging.info("#######training########### End Simulating client_index = %d, consuming time: %f" % \ (client_index, client_runtime)) self.send_result_to_server(0, local_agg_model_params, client_runtime_info) - - - - - - - - - - - diff --git a/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py b/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py index 97d257dbe2..d423d03d99 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaServerManager.py @@ -8,6 +8,28 @@ class FedNovaServerManager(FedMLCommManager): + """ + Manager for the server-side of the FedNova federated learning process. + + Methods: + __init__: Initialize the FedNovaServerManager. + run: Start the server manager. + send_init_msg: Send initialization messages to clients. + register_message_receive_handlers: Register message receive handlers for handling incoming messages. + handle_message_receive_model_from_client: Handle the received model from a client. + send_message_init_config: Send initialization configuration message to a client. + send_message_sync_model_to_client: Send model synchronization message to a client. + + Parameters: + args: Command-line arguments. + aggregator: Server aggregator responsible for aggregating client updates. + comm: Communication backend for distributed training. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + is_preprocessed (bool): Indicates whether clients have been preprocessed. + preprocessed_client_lists (list): Lists of preprocessed clients for each round. + """ def __init__( self, args, @@ -19,6 +41,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the FedNovaServerManager. + + Args: + args: Command-line arguments. + aggregator: Server aggregator responsible for aggregating client updates. + comm: Communication backend for distributed training. + rank (int): Rank of the server process. + size (int): Total number of processes. + backend (str): Communication backend (e.g., "MPI"). + is_preprocessed (bool): Indicates whether clients have been preprocessed. + preprocessed_client_lists (list): Lists of preprocessed clients for each round. + """ super().__init__(args, comm, rank, size, backend) self.args = args self.aggregator = aggregator @@ -28,12 +63,18 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """ + Start the server manager. + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + """ # sampling clients client_indexes = self.aggregator.client_sampling( self.round_idx, @@ -53,12 +94,18 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message receive handlers for handling incoming messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model from a client. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -112,8 +159,18 @@ def handle_message_receive_model_from_client(self, msg_params): average_weight_dict, client_schedule ) - def send_message_init_config(self, receive_id, global_model_params, - average_weight_dict, client_schedule): + def send_message_init_config( + self, receive_id, global_model_params, average_weight_dict, client_schedule + ): + """ + Send initialization configuration message to a client. + + Args: + receive_id: Receiver's process ID. + global_model_params: Global model parameters. + average_weight_dict: Dictionary of average weights for clients. + client_schedule: Schedule of clients for the current round. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -123,8 +180,22 @@ def send_message_init_config(self, receive_id, global_model_params, message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE, client_schedule) self.send_message(message) - def send_message_sync_model_to_client(self, receive_id, global_model_params, - average_weight_dict, client_schedule): + def send_message_sync_model_to_client( + self, + receive_id, + global_model_params, + average_weight_dict, + client_schedule + ): + """ + Send model synchronization message to a client. + + Args: + receive_id: Receiver's process ID. + global_model_params: Global model parameters. + average_weight_dict: Dictionary of average weights for clients. + client_schedule: Schedule of clients for the current round. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py b/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py index d55e6d9822..420be4e3ac 100644 --- a/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py +++ b/python/fedml/simulation/mpi/fednova/FedNovaTrainer.py @@ -2,6 +2,27 @@ class FedNovaTrainer(object): + """ + Trainer class for FedNova federated learning. + + Methods: + __init__: Initialize the FedNovaTrainer. + update_model: Update the model with global weights. + update_dataset: Update the local dataset for training. + get_lr: Calculate the learning rate for the current round. + train: Train the model on the local dataset. + test: Evaluate the model on the local training and test datasets. + + Parameters: + client_index (int): Index of the client. + train_data_local_dict (dict): Local training dataset for each client. + train_data_local_num_dict (dict): Number of samples in the local training dataset for each client. + test_data_local_dict (dict): Local test dataset for each client. + train_data_num (int): Total number of training samples across all clients. + device: Device (e.g., GPU or CPU) for model training. + args: Command-line arguments. + model_trainer: Trainer for the machine learning model. + """ def __init__( self, client_index, @@ -13,6 +34,19 @@ def __init__( args, model_trainer, ): + """ + Initialize the FedNovaTrainer. + + Args: + client_index (int): Index of the client. + train_data_local_dict (dict): Local training dataset for each client. + train_data_local_num_dict (dict): Number of samples in the local training dataset for each client. + test_data_local_dict (dict): Local test dataset for each client. + train_data_num (int): Total number of training samples across all clients. + device: Device (e.g., GPU or CPU) for model training. + args: Command-line arguments. + model_trainer: Trainer for the machine learning model. + """ self.trainer = model_trainer self.client_index = client_index @@ -29,15 +63,36 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model with global weights. + + Args: + weights: Global model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local dataset for training. + + Args: + client_index (int): Index of the client. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def get_lr(self, progress): + """ + Calculate the learning rate for the current round. + + Args: + progress (int): Current round index. + + Returns: + float: Learning rate. + """ # This aims to make a float step_size work. if self.args.lr_schedule == "StepLR": exp_num = progress / self.args.lr_step_size @@ -57,18 +112,28 @@ def get_lr(self, progress): return lr def train(self, round_idx=None): + """ + Train the model on the local dataset. + + Args: + round_idx (int): Current round index. + + Returns: + tuple: A tuple containing average loss, normalized gradient, and effective tau. + """ self.args.round_idx = round_idx - # lr = self.get_lr(round_idx) - # self.trainer.train(self.train_local, self.device, self.args, lr=lr) avg_loss, norm_grad, tau_eff = self.trainer.train(self.train_local, self.device, self.args, ratio=self.local_sample_number / self.total_train_num) - # weights = self.trainer.get_model_params() - - # return weights, self.local_sample_number return avg_loss, norm_grad, tau_eff - def test(self): + """ + Evaluate the model on the local training and test datasets. + + Returns: + tuple: A tuple containing training accuracy, training loss, training sample count, + test accuracy, test loss, and test sample count. + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( @@ -93,3 +158,4 @@ def test(self): test_loss, test_num_sample, ) + \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py b/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py index bf56731b5b..698730fd07 100644 --- a/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py +++ b/python/fedml/simulation/mpi/fednova/my_model_trainer_classification.py @@ -1,18 +1,55 @@ import torch from torch import nn - from ....core.alg_frame.client_trainer import ClientTrainer import logging - class MyModelTrainer(ClientTrainer): + """ + Custom client trainer for federated learning using PyTorch. + + Methods: + get_model_params: Get the model parameters as a state dictionary. + set_model_params: Set the model parameters from a state dictionary. + train: Train the model on the given training data. + test: Evaluate the model on the given test data. + test_on_the_server: Perform server-side testing (not implemented). + + Parameters: + model: The PyTorch model to be trained. + id (int): The identifier of the client. + """ + def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: Model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): Model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, lr=None): + """ + Train the model on the given training data. + + Args: + train_data: Training data for the client. + device: Device (e.g., GPU or CPU) for model training. + args: Command-line arguments for training configuration. + lr (float): Learning rate for optimization (optional). + + Returns: + None + """ model = self.model model.to(device) @@ -44,7 +81,7 @@ def train(self, train_data, device, args, lr=None): loss = criterion(log_probs, labels) # pylint: disable=E1102 loss.backward() - # Uncommet this following line to avoid nan loss + # Uncomment this following line to avoid nan loss # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) optimizer.step() @@ -66,6 +103,17 @@ def train(self, train_data, device, args, lr=None): ) def test(self, test_data, device, args): + """ + Evaluate the model on the given test data. + + Args: + test_data: Test data for the client. + device: Device (e.g., GPU or CPU) for model evaluation. + args: Command-line arguments for evaluation configuration. + + Returns: + dict: Evaluation metrics, including test_correct, test_loss, and test_total. + """ model = self.model model.to(device) @@ -93,4 +141,16 @@ def test(self, test_data, device, args): def test_on_the_server( self, train_data_local_dict, test_data_local_dict, device, args=None ) -> bool: + """ + Perform server-side testing (not implemented). + + Args: + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + device: Device (e.g., GPU or CPU) for testing. + args: Command-line arguments for testing configuration (not used). + + Returns: + bool: Always returns False (not implemented). + """ return False diff --git a/python/fedml/simulation/mpi/fednova/utils.py b/python/fedml/simulation/mpi/fednova/utils.py index aea2449590..f19b0adbab 100644 --- a/python/fedml/simulation/mpi/fednova/utils.py +++ b/python/fedml/simulation/mpi/fednova/utils.py @@ -1,24 +1,47 @@ import os - -import numpy as np import torch - +import numpy as np def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list of NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): Dictionary of model parameters represented as NumPy arrays. + + Returns: + dict: Dictionary of model parameters with tensors as values. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to lists. + + Args: + model_params (dict): Dictionary of model parameters represented as PyTorch tensors. + + Returns: + dict: Dictionary of model parameters with lists as values. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe for communication with another process. + + Args: + args: Additional information or configuration to include in the message. + + Returns: + None + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): @@ -26,4 +49,4 @@ def post_complete_message_to_sweep_process(args): pipe_fd = os.open(pipe_path, os.O_WRONLY) with os.fdopen(pipe_fd, "w") as pipe: - pipe.write("training is finished! \n%s\n" % (str(args))) + pipe.write("Training is finished! \n%s\n" % (str(args))) diff --git a/python/fedml/simulation/mpi/fedopt/FedOptAPI.py b/python/fedml/simulation/mpi/fedopt/FedOptAPI.py index dd1ec50208..81b48950f4 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptAPI.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptAPI.py @@ -10,6 +10,11 @@ def FedML_init(): + """Initialize the Federated Learning environment using MPI. + + Returns: + tuple: A tuple containing MPI communication object, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -24,9 +29,23 @@ def FedML_FedOpt_distributed( device, dataset, model, - client_trainer: ClientTrainer = None, - server_aggregator: ServerAggregator = None, + client_trainer=None, + server_aggregator=None, ): + """Initialize and run the Federated Optimization process. + + Args: + args: A configuration object containing federated optimization parameters. + process_id: The process ID. + worker_number: The total number of workers. + comm: MPI communication object. + device: The device (e.g., CPU or GPU) for training. + dataset: A list containing dataset information. + model: The machine learning model. + client_trainer: An optional client trainer object. + server_aggregator: An optional server aggregator object. + + """ [ train_data_num, test_data_num, @@ -37,6 +56,7 @@ def FedML_FedOpt_distributed( test_data_local_dict, class_num, ] = dataset + if process_id == 0: init_server( args, @@ -84,10 +104,28 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """Initialize the server-side components for federated optimization. + + Args: + args: A configuration object containing server parameters. + device: The device (e.g., CPU or GPU) for training. + comm: MPI communication object. + rank: The rank of the server process. + size: The total number of processes. + model: The machine learning model. + train_data_num: The number of training data samples. + train_data_global: Global training data. + test_data_global: Global test data. + train_data_local_dict: Dictionary of local training data. + test_data_local_dict: Dictionary of local test data. + train_data_local_num_dict: Dictionary of the number of local training data samples. + server_aggregator: The server aggregator object. + + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) - # aggregator + worker_num = size - 1 aggregator = FedOptAggregator( train_data_global, @@ -102,7 +140,7 @@ def init_server( server_aggregator, ) - # start the distributed training + server_manager = FedOptServerManager(args, aggregator, comm, rank, size) server_manager.send_init_msg() server_manager.run() @@ -121,6 +159,22 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """Initialize the client-side components for federated optimization. + + Args: + args: A configuration object containing client parameters. + device: The device (e.g., CPU or GPU) for training. + comm: MPI communication object. + process_id: The process ID. + size: The total number of processes. + model: The machine learning model. + train_data_num: The number of training data samples. + train_data_local_num_dict: Dictionary of the number of local training data samples. + train_data_local_dict: Dictionary of local training data. + test_data_local_dict: Dictionary of local test data. + model_trainer: An optional client trainer object. + + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) @@ -129,5 +183,6 @@ def init_client( trainer = FedOptTrainer( client_index, train_data_local_dict, train_data_local_num_dict, train_data_num, device, args, model_trainer, ) + client_manager = FedOptClientManager(args, trainer, comm, process_id, size) client_manager.run() diff --git a/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py b/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py index e86172ec2c..5d589f1c76 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptAggregator.py @@ -12,6 +12,39 @@ class FedOptAggregator(object): + """Aggregator for Federated Optimization. + + This class manages the aggregation of model updates from client devices in a federated optimization setting. + + Args: + train_global: The global training dataset. + test_global: The global testing dataset. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + args: An argparse.Namespace object containing various configuration options. + server_aggregator: An optional ServerAggregator object used for model aggregation. + + Attributes: + aggregator: The server aggregator for model aggregation. + args: An argparse.Namespace object containing various configuration options. + train_global: The global training dataset. + test_global: The global testing dataset. + val_global: A subset of the testing dataset used for validation. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + model_dict: A dictionary mapping client indices to their local model updates. + sample_num_dict: A dictionary mapping client indices to the number of samples used for their local updates. + flag_client_model_uploaded_dict: A dictionary tracking whether each client has uploaded its local model update. + opt: The server optimizer used for model aggregation. + """ def __init__( self, train_global, @@ -25,6 +58,20 @@ def __init__( args, server_aggregator, ): + """Initialize the FedOptAggregator. + + Args: + train_global: Global training data. + test_global: Global test data. + all_train_data_num: Total number of training data samples. + train_data_local_dict: Dictionary of local training data. + test_data_local_dict: Dictionary of local test data. + train_data_local_num_dict: Dictionary of the number of local training data samples. + worker_num: Number of worker clients. + device: The device (e.g., CPU or GPU) for training. + args: A configuration object containing aggregator parameters. + server_aggregator: The server aggregator object. + """ self.aggregator = server_aggregator self.args = args @@ -47,6 +94,11 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def _instantiate_opt(self): + """Instantiate the optimizer. + + Returns: + torch.optim.Optimizer: The instantiated optimizer. + """ return OptRepo.name2cls(self.args.server_optimizer)( filter(lambda p: p.requires_grad, self.get_model_params()), lr=self.args.server_lr, @@ -54,23 +106,48 @@ def _instantiate_opt(self): ) def get_model_params(self): - # return model parameters in type of generator + """Get model parameters. + + Returns: + generator: Generator of model parameters. + """ return self.aggregator.model.parameters() def get_global_model_params(self): - # return model parameters in type of ordered_dict + """Get global model parameters. + + Returns: + OrderedDict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """Set global model parameters. + + Args: + model_parameters: New global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """Add locally trained model results. + + Args: + index: Index of the client. + model_params: Model parameters. + sample_num: Number of training samples. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """Check if all clients have uploaded their models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -79,6 +156,11 @@ def check_whether_all_receive(self): return True def aggregate(self): + """Aggregate locally trained models. + + Returns: + OrderedDict: Aggregated global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 @@ -89,8 +171,9 @@ def aggregate(self): logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) + (num0, averaged_params) = model_list[0] + for k in averaged_params.keys(): for i in range(0, len(model_list)): local_sample_number, local_model_params = model_list[i] @@ -100,14 +183,13 @@ def aggregate(self): else: averaged_params[k] += local_model_params[k] * w - # server optimizer - # save optimizer state + # Server optimizer self.opt.zero_grad() opt_state = self.opt.state_dict() - # set new aggregated grad + self.set_model_global_grads(averaged_params) self.opt = self._instantiate_opt() - # load optimizer state + self.opt.load_state_dict(opt_state) self.opt.step() @@ -116,30 +198,53 @@ def aggregate(self): return self.get_global_model_params() def set_model_global_grads(self, new_state): + """Set global model gradients. + + Args: + new_state: New global model parameters. + """ new_model = copy.deepcopy(self.aggregator.model) new_model.load_state_dict(new_state) with torch.no_grad(): for parameter, new_parameter in zip(self.aggregator.model.parameters(), new_model.parameters()): parameter.grad = parameter.data - new_parameter.data - # because we go to the opposite direction of the gradient + model_state_dict = self.aggregator.model.state_dict() new_model_state_dict = new_model.state_dict() for k in dict(self.aggregator.model.named_parameters()).keys(): new_model_state_dict[k] = model_state_dict[k] - # self.trainer.model.load_state_dict(new_model_state_dict) + self.set_global_model_params(new_model_state_dict) def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """Sample clients for communication. + + Args: + round_idx: The current communication round. + client_num_in_total: Total number of clients. + client_num_per_round: Number of clients to sample per round. + + Returns: + list: List of sampled client indexes. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """Generate a validation dataset. + + Args: + num_samples: Number of samples in the validation dataset. + + Returns: + DataLoader: DataLoader for the validation dataset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -150,6 +255,11 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """Test on the server for all clients. + + Args: + round_idx: The current communication round. + """ if self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -169,7 +279,7 @@ def test_on_server_for_all_clients(self, round_idx): train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): - # train data + # Train data metrics = self.aggregator.test( self.train_data_local_dict[client_idx], self.device, self.args ) @@ -182,7 +292,7 @@ def test_on_server_for_all_clients(self, round_idx): train_num_samples.append(copy.deepcopy(train_num_sample)) train_losses.append(copy.deepcopy(train_loss)) - # test on training dataset + # Test on training dataset train_acc = sum(train_tot_corrects) / sum(train_num_samples) train_loss = sum(train_losses) / sum(train_num_samples) if self.args.enable_wandb: @@ -191,7 +301,7 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - # test data + # Test data test_num_samples = [] test_tot_corrects = [] test_losses = [] @@ -210,7 +320,7 @@ def test_on_server_for_all_clients(self, round_idx): test_num_samples.append(copy.deepcopy(test_num_sample)) test_losses.append(copy.deepcopy(test_loss)) - # test on test dataset + # Test on test dataset test_acc = sum(test_tot_corrects) / sum(test_num_samples) test_loss = sum(test_losses) / sum(test_num_samples) if self.args.enable_wandb: diff --git a/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py b/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py index 63222972ea..bc48cd9040 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptClientManager.py @@ -7,6 +7,30 @@ class FedOptClientManager(FedMLCommManager): + """Manages client-side operations for federated optimization. + + This class is responsible for managing client-side operations during federated optimization. + It handles communication with the server, updates model parameters, and performs training rounds. + + Attributes: + args: A configuration object containing client parameters. + trainer: An instance of the federated optimizer trainer. + comm: The communication backend. + rank: The rank of the client in the communication group. + size: The total number of processes in the communication group. + backend: The communication backend (e.g., "MPI"). + + Methods: + run(): Runs the client manager to participate in federated optimization. + register_message_receive_handlers(): Registers message handlers for receiving updates from the server. + handle_message_init(msg_params): Handles initialization messages from the server. + start_training(): Starts the federated training process. + handle_message_receive_model_from_server(msg_params): Handles received model updates from the server. + send_model_to_server(receive_id, weights, local_sample_num): Sends updated model to the server. + __train(): Performs the training process. + + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -14,9 +38,11 @@ def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): self.args.round_idx = 0 def run(self): + """Runs the client manager to participate in federated optimization.""" super().run() def register_message_receive_handlers(self): + """Registers message handlers for receiving updates from the server.""" self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +52,7 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """Handles initialization messages from the server.""" global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -35,10 +62,12 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """Starts the federated training process.""" self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """Handles received model updates from the server.""" logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -52,6 +81,7 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """Sends updated model to the server.""" message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -62,6 +92,7 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """Performs the training process.""" logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py b/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py index febbb4ac39..a1fa6e85d6 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptServerManager.py @@ -5,8 +5,32 @@ from ....core.distributed.fedml_comm_manager import FedMLCommManager from ....core.distributed.communication.message import Message - class FedOptServerManager(FedMLCommManager): + """Manages the server-side operations for federated optimization. + + This class is responsible for managing the server-side operations during federated optimization. + It handles communication with clients, aggregation of model updates, and coordination of training rounds. + + Attributes: + args: A configuration object containing server parameters. + aggregator: An aggregator for collecting and aggregating model updates from clients. + comm: The communication backend. + rank: The rank of the server in the communication group. + size: The total number of processes in the communication group. + backend: The communication backend (e.g., "MPI"). + is_preprocessed: A boolean flag indicating whether data preprocessing has been applied. + preprocessed_client_lists: A list of preprocessed client data (optional). + + Methods: + run(): Runs the server manager to coordinate federated optimization. + send_init_msg(): Sends initialization messages to clients at the start of each round. + register_message_receive_handlers(): Registers message handlers for receiving updates from clients. + handle_message_receive_model_from_client(msg_params): Handles received model updates from clients. + send_message_init_config(receive_id, global_model_params, client_index): Sends initialization messages to clients. + send_message_sync_model_to_client(receive_id, global_model_params, client_index): Sends updated models to clients. + + """ + def __init__( self, args, @@ -27,10 +51,12 @@ def __init__( self.preprocessed_client_lists = preprocessed_client_lists def run(self): + """Runs the server manager to coordinate federated optimization.""" super().run() def send_init_msg(self): - # sampling clients + """Sends initialization messages to clients at the start of each round.""" + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -43,12 +69,14 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """Registers message handlers for receiving updates from clients.""" self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """Handles received model updates from clients.""" sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -62,36 +90,35 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: post_complete_message_to_sweep_process(self.args) self.finish() return - # sampling clients + # Sampling clients if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, self.args.client_num_per_round, ) - print("size = %d" % self.size) - for receiver_id in range(1, self.size): self.send_message_sync_model_to_client( receiver_id, global_model_params, client_indexes[receiver_id - 1] ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """Sends initialization messages to clients.""" message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -102,6 +129,7 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """Sends updated models to clients.""" logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py b/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py index 00661f35b0..8f99915857 100644 --- a/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py +++ b/python/fedml/simulation/mpi/fedopt/FedOptTrainer.py @@ -2,6 +2,28 @@ class FedOptTrainer(object): + """Trains a federated optimizer on a client's local data. + + This class is responsible for training a federated optimizer on a client's + local data. It updates the model using the federated optimization technique + and returns the updated model weights. + + Attributes: + trainer: The model trainer used for local training. + client_index: The index of the client. + train_data_local_dict: A dictionary containing local training data. + train_data_local_num_dict: A dictionary containing the number of samples for each client. + all_train_data_num: The total number of training samples across all clients. + device: The device (e.g., CPU or GPU) for training. + args: A configuration object containing training parameters. + + Methods: + update_model(weights): Updates the model with the provided weights. + update_dataset(client_index): Updates the dataset for the given client. + train(round_idx=None): Trains the federated optimizer on the local data. + + """ + def __init__( self, client_index, @@ -17,22 +39,44 @@ def __init__( self.client_index = client_index self.train_data_local_dict = train_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict + self.all_train_data_num = train_data_num - # self.train_local = self.train_data_local_dict[client_index] - # self.local_sample_number = self.train_data_local_num_dict[client_index] + self.device = device self.args = args def update_model(self, weights): + """Update the model with the provided weights. + + Args: + weights: The updated model weights. + + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """Update the dataset for the given client. + + Args: + client_index: The index of the client. + + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] def train(self, round_idx=None): + """Train the federated optimizer on the local data. + + Args: + round_idx: The index of the training round (optional). + + Returns: + weights: The updated model weights. + local_sample_number: The number of local training samples. + + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) diff --git a/python/fedml/simulation/mpi/fedopt/optrepo.py b/python/fedml/simulation/mpi/fedopt/optrepo.py index 50615227d7..14a7e1e135 100644 --- a/python/fedml/simulation/mpi/fedopt/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt/optrepo.py @@ -1,17 +1,29 @@ import logging from typing import List, Union - import torch - class OptRepo: - """Collects and provides information about the subclasses of torch.optim.Optimizer.""" + """Collects and provides information about the subclasses of torch.optim.Optimizer. + + This class allows you to access and retrieve information about different PyTorch + optimizer classes. + + Attributes: + repo (dict): A dictionary containing optimizer class names as keys and the + corresponding optimizer classes as values. + + Methods: + get_opt_names(): Returns a list of supported optimizer names. + name2cls(name: str): Returns the optimizer class based on its name. + supported_parameters(opt: Union[str, torch.optim.Optimizer]): Returns a list of + __init__ function parameters of an optimizer. + """ repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} @classmethod def get_opt_names(cls) -> List[str]: - """Returns a list of supported optimizers. + """Returns a list of supported optimizer names. Returns: List[str]: Names of optimizers. @@ -29,6 +41,9 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: Returns: torch.optim.Optimizer: The class corresponding to the name. + + Raises: + KeyError: If the provided optimizer name is invalid. """ try: return cls.repo[name.lower()] @@ -39,7 +54,7 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: @classmethod def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[str]: - """Returns a lost of __init__ function parametrs of an optimizer. + """Returns a list of __init__ function parameters of an optimizer. Args: opt (Union[str, torch.optim.Optimizer]): The name or class of the optimizer. diff --git a/python/fedml/simulation/mpi/fedopt/utils.py b/python/fedml/simulation/mpi/fedopt/utils.py index aea2449590..5bcbb1954a 100644 --- a/python/fedml/simulation/mpi/fedopt/utils.py +++ b/python/fedml/simulation/mpi/fedopt/utils.py @@ -5,6 +5,15 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from a list of NumPy arrays to PyTorch tensors. + + Args: + model_params_list (dict): Dictionary of model parameters represented as NumPy arrays. + + Returns: + dict: Dictionary of model parameters with tensors as values. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -13,12 +22,30 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from PyTorch tensors to lists. + + Args: + model_params (dict): Dictionary of model parameters represented as PyTorch tensors. + + Returns: + dict: Dictionary of model parameters with lists as values. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a named pipe for communication with another process. + + Args: + args: Additional information or configuration to include in the message. + + Returns: + None + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): @@ -26,4 +53,4 @@ def post_complete_message_to_sweep_process(args): pipe_fd = os.open(pipe_path, os.O_WRONLY) with os.fdopen(pipe_fd, "w") as pipe: - pipe.write("training is finished! \n%s\n" % (str(args))) + pipe.write("Training is finished! \n%s\n" % (str(args))) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py index dc91017c3b..70920e7688 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptAggregator.py @@ -15,6 +15,42 @@ class FedOptAggregator(object): + """Aggregator for Federated Optimization. + + This class manages the aggregation of model updates from client devices in a federated optimization setting. + + Args: + train_global: The global training dataset. + test_global: The global testing dataset. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + args: An argparse.Namespace object containing various configuration options. + server_aggregator: An optional ServerAggregator object used for model aggregation. + + Attributes: + aggregator: The server aggregator for model aggregation. + args: An argparse.Namespace object containing various configuration options. + train_global: The global training dataset. + test_global: The global testing dataset. + val_global: A subset of the testing dataset used for validation. + all_train_data_num: The total number of training samples across all clients. + train_data_local_dict: A dictionary mapping client indices to their local training datasets. + test_data_local_dict: A dictionary mapping client indices to their local testing datasets. + train_data_local_num_dict: A dictionary mapping client indices to the number of samples in their local training datasets. + worker_num: The number of worker (client) devices. + device: The device (CPU or GPU) to use for model aggregation. + model_dict: A dictionary mapping client indices to their local model updates. + sample_num_dict: A dictionary mapping client indices to the number of samples used for their local updates. + flag_client_model_uploaded_dict: A dictionary tracking whether each client has uploaded its local model update. + opt: The server optimizer used for model aggregation. + runtime_history: A dictionary to track the runtime history of clients. + runtime_avg: A dictionary to track the average runtime of clients. + """ + def __init__( self, train_global, @@ -28,6 +64,11 @@ def __init__( args, server_aggregator, ): + """Instantiate the server optimizer based on configuration options. + + Returns: + torch.optim.Optimizer: The server optimizer. + """ self.aggregator = server_aggregator self.args = args @@ -59,6 +100,12 @@ def __init__( def _instantiate_opt(self): + """ + Instantiate the server optimizer based on configuration options. + + Returns: + torch.optim.Optimizer: The server optimizer. + """ return OptRepo.name2cls(self.args.server_optimizer)( filter(lambda p: p.requires_grad, self.get_model_params()), lr=self.args.server_lr, @@ -66,23 +113,55 @@ def _instantiate_opt(self): ) def get_model_params(self): + """ + Get the model parameters in the form of a generator. + + Returns: + generator: A generator of model parameters. + """ # return model parameters in type of generator return self.aggregator.model.parameters() def get_global_model_params(self): + """ + Get the global model parameters as an ordered dictionary. + + Returns: + collections.OrderedDict: The global model parameters. + """ + # return model parameters in type of ordered_dict return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters based on a provided dictionary. + + Args: + model_parameters (dict): A dictionary containing global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params): + """ + Add the local trained model update for a client. + + Args: + index (int): The index of the client. + model_params (dict): The model parameters of the local trained model. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params # self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all clients have uploaded their local model updates. + + Returns: + bool: True if all clients have uploaded their updates, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -93,6 +172,16 @@ def check_whether_all_receive(self): def workload_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the workload for selected clients. + + Args: + client_indexes (list): The indices of selected clients. + mode (str): The mode for workload estimation ("simulate" or "real"). + + Returns: + list: Workload estimates for the selected clients. + """ if mode == "simulate": client_samples = [ self.train_data_local_num_dict[client_index] @@ -106,6 +195,16 @@ def workload_estimate(self, client_indexes, mode="simulate"): return workload def memory_estimate(self, client_indexes, mode="simulate"): + """ + Estimate the memory requirements for selected clients. + + Args: + client_indexes (list): The indices of selected clients. + mode (str): The mode for memory estimation ("simulate" or "real"). + + Returns: + numpy.ndarray: Memory estimates for the selected clients. + """ if mode == "simulate": memory = np.ones(self.worker_num) elif mode == "real": @@ -115,6 +214,15 @@ def memory_estimate(self, client_indexes, mode="simulate"): return memory def resource_estimate(self, mode="simulate"): + """ + Estimate the resource requirements for clients. + + Args: + mode (str): The mode for resource estimation ("simulate" or "real"). + + Returns: + numpy.ndarray: Resource estimates for clients. + """ if mode == "simulate": resource = np.ones(self.worker_num) elif mode == "real": @@ -124,6 +232,13 @@ def resource_estimate(self, mode="simulate"): return resource def record_client_runtime(self, worker_id, client_runtimes): + """ + Record the runtime of clients. + + Args: + worker_id (int): The ID of the worker (client). + client_runtimes (dict): A dictionary mapping client IDs to their runtimes. + """ for client_id, runtime in client_runtimes.items(): self.runtime_history[worker_id][client_id].append(runtime) if hasattr(self.args, "runtime_est_mode"): @@ -140,6 +255,15 @@ def record_client_runtime(self, worker_id, client_runtimes): def generate_client_schedule(self, round_idx, client_indexes): + """Generate a schedule of clients for training. + + Args: + round_idx (int): The current communication round index. + client_indexes (list): The indices of selected clients. + + Returns: + list: A schedule of clients for training. + """ # self.runtime_history = {} # for i in range(self.worker_num): # self.runtime_history[i] = {} @@ -195,6 +319,14 @@ def generate_client_schedule(self, round_idx, client_indexes): def get_average_weight(self, client_indexes): + """Calculate the average weight for selected clients. + + Args: + client_indexes (list): The indices of selected clients. + + Returns: + dict: A dictionary mapping client indices to their average weights. + """ average_weight_dict = {} training_num = 0 for client_index in client_indexes: @@ -208,6 +340,12 @@ def get_average_weight(self, client_indexes): def aggregate(self): + """ + Aggregate the model updates from clients. + + Returns: + collections.OrderedDict: The aggregated global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 @@ -246,6 +384,12 @@ def aggregate(self): return self.get_global_model_params() def set_model_global_grads(self, new_state): + """ + Set the global model gradients based on a provided dictionary. + + Args: + new_state (dict): A dictionary containing the new global model gradients. + """ new_model = copy.deepcopy(self.aggregator.model) new_model.load_state_dict(new_state) with torch.no_grad(): @@ -260,6 +404,16 @@ def set_model_global_grads(self, new_state): self.set_global_model_params(new_model_state_dict) def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """Randomly sample a subset of clients for a communication round. + + Args: + round_idx (int): The current communication round index. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to sample for the round. + + Returns: + list: A list of indices representing the selected clients for the round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: @@ -270,6 +424,14 @@ def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): return client_indexes def _generate_validation_set(self, num_samples=10000): + """Generate a subset of the testing dataset for validation. + + Args: + num_samples (int): The number of samples to include in the validation set. + + Returns: + torch.utils.data.DataLoader: A DataLoader containing the validation subset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -280,6 +442,11 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """Test the global model on all clients. + + Args: + round_idx (int): The current communication round index. + """ if ( round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1 From 75483ba7b4c7454d7b11e3564cfa5f0a37c723de Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 13:14:08 +0530 Subject: [PATCH 51/70] g --- .../simulation/mpi/fedopt_seq/optrepo.py | 7 ++++- .../fedml/simulation/mpi/fedopt_seq/utils.py | 30 +++++++++++++++---- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py index 50615227d7..6942b78b85 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py @@ -3,6 +3,7 @@ import torch +from typing import List, Union class OptRepo: """Collects and provides information about the subclasses of torch.optim.Optimizer.""" @@ -29,6 +30,9 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: Returns: torch.optim.Optimizer: The class corresponding to the name. + + Raises: + KeyError: If the provided optimizer name is invalid. """ try: return cls.repo[name.lower()] @@ -39,7 +43,7 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: @classmethod def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[str]: - """Returns a lost of __init__ function parametrs of an optimizer. + """Returns a list of __init__ function parameters of an optimizer. Args: opt (Union[str, torch.optim.Optimizer]): The name or class of the optimizer. @@ -60,4 +64,5 @@ def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[st @classmethod def _update_repo(cls): + """Updates the optimizer repository with the latest subclasses.""" cls.repo = {x.__name__: x for x in torch.optim.Optimizer.__subclasses__()} diff --git a/python/fedml/simulation/mpi/fedopt_seq/utils.py b/python/fedml/simulation/mpi/fedopt_seq/utils.py index aea2449590..8c2ff95a19 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/utils.py +++ b/python/fedml/simulation/mpi/fedopt_seq/utils.py @@ -1,24 +1,44 @@ +import torch +import numpy as np import os -import numpy as np -import torch +def transform_list_to_tensor(model_params_list): + """ + Convert a dictionary of model parameters from NumPy arrays in a list to PyTorch tensors. + Args: + model_params_list (dict): A dictionary of model parameters, where values are lists of NumPy arrays. -def transform_list_to_tensor(model_params_list): + Returns: + dict: A dictionary of model parameters with values as PyTorch tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) ).float() return model_params_list - def transform_tensor_to_list(model_params): + """ + Convert a dictionary of model parameters from PyTorch tensors to lists of NumPy arrays. + + Args: + model_params (dict): A dictionary of model parameters, where values are PyTorch tensors. + + Returns: + dict: A dictionary of model parameters with values as lists of NumPy arrays. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params - def post_complete_message_to_sweep_process(args): + """ + Send a completion message to a sweep process using a named pipe. + + Args: + args (str): A string containing information about the training completion status or other relevant details. + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From bcb67fed64c53e4b13ed53927300ad04f06157ae Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 13:17:05 +0530 Subject: [PATCH 52/70] gg --- .../mpi/fedopt_seq/FedOptTrainer.py | 33 +++++++++++++++++-- .../simulation/mpi/fedopt_seq/optrepo.py | 17 +++++++--- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py index 00661f35b0..162fbf1ef9 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptTrainer.py @@ -2,6 +2,8 @@ class FedOptTrainer(object): + """Trains a federated learning model for a specific client.""" + def __init__( self, client_index, @@ -12,27 +14,54 @@ def __init__( args, model_trainer, ): + """Initialize the FedOptTrainer. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary mapping client indexes to their local training datasets. + train_data_local_num_dict (dict): A dictionary mapping client indexes to the number of samples in their local datasets. + train_data_num (int): The total number of training samples. + device (str): The device (e.g., 'cuda' or 'cpu') on which to perform training. + args (object): Configuration parameters for training. + model_trainer (object): An instance of the model trainer for this client. + """ self.trainer = model_trainer self.client_index = client_index self.train_data_local_dict = train_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict self.all_train_data_num = train_data_num - # self.train_local = self.train_data_local_dict[client_index] - # self.local_sample_number = self.train_data_local_num_dict[client_index] self.device = device self.args = args def update_model(self, weights): + """Update the model parameters. + + Args: + weights (dict): A dictionary containing the updated model parameters. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """Update the local dataset for the client. + + Args: + client_index (int): The index of the client whose dataset should be updated. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] def train(self, round_idx=None): + """Train the federated learning model for the client. + + Args: + round_idx (int, optional): The current federated learning round index. Defaults to None. + + Returns: + tuple: A tuple containing the updated model weights and the number of local samples used for training. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) diff --git a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py index 6942b78b85..df6ec80985 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py @@ -6,13 +6,16 @@ from typing import List, Union class OptRepo: - """Collects and provides information about the subclasses of torch.optim.Optimizer.""" + """ + Collects and provides information about the subclasses of torch.optim.Optimizer. + """ repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} @classmethod def get_opt_names(cls) -> List[str]: - """Returns a list of supported optimizers. + """ + Returns a list of supported optimizers. Returns: List[str]: Names of optimizers. @@ -23,7 +26,8 @@ def get_opt_names(cls) -> List[str]: @classmethod def name2cls(cls, name: str) -> torch.optim.Optimizer: - """Returns the optimizer class belonging to the name. + """ + Returns the optimizer class belonging to the name. Args: name (str): Name of the optimizer. @@ -43,7 +47,8 @@ def name2cls(cls, name: str) -> torch.optim.Optimizer: @classmethod def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[str]: - """Returns a list of __init__ function parameters of an optimizer. + """ + Returns a list of __init__ function parameters of an optimizer. Args: opt (Union[str, torch.optim.Optimizer]): The name or class of the optimizer. @@ -64,5 +69,7 @@ def supported_parameters(cls, opt: Union[str, torch.optim.Optimizer]) -> List[st @classmethod def _update_repo(cls): - """Updates the optimizer repository with the latest subclasses.""" + """ + Updates the optimizer repository with the latest subclasses. + """ cls.repo = {x.__name__: x for x in torch.optim.Optimizer.__subclasses__()} From 988b48fbe711fd41762592d656371c3ec3917567 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 17:30:30 +0530 Subject: [PATCH 53/70] uodtae --- .../mpi/fedopt_seq/FedOptClientManager.py | 89 ++++++++++++++++- .../simulation/mpi/fedopt_seq/FedOptSeqAPI.py | 63 ++++++++++++ .../mpi/fedopt_seq/FedOptServerManager.py | 77 ++++++++++++++- .../simulation/mpi/fedopt_seq/optrepo.py | 3 +- .../simulation/mpi/fedprox/FedProxAPI.py | 54 ++++++++++ .../mpi/fedprox/FedProxAggregator.py | 99 ++++++++++++++++--- .../mpi/fedprox/FedProxClientManager.py | 48 +++++++++ .../mpi/fedprox/FedProxServerManager.py | 62 +++++++++++- .../simulation/mpi/fedprox/FedProxTrainer.py | 61 +++++++++++- .../simulation/mpi/fedprox/message_define.py | 11 +-- python/fedml/simulation/mpi/fedprox/utils.py | 24 +++++ 11 files changed, 554 insertions(+), 37 deletions(-) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py index 3ec4cdf370..531352fdf0 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptClientManager.py @@ -8,6 +8,24 @@ class FedOptClientManager(FedMLCommManager): + """ + Manager for Federated Optimization Clients. + + Args: + args (object): Arguments for configuration. + trainer (object): Trainer for client-side training. + comm (object, optional): Communication module (default: None). + rank (int, optional): Client's rank (default: 0). + size (int, optional): Number of clients (default: 0). + backend (str, optional): Backend for communication (default: "MPI"). + + Attributes: + trainer (object): Trainer for client-side training. + num_rounds (int): Number of communication rounds. + round_idx (int): Current communication round index. + worker_id (int): Worker's unique identifier within the communication group. + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -19,6 +37,9 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register handlers for receiving messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -28,6 +49,16 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle initialization message from the server. + + Args: + msg_params (dict): Message parameters. + + Notes: + This method handles the initialization message from the server, including + model parameters, average weights, and client schedule. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -39,10 +70,25 @@ def handle_message_init(self, msg_params): self.__train(global_model_params, client_indexes, average_weight_dict) def start_training(self): + """ + Start the training process for a new round. + """ self.round_idx = 0 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Message parameters. + + Notes: + This method handles the received model from the server, including model + parameters, average weights, and client schedule. It triggers the training + process and completes communication rounds. + """ logging.info("handle_message_receive_model_from_server.") + global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -57,28 +103,61 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, client_runtime_info): + """ + Send the client's model to the server. + + Args: + receive_id (int): Receiver's ID. + weights (dict): Model parameters. + client_runtime_info (dict): Information about client runtime. + + Notes: + This method constructs and sends a message containing the client's model + and runtime information to the server. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id, ) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - # message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_RUNTIME_INFO, client_runtime_info + ) self.send_message(message) def add_client_model(self, local_agg_model_params, model_params, weight=1.0): - # Add params that needed to be reduces from clients + """ + Add a client's model to the local aggregated model. + + Args: + local_agg_model_params (dict): Local aggregated model parameters. + model_params (dict): Client's model parameters. + weight (float, optional): Weight for the client's model (default: 1.0). + + Notes: + This method adds client model parameters to the local aggregated model. + """ for name, param in model_params.items(): if name not in local_agg_model_params: local_agg_model_params[name] = param * weight else: local_agg_model_params[name] += param * weight - - def __train(self, global_model_params, client_indexes, average_weight_dict): + """ + Train the client's model. + + Args: + global_model_params (dict): Global model parameters. + client_indexes (list): List of client indexes. + average_weight_dict (dict): Dictionary of average weights for clients. + + Notes: + This method simulates client-side training, updating the local aggregated + model with the client's contributions. + """ logging.info("#######training########### round_id = %d" % self.round_idx) local_agg_model_params = {} diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py index f771a78598..6383688af2 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptSeqAPI.py @@ -10,6 +10,12 @@ def FedML_init(): + """ + Initialize the Federated Learning environment. + + Returns: + tuple: A tuple containing the MPI communicator, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -27,6 +33,23 @@ def FedML_FedOptSeq_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Run the Federated Optimization (FedOpt) distributed training. + + Args: + args (object): Arguments for configuration. + process_id (int): Process ID or rank. + worker_number (int): Total number of workers. + comm (object): MPI communicator. + device (object): Device for computation. + dataset (list): List of dataset elements. + model (object): Model for training. + client_trainer (ClientTrainer, optional): Client trainer (default: None). + server_aggregator (ServerAggregator, optional): Server aggregator (default: None). + + Notes: + This function orchestrates the FedOpt distributed training process. + """ [ train_data_num, test_data_num, @@ -84,6 +107,27 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for FedOpt distributed training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): MPI communicator. + rank (int): Server's rank. + size (int): Total number of workers. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_global (object): Global training data. + test_data_global (object): Global test data. + train_data_local_dict (dict): Local training data per client. + test_data_local_dict (dict): Local test data per client. + train_data_local_num_dict (dict): Number of local training data per client. + server_aggregator (ServerAggregator, optional): Server aggregator (default: None). + + Notes: + This function initializes the server and starts distributed training. + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -121,6 +165,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for FedOpt distributed training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): MPI communicator. + process_id (int): Client's process ID. + size (int): Total number of workers. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_local_num_dict (dict): Number of local training data per client. + train_data_local_dict (dict): Local training data per client. + test_data_local_dict (dict): Local test data per client. + model_trainer (object, optional): Model trainer (default: None). + + Notes: + This function initializes a client and runs the training process. + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py b/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py index 207fcf37ed..a089018a96 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py +++ b/python/fedml/simulation/mpi/fedopt_seq/FedOptServerManager.py @@ -8,6 +8,28 @@ class FedOptServerManager(FedMLCommManager): + """ + Manager for the Federated Optimization (FedOpt) Server. + + Args: + args (object): Arguments for configuration. + aggregator (object): Aggregator for Federated Optimization. + comm (object, optional): Communication module (default: None). + rank (int, optional): Server's rank (default: 0). + size (int, optional): Total number of workers (default: 0). + backend (str, optional): Backend for communication (default: "MPI"). + is_preprocessed (bool, optional): Flag indicating preprocessed data (default: False). + preprocessed_client_lists (list, optional): Preprocessed client lists (default: None). + + Attributes: + args (object): Arguments for configuration. + aggregator (object): Aggregator for Federated Optimization. + round_num (int): Number of communication rounds. + round_idx (int): Current communication round index. + is_preprocessed (bool): Flag indicating preprocessed data. + preprocessed_client_lists (list): Preprocessed client lists. + """ + def __init__( self, args, @@ -31,6 +53,13 @@ def run(self): super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + Notes: + This method initializes and sends configuration messages to clients for the + start of a new communication round. + """ # sampling clients client_indexes = self.aggregator.client_sampling( self.round_idx, @@ -49,12 +78,30 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register handlers for receiving messages. + + Notes: + This method registers message handlers for the server to process incoming + messages from clients. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the received model from a client. + + Args: + msg_params (dict): Message parameters. + + Notes: + This method handles the received model from a client, records client + runtime information, adds local trained results, and checks whether all + clients have sent their updates to proceed to the next round. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) # local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -106,6 +153,19 @@ def handle_message_receive_model_from_client(self, msg_params): def send_message_init_config(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send initialization configuration message to a client. + + Args: + receive_id (int): Receiver's ID. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Dictionary of average weights for clients. + client_schedule (list): Schedule of clients for the round. + + Notes: + This method constructs and sends an initialization configuration message to + a client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -117,6 +177,18 @@ def send_message_init_config(self, receive_id, global_model_params, def send_message_sync_model_to_client(self, receive_id, global_model_params, average_weight_dict, client_schedule): + """ + Send model synchronization message to a client. + + Args: + receive_id (int): Receiver's ID. + global_model_params (dict): Global model parameters. + average_weight_dict (dict): Dictionary of average weights for clients. + client_schedule (list): Schedule of clients for the round. + + Notes: + This method constructs and sends a model synchronization message to a client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, @@ -127,7 +199,4 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, # message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_AVG_WEIGHTS, average_weight_dict) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_SCHEDULE, client_schedule) - self.send_message(message) - - - + self.send diff --git a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py index df6ec80985..a6c07959b3 100644 --- a/python/fedml/simulation/mpi/fedopt_seq/optrepo.py +++ b/python/fedml/simulation/mpi/fedopt_seq/optrepo.py @@ -1,3 +1,4 @@ +import torch import logging from typing import List, Union @@ -72,4 +73,4 @@ def _update_repo(cls): """ Updates the optimizer repository with the latest subclasses. """ - cls.repo = {x.__name__: x for x in torch.optim.Optimizer.__subclasses__()} + cls.repo = {x.__name__.lower(): x for x in torch.optim.Optimizer.__subclasses__()} diff --git a/python/fedml/simulation/mpi/fedprox/FedProxAPI.py b/python/fedml/simulation/mpi/fedprox/FedProxAPI.py index 4ab1af38da..9be6f03cfe 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxAPI.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxAPI.py @@ -10,6 +10,12 @@ def FedML_init(): + """ + Initialize the Federated Machine Learning environment. + + Returns: + tuple: A tuple containing the MPI communication object, process ID, and worker number. + """ comm = MPI.COMM_WORLD process_id = comm.Get_rank() worker_number = comm.Get_size() @@ -27,6 +33,20 @@ def FedML_FedProx_distributed( client_trainer: ClientTrainer = None, server_aggregator: ServerAggregator = None, ): + """ + Run the Federated Proximal training process. + + Args: + args (object): Arguments for configuration. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + comm (object): Communication object. + device (object): Device for computation. + dataset (list): List containing dataset information. + model (object): Model for training. + client_trainer (object): Trainer for client-side training (default: None). + server_aggregator (object): Server aggregator for aggregation (default: None). + """ [ train_data_num, test_data_num, @@ -84,6 +104,24 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): Communication object. + rank (int): Rank of the server. + size (int): Total number of participants. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_global (object): Global training data. + test_data_global (object): Global testing data. + train_data_local_dict (dict): Dictionary of local training data. + test_data_local_dict (dict): Dictionary of local testing data. + train_data_local_num_dict (dict): Dictionary of local training data sizes. + server_aggregator (object): Server aggregator for aggregation. + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(-1) @@ -123,6 +161,22 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + device (object): Device for computation. + comm (object): Communication object. + process_id (int): Process ID of the client. + size (int): Total number of participants. + model (object): Model for training. + train_data_num (int): Number of training data samples. + train_data_local_num_dict (dict): Dictionary of local training data sizes. + train_data_local_dict (dict): Dictionary of local training data. + test_data_local_dict (dict): Dictionary of local testing data. + model_trainer (object): Trainer for the model (default: None). + """ client_index = process_id - 1 if model_trainer is None: model_trainer = create_model_trainer(model, args) diff --git a/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py b/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py index 026729e264..e5da91e8d4 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxAggregator.py @@ -10,6 +10,10 @@ class FedProxAggregator(object): + """ + Aggregator for Federated Proximal training. + """ + def __init__( self, train_global, @@ -23,39 +27,78 @@ def __init__( args, server_aggregator, ): + """ + Initialize the FedProxAggregator. + + Args: + train_global (object): Global training data. + test_global (object): Global testing data. + all_train_data_num (int): Number of training data samples. + train_data_local_dict (dict): Dictionary of local training data. + test_data_local_dict (dict): Dictionary of local testing data. + train_data_local_num_dict (dict): Dictionary of local training data sizes. + worker_num (int): Number of workers. + device (object): Device for computation. + args (object): Arguments for configuration. + server_aggregator (object): Server aggregator for aggregation. + """ self.aggregator = server_aggregator - self.args = args self.train_global = train_global self.test_global = test_global self.val_global = self._generate_validation_set() self.all_train_data_num = all_train_data_num - self.train_data_local_dict = train_data_local_dict self.test_data_local_dict = test_data_local_dict self.train_data_local_num_dict = train_data_local_num_dict - self.worker_num = worker_num self.device = device self.model_dict = dict() self.sample_num_dict = dict() self.flag_client_model_uploaded_dict = dict() + for idx in range(self.worker_num): self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add local trained model results to the aggregator. + + Args: + index (int): Index of the client. + model_params (dict): Local model parameters. + sample_num (int): Number of local samples. + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Returns: + bool: True if all models have been received, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -64,6 +107,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models from clients and calculate the global model. + + Returns: + dict: Averaged global model parameters. + """ start_time = time.time() model_list = [] training_num = 0 @@ -74,7 +123,6 @@ def aggregate(self): logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) - # logging.info("################aggregate: %d" % len(model_list)) (num0, averaged_params) = model_list[0] for k in averaged_params.keys(): for i in range(0, len(model_list)): @@ -85,7 +133,7 @@ def aggregate(self): else: averaged_params[k] += local_model_params[k] * w - # update the global model which is cached at the server side + # Update the global model which is cached at the server side self.set_global_model_params(averaged_params) end_time = time.time() @@ -93,16 +141,36 @@ def aggregate(self): return averaged_params def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select clients for participation in each round of training. + + Args: + round_idx (int): Current training round index. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select per round. + + Returns: + list: List of client indexes selected for the current training round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Ensure consistent client selection for each round client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for testing. + + Args: + num_samples (int): Number of samples in the validation set. + + Returns: + object: Validation dataset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -113,6 +181,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients. + + Args: + round_idx (int): Current training round index. + """ if self.aggregator.test_all(self.train_data_local_dict, self.test_data_local_dict, self.device, self.args,): return @@ -122,7 +196,7 @@ def test_on_server_for_all_clients(self, round_idx): train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): - # train data + # Train data metrics = self.aggregator.test(self.train_data_local_dict[client_idx], self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( metrics["test_correct"], @@ -135,12 +209,13 @@ def test_on_server_for_all_clients(self, round_idx): """ Note: CI environment is CPU-based computing. - The training speed for RNN training is to slow in this setting, so we only test a client to make sure there is no programming error. + The training speed for RNN training is too slow in this setting, + so we only test a client to make sure there is no programming error. """ if self.args.ci == 1: break - # test on training dataset + # Test on training dataset train_acc = sum(train_tot_corrects) / sum(train_num_samples) train_loss = sum(train_losses) / sum(train_num_samples) # wandb.log({"Train/Acc": train_acc, "round": round_idx}) @@ -148,7 +223,7 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - # test data + # Test data test_num_samples = [] test_tot_corrects = [] test_losses = [] @@ -167,10 +242,10 @@ def test_on_server_for_all_clients(self, round_idx): test_num_samples.append(copy.deepcopy(test_num_sample)) test_losses.append(copy.deepcopy(test_loss)) - # test on test dataset + # Test on test dataset test_acc = sum(test_tot_corrects) / sum(test_num_samples) test_loss = sum(test_losses) / sum(test_num_samples) # wandb.log({"Test/Acc": test_acc, "round": round_idx}) # wandb.log({"Test/Loss": test_loss, "round": round_idx}) stats = {"test_acc": test_acc, "test_loss": test_loss} - logging.info(stats) + logging.info(stats) \ No newline at end of file diff --git a/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py b/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py index 860fe336b0..cc13142319 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxClientManager.py @@ -7,6 +7,18 @@ class FedProxClientManager(FedMLCommManager): + """ + Client manager for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + trainer (object): Trainer for the client. + comm (object): Communication object. + rank (int): Rank of the client. + size (int): Total number of participants. + backend (str): Backend for communication (default: "MPI"). + """ + def __init__(self, args, trainer, comm=None, rank=0, size=0, backend="MPI"): super().__init__(args, comm, rank, size, backend) self.trainer = trainer @@ -17,6 +29,16 @@ def run(self): super().run() def register_message_receive_handlers(self): + """ + Register message receive handlers for the client manager. + + This method registers message handlers for receiving initialization + and model synchronization messages. + + Message Types: + - MyMessage.MSG_TYPE_S2C_INIT_CONFIG + - MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init ) @@ -26,6 +48,12 @@ def register_message_receive_handlers(self): ) def handle_message_init(self, msg_params): + """ + Handle initialization message from the server. + + Args: + msg_params (dict): Message parameters. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -37,10 +65,19 @@ def handle_message_init(self, msg_params): self.__train() def start_training(self): + """ + Start the training process. + """ self.args.round_idx = 0 self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle model synchronization message from the server. + + Args: + msg_params (dict): Message parameters. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -56,6 +93,14 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the trained model to the server. + + Args: + receive_id (int): Receiver ID (typically the server). + weights (object): Model weights. + local_sample_num (int): Number of local training samples. + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), @@ -66,6 +111,9 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): self.send_message(message) def __train(self): + """ + Execute the training process. + """ logging.info("#######training########### round_id = %d" % self.args.round_idx) weights, local_sample_num = self.trainer.train(self.args.round_idx) self.send_model_to_server(0, weights, local_sample_num) diff --git a/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py b/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py index ccf9f087cf..3e20185317 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxServerManager.py @@ -7,6 +7,20 @@ class FedProxServerManager(FedMLCommManager): + """ + Server manager for Federated Proximal training. + + Args: + args (object): Arguments for configuration. + aggregator (object): Aggregator for model updates. + comm (object): Communication object. + rank (int): Rank of the server. + size (int): Total number of participants. + backend (str): Backend for communication (default: "MPI"). + is_preprocessed (bool): Flag indicating if data is preprocessed (default: False). + preprocessed_client_lists (list): Preprocessed client lists (default: None). + """ + def __init__( self, args, @@ -30,7 +44,13 @@ def run(self): super().run() def send_init_msg(self): - # sampling clients + """ + Send initialization messages to clients. + + Initializes the communication with clients by sending initial model parameters + and client indexes. + """ + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -44,12 +64,30 @@ def send_init_msg(self): ) def register_message_receive_handlers(self): + """ + Register message receive handlers for the server manager. + + This method registers the message receive handler for receiving model updates from clients. + + Message Types: + - MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER + + Message Handler: + - self.handle_message_receive_model_from_client + + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.handle_message_receive_model_from_client, ) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model updates from clients. + + Args: + msg_params (dict): Message parameters. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) @@ -63,7 +101,7 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_params = self.aggregator.aggregate() self.aggregator.test_on_server_for_all_clients(self.args.round_idx) - # start the next round + # Start the next round self.args.round_idx += 1 if self.args.round_idx == self.round_num: post_complete_message_to_sweep_process(self.args) @@ -72,12 +110,12 @@ def handle_message_receive_model_from_client(self, msg_params): return if self.is_preprocessed: if self.preprocessed_client_lists is None: - # sampling has already been done in data preprocessor + # Sampling has already been done in data preprocessor client_indexes = [self.args.round_idx] * self.args.client_num_per_round else: client_indexes = self.preprocessed_client_lists[self.args.round_idx] else: - # sampling clients + # Sampling clients client_indexes = self.aggregator.client_sampling( self.args.round_idx, self.args.client_num_in_total, @@ -93,6 +131,14 @@ def handle_message_receive_model_from_client(self, msg_params): ) def send_message_init_config(self, receive_id, global_model_params, client_index): + """ + Send initialization configuration message to a client. + + Args: + receive_id (int): Receiver ID. + global_model_params (object): Global model parameters. + client_index (int): Index of the client. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) @@ -103,6 +149,14 @@ def send_message_init_config(self, receive_id, global_model_params, client_index def send_message_sync_model_to_client( self, receive_id, global_model_params, client_index ): + """ + Send model synchronization message to a client. + + Args: + receive_id (int): Receiver ID. + global_model_params (object): Global model parameters. + client_index (int): Index of the client. + """ logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, diff --git a/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py b/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py index e77096b452..18ababf006 100644 --- a/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py +++ b/python/fedml/simulation/mpi/fedprox/FedProxTrainer.py @@ -2,6 +2,33 @@ class FedProxTrainer(object): + """ + Federated Proximal Trainer for model training. + + Args: + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training data. + train_data_local_num_dict (dict): Dictionary of local training data counts. + test_data_local_dict (dict): Dictionary of local testing data. + train_data_num (int): Total number of training data samples. + device (object): Device for training (e.g., CPU or GPU). + args (object): Arguments for configuration. + model_trainer (object): Model trainer for training. + + Attributes: + trainer (object): Model trainer for training. + client_index (int): Index of the client. + train_data_local_dict (dict): Dictionary of local training data. + train_data_local_num_dict (dict): Dictionary of local training data counts. + test_data_local_dict (dict): Dictionary of local testing data. + all_train_data_num (int): Total number of training data samples. + train_local (object): Local training data for the client. + local_sample_number (int): Number of local training data samples. + test_local (object): Local testing data for the client. + device (object): Device for training. + args (object): Arguments for configuration. + """ + def __init__( self, client_index, @@ -20,9 +47,6 @@ def __init__( self.train_data_local_num_dict = train_data_local_num_dict self.test_data_local_dict = test_data_local_dict self.all_train_data_num = train_data_num - # self.train_local = self.train_data_local_dict[client_index] - # self.local_sample_number = self.train_data_local_num_dict[client_index] - # self.test_local = self.test_data_local_dict[client_index] self.train_local = None self.local_sample_number = None self.test_local = None @@ -31,15 +55,36 @@ def __init__( self.args = args def update_model(self, weights): + """ + Update the model with new weights. + + Args: + weights (object): New model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the dataset for training and testing. + + Args: + client_index (int): Index of the client. + """ self.client_index = client_index self.train_local = self.train_data_local_dict[client_index] self.local_sample_number = self.train_data_local_num_dict[client_index] self.test_local = self.test_data_local_dict[client_index] def train(self, round_idx=None): + """ + Train the model. + + Args: + round_idx (int, optional): Index of the training round (default: None). + + Returns: + tuple: Tuple containing trained model weights and local sample count. + """ self.args.round_idx = round_idx self.trainer.train(self.train_local, self.device, self.args) @@ -48,7 +93,13 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): - # train data + """ + Test the trained model. + + Returns: + tuple: Tuple containing training and testing metrics. + """ + # Train data metrics train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( train_metrics["test_correct"], @@ -56,7 +107,7 @@ def test(self): train_metrics["test_loss"], ) - # test data + # Test data metrics test_metrics = self.trainer.test(self.test_local, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( test_metrics["test_correct"], diff --git a/python/fedml/simulation/mpi/fedprox/message_define.py b/python/fedml/simulation/mpi/fedprox/message_define.py index 092e2ba618..57a51e3b1c 100644 --- a/python/fedml/simulation/mpi/fedprox/message_define.py +++ b/python/fedml/simulation/mpi/fedprox/message_define.py @@ -1,23 +1,22 @@ class MyMessage(object): """ - message type definition + Message type definition. """ - # server to client + # Server to client messages MSG_TYPE_S2C_INIT_CONFIG = 1 MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT = 2 - # client to server + # Client to server messages MSG_TYPE_C2S_SEND_MODEL_TO_SERVER = 3 MSG_TYPE_C2S_SEND_STATS_TO_SERVER = 4 + # Message argument keys MSG_ARG_KEY_TYPE = "msg_type" MSG_ARG_KEY_SENDER = "sender" MSG_ARG_KEY_RECEIVER = "receiver" - """ - message payload keywords definition - """ + # Message payload keywords MSG_ARG_KEY_NUM_SAMPLES = "num_samples" MSG_ARG_KEY_MODEL_PARAMS = "model_params" MSG_ARG_KEY_CLIENT_INDEX = "client_idx" diff --git a/python/fedml/simulation/mpi/fedprox/utils.py b/python/fedml/simulation/mpi/fedprox/utils.py index aea2449590..932ca053de 100644 --- a/python/fedml/simulation/mpi/fedprox/utils.py +++ b/python/fedml/simulation/mpi/fedprox/utils.py @@ -5,6 +5,15 @@ def transform_list_to_tensor(model_params_list): + """ + Transform a dictionary of model parameters from lists to tensors. + + Args: + model_params_list (dict): Dictionary of model parameters with lists. + + Returns: + dict: Dictionary of model parameters with tensors. + """ for k in model_params_list.keys(): model_params_list[k] = torch.from_numpy( np.asarray(model_params_list[k]) @@ -13,12 +22,27 @@ def transform_list_to_tensor(model_params_list): def transform_tensor_to_list(model_params): + """ + Transform a dictionary of model parameters from tensors to lists. + + Args: + model_params (dict): Dictionary of model parameters with tensors. + + Returns: + dict: Dictionary of model parameters with lists. + """ for k in model_params.keys(): model_params[k] = model_params[k].detach().numpy().tolist() return model_params def post_complete_message_to_sweep_process(args): + """ + Post a completion message to a sweep process. + + Args: + args (object): Arguments for configuration. + """ pipe_path = "./tmp/fedml" os.system("mkdir ./tmp/; touch ./tmp/fedml") if not os.path.exists(pipe_path): From 4f00cd99eecca2aa0ee683698820d734883f5e9f Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 19:41:38 +0530 Subject: [PATCH 54/70] docs --- .../launch/serve_mnist/model/minist_model.py | 45 ++- python/fedml/model/cv/vgg.py | 52 ++- python/fedml/model/finance/vfl_classifier.py | 29 ++ .../model/finance/vfl_feature_extractor.py | 36 ++ .../model/finance/vfl_models_standalone.py | 110 +++++- python/fedml/model/linear/lr.py | 34 ++ python/fedml/model/linear/lr_cifar10.py | 30 ++ python/fedml/model/mobile/mnn_lenet.py | 21 +- python/fedml/model/mobile/mnn_resnet.py | 180 ++++++---- python/fedml/model/mobile/torch_lenet.py | 37 ++ python/fedml/model/model_hub.py | 18 + python/fedml/model/nlp/model_args.py | 334 +++++++++++++++++- python/fedml/model/nlp/rnn.py | 74 +++- .../serving/client/client_initializer.py | 60 ++++ .../fedml/serving/client/client_launcher.py | 32 ++ .../client/fedml_client_master_manager.py | 99 ++++++ .../client/fedml_client_slave_manager.py | 38 +- python/fedml/serving/client/fedml_trainer.py | 48 ++- .../client/fedml_trainer_dist_adapter.py | 63 +++- .../serving/client/process_group_manager.py | 24 +- python/fedml/serving/client/utils.py | 40 ++- .../example/mnist/src/mnist_serve_main.py | 36 ++ .../src/app/pipe/instruct_pipeline.py | 98 ++++- python/fedml/serving/fedml_client.py | 50 ++- .../fedml/serving/fedml_inference_runner.py | 47 ++- python/fedml/serving/fedml_predictor.py | 106 +++++- python/fedml/serving/fedml_server.py | 30 ++ .../fedml/serving/server/fedml_aggregator.py | 160 +++++++-- .../serving/server/fedml_server_manager.py | 200 ++++++++++- python/fedml/serving/server/message_define.py | 19 +- .../serving/server/server_initializer.py | 24 +- 31 files changed, 1989 insertions(+), 185 deletions(-) diff --git a/python/examples/launch/serve_mnist/model/minist_model.py b/python/examples/launch/serve_mnist/model/minist_model.py index 25789d4e1c..1aed515cd6 100644 --- a/python/examples/launch/serve_mnist/model/minist_model.py +++ b/python/examples/launch/serve_mnist/model/minist_model.py @@ -1,11 +1,54 @@ import torch class LogisticRegression(torch.nn.Module): + """ + Logistic Regression model for binary classification. + + This class defines a logistic regression model with a single linear layer followed by a sigmoid activation function + for binary classification tasks. + + Args: + input_dim (int): The dimensionality of the input features. + output_dim (int): The number of output classes, which should be 1 for binary classification. + + Example: + # Create a logistic regression model for binary classification + input_dim = 10 + output_dim = 1 + model = LogisticRegression(input_dim, output_dim) + + Forward Method: + The forward method computes the output of the model for a given input. + + Example: + # Forward pass with input tensor 'x' + input_tensor = torch.tensor([0.1, 0.2, 0.3, ..., 0.9]) + output = model(input_tensor) + + Note: + - For binary classification, the `output_dim` should be set to 1. + - The `forward` method applies a sigmoid activation to the linear output, producing values in the range [0, 1]. + + """ + def __init__(self, input_dim, output_dim): super(LogisticRegression, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim) def forward(self, x): - import torch + """ + Forward pass of the logistic regression model. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, output_dim). + + Example: + # Forward pass with input tensor 'x' + input_tensor = torch.tensor([0.1, 0.2, 0.3, ..., 0.9]) + output = model(input_tensor) + """ outputs = torch.sigmoid(self.linear(x)) return outputs diff --git a/python/fedml/model/cv/vgg.py b/python/fedml/model/cv/vgg.py index 303a804137..3a2088369b 100644 --- a/python/fedml/model/cv/vgg.py +++ b/python/fedml/model/cv/vgg.py @@ -18,6 +18,25 @@ class VGG(nn.Module): + """ + VGG model implementation. + + Args: + features (nn.Module): The feature extractor module. + num_classes (int): Number of output classes. + init_weights (bool): Whether to initialize the model weights. + + Attributes: + features (nn.Module): The feature extractor module. + avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer. + classifier (nn.Sequential): Classifier module. + + Methods: + forward(x): Forward pass of the VGG model. + _initialize_weights(): Initialize model weights. + + """ + def __init__( self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True ) -> None: @@ -37,6 +56,16 @@ def __init__( self._initialize_weights() def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the VGG model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + """ x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) @@ -44,6 +73,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x def _initialize_weights(self) -> None: + """ + Initialize model weights. + + """ for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") @@ -58,19 +91,34 @@ def _initialize_weights(self) -> None: def make_layers(cfg, batch_norm=False): + """ + Create a list of layers for a VGG network based on the provided configuration. + + Args: + cfg (list): List of layer configurations where each element represents + the number of filters or "M" for max-pooling. + batch_norm (bool): If True, apply batch normalization after convolution. + + Returns: + nn.Sequential: A sequential container of layers. + + """ layers = [] - in_channels = 3 + in_channels = 3 # Input channel for RGB images for v in cfg: if v == "M": + # Max-pooling layer layers += [nn.MaxPool2d(kernel_size=2, stride=2)] else: v = int(v) conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: + # Add convolution, batch normalization, and ReLU activation layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: + # Add convolution and ReLU activation layers += [conv2d, nn.ReLU(inplace=True)] - in_channels = v + in_channels = v # Update the input channels for the next layer return nn.Sequential(*layers) diff --git a/python/fedml/model/finance/vfl_classifier.py b/python/fedml/model/finance/vfl_classifier.py index 2359e42209..c4f80065e8 100644 --- a/python/fedml/model/finance/vfl_classifier.py +++ b/python/fedml/model/finance/vfl_classifier.py @@ -2,6 +2,25 @@ class VFLClassifier(nn.Module): + """ + Virtual Federated Learning (VFL) Classifier Model. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing class predictions or scores. + + Architecture: + - Linear Layer: + - Input: input_dim neurons + - Output: output_dim neurons (typically the number of classes) + + """ def __init__(self, input_dim, output_dim, bias=True): super(VFLClassifier, self).__init__() self.classifier = nn.Sequential( @@ -9,4 +28,14 @@ def __init__(self, input_dim, output_dim, bias=True): ) def forward(self, x): + """ + Forward pass of the VFL Classifier model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (Tensor): Output tensor of shape (batch_size, output_dim) with class predictions or scores. + + """ return self.classifier(x) diff --git a/python/fedml/model/finance/vfl_feature_extractor.py b/python/fedml/model/finance/vfl_feature_extractor.py index 95a17c171f..c1bcccee73 100644 --- a/python/fedml/model/finance/vfl_feature_extractor.py +++ b/python/fedml/model/finance/vfl_feature_extractor.py @@ -2,6 +2,25 @@ class VFLFeatureExtractor(nn.Module): + """ + Virtual Federated Learning (VFL) Feature Extractor Model. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the desired feature dimension. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing extracted features. + + Architecture: + - Linear Layer followed by Leaky ReLU activation: + - Input: input_dim neurons + - Output: output_dim neurons (representing feature dimension) + + """ def __init__(self, input_dim, output_dim): super(VFLFeatureExtractor, self).__init__() self.classifier = nn.Sequential( @@ -10,7 +29,24 @@ def __init__(self, input_dim, output_dim): self.output_dim = output_dim def forward(self, x): + """ + Forward pass of the VFL Feature Extractor model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + features (Tensor): Output tensor of shape (batch_size, output_dim) with extracted features. + + """ return self.classifier(x) def get_output_dim(self): + """ + Get the output dimension of the feature extractor. + + Returns: + int: The output dimension (feature dimension). + + """ return self.output_dim diff --git a/python/fedml/model/finance/vfl_models_standalone.py b/python/fedml/model/finance/vfl_models_standalone.py index 89640c8453..46ab393090 100644 --- a/python/fedml/model/finance/vfl_models_standalone.py +++ b/python/fedml/model/finance/vfl_models_standalone.py @@ -4,6 +4,27 @@ class DenseModel(nn.Module): + """ + Dense Model with Linear Classifier. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes or features. + learning_rate (float, optional): The learning rate for the optimizer. Default is 0.01. + bias (bool, optional): Whether to include bias terms in the linear layer. Default is True. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim) representing the model's predictions. + + Methods: + - forward(x): Forward pass of the model to make predictions. + - backward(x, grads): Backward pass to compute gradients and update model parameters. + + """ + def __init__(self, input_dim, output_dim, learning_rate=0.01, bias=True): super(DenseModel, self).__init__() self.classifier = nn.Sequential( @@ -15,20 +36,42 @@ def __init__(self, input_dim, output_dim, learning_rate=0.01, bias=True): ) def forward(self, x): + """ + Forward pass of the Dense Model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (Tensor): Output tensor of shape (batch_size, output_dim) with model predictions. + + """ if self.is_debug: print("[DEBUG] DenseModel.forward") x = torch.tensor(x).float() - return self.classifier(x).detach().numpy() + return self.classifier(x) def backward(self, x, grads): + """ + Backward pass of the Dense Model. + + Args: + x (array-like): Input data of shape (batch_size, input_dim). + grads (array-like): Gradients of the loss with respect to the model's output. + + Returns: + x_grad (array-like): Gradients of the loss with respect to the input data. + + """ if self.is_debug: print("[DEBUG] DenseModel.backward") x = torch.tensor(x, requires_grad=True).float() grads = torch.tensor(grads).float() output = self.classifier(x) - output.backward(gradient=grads) + loss = torch.sum(output * grads) # Compute dot product for backward pass + loss.backward() x_grad = x.grad.numpy() self.optimizer.step() @@ -38,6 +81,25 @@ def backward(self, x, grads): class LocalModel(nn.Module): + """ + Local Model with a Linear Classifier. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes or features. + learning_rate (float): The learning rate for the optimizer. + + Attributes: + output_dim (int): The output dimension of the model. + + Methods: + forward(x): Forward pass of the model to make predictions. + predict(x): Make predictions using the model. + backward(x, grads): Backward pass to compute gradients and update model parameters. + get_output_dim(): Get the output dimension of the model. + + """ + def __init__(self, input_dim, output_dim, learning_rate): super(LocalModel, self).__init__() self.classifier = nn.Sequential( @@ -51,30 +113,66 @@ def __init__(self, input_dim, output_dim, learning_rate): ) def forward(self, x): + """ + Forward pass of the Local Model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (array-like): Output predictions as a numpy array. + + """ if self.is_debug: - print("[DEBUG] DenseModel.forward") + print("[DEBUG] LocalModel.forward") x = torch.tensor(x).float() return self.classifier(x).detach().numpy() def predict(self, x): + """ + Make predictions using the Local Model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + predictions (array-like): Output predictions as a numpy array. + + """ if self.is_debug: - print("[DEBUG] DenseModel.predict") + print("[DEBUG] LocalModel.predict") x = torch.tensor(x).float() return self.classifier(x).detach().numpy() def backward(self, x, grads): + """ + Backward pass of the Local Model. + + Args: + x (array-like): Input data of shape (batch_size, input_dim). + grads (array-like): Gradients of the loss with respect to the model's output. + + """ if self.is_debug: - print("[DEBUG] DenseModel.backward") + print("[DEBUG] LocalModel.backward") x = torch.tensor(x).float() grads = torch.tensor(grads).float() output = self.classifier(x) - output.backward(gradient=grads) + loss = torch.sum(output * grads) # Compute dot product for backward pass + loss.backward() self.optimizer.step() self.optimizer.zero_grad() def get_output_dim(self): + """ + Get the output dimension of the Local Model. + + Returns: + output_dim (int): The output dimension of the model. + + """ return self.output_dim diff --git a/python/fedml/model/linear/lr.py b/python/fedml/model/linear/lr.py index 53b5ce0c09..d5bca7fde2 100644 --- a/python/fedml/model/linear/lr.py +++ b/python/fedml/model/linear/lr.py @@ -2,11 +2,45 @@ class LogisticRegression(torch.nn.Module): + """ + Logistic Regression Model. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes or a single output. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing class probabilities or a single output. + + Architecture: + - Linear Layer: + - Input: input_dim neurons + - Output: output_dim neurons + - Activation: Sigmoid (for binary classification) or Softmax (for multi-class classification) + + Note: + - For binary classification, output_dim is typically set to 1. + - For multi-class classification, output_dim is the number of classes. + + """ def __init__(self, input_dim, output_dim): super(LogisticRegression, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim) def forward(self, x): + """ + Forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities or a single output. + + """ # try: outputs = torch.sigmoid(self.linear(x)) # except: diff --git a/python/fedml/model/linear/lr_cifar10.py b/python/fedml/model/linear/lr_cifar10.py index 762b9c2c3a..87d593a547 100644 --- a/python/fedml/model/linear/lr_cifar10.py +++ b/python/fedml/model/linear/lr_cifar10.py @@ -2,11 +2,41 @@ class LogisticRegression_Cifar10(torch.nn.Module): + """ + Logistic Regression Model for CIFAR-10 Image Classification. + + Args: + input_dim (int): The input dimension, typically the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes in CIFAR-10. + + Input: + - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, output_dim), representing class probabilities for CIFAR-10 classes. + + Architecture: + - Linear Layer: + - Input: input_dim neurons (flattened image vectors) + - Output: output_dim neurons (class probabilities) + - Activation: Sigmoid (to produce class probabilities) + + """ def __init__(self, input_dim, output_dim): super(LogisticRegression_Cifar10, self).__init__() self.linear = torch.nn.Linear(input_dim, output_dim) def forward(self, x): + """ + Forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities. + + """ # Flatten images into vectors # print(f"size = {x.size()}") x = x.view(x.size(0), -1) diff --git a/python/fedml/model/mobile/mnn_lenet.py b/python/fedml/model/mobile/mnn_lenet.py index a803da7fea..2378fc9695 100644 --- a/python/fedml/model/mobile/mnn_lenet.py +++ b/python/fedml/model/mobile/mnn_lenet.py @@ -5,7 +5,17 @@ class Lenet5(nn.Module): - """construct a lenet 5 model""" + """ + LeNet-5 convolutional neural network model. + + This class defines the LeNet-5 architecture for image classification. + + Args: + None + + Returns: + torch.Tensor: Model predictions. + """ def __init__(self): super(Lenet5, self).__init__() @@ -15,6 +25,15 @@ def __init__(self): self.fc2 = nn.linear(500, 10) def forward(self, x): + """ + Forward pass of the LeNet-5 model. + + Args: + x (torch.Tensor): Input image tensor. + + Returns: + torch.Tensor: Model predictions. + """ x = F.relu(self.conv1(x)) x = F.max_pool(x, [2, 2], [2, 2]) x = F.relu(self.conv2(x)) diff --git a/python/fedml/model/mobile/mnn_resnet.py b/python/fedml/model/mobile/mnn_resnet.py index 9ae9703bb3..4f9cf53744 100644 --- a/python/fedml/model/mobile/mnn_resnet.py +++ b/python/fedml/model/mobile/mnn_resnet.py @@ -5,93 +5,126 @@ class ResBlock(nn.Module): + """ + Residual Block for a ResNet-like architecture. + + This class defines a basic residual block with two convolutional layers and batch normalization. + + Args: + in_planes (int): Number of input channels. + planes (int): Number of output channels (number of filters in the convolutional layers). + stride (int): Stride value for the first convolutional layer (default is 1). + + Returns: + torch.Tensor: Output tensor from the residual block. + """ + def __init__(self, in_planes, planes, stride=1): super(ResBlock, self).__init__() - self.conv1 = nn.conv( - in_planes, - planes, - kernel_size=[3, 3], - stride=[stride, stride], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) - self.bn1 = nn.batch_norm(planes) - self.conv2 = nn.conv( - planes, - planes, - kernel_size=[3, 3], - stride=[1, 1], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) - self.bn2 = nn.batch_norm(planes) + self.bn2 = nn.BatchNorm2d(planes) def forward(self, x): + """ + Forward pass of the Residual Block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after passing through the residual block. + """ + out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) - out += x + out += x # Skip connection out = F.relu(out) return out class ResBlock_conv_shortcut(nn.Module): + """ + Residual Block with Convolutional Shortcuts for a ResNet-like architecture. + + This class defines a residual block with convolutional shortcuts. It consists of two convolutional layers + with batch normalization and a convolutional shortcut connection. + + Args: + in_planes (int): Number of input channels. + planes (int): Number of output channels (number of filters in the convolutional layers). + stride (int): Stride value for the first convolutional layer (default is 1). + + Returns: + torch.Tensor: Output tensor from the residual block. + """ + def __init__(self, in_planes, planes, stride=1): super(ResBlock_conv_shortcut, self).__init__() - self.conv1 = nn.conv( - in_planes, - planes, - kernel_size=[3, 3], - stride=[stride, stride], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False ) - self.bn1 = nn.batch_norm(planes) - self.conv2 = nn.conv( - planes, - planes, - kernel_size=[3, 3], - stride=[1, 1], - padding=[1, 1], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=1, padding=1, bias=False ) - self.bn2 = nn.batch_norm(planes) + self.bn2 = nn.BatchNorm2d(planes) - self.conv_shortcut = nn.conv( - in_planes, - planes, - kernel_size=[1, 1], - stride=[stride, stride], - bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, + self.conv_shortcut = nn.Conv2d( + in_planes, planes, kernel_size=1, stride=stride, bias=False ) - self.bn_shortcut = nn.batch_norm(planes) + self.bn_shortcut = nn.BatchNorm2d(planes) def forward(self, x): + """ + Forward pass of the Residual Block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after passing through the residual block. + """ + out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) - out += self.bn_shortcut(self.conv_shortcut(x)) + shortcut = self.bn_shortcut(self.conv_shortcut(x)) + out += shortcut # Skip connection with convolutional shortcut out = F.relu(out) return out class Resnet20(nn.Module): + """ + ResNet-20 implementation for image classification. + + This class defines a ResNet-20 architecture with convolutional blocks and shortcuts. + It consists of four stages, each containing convolutional blocks. + + Args: + num_classes (int): Number of output classes. + + Returns: + torch.Tensor: Output tensor representing class probabilities. + """ + def __init__(self, num_classes=10): super(Resnet20, self).__init__() - self.conv1 = nn.conv( + self.conv1 = nn.Conv2d( 3, 16, - kernel_size=[3, 3], - stride=[1, 1], - padding=[1, 1], + kernel_size=3, + stride=1, + padding=1, bias=False, - padding_mode=MNN.expr.Padding_Mode.SAME, ) - self.bn1 = nn.batch_norm(16) + self.bn1 = nn.BatchNorm2d(16) self.layer1 = ResBlock(16, 16, 1) self.layer2 = ResBlock(16, 16, 1) @@ -105,28 +138,37 @@ def __init__(self, num_classes=10): self.layer8 = ResBlock(64, 64, 1) self.layer9 = ResBlock(64, 64, 1) - self.fc = nn.linear(64, num_classes) + self.fc = nn.Linear(64, num_classes) def forward(self, x): + """ + Forward pass of the ResNet-20 model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor representing class probabilities. + """ + x = F.relu(self.bn1(self.conv1(x))) - x = self.layer1.forward(x) - x = self.layer2.forward(x) - x = self.layer3.forward(x) - # print(x.shape) - x = self.layer4.forward(x) - x = self.layer5.forward(x) - x = self.layer6.forward(x) - # print(x.shape) - x = self.layer7.forward(x) - x = self.layer8.forward(x) - x = self.layer9.forward(x) - # print(x.shape) - x = F.avg_pool(x, kernel=[8, 8], stride=[8, 8]) - x = F.convert(x, F.NCHW) - x = F.reshape(x, [0, -1]) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.layer4(x) + x = self.layer5(x) + x = self.layer6(x) + + x = self.layer7(x) + x = self.layer8(x) + x = self.layer9(x) + + x = F.avg_pool2d(x, kernel_size=8, stride=8) + x = x.view(x.size(0), -1) x = self.fc(x) - out = F.softmax(x, 1) + out = F.softmax(x, dim=1) return out diff --git a/python/fedml/model/mobile/torch_lenet.py b/python/fedml/model/mobile/torch_lenet.py index fc1a64b457..ee3f30241f 100644 --- a/python/fedml/model/mobile/torch_lenet.py +++ b/python/fedml/model/mobile/torch_lenet.py @@ -3,6 +3,43 @@ class LeNet(nn.Module): + """ + LeNet-5 Convolutional Neural Network model for image classification. + + Args: + None + + Input: + - Input tensor of shape (batch_size, 1, 32, 32), where batch_size is the number of input samples. + + Output: + - Output tensor of shape (batch_size, 10), representing class probabilities for 10 classes. + + Architecture: + - Convolutional Layer 1: + - Input: 1 channel (grayscale image) + - Output: 20 feature maps + - Kernel size: 5x5 + - Activation: ReLU + - Max Pooling: 2x2 + - Convolutional Layer 2: + - Input: 20 feature maps + - Output: 50 feature maps + - Kernel size: 5x5 + - Activation: ReLU + - Max Pooling: 2x2 + - Fully Connected Layer 1: + - Input: 800 neurons (flattened 50x4x4 from previous layer) + - Output: 500 neurons + - Activation: ReLU + - Dropout: 50% dropout rate + - Fully Connected Layer 2: + - Input: 500 neurons + - Output: 10 neurons (class probabilities) + - Activation: Softmax + + """ + def __init__(self): super(LeNet, self).__init__() self.fc2 = nn.Linear(500, 10) diff --git a/python/fedml/model/model_hub.py b/python/fedml/model/model_hub.py index e4ccc3acfc..fefa775055 100644 --- a/python/fedml/model/model_hub.py +++ b/python/fedml/model/model_hub.py @@ -17,6 +17,24 @@ def create(args, output_dim): + """ + Create a deep learning model based on the provided arguments and dataset. + + Args: + args (Namespace): Command-line arguments containing model and dataset information. + output_dim (int): Dimension of the model's output. + + Returns: + torch.nn.Module or Tuple[torch.nn.Module, torch.nn.Module] or None: The created model(s). + + Raises: + Exception: If the specified model or dataset is not supported. + + Example: + >>> import argparse + >>> args = argparse.Namespace(model="cnn", dataset="mnist") + >>> model = create(args, 10) + """ global model model_name = args.model logging.info("create_model. model_name = %s, output_dim = %s" % (model_name, output_dim)) diff --git a/python/fedml/model/nlp/model_args.py b/python/fedml/model/nlp/model_args.py index 2aaaa7319f..871f56ccd4 100644 --- a/python/fedml/model/nlp/model_args.py +++ b/python/fedml/model/nlp/model_args.py @@ -9,18 +9,78 @@ def get_default_process_count(): + """ + Get the default number of processes to use for multi-processing tasks. + + Returns: + int: The default process count. + + Example: + >>> process_count = get_default_process_count() + """ process_count = int(cpu_count() / 2) if cpu_count() > 2 else 1 if sys.platform == "win32": process_count = min(process_count, 61) return process_count - def get_special_tokens(): + """ + Get a list of special tokens commonly used in natural language processing tasks. + + Returns: + List[str]: A list of special tokens. + + Example: + >>> special_tokens = get_special_tokens() + """ return ["", "", "", "", ""] @dataclass class ModelArgs: + """ + Configuration class for model training and evaluation. + + Attributes: + adam_epsilon (float): Epsilon value for Adam optimizer. Default is 1e-8. + best_model_dir (str): Directory to save the best model checkpoints. Default is "outputs/best_model". + cache_dir (str): Directory for caching data. Default is "cache_dir/". + config (dict): Additional configuration settings as a dictionary. Default is an empty dictionary. + custom_layer_parameters (list): List of custom layer parameters. Default is an empty list. + custom_parameter_groups (list): List of custom parameter groups. Default is an empty list. + dataloader_num_workers (int): Number of workers for data loading. Default is determined by `get_default_process_count`. + do_lower_case (bool): Whether to convert input text to lowercase. Default is False. + dynamic_quantize (bool): Whether to dynamically quantize the model. Default is False. + early_stopping_consider_epochs (bool): Whether to consider epochs for early stopping. Default is False. + early_stopping_delta (float): Minimum change in metric value to consider for early stopping. Default is 0. + early_stopping_metric (str): Metric to monitor for early stopping. Default is "eval_loss". + early_stopping_metric_minimize (bool): Whether to minimize the early stopping metric. Default is True. + early_stopping_patience (int): Number of epochs with no improvement to wait before early stopping. Default is 3. + encoding (str): Encoding for input text. Default is None. + eval_batch_size (int): Batch size for evaluation. Default is 8. + evaluate_during_training (bool): Whether to perform evaluation during training. Default is False. + evaluate_during_training_silent (bool): Whether to silence evaluation logs during training. Default is True. + evaluate_during_training_steps (int): Frequency of evaluation steps during training. Default is 2000. + evaluate_during_training_verbose (bool): Whether to print evaluation results during training. Default is False. + evaluate_each_epoch (bool): Whether to perform evaluation after each epoch. Default is True. + fp16 (bool): Whether to use mixed-precision training (FP16). Default is True. + gradient_accumulation_steps (int): Number of gradient accumulation steps. Default is 1. + learning_rate (float): Learning rate for training. Default is 4e-5. + local_rank (int): Local rank for distributed training. Default is -1. + logging_steps (int): Frequency of logging training steps. Default is 50. + manual_seed (int): Seed for random number generation. Default is None. + max_grad_norm (float): Maximum gradient norm for clipping gradients. Default is 1.0. + max_seq_length (int): Maximum sequence length for input data. Default is 128. + model_name (str): Name of the model being used. Default is None. + model_type (str): Type of the model being used. Default is None. + ... (other attributes) + + Methods: + update_from_dict(new_values): Update attribute values from a dictionary. + get_args_for_saving(): Get a dictionary of attributes suitable for saving. + save(output_dir): Save the model configuration to a JSON file in the specified output directory. + load(input_dir): Load the model configuration from a JSON file in the specified input directory. + """ adam_epsilon: float = 1e-8 best_model_dir: str = "outputs/best_model" cache_dir: str = "cache_dir/" @@ -84,6 +144,20 @@ class ModelArgs: skip_special_tokens: bool = True def update_from_dict(self, new_values): + """ + Update attributes of the ModelArgs instance from a dictionary. + + Args: + new_values (dict): A dictionary containing attribute-value pairs to update. + + Raises: + TypeError: If the input `new_values` is not a Python dictionary. + + Example: + model_args = ModelArgs() + new_values = {'learning_rate': 0.01, 'train_batch_size': 16} + model_args.update_from_dict(new_values) + """ if isinstance(new_values, dict): for key, value in new_values.items(): setattr(self, key, value) @@ -91,6 +165,16 @@ def update_from_dict(self, new_values): raise (TypeError(f"{new_values} is not a Python dict.")) def get_args_for_saving(self): + """ + Get a dictionary of model arguments suitable for saving. + + Returns: + dict: A dictionary containing model arguments, excluding those specified in `not_saved_args`. + + Example: + model_args = ModelArgs() + args_to_save = model_args.get_args_for_saving() + """ args_for_saving = { key: value for key, value in asdict(self).items() @@ -99,11 +183,31 @@ def get_args_for_saving(self): return args_for_saving def save(self, output_dir): + """ + Save the model configuration to a JSON file in the specified output directory. + + Args: + output_dir (str): The directory where the model configuration JSON file will be saved. + + Example: + model_args = ModelArgs() + model_args.save("output_directory") + """ os.makedirs(output_dir, exist_ok=True) with open(os.path.join(output_dir, "model_args.json"), "w") as f: json.dump(self.get_args_for_saving(), f) def load(self, input_dir): + """ + Load the model configuration from a JSON file in the specified input directory. + + Args: + input_dir (str): The directory where the model configuration JSON file is located. + + Example: + model_args = ModelArgs() + model_args.load("input_directory") + """ if input_dir: model_args_file = os.path.join(input_dir, "model_args.json") if os.path.isfile(model_args_file): @@ -120,41 +224,151 @@ class ClassificationArgs(ModelArgs): """ model_class: str = "ClassificationModel" + """ + (str) The name of the classification model class. Defaults to "ClassificationModel". + """ + labels_list: list = field(default_factory=list) + """ + (list) A list of labels used for classification. Defaults to an empty list. + """ + labels_map: dict = field(default_factory=dict) + """ + (dict) A dictionary that maps labels to their corresponding indices. Defaults to an empty dictionary. + """ + lazy_delimiter: str = "\t" + """ + (str) The delimiter used for lazy loading of data. Defaults to the tab character ("\t"). + """ + lazy_labels_column: int = 1 + """ + (int) The column index (1-based) containing labels when using lazy loading. Defaults to 1. + """ + lazy_loading: bool = False + """ + (bool) Whether to use lazy loading of data. Defaults to False. + """ + lazy_loading_start_line: int = 1 + """ + (int) The line number (1-based) to start reading data when using lazy loading. Defaults to 1. + """ + lazy_text_a_column: bool = None + """ + (bool) Whether the lazy loading data contains a text column for input "text_a". Defaults to None. + """ + lazy_text_b_column: bool = None + """ + (bool) Whether the lazy loading data contains a text column for input "text_b". Defaults to None. + """ + lazy_text_column: int = 0 + """ + (int) The column index (0-based) containing text data when using lazy loading. Defaults to 0. + """ + onnx: bool = False + """ + (bool) Whether to use ONNX format for the model. Defaults to False. + """ + regression: bool = False + """ + (bool) Whether the task is regression (True) or classification (False). Defaults to False. + """ + sliding_window: bool = False + """ + (bool) Whether to use a sliding window approach for long documents. Defaults to False. + """ + stride: float = 0.8 + """ + (float) The stride used in the sliding window approach. Defaults to 0.8. + """ + tie_value: int = 1 + """ + (int) The value used for tied tokens in the dataset. Defaults to 1. + """ + evaluate_during_training_steps: int = 20 + """ + (int) The number of steps between evaluations during training. Defaults to 20. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ @dataclass class SeqTaggingArgs(ModelArgs): """ - Model args for a SeqTaggingArgs + Model args for a SeqTaggingModel """ model_class: str = "SeqTaggingModel" + """ + (str) The name of the SeqTagging model class. Defaults to "SeqTaggingModel". + """ + labels_list: list = field(default_factory=list) + """ + (list) A list of labels used for sequence tagging. Defaults to an empty list. + """ + lazy_delimiter: str = "\t" + """ + (str) The delimiter used for lazy loading of data. Defaults to the tab character ("\t"). + """ + lazy_labels_column: int = 1 + """ + (int) The column index (1-based) containing labels when using lazy loading. Defaults to 1. + """ + lazy_loading: bool = False + """ + (bool) Whether to use lazy loading of data. Defaults to False. + """ + lazy_loading_start_line: int = 1 + """ + (int) The line number (1-based) to start reading data when using lazy loading. Defaults to 1. + """ + onnx: bool = False + """ + (bool) Whether to use ONNX format for the model. Defaults to False. + """ + evaluate_during_training_steps: int = 20 + """ + (int) The number of steps between evaluations during training. Defaults to 20. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ + classification_report: bool = True + """ + (bool) Whether to generate a classification report. Defaults to True. + """ + pad_token_label_id: int = CrossEntropyLoss().ignore_index + """ + (int) The ID of the pad token label used for padding. Defaults to CrossEntropyLoss().ignore_index. + """ @dataclass @@ -164,16 +378,60 @@ class SpanExtractionArgs(ModelArgs): """ model_class: str = "QuestionAnsweringModel" + """ + (str) The name of the SpanExtraction model class. Defaults to "QuestionAnsweringModel". + """ + doc_stride: int = 384 + """ + (int) The document stride for span extraction. Defaults to 384. + """ + early_stopping_metric: str = "correct" + """ + (str) The early stopping metric. Defaults to "correct". + """ + early_stopping_metric_minimize: bool = False + """ + (bool) Whether to minimize the early stopping metric. Defaults to False. + """ + lazy_loading: bool = False + """ + (bool) Whether to use lazy loading of data. Defaults to False. + """ + max_answer_length: int = 100 + """ + (int) The maximum answer length. Defaults to 100. + """ + max_query_length: int = 64 + """ + (int) The maximum query length. Defaults to 64. + """ + n_best_size: int = 20 + """ + (int) The number of best answers to consider. Defaults to 20. + """ + null_score_diff_threshold: float = 0.0 + """ + (float) The null score difference threshold. Defaults to 0.0. + """ + evaluate_during_training_steps: int = 20 + """ + (int) The number of steps between evaluations during training. Defaults to 20. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ + @dataclass @@ -183,20 +441,92 @@ class Seq2SeqArgs(ModelArgs): """ model_class: str = "Seq2SeqModel" + """ + (str) The name of the Seq2Seq model class. Defaults to "Seq2SeqModel". + """ + base_marian_model_name: str = None + """ + (str) The base Marian model name. Defaults to None. + """ + dataset_class: Dataset = None + """ + (Dataset) The dataset class. Defaults to None. + """ + do_sample: bool = False + """ + (bool) Whether to perform sampling during decoding. Defaults to False. + """ + early_stopping: bool = True + """ + (bool) Whether to use early stopping during training. Defaults to True. + """ + evaluate_generated_text: bool = False + """ + (bool) Whether to evaluate generated text. Defaults to False. + """ + length_penalty: float = 2.0 + """ + (float) The length penalty factor during decoding. Defaults to 2.0. + """ + max_length: int = 20 + """ + (int) The maximum length of generated text. Defaults to 20. + """ + max_steps: int = -1 + """ + (int) The maximum number of training steps. Defaults to -1 (unlimited). + """ + num_beams: int = 4 + """ + (int) The number of beams used during decoding. Defaults to 4. + """ + num_return_sequences: int = 1 + """ + (int) The number of generated sequences to return. Defaults to 1. + """ + repetition_penalty: float = 1.0 + """ + (float) The repetition penalty factor during decoding. Defaults to 1.0. + """ + top_k: float = None + """ + (float) The top-k value used during decoding. Defaults to None. + """ + top_p: float = None + """ + (float) The top-p value used during decoding. Defaults to None. + """ + use_multiprocessed_decoding: bool = False + """ + (bool) Whether to use multiprocessed decoding. Defaults to False. + """ + evaluate_during_training: bool = True + """ + (bool) Whether to perform evaluations during training. Defaults to True. + """ + src_lang: str = "en_XX" + """ + (str) The source language for translation. Defaults to "en_XX". + """ + tgt_lang: str = "ro_RO" + """ + (str) The target language for translation. Defaults to "ro_RO". + """ + diff --git a/python/fedml/model/nlp/rnn.py b/python/fedml/model/nlp/rnn.py index 7af13e4618..b46c57f018 100644 --- a/python/fedml/model/nlp/rnn.py +++ b/python/fedml/model/nlp/rnn.py @@ -3,17 +3,20 @@ class RNN_OriginalFedAvg(nn.Module): - """Creates a RNN model using LSTM layers for Shakespeare language models (next character prediction task). + """ + Creates a RNN model using LSTM layers for Shakespeare language models (next character prediction task). This replicates the model structure in the paper: Communication-Efficient Learning of Deep Networks from Decentralized Data - H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Agueray Arcas. AISTATS 2017. - https://arxiv.org/abs/1602.05629 + H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Agueray Arcas. AISTATS 2017. + https://arxiv.org/abs/1602.05629 This is also recommended model by "Adaptive Federated Optimization. ICML 2020" (https://arxiv.org/pdf/2003.00295.pdf) + Args: - vocab_size: the size of the vocabulary, used as a dimension in the input embedding. - sequence_length: the length of input sequences. + embedding_dim: The dimension of word embeddings. Default is 8. + vocab_size: The size of the vocabulary, used as a dimension in the input embedding. Default is 90. + hidden_size: The size of the hidden state in the LSTM layers. Default is 256. Returns: - An uncompiled `torch.nn.Module`. + An uncompiled `torch.nn.Module`. """ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): @@ -30,6 +33,14 @@ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): self.fc = nn.Linear(hidden_size, vocab_size) def forward(self, input_seq): + """ + Forward pass of the model. + + Args: + input_seq: Input sequence of character indices. + Returns: + output: Model predictions. + """ embeds = self.embeddings(input_seq) # Note that the order of mini-batch is random so there is no hidden relationship among batches. # So we do not input the previous batch's hidden state, @@ -45,6 +56,20 @@ def forward(self, input_seq): class RNN_FedShakespeare(nn.Module): + """ + RNN model for Shakespeare language modeling (next character prediction task). + + This class defines an RNN model for predicting the next character in a sequence of text, + specifically tailored for the "fed_shakespeare" task. + + Args: + embedding_dim (int): Dimension of the character embeddings. + vocab_size (int): Size of the vocabulary (number of unique characters). + hidden_size (int): Size of the hidden state of the LSTM layers. + + Returns: + torch.Tensor: The model's output predictions. + """ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): super(RNN_FedShakespeare, self).__init__() self.embeddings = nn.Embedding( @@ -59,6 +84,14 @@ def __init__(self, embedding_dim=8, vocab_size=90, hidden_size=256): self.fc = nn.Linear(hidden_size, vocab_size) def forward(self, input_seq): + """ + Forward pass of the model. + + Args: + input_seq: Input sequence of character indices. + Returns: + output: Model predictions. + """ embeds = self.embeddings(input_seq) # Note that the order of mini-batch is random so there is no hidden relationship among batches. # So we do not input the previous batch's hidden state, @@ -74,15 +107,22 @@ def forward(self, input_seq): class RNN_StackOverFlow(nn.Module): - """Creates a RNN model using LSTM layers for StackOverFlow (next word prediction task). - This replicates the model structure in the paper: + """ + RNN model for StackOverflow language modeling (next word prediction task). + + This class defines an RNN model for predicting the next word in a sequence of text, specifically tailored + for the "stackoverflow_nwp" task. "Adaptive Federated Optimization. ICML 2020" (https://arxiv.org/pdf/2003.00295.pdf) - Table 9 + Args: - vocab_size: the size of the vocabulary, used as a dimension in the input embedding. - sequence_length: the length of input sequences. + vocab_size (int): Size of the vocabulary (number of unique words). + num_oov_buckets (int): Number of out-of-vocabulary (OOV) buckets. + embedding_size (int): Dimension of the word embeddings. + latent_size (int): Size of the LSTM hidden state. + num_layers (int): Number of LSTM layers. + Returns: - An uncompiled `torch.nn.Module`. + torch.Tensor: The model's output predictions. """ def __init__( @@ -107,6 +147,16 @@ def __init__( self.fc2 = nn.Linear(embedding_size, extended_vocab_size) def forward(self, input_seq, hidden_state=None): + """ + Forward pass of the model. + + Args: + input_seq (torch.Tensor): Input sequence of word indices. + hidden_state (tuple): Initial hidden state of the LSTM. + + Returns: + torch.Tensor: Model predictions. + """ embeds = self.word_embeddings(input_seq) lstm_out, hidden_state = self.lstm(embeds, hidden_state) fc1_output = self.fc1(lstm_out[:, :]) diff --git a/python/fedml/serving/client/client_initializer.py b/python/fedml/serving/client/client_initializer.py index 37791b80de..b26e727937 100644 --- a/python/fedml/serving/client/client_initializer.py +++ b/python/fedml/serving/client/client_initializer.py @@ -16,6 +16,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize and run a federated learning client. + + Args: + args: Arguments and configuration for the client. + device: The device on which the client should run (e.g., 'cpu' or 'cuda'). + comm: The communication backend for distributed training. + client_rank: The rank or identifier of this client. + client_num: The total number of clients in the federated learning scenario. + model: The machine learning model to be trained. + train_data_num: The number of training data points. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data points. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: An optional custom model trainer. + + Returns: + None + """ backend = args.backend trainer_dist_adapter = get_trainer_dist_adapter( @@ -60,6 +79,23 @@ def get_trainer_dist_adapter( test_data_local_dict, model_trainer, ): + """ + Get a distributed trainer adapter for the federated learning client. + + Args: + args: Arguments and configuration for the client. + device: The device on which the client should run (e.g., 'cpu' or 'cuda'). + client_rank: The rank or identifier of this client. + model: The machine learning model to be trained. + train_data_num: The number of training data points. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data points. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: An optional custom model trainer. + + Returns: + TrainerDistAdapter: A distributed trainer adapter. + """ return TrainerDistAdapter( args, device, @@ -74,10 +110,34 @@ def get_trainer_dist_adapter( def get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend): + """ + Get a federated learning client manager for the master client in the hierarchical scenario. + + Args: + args: Arguments and configuration for the client. + trainer_dist_adapter: A distributed trainer adapter. + comm: The communication backend for distributed training. + client_rank: The rank or identifier of this client. + client_num: The total number of clients in the federated learning scenario. + backend: The backend for distributed training (e.g., 'nccl' or 'gloo'). + + Returns: + ClientMasterManager: A federated learning client manager for the master client. + """ return ClientMasterManager(args, trainer_dist_adapter, comm, client_rank, client_num, backend) def get_client_manager_salve(args, trainer_dist_adapter): + """ + Get a federated learning client manager for a slave client in the hierarchical scenario. + + Args: + args: Arguments and configuration for the client. + trainer_dist_adapter: A distributed trainer adapter. + + Returns: + ClientSlaveManager: A federated learning client manager for a slave client. + """ from .fedml_client_slave_manager import ClientSlaveManager return ClientSlaveManager(args, trainer_dist_adapter) diff --git a/python/fedml/serving/client/client_launcher.py b/python/fedml/serving/client/client_launcher.py index 1a4831b11e..1034ab2cb7 100644 --- a/python/fedml/serving/client/client_launcher.py +++ b/python/fedml/serving/client/client_launcher.py @@ -27,6 +27,16 @@ class CrossSiloLauncher: @staticmethod def launch_dist_trainers(torch_client_filename, inputs): + """ + Launch distributed trainers based on the specified scenario. + + Args: + torch_client_filename (str): The filename of the torch client script to be launched. + inputs (List[str]): List of input arguments to be passed to the torch client script. + + Returns: + None + """ # this is only used by the client (DDP or single process), so there is no need to specify the backend. args = load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -38,12 +48,34 @@ def launch_dist_trainers(torch_client_filename, inputs): @staticmethod def _run_cross_silo_horizontal(args, torch_client_filename, inputs): + """ + Run distributed training in a horizontal federated learning scenario. + + Args: + args: Arguments and configuration for the client. + torch_client_filename (str): The filename of the torch client script to be launched. + inputs (List[str]): List of input arguments to be passed to the torch client script. + + Returns: + None + """ python_path = subprocess.run(["which", "python"], capture_output=True, text=True).stdout.strip() process_arguments = [python_path, torch_client_filename] + inputs subprocess.run(process_arguments) @staticmethod def _run_cross_silo_hierarchical(args, torch_client_filename, inputs): + """ + Run distributed training in a hierarchical federated learning scenario. + + Args: + args: Arguments and configuration for the client. + torch_client_filename (str): The filename of the torch client script to be launched. + inputs (List[str]): List of input arguments to be passed to the torch client script. + + Returns: + None + """ def get_torchrun_arguments(node_rank): torchrun_path = subprocess.run(["which", "torchrun"], capture_output=True, text=True).stdout.strip() diff --git a/python/fedml/serving/client/fedml_client_master_manager.py b/python/fedml/serving/client/fedml_client_master_manager.py index 6e4d2b7495..c14720b214 100644 --- a/python/fedml/serving/client/fedml_client_master_manager.py +++ b/python/fedml/serving/client/fedml_client_master_manager.py @@ -19,6 +19,17 @@ class ClientMasterManager(FedMLCommManager): RUN_FINISHED_STATUS_FLAG = "FINISHED" def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the ClientMasterManager. + + Args: + args: Arguments and configuration for the client manager. + trainer_dist_adapter: Trainer distribution adapter for distributed training. + comm: Communication backend (MPI, etc.). + rank: Rank of the client. + size: Size of the client group. + backend: Backend for distributed training (MPI, etc.). + """ super().__init__(args, comm, rank, size, backend) self.trainer_dist_adapter = trainer_dist_adapter self.args = args @@ -35,6 +46,9 @@ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backen self.is_inited = False def register_message_receive_handlers(self): + """ + Register message receive handlers for handling various types of messages. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -53,6 +67,12 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the "connection ready" message. + + Args: + msg_params: Parameters of the message. + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -60,9 +80,21 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the "check client status" message. + + Args: + msg_params: Parameters of the message. + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the "initialize" message and prepare for training. + + Args: + msg_params: Parameters of the message. + """ if self.is_inited: return @@ -88,6 +120,12 @@ def handle_message_init(self, msg_params): self.round_idx += 1 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the "receive model from server" message. + + Args: + msg_params: Parameters of the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -108,15 +146,35 @@ def handle_message_receive_model_from_server(self, msg_params): self.finish() def handle_message_finish(self, msg_params): + """ + Handle the "finish" message and perform cleanup. + + Args: + msg_params: Parameters of the message. + """ logging.info(" ====================cleanup ====================") self.cleanup() def cleanup(self): + """ + Perform cleanup operations at the end of training. + """ self.send_client_status(0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) mlops.log_training_finished_status() self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model to the server. + + Args: + receive_id: ID of the recipient (usually the server). + weights: Model weights to be sent. + local_sample_num: Number of local training samples. + + Note: + This method sends model parameters to the server for aggregation. + """ tick = time.time() mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id,) @@ -130,6 +188,17 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): + """ + Send the client status message to the specified recipient. + + Args: + receive_id: ID of the recipient. + status: Status flag to be sent (default is ONLINE_STATUS_FLAG). + + Note: + This method sends information about the client's status, including the operating system. + + """ logging.info("send_client_status") logging.info("self.client_real_id = {}".format(self.client_real_id)) message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) @@ -149,9 +218,32 @@ def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status: Training status to be reported. + + Note: + This method logs the training status using MLOps. + + """ mlops.log_training_status(status) def sync_process_group(self, round_idx, model_params=None, client_index=None, src=0): + """ + Synchronize the process group with information about the current training round. + + Args: + round_idx: The current training round index. + model_params: Model parameters (default is None). + client_index: Client index (default is None). + src: Source of the synchronization (default is 0). + + Note: + This method broadcasts information about the current training round to the process group. + + """ logging.info("sending round number to pg") round_number = [round_idx, model_params, client_index] dist.broadcast_object_list( @@ -160,6 +252,13 @@ def sync_process_group(self, round_idx, model_params=None, client_index=None, sr logging.info("round number %d broadcast to process group" % round_number[0]) def __train(self): + """ + Perform the training for the current round. + + Note: + This method initiates the training process and sends the updated model to the server. + + """ logging.info("#######training########### round_id = %d" % self.round_idx) mlops.event("train", event_started=True, event_value=str(self.round_idx)) diff --git a/python/fedml/serving/client/fedml_client_slave_manager.py b/python/fedml/serving/client/fedml_client_slave_manager.py index 48f30d8263..5b817e34fa 100644 --- a/python/fedml/serving/client/fedml_client_slave_manager.py +++ b/python/fedml/serving/client/fedml_client_slave_manager.py @@ -5,6 +5,14 @@ class ClientSlaveManager: def __init__(self, args, trainer_dist_adapter): + """ + Initialize the ClientSlaveManager. + + Args: + args: Command-line arguments. + trainer_dist_adapter: Trainer distributed adapter. + + """ self.trainer_dist_adapter = trainer_dist_adapter self.args = args self.round_idx = 0 @@ -12,6 +20,13 @@ def __init__(self, args, trainer_dist_adapter): self.finished = False def train(self): + """ + Perform training for the current round. + + This method synchronizes with the process group, updates the model and dataset if necessary, and initiates training + for the current round. + + """ [round_idx, model_params, client_index] = self.await_sync_process_group() if round_idx: self.round_idx = round_idx @@ -28,7 +43,12 @@ def train(self): self.trainer_dist_adapter.train(self.round_idx) def finish(self): - # pass + """ + Finish the training process. + + This method performs cleanup operations and logs the completion of training. + + """ self.trainer_dist_adapter.cleanup_pg() logging.info( "Training finished for slave client rank %s in silo %s" @@ -37,6 +57,16 @@ def finish(self): self.finished = True def await_sync_process_group(self, src=0): + """ + Wait for synchronization with the process group. + + Args: + src: Source rank for synchronization (default is 0). + + Returns: + List: A list containing round_idx, model_params, and client_index. + + """ logging.info("process %d waiting for round number" % dist.get_rank()) objects = [None, None, None] dist.broadcast_object_list( @@ -46,5 +76,11 @@ def await_sync_process_group(self, src=0): return objects def run(self): + """ + Start the client manager's main execution loop. + + This method continuously trains the client while it is not finished. + + """ while not self.finished: self.train() diff --git a/python/fedml/serving/client/fedml_trainer.py b/python/fedml/serving/client/fedml_trainer.py index 827644cc42..ae6d9e9a7f 100755 --- a/python/fedml/serving/client/fedml_trainer.py +++ b/python/fedml/serving/client/fedml_trainer.py @@ -17,8 +17,21 @@ def __init__( args, model_trainer, ): + """ + Initialize the Federated Learning Trainer. + + Args: + client_index: Index of the client. + train_data_local_dict: Dictionary mapping client IDs to local training datasets. + train_data_local_num_dict: Dictionary mapping client IDs to local training data counts. + test_data_local_dict: Dictionary mapping client IDs to local test datasets. + train_data_num: Number of training data samples. + device: Torch device for training. + args: Command-line arguments. + model_trainer: Trainer for the model. + + """ self.trainer = model_trainer - self.client_index = client_index if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -32,15 +45,28 @@ def __init__( self.train_local = None self.local_sample_number = None self.test_local = None - self.device = device self.args = args self.args.device = device def update_model(self, weights): + """ + Update the model with new parameters. + + Args: + weights: Updated model parameters. + + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local dataset for training. + + Args: + client_index: Index of the client to update the dataset for. + + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -64,6 +90,16 @@ def update_dataset(self, client_index): self.trainer.update_dataset(self.train_local, self.test_local, self.local_sample_number) def train(self, round_idx=None): + """ + Perform federated training for the specified round. + + Args: + round_idx (Optional): Index of the current training round (default is None). + + Returns: + Tuple: A tuple containing the updated model weights and the number of local training samples. + + """ self.args.round_idx = round_idx tick = time.time() @@ -77,6 +113,14 @@ def train(self, round_idx=None): return weights, self.local_sample_number def test(self): + """ + Test the model on local data. + + Returns: + Tuple: A tuple containing training accuracy, training loss, number of training samples, + test accuracy, test loss, and number of test samples. + + """ # train data train_metrics = self.trainer.test(self.train_local, self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( diff --git a/python/fedml/serving/client/fedml_trainer_dist_adapter.py b/python/fedml/serving/client/fedml_trainer_dist_adapter.py index 60383d31cf..36db5d3bb3 100644 --- a/python/fedml/serving/client/fedml_trainer_dist_adapter.py +++ b/python/fedml/serving/client/fedml_trainer_dist_adapter.py @@ -19,7 +19,21 @@ def __init__( test_data_local_dict, model_trainer, ): + """ + Initialize the TrainerDistAdapter. + Args: + args: Command-line arguments. + device: Torch device for training. + client_rank: Rank of the client. + model: The neural network model. + train_data_num: Number of training data samples. + train_data_local_num_dict: Dictionary mapping client IDs to local training data counts. + train_data_local_dict: Dictionary mapping client IDs to local training datasets. + test_data_local_dict: Dictionary mapping client IDs to local test datasets. + model_trainer: Trainer for the model. + + """ ml_engine_adapter.model_to_device(args, model, device) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -62,6 +76,23 @@ def get_trainer( args, model_trainer, ): + """ + Create and return a trainer for the federated learning process. + + Args: + client_index: Index of the client. + train_data_local_dict: Dictionary mapping client IDs to local training datasets. + train_data_local_num_dict: Dictionary mapping client IDs to local training data counts. + test_data_local_dict: Dictionary mapping client IDs to local test datasets. + train_data_num: Number of training data samples. + device: Torch device for training. + args: Command-line arguments. + model_trainer: Trainer for the model. + + Returns: + FedMLTrainer: Trainer instance for federated learning. + + """ return FedMLTrainer( client_index, train_data_local_dict, @@ -74,20 +105,50 @@ def get_trainer( ) def train(self, round_idx): + """ + Perform federated training for the specified round. + + Args: + round_idx: Index of the current training round. + + Returns: + Tuple: A tuple containing the updated model weights and the number of local training samples. + + """ weights, local_sample_num = self.trainer.train(round_idx) return weights, local_sample_num def update_model(self, model_params): + """ + Update the model with new parameters. + + Args: + model_params: Updated model parameters. + + """ self.trainer.update_model(model_params) def update_dataset(self, client_index=None): + """ + Update the local dataset for training. + + Args: + client_index (Optional): Index of the client to update the dataset for (default is None, uses client's index). + + """ _client_index = client_index or self.client_index self.trainer.update_dataset(int(_client_index)) def cleanup_pg(self): + """ + Clean up the process group if using distributed training. + + This method is called to clean up the process group when hierarchical federated learning is used. + + """ if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: logging.info( - "Cleaningup process group for client %s in silo %s" + "Cleaning up process group for client %s in silo %s" % (self.args.proc_rank_in_silo, self.args.rank_in_node) ) self.process_group_manager.cleanup() diff --git a/python/fedml/serving/client/process_group_manager.py b/python/fedml/serving/client/process_group_manager.py index 92519c6cc4..06da2e9738 100644 --- a/python/fedml/serving/client/process_group_manager.py +++ b/python/fedml/serving/client/process_group_manager.py @@ -7,6 +7,17 @@ class ProcessGroupManager: def __init__(self, rank, world_size, master_address, master_port, only_gpu): + """ + Initialize a process group manager for distributed training. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master process. + master_port (int): The port for communication with the master process. + only_gpu (bool): Whether to use NCCL backend for GPU communication. + + """ logging.info("Start process group") logging.info( "rank: %d, world_size: %d, master_address: %s, master_port: %s" @@ -31,7 +42,18 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Cleanup the process group. + + """ dist.destroy_process_group() def get_process_group(self): - return self.messaging_pg + """ + Get the messaging process group. + + Returns: + torch.distributed.ProcessGroup: The process group for communication. + + """ + return self.messaging_pg \ No newline at end of file diff --git a/python/fedml/serving/client/utils.py b/python/fedml/serving/client/utils.py index 38f4a169d1..4d8657fe1c 100644 --- a/python/fedml/serving/client/utils.py +++ b/python/fedml/serving/client/utils.py @@ -3,16 +3,42 @@ # ref: https://discuss.pytorch.org/t/failed-to-load-model-trained-by-ddp-for-inference/84841/2?u=amir_zsh def convert_model_params_from_ddp(ddp_model_params): - model_params = OrderedDict() - for k, v in ddp_model_params.items(): - name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel - model_params[name] = v - return model_params + """ + Convert model parameters from DistributedDataParallel (DDP) format to a regular format. + + Args: + ddp_model_params (OrderedDict): Model parameters in DDP format. + Returns: + OrderedDict: Model parameters in regular format. -def convert_model_params_to_ddp(ddp_model_params): + Example: + >>> ddp_params = OrderedDict([('module.conv1.weight', tensor), ('module.fc1.weight', tensor)]) + >>> regular_params = convert_model_params_from_ddp(ddp_params) + """ model_params = OrderedDict() for k, v in ddp_model_params.items(): - name = f"module.{k}" # add 'module.' of DataParallel/DistributedDataParallel + name = k[7:] # Remove 'module.' of DataParallel/DistributedDataParallel model_params[name] = v return model_params + + +def convert_model_params_to_ddp(model_params): + """ + Convert model parameters from a regular format to DistributedDataParallel (DDP) format. + + Args: + model_params (OrderedDict): Model parameters in regular format. + + Returns: + OrderedDict: Model parameters in DDP format. + + Example: + >>> regular_params = OrderedDict([('conv1.weight', tensor), ('fc1.weight', tensor)]) + >>> ddp_params = convert_model_params_to_ddp(regular_params) + """ + ddp_model_params = OrderedDict() + for k, v in model_params.items(): + name = f"module.{k}" # Add 'module.' for DataParallel/DistributedDataParallel + ddp_model_params[name] = v + return ddp_model_params diff --git a/python/fedml/serving/example/mnist/src/mnist_serve_main.py b/python/fedml/serving/example/mnist/src/mnist_serve_main.py index c6d8fdabcd..7de58955d1 100644 --- a/python/fedml/serving/example/mnist/src/mnist_serve_main.py +++ b/python/fedml/serving/example/mnist/src/mnist_serve_main.py @@ -12,7 +12,24 @@ # DATA_CACHE_DIR = "" class MnistPredictor(FedMLPredictor): + """ + A custom predictor for MNIST digit classification using a logistic regression model. + + This class loads a pretrained logistic regression model and provides a predict method to make predictions + on input data. + + Args: + None + + Example: + predictor = MnistPredictor() + input_data = {"arr": [0.1, 0.2, 0.3, ..., 0.9]} + prediction = predictor.predict(input_data) + """ def __init__(self): + """ + Initialize the MnistPredictor by loading a pretrained logistic regression model. + """ import pickle import torch @@ -28,6 +45,25 @@ def __init__(self): self.list_to_tensor_func = torch.tensor def predict(self, request): + """ + Perform predictions on input data using the pretrained logistic regression model. + + Args: + request (dict): A dictionary containing input data for prediction. + The dictionary should have the following key: + - "arr" (list): A list of float values representing the input features for a MNIST digit image. + + Returns: + torch.Tensor: A tensor representing the model's prediction. + + Example: + predictor = MnistPredictor() + input_data = {"arr": [0.1, 0.2, 0.3, ..., 0.9]} + prediction = predictor.predict(input_data) + + Note: + The input data should be a list of float values with the same dimensionality as the model's input. + """ arr = request["arr"] input_tensor = self.list_to_tensor_func(arr) return self.model(input_tensor) diff --git a/python/fedml/serving/example/quick_start/src/app/pipe/instruct_pipeline.py b/python/fedml/serving/example/quick_start/src/app/pipe/instruct_pipeline.py index edcc1a643b..f3ae5a4089 100644 --- a/python/fedml/serving/example/quick_start/src/app/pipe/instruct_pipeline.py +++ b/python/fedml/serving/example/quick_start/src/app/pipe/instruct_pipeline.py @@ -2,10 +2,8 @@ Adapted from https://github.com/databrickslabs/dolly/blob/master/training/generate.py """ from typing import List, Optional, Tuple - import logging import re - import torch from transformers import ( AutoModelForCausalLM, @@ -27,13 +25,17 @@ def load_model_tokenizer_for_generate( pretrained_model_name_or_path: str, ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: - """Loads the model and tokenizer so that it can be used for generating responses. + """ + Load the model and tokenizer for generating responses. Args: - pretrained_model_name_or_path (str): name or path for model + pretrained_model_name_or_path (str): Name or path for the pretrained model. Returns: - Tuple[PreTrainedModel, PreTrainedTokenizer]: model and tokenizer + Tuple[PreTrainedModel, PreTrainedTokenizer]: A tuple containing the loaded model and tokenizer. + + Example: + model, tokenizer = load_model_tokenizer_for_generate("gpt2") """ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left") model = AutoModelForCausalLM.from_pretrained( @@ -41,22 +43,25 @@ def load_model_tokenizer_for_generate( ) return model, tokenizer - def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int: - """Gets the token ID for a given string that has been added to the tokenizer as a special token. + """ + Get the token ID for a given string that has been added to the tokenizer as a special token. - When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are - treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to. + When training, we configure the tokenizer so that sequences like "### Instruction:" and "### End" are + treated specially and converted to a single, new token. This function retrieves the token ID for a given key. Args: - tokenizer (PreTrainedTokenizer): the tokenizer - key (str): the key to convert to a single token + tokenizer (PreTrainedTokenizer): The tokenizer. + key (str): The key to convert to a single token. Raises: - ValueError: if more than one ID was generated + ValueError: If more than one ID was generated for the key. Returns: - int: the token ID for the given key + int: The token ID for the given key. + + Example: + special_token_id = get_special_token_id(tokenizer, "### Instruction:") """ token_ids = tokenizer.encode(key) if len(token_ids) > 1: @@ -64,6 +69,7 @@ def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int: return token_ids[0] + class InstructionTextGenerationPipeline(Pipeline): def __init__( self, @@ -98,6 +104,18 @@ def _sanitize_parameters( return_full_text: bool = None, **generate_kwargs ): + """ + Sanitize and configure parameters for text generation. + + Args: + return_full_text (bool, optional): Whether to return the full text. Defaults to None. + + Returns: + Tuple[Dict, Dict, Dict]: A tuple containing preprocess_params, forward_params, and postprocess_params. + + Raises: + ValueError: If the response key token is not found. + """ preprocess_params = {} # newer versions of the tokenizer configure the response key as a special token. newer versions still may @@ -130,6 +148,18 @@ def _sanitize_parameters( return preprocess_params, forward_params, postprocess_params def preprocess(self, instruction_text, **generate_kwargs): + """ + Preprocess the input text for text generation. + + Args: + instruction_text (str): The instruction text. + + Returns: + Dict: Preprocessed inputs for text generation. + + Example: + inputs = preprocess("Write a summary of a book.") + """ prompt_text = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction_text) inputs = self.tokenizer( prompt_text, @@ -140,6 +170,15 @@ def preprocess(self, instruction_text, **generate_kwargs): return inputs def _forward(self, model_inputs, **generate_kwargs): + """ + Forward pass for text generation. + + Args: + model_inputs (Dict): Inputs for text generation. + + Returns: + Dict: Model outputs for text generation. + """ input_ids = model_inputs["input_ids"] attention_mask = model_inputs.get("attention_mask", None) @@ -173,6 +212,21 @@ def postprocess( end_key_token_id: Optional[int] = None, return_full_text: bool = False ): + """ + Postprocess the model outputs for text generation. + + Args: + model_outputs (Dict): Model outputs for text generation. + response_key_token_id (int, optional): Token ID for the response key. Defaults to None. + end_key_token_id (int, optional): Token ID for the end key. Defaults to None. + return_full_text (bool, optional): Whether to return the full text. Defaults to False. + + Returns: + List[Dict]: List of generated text records. + + Example: + generated_text = postprocess(model_outputs) + """ generated_sequence: torch.Tensor = model_outputs["generated_sequence"][0] instruction_text = model_outputs["instruction_text"] @@ -236,6 +290,9 @@ def postprocess( records.append(rec) return records + + + def generate_response( @@ -245,16 +302,21 @@ def generate_response( tokenizer: PreTrainedTokenizer, **kwargs, ) -> str: - """Given an instruction, uses the model and tokenizer to generate a response. This formats the instruction in + """ + Given an instruction, uses the model and tokenizer to generate a response. This formats the instruction in the instruction format that the model was fine-tuned on. Args: - instruction (str): _description_ - model (PreTrainedModel): the model to use - tokenizer (PreTrainedTokenizer): the tokenizer to use + instruction (str): The instruction for text generation. + model (PreTrainedModel): The pretrained model to use for text generation. + tokenizer (PreTrainedTokenizer): The tokenizer associated with the pretrained model. + **kwargs: Additional keyword arguments for text generation. Returns: - str: response + str: The generated response based on the provided instruction. + + Example: + response = generate_response("Write a summary of a book.", model=my_model, tokenizer=my_tokenizer) """ generation_pipeline = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer, **kwargs) diff --git a/python/fedml/serving/fedml_client.py b/python/fedml/serving/fedml_client.py index 69ca57a923..0c88d1a4ee 100644 --- a/python/fedml/serving/fedml_client.py +++ b/python/fedml/serving/fedml_client.py @@ -3,9 +3,49 @@ class FedMLModelServingClient: + """ + Client for Federated Machine Learning Model Serving. + + This class is responsible for initializing and running the client for federated machine learning model serving. + + Args: + args: An instance of arguments containing configuration settings. + end_point_name: The name of the model serving endpoint. + model_name: The name of the machine learning model. + model_version: The version of the machine learning model. + inference_request: An optional inference request configuration. + device: The device (e.g., 'cuda:0') to run the client on. + dataset: The dataset used for training and testing the model. + model: The machine learning model to be used. + model_trainer: An optional client trainer for model training. + + Attributes: + end_point_name: The name of the model serving endpoint. + model_name: The name of the machine learning model. + model_version: The version of the machine learning model. + inference_request: An optional inference request configuration. + + Methods: + run(): Start the client for federated machine learning model serving. + """ + def __init__(self, args, end_point_name, model_name, model_version, inference_request=None, device=None, dataset=None, model=None, model_trainer: ClientTrainer = None): + """ + Initializes the FedMLModelServingClient. + + Args: + args: An instance of arguments containing configuration settings. + end_point_name: The name of the model serving endpoint. + model_name: The name of the machine learning model. + model_version: The version of the machine learning model. + inference_request: An optional inference request configuration. + device: The device (e.g., 'cuda:0') to run the client on. + dataset: The dataset used for training and testing the model. + model: The machine learning model to be used. + model_trainer: An optional client trainer for model training. + """ self.end_point_name = end_point_name self.model_name = model_name self.model_version = model_version @@ -36,7 +76,15 @@ def __init__(self, args, end_point_name, model_name, model_version, model_trainer, ) else: - raise Exception("Exception") + raise Exception("Unsupported federated optimizer") def run(self): + """ + Start the client for federated machine learning model serving. + + This method initializes and runs the client for federated machine learning model serving. + + Returns: + None + """ pass diff --git a/python/fedml/serving/fedml_inference_runner.py b/python/fedml/serving/fedml_inference_runner.py index 5257bf75e7..d87309f347 100644 --- a/python/fedml/serving/fedml_inference_runner.py +++ b/python/fedml/serving/fedml_inference_runner.py @@ -2,22 +2,65 @@ from fastapi import FastAPI, Request class FedMLInferenceRunner(ABC): + """ + Abstract base class for federated machine learning inference runners. + + Subclasses should implement the `predict` method for making predictions. + + Attributes: + client_predictor: An instance of a client predictor class that implements the `predict` method. + + Methods: + run(): Start the FastAPI server to handle prediction requests. + """ + def __init__(self, client_predictor): + """ + Initializes the FedMLInferenceRunner. + + Args: + client_predictor: An instance of a client predictor class that implements the `predict` method. + """ self.client_predictor = client_predictor def run(self): + """ + Start the FastAPI server to handle prediction requests. + + This method creates an HTTP server using FastAPI and defines two routes: '/predict' for making predictions + and '/ready' to check the server's readiness. + + Returns: + None + """ api = FastAPI() + @api.post("/predict") async def predict(request: Request): + """ + Handle POST requests to the '/predict' route for making predictions. + + Args: + request: The HTTP request object containing the input data. + + Returns: + dict: A JSON response containing the generated text. + """ input_json = await request.json() response_text = self.client_predictor.predict(input_json) - + return {"generated_text": str(response_text)} @api.get("/ready") async def ready(): + """ + Handle GET requests to the '/ready' route to check the server's readiness. + + Returns: + dict: A JSON response indicating the server's readiness status. + """ return {"status": "Success"} import uvicorn port = 2345 - uvicorn.run(api, host="0.0.0.0", port=port) \ No newline at end of file + uvicorn.run(api, host="0.0.0.0", port=port) diff --git a/python/fedml/serving/fedml_predictor.py b/python/fedml/serving/fedml_predictor.py index 3f6fb26023..84f26e138a 100644 --- a/python/fedml/serving/fedml_predictor.py +++ b/python/fedml/serving/fedml_predictor.py @@ -8,9 +8,111 @@ from ..computing.scheduler.comm_utils import sys_utils class FedMLPredictor(ABC): + """ + Abstract base class for federated machine learning predictors. + + Subclasses should implement the `predict` method for making predictions. + + Attributes: + None + + Methods: + predict(*args, **kwargs): Abstract method for making predictions. + """ + def __init__(self): - pass + """ + Initializes the FedMLPredictor. + + This constructor can be extended by subclasses as needed. + """ + build_dynamic_args() @abstractmethod def predict(self, *args, **kwargs): - pass \ No newline at end of file + """ + Abstract method for making predictions. + + Subclasses should implement this method to define the prediction logic. + + Args: + *args: Variable-length arguments. + **kwargs: Keyword arguments. + + Returns: + None + """ + pass + +def build_dynamic_args(): + """ + Builds dynamic arguments based on environment variables. + + This function checks for environment variables related to a bootstrap script and executes it if found. + + Args: + None + + Returns: + bool: True if the bootstrap script runs successfully, False otherwise. + """ + DEFAULT_BOOTSTRAP_FULL_DIR = os.environ.get("BOOTSTRAP_DIR", None) + if DEFAULT_BOOTSTRAP_FULL_DIR is None or DEFAULT_BOOTSTRAP_FULL_DIR == "": + return + + print("DEFAULT_BOOTSTRAP_FULL_DIR: {}".format(DEFAULT_BOOTSTRAP_FULL_DIR)) + + DEFAULT_BOOTSTRAP_SCRIPT_DIR = os.path.dirname(DEFAULT_BOOTSTRAP_FULL_DIR) + DEFAULT_BOOTSTRAP_SCRIPT_PATH = os.path.dirname(DEFAULT_BOOTSTRAP_FULL_DIR) + DEFAULT_BOOTSTRAP_SCRIPT_FILE = os.path.basename(DEFAULT_BOOTSTRAP_FULL_DIR) + bootstrap_script_dir = DEFAULT_BOOTSTRAP_SCRIPT_DIR + bootstrap_script_path = DEFAULT_BOOTSTRAP_SCRIPT_PATH + bootstrap_script_file = DEFAULT_BOOTSTRAP_SCRIPT_FILE + + is_bootstrap_run_ok = True + print("bootstrap_script_dir: {}".format(bootstrap_script_dir)) + print("bootstrap_script_path: {}".format(bootstrap_script_path)) + print("bootstrap_script_file: {}".format(bootstrap_script_file)) + try: + if bootstrap_script_path is not None: + if os.path.exists(bootstrap_script_path): + bootstrap_stat = os.stat(bootstrap_script_path) + if platform.system() == 'Windows': + os.chmod(bootstrap_script_path, + bootstrap_stat.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + bootstrap_scripts = "{}".format(bootstrap_script_path) + else: + os.chmod(bootstrap_script_path, + bootstrap_stat.st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + bootstrap_scripts = "cd {}; sh {}".format(bootstrap_script_dir, # Use sh over ./ to avoid permission denied error + os.path.basename(bootstrap_script_file)) + bootstrap_scripts = str(bootstrap_scripts).replace('\\', os.sep).replace('/', os.sep) + + process = ClientConstants.exec_console_with_script(bootstrap_scripts, should_capture_stdout=True, + should_capture_stderr=True) + # ClientConstants.save_bootstrap_process(run_id, process.pid) + ret_code, out, err = ClientConstants.get_console_pipe_out_err_results(process) + + if ret_code is None or ret_code <= 0: + if out is not None: + out_str = sys_utils.decode_our_err_result(out) + if out_str != "": + logging.info("{}".format(out_str)) + + sys_utils.log_return_info(bootstrap_script_file, 0) + + is_bootstrap_run_ok = True + else: + if err is not None: + err_str = sys_utils.decode_our_err_result(err) + if err_str != "": + logging.error("{}".format(err_str)) + + sys_utils.log_return_info(bootstrap_script_file, ret_code) + + is_bootstrap_run_ok = False + except Exception as e: + logging.error("Bootstrap script error: {}".format(traceback.format_exc())) + is_bootstrap_run_ok = False + + return is_bootstrap_run_ok diff --git a/python/fedml/serving/fedml_server.py b/python/fedml/serving/fedml_server.py index 3663755102..1ba0cfc682 100644 --- a/python/fedml/serving/fedml_server.py +++ b/python/fedml/serving/fedml_server.py @@ -2,9 +2,36 @@ class FedMLModelServingServer: + """ + Represents a server for serving federated machine learning models. + + This class initializes and manages the server-side functionality for serving federated models + in a federated learning system. + + Args: + args (object): Configuration arguments for the server. + end_point_name (str): The name of the endpoint for serving the model. + model_name (str): The name of the federated model. + model_version (str): The version of the federated model. + inference_request (object, optional): An inference request object for making predictions. + device (str, optional): The hardware device to use for inference (e.g., 'cpu' or 'cuda'). + dataset (list, optional): A list containing dataset-related information. + model (object, optional): The federated machine learning model. + server_aggregator (ServerAggregator, optional): The server aggregator for model aggregation. + + Methods: + run(): Starts the server and serves the federated model for inference. + + Note: + This class is designed for serving federated models in a federated learning system. + """ + def __init__(self, args, end_point_name, model_name, model_version, inference_request=None, device=None, dataset=None, model=None, server_aggregator: ServerAggregator = None): + """ + Initializes a Federated Model Serving Server instance. + """ self.end_point_name = end_point_name self.model_name = model_name self.model_version = model_version @@ -42,4 +69,7 @@ def __init__(self, args, end_point_name, model_name, model_version, raise Exception("Exception") def run(self): + """ + Starts the server and serves the federated model for inference. + """ pass diff --git a/python/fedml/serving/server/fedml_aggregator.py b/python/fedml/serving/server/fedml_aggregator.py index 08f4ead226..9cb2caa5ec 100644 --- a/python/fedml/serving/server/fedml_aggregator.py +++ b/python/fedml/serving/server/fedml_aggregator.py @@ -11,6 +11,22 @@ class FedMLAggregator(object): + """ + A class for federated machine learning aggregation and related tasks. + + Args: + train_global: Global training data. + test_global: Global testing data. + all_train_data_num: Number of samples in the entire training dataset. + train_data_local_dict: Local training data dictionary. + test_data_local_dict: Local testing data dictionary. + train_data_local_num_dict: Number of local samples for each client. + client_num: Number of clients. + device: Device to run computations (e.g., 'cuda' or 'cpu'). + args: Additional configuration arguments. + server_aggregator: Aggregator for server-side operations. + """ + def __init__( self, train_global, @@ -49,15 +65,35 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + dict: Global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (dict): Global model parameters. + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add locally trained model parameters for aggregation. + + Args: + index (int): Index of the client. + model_params (dict): Local model parameters. + sample_num (int): Number of local samples used for training. + """ logging.info("add_model. index = %d" % index) - # for dictionary model_params, we let the user level code to control the device + # For dictionary model_params, let the user-level code control the device if type(model_params) is not dict: model_params = ml_engine_adapter.model_params_to_device(self.args, model_params, self.device) @@ -66,6 +102,12 @@ def add_local_trained_result(self, index, model_params, sample_num): self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.debug("client_num = {}".format(self.client_num)) for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -75,20 +117,29 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models from clients to obtain a global model. + + Returns: + tuple: A tuple containing: + - dict: Averaged global model parameters. + - list: List of model tuples before aggregation. + - list: List of indices corresponding to selected models for aggregation. + """ start_time = time.time() model_list = [] for idx in range(self.client_num): model_list.append((self.sample_num_dict[idx], self.model_dict[idx])) - # model_list is the list after outlier removal + # Model list is the list after outlier removal model_list, model_list_idxes = self.aggregator.on_before_aggregation(model_list) Context().add(Context.KEY_CLIENT_MODEL_LIST, model_list) averaged_params = self.aggregator.aggregate(model_list) if type(averaged_params) is dict: - if len(averaged_params) == self.client_num + 1: # aggregator pass extra {-1 : global_parms_dict} as global_params - itr_count = len(averaged_params) - 1 # do not apply on_after_aggregation to client -1 + if len(averaged_params) == self.client_num + 1: # Aggregator passes extra {-1: global_parms_dict} as global_params + itr_count = len(averaged_params) - 1 # Do not apply on_after_aggregation to client -1 else: itr_count = len(averaged_params) @@ -104,23 +155,24 @@ def aggregate(self): return averaged_params, model_list, model_list_idxes def assess_contribution(self): + """ + Assess the contribution of clients to the global model. + """ if hasattr(self.args, "enable_contribution") and \ self.args.enable_contribution is not None and self.args.enable_contribution: self.aggregator.assess_contribution() def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_round): """ + Select a subset of data silos (clients) for a federated learning round. Args: - round_idx: round index, starting from 0 - client_num_in_total: this is equal to the users in a synthetic data, - e.g., in synthetic_1_1, this value is 30 - client_num_per_round: the number of edge devices that can train + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select for the current round. Returns: - data_silo_index_list: e.g., when client_num_in_total = 30, client_num_in_total = 3, - this value is the form of [0, 11, 20] - + list: List of selected data silo indices. """ logging.info( "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) @@ -130,39 +182,59 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): """ + Select a subset of clients for a federated learning round. + Args: - round_idx: round index, starting from 0 - client_id_list_in_total: this is the real edge IDs. - In MLOps, its element is real edge ID, e.g., [64, 65, 66, 67]; - in simulated mode, its element is client index starting from 1, e.g., [1, 2, 3, 4] - client_num_per_round: + round_idx (int): Round index, starting from 0. + client_id_list_in_total (list): List of real edge IDs or client indices. + client_num_per_round (int): Number of clients to select for the current round. Returns: - client_id_list_in_this_round: sampled real edge ID list, e.g., [64, 66] + list: List of selected client IDs or indices. """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for a federated learning round. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample for the current round. + + Returns: + list: List of sampled client indices. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set for model evaluation. + + Args: + num_samples (int, optional): Number of samples to include in the validation set. Defaults to 10000. + + Returns: + DataLoader: DataLoader containing the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) @@ -173,6 +245,12 @@ def _generate_validation_set(self, num_samples=10000): return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform model testing on the server for all clients in the current round. + + Args: + round_idx (int): Round index. + """ if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) self.aggregator.test_all( @@ -183,7 +261,7 @@ def test_on_server_for_all_clients(self, round_idx): ) if round_idx == self.args.comm_round - 1: - # we allow to return four metrics, such as accuracy, AUC, loss, etc. + # Allow returning multiple metrics (e.g., accuracy, AUC, loss, etc.) in the final round metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) else: metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) @@ -200,25 +278,39 @@ def test_on_server_for_all_clients(self, round_idx): mlops.log({"round_idx": round_idx}) def get_dummy_input_tensor(self): + """ + Get a dummy input tensor from the test data. + + Returns: + list: List of dummy input tensors. + """ test_data = None if self.test_global: test_data = self.test_global - else: # if test_global is None, then we use the first non-empty test_data_local_dict + else: # If test_global is None, use the first non-empty test_data_local_dict for k, v in self.test_data_local_dict.items(): if v: test_data = v break with torch.no_grad(): - batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> dataloader obj + batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> DataLoader object dummy_list = [] for tensor in features_label_tensors: - dummy_tensor = tensor[:1] # only take the first element as dummy input + dummy_tensor = tensor[:1] # Only take the first element as dummy input dummy_list.append(dummy_tensor) - features = dummy_list[:-1] # Can adapt Process Multi-Label + features = dummy_list[:-1] # Can adapt to process multi-label data return features def get_input_shape_type(self): + """ + Get the shapes and types of input features in the test data. + + Returns: + tuple: A tuple containing: + - list: List of input feature shapes. + - list: List of input feature types ('int' or 'float'). + """ test_data = None if self.test_global: test_data = self.test_global @@ -248,10 +340,24 @@ def get_input_shape_type(self): return input_shape, input_type + def save_dummy_input_tensor(self): + """ + Save the dummy input tensor information to a file. + + This function saves the input shape and type information to a file named 'dummy_input_tensor.pkl'. + The saved file can be used for reference or documentation purposes. + + Note: To save the file to a specific location (e.g., S3), additional implementation is required. + + Example: + To save to a specific location (e.g., S3), you can modify this function to upload the file accordingly. + + """ import pickle - features = self.get_input_size_type() + features = self.get_input_shape_type() with open('dummy_input_tensor.pkl', 'wb') as handle: pickle.dump(features, handle) - # TODO: save the dummy_input_tensor.pkl to s3, and transfer when click "Create Model Card" + # TODO: Save the 'dummy_input_tensor.pkl' to S3 or another desired location, and transfer it when needed. + \ No newline at end of file diff --git a/python/fedml/serving/server/fedml_server_manager.py b/python/fedml/serving/server/fedml_server_manager.py index 9e871c6ff6..93eb32380a 100644 --- a/python/fedml/serving/server/fedml_server_manager.py +++ b/python/fedml/serving/server/fedml_server_manager.py @@ -13,11 +13,26 @@ class FedMLServerManager(FedMLCommManager): + """ + Manages the server-side operations for federated machine learning. + + This class handles communication with clients, aggregation of model updates, + and the overall server-side federated learning process. + + Args: + args: Configuration arguments for the server. + aggregator: Aggregator for model updates from clients. + comm: Communication backend (e.g., MQTT, S3). + client_rank: Rank of the client. + client_num: Total number of clients. + backend: Communication backend (default is "MQTT_S3"). + """ + ONLINE_STATUS_FLAG = "ONLINE" RUN_FINISHED_STATUS_FLAG = "FINISHED" def __init__( - self, args, aggregator, comm=None, client_rank=0, client_num=0, backend="MQTT_S3", + self, args, aggregator, comm=None, client_rank=0, client_num=0, backend="MQTT_S3", ): super().__init__(args, comm, client_rank, client_num, backend) self.args = args @@ -35,9 +50,15 @@ def __init__( self.data_silo_index_list = None def run(self): + """ + Start the federated server manager. + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients to start the training process. + """ global_model_params = self.aggregator.get_global_model_params() global_model_url = None @@ -54,26 +75,29 @@ def send_init_msg(self): mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) try: - # get input type and shape for inference + # Get input type and shape for inference dummy_input_tensor = self.aggregator.get_dummy_input_tensor() logging.info(f"dummy tensor: {dummy_input_tensor}") # sample tensor for ONNX if not getattr(self.args, "skip_log_model_net", False): model_net_url = mlops.log_training_model_net_info(self.aggregator.aggregator.model, dummy_input_tensor) - # type and shape for later configuration + # Type and shape for later configuration input_shape, input_type = self.aggregator.get_input_shape_type() logging.info(f"input shape: {input_shape}") # [torch.Size([1, 24]), torch.Size([1, 2])] logging.info(f"input type: {input_type}") # [torch.int64, torch.float32] - # Send output input size and type (saved as json) to s3, - # and transfer when click "Create Model Card" + # Send output input size and type (saved as json) to S3, + # and transfer when clicking "Create Model Card" model_input_url = mlops.log_training_model_input_info(list(input_shape), list(input_type)) except Exception as e: logging.info("exception when processing model net and model input info: {}".format( traceback.format_exc())) def register_message_receive_handlers(self): + """ + Register message handlers for different message types. + """ logging.info("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready @@ -88,17 +112,33 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handles the connection readiness message from clients and initiates the federated learning process. + + This method is called when the server receives a message indicating that clients are ready to connect. + It selects the clients for the current round and checks their status. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( - self.args.round_idx, self.client_real_ids, self.args.client_num_per_round + self.args.round_idx, + self.client_real_ids, + self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.args.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.args.round_idx, + self.args.client_num_in_total, + len(self.client_id_list_in_this_round), ) mlops.log_round_info(self.round_num, -1) - # check client status in case that some clients start earlier than the server + # Check client status in case that some clients start earlier than the server client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: try: @@ -111,6 +151,19 @@ def handle_message_connection_ready(self, msg_params): client_idx_in_this_round += 1 def process_online_status(self, client_status, msg_params): + """ + Processes online status messages from clients. + + This method is called when the server receives an online status message from a client. + It updates the client online mapping and checks if all clients are online. + + Args: + client_status: The client's online status. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_online_mapping[str(msg_params.get_sender_id())] = True logging.info("self.client_online_mapping = {}".format(self.client_online_mapping)) @@ -128,11 +181,24 @@ def process_online_status(self, client_status, msg_params): if all_client_is_online: mlops.log_aggregation_status(MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) - # send initialization message to all clients to start training + # Send initialization message to all clients to start training self.send_init_msg() self.is_initialized = True def process_finished_status(self, client_status, msg_params): + """ + Processes finished status messages from clients. + + This method is called when the server receives a finished status message from a client. + It updates the client finished mapping and checks if all clients have finished. + + Args: + client_status: The client's finished status. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_finished_mapping[str(msg_params.get_sender_id())] = True all_client_is_finished = True @@ -151,6 +217,18 @@ def process_finished_status(self, client_status, msg_params): self.finish() def handle_message_client_status_update(self, msg_params): + """ + Handles client status update messages. + + This method is called when the server receives a client status update message. + It processes the received client status and takes appropriate actions. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) logging.info(f"received client status {client_status}") if client_status == FedMLServerManager.ONLINE_STATUS_FLAG: @@ -159,6 +237,18 @@ def handle_message_client_status_update(self, msg_params): self.process_finished_status(client_status, msg_params) def handle_message_receive_model_from_client(self, msg_params): + """ + Handles messages containing trained models received from clients. + + This method is called when the server receives a message containing a trained model from a client. + It processes the received model, performs aggregation, and sends updated models to clients. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) @@ -191,7 +281,7 @@ def handle_message_receive_model_from_client(self, msg_params): mlops.event("server.agg_and_eval", event_started=False, event_value=str(self.args.round_idx)) - # send round info to the MQTT backend + # Send round info to the MQTT backend mlops.log_round_info(self.round_num, self.args.round_idx) self.client_id_list_in_this_round = self.aggregator.client_selection( @@ -211,7 +301,7 @@ def handle_message_receive_model_from_client(self, msg_params): for receiver_id in self.client_id_list_in_this_round: client_index = self.data_silo_index_list[client_idx_in_this_round] if type(global_model_params) is dict: - # compatible with the old version that, user did not give {-1 : global_parms_dict} + # Compatible with the old version that user did not give {-1 : global_parms_dict} global_model_url, global_model_key = self.send_message_diff_sync_model_to_client( receiver_id, global_model_params[client_index], client_index ) @@ -221,7 +311,7 @@ def handle_message_receive_model_from_client(self, msg_params): ) client_idx_in_this_round += 1 - # if user give {-1 : global_parms_dict}, then record global_model url separately + # If the user gives {-1 : global_parms_dict}, then record global_model url separately if type(global_model_params) is dict and (-1 in global_model_params.keys()): global_model_url, global_model_key = self.send_message_diff_sync_model_to_client( -1, global_model_params[-1], -1 @@ -230,13 +320,21 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx += 1 mlops.log_aggregated_model_info( self.args.round_idx, model_url=global_model_url, - ) + ) logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) if self.args.round_idx < self.round_num: mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) def cleanup(self): + """ + Cleans up after a round of federated learning. + + This method is called to clean up resources and send finish messages to clients. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -245,7 +343,23 @@ def cleanup(self): client_idx_in_this_round += 1 def send_message_init_config(self, receive_id, global_model_params, datasilo_index, - global_model_url=None, global_model_key=None): + global_model_url=None, global_model_key=None): + """ + Sends initialization configuration message to a client. + + This method constructs and sends an initialization configuration message to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + global_model_params: Global model parameters to be sent. + datasilo_index: Index of the data silo associated with the client. + global_model_url: URL of the global model (optional). + global_model_key: Key of the global model (optional). + + Returns: + global_model_url: URL of the global model (if provided). + global_model_key: Key of the global model (if provided). + """ tick = time.time() message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) if global_model_url is not None: @@ -262,11 +376,35 @@ def send_message_init_config(self, receive_id, global_model_params, datasilo_ind return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): + """ + Sends a message to check the status of a client. + + This method constructs and sends a message to check the status of a specified client. + + Args: + receive_id: The ID of the client to receive the message. + datasilo_index: Index of the data silo associated with the client. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): + """ + Sends a finish message to a client. + + This method constructs and sends a finish message to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + datasilo_index: Index of the data silo associated with the client. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) @@ -275,7 +413,23 @@ def send_message_finish(self, receive_id, datasilo_index): logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index, - global_model_url=None, global_model_key=None): + global_model_url=None, global_model_key=None): + """ + Sends a synchronized model to a client. + + This method constructs and sends a message containing synchronized model parameters to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + global_model_params: The synchronized global model parameters to be sent. + client_index: Index of the client associated with the model. + global_model_url: URL for the global model parameters (optional). + global_model_key: Key for the global model parameters (optional). + + Returns: + global_model_url: URL for the global model parameters. + global_model_key: Key for the global model parameters. + """ tick = time.time() logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) @@ -296,6 +450,20 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, cli return global_model_url, global_model_key def send_message_diff_sync_model_to_client(self, receive_id, client_model_params, client_index): + """ + Sends a differentiated synchronized model to a client. + + This method constructs and sends a message containing differentiated synchronized model parameters to a specified client. + + Args: + receive_id: The ID of the client to receive the message. + client_model_params: The differentiated synchronized model parameters to be sent. + client_index: Index of the client associated with the model. + + Returns: + global_model_url: URL for the global model parameters. + global_model_key: Key for the global model parameters. + """ tick = time.time() logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) @@ -309,4 +477,4 @@ def send_message_diff_sync_model_to_client(self, receive_id, client_model_params global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) - return global_model_url, global_model_key \ No newline at end of file + return global_model_url, global_model_key diff --git a/python/fedml/serving/server/message_define.py b/python/fedml/serving/server/message_define.py index 1c1db66741..e10b68f760 100644 --- a/python/fedml/serving/server/message_define.py +++ b/python/fedml/serving/server/message_define.py @@ -1,30 +1,29 @@ class MyMessage(object): """ - message type definition + Defines message types and their associated constants for communication between server and clients. """ - # connection info + # Connection Info MSG_TYPE_CONNECTION_IS_READY = 0 - # server to client + # Server to Client Messages MSG_TYPE_S2C_INIT_CONFIG = 1 MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT = 2 MSG_TYPE_S2C_CHECK_CLIENT_STATUS = 6 MSG_TYPE_S2C_FINISH = 7 - # client to server + # Client to Server Messages MSG_TYPE_C2S_SEND_MODEL_TO_SERVER = 3 MSG_TYPE_C2S_SEND_STATS_TO_SERVER = 4 MSG_TYPE_C2S_CLIENT_STATUS = 5 MSG_TYPE_C2S_FINISHED = 8 + # Message Argument Keys MSG_ARG_KEY_TYPE = "msg_type" MSG_ARG_KEY_SENDER = "sender" MSG_ARG_KEY_RECEIVER = "receiver" - """ - message payload keywords definition - """ + # Message Payload Keywords MSG_ARG_KEY_NUM_SAMPLES = "num_samples" MSG_ARG_KEY_MODEL_PARAMS = "model_params" MSG_ARG_KEY_MODEL_PARAMS_URL = "model_params_url" @@ -41,14 +40,12 @@ class MyMessage(object): MSG_ARG_KEY_CLIENT_STATUS = "client_status" MSG_ARG_KEY_CLIENT_OS = "client_os" - + MSG_ARG_KEY_EVENT_NAME = "event_name" MSG_ARG_KEY_EVENT_VALUE = "event_value" MSG_ARG_KEY_EVENT_MSG = "event_msg" - """ - MLOps related message - """ + # MLOps Related Messages # Client Status MSG_MLOPS_CLIENT_STATUS_IDLE = "IDLE" MSG_MLOPS_CLIENT_STATUS_UPGRADING = "UPGRADING" diff --git a/python/fedml/serving/server/server_initializer.py b/python/fedml/serving/server/server_initializer.py index 5877d96fea..941e486fd8 100644 --- a/python/fedml/serving/server/server_initializer.py +++ b/python/fedml/serving/server/server_initializer.py @@ -16,13 +16,31 @@ def init_server( train_data_local_dict, test_data_local_dict, train_data_local_num_dict, - server_aggregator, + server_aggregator=None, ): + """ + Initialize and start the server for federated machine learning. + + Args: + args: Configuration arguments for the server. + device: The device (e.g., GPU) to be used for computation. + comm: Communication module for distributed computing. + rank: The rank of the server in the communication group. + worker_num: The number of worker nodes in the federated setup. + model: The machine learning model to be used. + train_data_num: The number of training data samples. + train_data_global: The global training dataset. + test_data_global: The global test dataset. + train_data_local_dict: Dictionary of local training datasets for workers. + test_data_local_dict: Dictionary of local test datasets for workers. + train_data_local_num_dict: Dictionary of the number of local training samples for workers. + server_aggregator: The aggregator responsible for aggregating model updates (default: None). + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(0) - # aggregator + # Create the aggregator aggregator = FedMLAggregator( train_data_global, test_data_global, @@ -36,7 +54,7 @@ def init_server( server_aggregator, ) - # start the distributed training + # Start the distributed training backend = args.backend server_manager = FedMLServerManager(args, aggregator, comm, rank, worker_num, backend) server_manager.run() From 3f23f67d168d8d2a2c1e8181db80616cdd251bf7 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 9 Sep 2023 23:21:48 +0530 Subject: [PATCH 55/70] j --- python/fedml/model/cv/resnet.py | 123 ++++++++++++++++-- .../serving/server/fedml_server_manager.py | 1 + 2 files changed, 111 insertions(+), 13 deletions(-) diff --git a/python/fedml/model/cv/resnet.py b/python/fedml/model/cv/resnet.py index 17cf6a622c..d833e3b762 100644 --- a/python/fedml/model/cv/resnet.py +++ b/python/fedml/model/cv/resnet.py @@ -17,7 +17,18 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation factor for convolution. Default is 1. + + Returns: + nn.Conv2d: 3x3 convolutional layer. + """ return nn.Conv2d( in_planes, out_planes, @@ -31,11 +42,22 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: 1x1 convolutional layer. + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """Basic residual block used in ResNet architectures.""" + expansion = 1 def __init__( @@ -49,6 +71,18 @@ def __init__( dilation=1, norm_layer=None, ): + """Initialize a BasicBlock instance. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connections. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for the convolution. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + """ super(BasicBlock, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -56,16 +90,27 @@ def __init__( raise ValueError("BasicBlock only supports groups=1 and base_width=64") if dilation > 1: raise NotImplementedError("Dilation > 1 not supported in BasicBlock") - # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + + # Define the convolutional layers self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = norm_layer(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = conv3x3(planes, planes) self.bn2 = norm_layer(planes) + + # Downsample layer for shortcut connections self.downsample = downsample self.stride = stride def forward(self, x): + """Forward pass through the BasicBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -98,22 +143,47 @@ def __init__( dilation=1, norm_layer=None, ): + """Initialize a Bottleneck block used in ResNet architectures. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connections. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for the convolution. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + """ super(Bottleneck, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d width = int(planes * (base_width / 64.0)) * groups - # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + + # Define the three convolutional layers self.conv1 = conv1x1(inplanes, width) self.bn1 = norm_layer(width) self.conv2 = conv3x3(width, width, stride, groups, dilation) self.bn2 = norm_layer(width) self.conv3 = conv1x1(width, planes * self.expansion) self.bn3 = norm_layer(planes * self.expansion) + + # ReLU activation function self.relu = nn.ReLU(inplace=True) + + # Downsample layer for shortcut connections self.downsample = downsample self.stride = stride def forward(self, x): + """Forward pass through the Bottleneck block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -135,7 +205,6 @@ def forward(self, x): return out - class ResNet(nn.Module): def __init__( self, @@ -149,6 +218,19 @@ def __init__( norm_layer=None, KD=False, ): + """Initialize a ResNet model. + + Args: + block (nn.Module): The residual block type, either BasicBlock or Bottleneck. + layers (list): List of integers indicating the number of blocks in each layer. + num_classes (int, optional): Number of output classes. Default is 10. + zero_init_residual (bool, optional): If True, zero-initialize the last BN in each residual branch. Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Base width for the convolution. Default is 64. + replace_stride_with_dilation (tuple, optional): Replace stride with dilation in certain stages. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + KD (bool, optional): Whether to perform Knowledge Distillation. Default is False. + """ super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -173,7 +255,6 @@ def __init__( ) self.bn1 = nn.BatchNorm2d(self.inplanes) self.relu = nn.ReLU(inplace=True) - # self.maxpool = nn.MaxPool2d() self.layer1 = self._make_layer(block, 16, layers[0]) self.layer2 = self._make_layer(block, 32, layers[1], stride=2) self.layer3 = self._make_layer(block, 64, layers[2], stride=2) @@ -197,6 +278,19 @@ def __init__( nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + """ + Create a layer in the ResNet model. + + Args: + block (nn.Module): The residual block type, either BasicBlock or Bottleneck. + planes (int): Number of output channels for the layer. + blocks (int): Number of residual blocks in the layer. + stride (int, optional): The stride for the convolutional layers. Default is 1. + dilate (bool, optional): Whether to apply dilation to the convolutional layers. Default is False. + + Returns: + nn.Sequential: A sequential layer containing the specified number of residual blocks. + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -238,6 +332,14 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """Forward pass through the ResNet model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) # B x 16 x 32 x 32 @@ -245,13 +347,8 @@ def forward(self, x): x = self.layer2(x) # B x 32 x 16 x 16 x = self.layer3(x) # B x 64 x 8 x 8 - x = self.avgpool(x) # B x 64 x 1 x 1 - x_f = x.view(x.size(0), -1) # B x 64 - x = self.fc(x_f) # B x num_classes - if self.KD == True: - return x_f, x - else: - return x + x = self.avgpool(x) # B + def resnet20(class_num, pretrained=False, path=None, **kwargs): diff --git a/python/fedml/serving/server/fedml_server_manager.py b/python/fedml/serving/server/fedml_server_manager.py index 93eb32380a..b2fb7cc9c0 100644 --- a/python/fedml/serving/server/fedml_server_manager.py +++ b/python/fedml/serving/server/fedml_server_manager.py @@ -249,6 +249,7 @@ def handle_message_receive_model_from_client(self, msg_params): Returns: None """ + sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) From b2f820a7958ecb7448e3b965eb41289fe1923188 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sun, 10 Sep 2023 11:06:07 +0530 Subject: [PATCH 56/70] qdd ds --- python/fedml/model/cv/common.py | 338 +++++++++++++++++++++++++- python/fedml/model/cv/resnet.py | 16 +- python/fedml/model/cv/resnet_torch.py | 165 ++++++++++++- 3 files changed, 498 insertions(+), 21 deletions(-) diff --git a/python/fedml/model/cv/common.py b/python/fedml/model/cv/common.py index bcd3e452ff..267bb4494d 100644 --- a/python/fedml/model/cv/common.py +++ b/python/fedml/model/cv/common.py @@ -37,60 +37,190 @@ def round_channels(channels, ------- int Weighted number of channels. + + Examples: + -------- + >>> channels = 64 + >>> rounded_channels = round_channels(channels) + >>> print(rounded_channels) + 64 + + >>> channels = 57 + >>> rounded_channels = round_channels(channels) + >>> print(rounded_channels) + 56 """ rounded_channels = max(int(channels + divisor / 2.0) // divisor * divisor, divisor) if float(rounded_channels) < 0.9 * channels: rounded_channels += divisor - return rounded_channels + return rounded_channel class Identity(nn.Module): """ Identity block. + + This block represents the identity function, which means it does not perform any + operations on the input and simply returns it unchanged. It is commonly used in + residual neural networks (ResNets) to create skip connections. + + Attributes: + None + + Methods: + forward(x): Performs a forward pass of the identity block. + __repr__(): Returns a string representation of the Identity block. + + Examples: + >>> identity_block = Identity() + >>> x = torch.randn(1, 64, 32, 32) + >>> output = identity_block(x) + >>> assert torch.allclose(x, output) # The output should be the same as the input. + """ def __init__(self): super(Identity, self).__init__() def forward(self, x): + """ + Forward pass of the identity block. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The input tensor unchanged. + """ return x def __repr__(self): + """ + String representation of the Identity block. + + Returns: + str: A string representing the Identity block. + """ return '{name}()'.format(name=self.__class__.__name__) class BreakBlock(nn.Module): """ - Break coonnection block for hourglass. + Break connection block for hourglass network. + + This block serves as a break in the network's connections. It takes an input and returns None. + It is commonly used in hourglass-style networks to create skips in the network flow. + + Attributes: + ---------- + None + + Methods: + ------- + forward(x): + Forward pass through the block. + + __repr__(): + Returns a string representation of the block. """ def __init__(self): super(BreakBlock, self).__init__() def forward(self, x): + """ + Forward pass through the block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + None + The block returns None, effectively breaking the connection. + """ return None def __repr__(self): + """ + Returns a string representation of the block. + + Returns: + ------- + str + A string representation of the block, indicating its name. + """ return '{name}()'.format(name=self.__class__.__name__) + class Swish(nn.Module): """ Swish activation function from 'Searching for Activation Functions,' https://arxiv.org/abs/1710.05941. + + This activation function is defined as Swish(x) = x * sigmoid(x). + + Attributes: + ---------- + None + + Methods: + ------- + forward(x): + Forward pass through the Swish activation function. """ def forward(self, x): + """ + Forward pass through the Swish activation function. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor after applying the Swish activation function. + """ return x * torch.sigmoid(x) class HSigmoid(nn.Module): """ - Approximated sigmoid function, so-called hard-version of sigmoid from 'Searching for MobileNetV3,' + Approximated sigmoid function, the hard version of sigmoid, from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244. + + This activation function is defined as HSigmoid(x) = relu6(x + 3.0) / 6.0. + + Attributes: + ---------- + None + + Methods: + ------- + forward(x): + Forward pass through the HSigmoid activation function. """ def forward(self, x): + """ + Forward pass through the HSigmoid activation function. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor after applying the HSigmoid activation function. + """ return F.relu6(x + 3.0, inplace=True) / 6.0 @@ -98,10 +228,22 @@ class HSwish(nn.Module): """ H-Swish activation function from 'Searching for MobileNetV3,' https://arxiv.org/abs/1905.02244. + This activation function is defined as HSwish(x) = x * relu6(x + 3.0) / 6.0. + Parameters: ---------- + inplace : bool, optional (default=False) + Whether to use the inplace version of the module. + + Attributes: + ---------- inplace : bool - Whether to use inplace version of the module. + Indicates whether the inplace version is used. + + Methods: + ------- + forward(x): + Forward pass through the H-Swish activation function. """ def __init__(self, inplace=False): @@ -109,24 +251,42 @@ def __init__(self, inplace=False): self.inplace = inplace def forward(self, x): + """ + Forward pass through the H-Swish activation function. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor after applying the H-Swish activation function. + """ return x * F.relu6(x + 3.0, inplace=self.inplace) / 6.0 def get_activation_layer(activation): """ - Create activation layer from string/function. + Create an activation layer from a string/function. Parameters: ---------- - activation : function, or str, or nn.Module - Activation function or name of activation function. + activation : function, str, or nn.Module + Activation function or name of the activation function. Returns: ------- nn.Module Activation layer. + + Raises: + ------- + NotImplementedError: + If the specified activation function is not supported. """ - assert (activation is not None) + assert activation is not None if isfunction(activation): return activation() elif isinstance(activation, str): @@ -145,9 +305,9 @@ def get_activation_layer(activation): elif activation == "identity": return Identity() else: - raise NotImplementedError() + raise NotImplementedError("Unsupported activation function: {}".format(activation)) else: - assert (isinstance(activation, nn.Module)) + assert isinstance(activation, nn.Module) return activation @@ -165,6 +325,21 @@ class SelectableDense(nn.Module): Whether the layer uses a bias vector. num_options : int, default 1 Number of selectable options. + + Attributes: + ---------- + in_features : int + Number of input features. + out_features : int + Number of output features. + use_bias : bool + Whether the layer uses a bias vector. + num_options : int + Number of selectable options. + weight : torch.nn.Parameter + Learnable weight parameter. + bias : torch.nn.Parameter + Learnable bias parameter (if use_bias=True). """ def __init__(self, @@ -185,6 +360,21 @@ def __init__(self, self.register_parameter("bias", None) def forward(self, x, indices): + """ + Forward pass through the SelectableDense layer. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + indices : torch.Tensor + Tensor containing the indices of the selected options. + + Returns: + ------- + torch.Tensor + Output tensor after applying the SelectableDense layer. + """ weight = torch.index_select(self.weight, dim=0, index=indices) x = x.unsqueeze(-1) x = weight.bmm(x) @@ -195,6 +385,14 @@ def forward(self, x, indices): return x def extra_repr(self): + """ + Extra representation of the SelectableDense layer. + + Returns: + ------- + str + String representation of the layer's attributes. + """ return "in_features={}, out_features={}, bias={}, num_options={}".format( self.in_features, self.out_features, self.use_bias, self.num_options) @@ -242,6 +440,19 @@ def __init__(self, self.activ = get_activation_layer(activation) def forward(self, x): + """ + Forward pass of the dense block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor. + """ x = self.fc(x) if self.use_bn: x = self.bn(x) @@ -313,6 +524,19 @@ def __init__(self, self.activ = get_activation_layer(activation) def forward(self, x): + """ + Forward pass of the 1D convolution block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ------- + torch.Tensor + Output tensor. + """ x = self.conv(x) if self.use_bn: x = self.bn(x) @@ -341,6 +565,11 @@ def conv1x1(in_channels, Number of groups. bias : bool, default False Whether the layer uses a bias vector. + + Returns: + ------- + nn.Conv2d + 1x1 convolutional layer. """ return nn.Conv2d( in_channels=in_channels, @@ -377,6 +606,11 @@ def conv3x3(in_channels, Number of groups. bias : bool, default False Whether the layer uses a bias vector. + + Returns: + ------- + nn.Conv2d + 3x3 convolutional layer. """ return nn.Conv2d( in_channels=in_channels, @@ -409,6 +643,11 @@ def depthwise_conv3x3(channels, Dilation value for convolution layer. bias : bool, default False Whether the layer uses a bias vector. + + Returns: + ------- + nn.Conv2d + Depthwise 3x3 convolutional layer. """ return nn.Conv2d( in_channels=channels, @@ -449,6 +688,15 @@ class ConvBlock(nn.Module): Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Examples: + -------- + An example of using the ConvBlock: + + >>> import torch + >>> x = torch.randn(1, 3, 64, 64) # Input tensor with shape (batch_size, channels, height, width) + >>> conv_block = ConvBlock(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1) + >>> output = conv_block(x) # Forward pass through the ConvBlock """ def __init__(self, @@ -489,6 +737,19 @@ def __init__(self, self.activ = get_activation_layer(activation) def forward(self, x): + """ + Forward pass of the ConvBlock. + + Parameters: + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, in_channels, height, width). + + Returns: + ------- + torch.Tensor + Output tensor after applying convolution, batch normalization, and activation. + """ if self.use_pad: x = self.pad(x) x = self.conv(x) @@ -531,6 +792,11 @@ def conv1x1_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 1x1 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -580,6 +846,11 @@ def conv3x3_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 3x3 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -594,7 +865,6 @@ def conv3x3_block(in_channels, bn_eps=bn_eps, activation=activation) - def conv5x5_block(in_channels, out_channels, stride=1, @@ -630,6 +900,11 @@ def conv5x5_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 5x5 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -664,7 +939,7 @@ def conv7x7_block(in_channels, Number of input channels. out_channels : int Number of output channels. - padding : int, or tuple/list of 2 int, or tuple/list of 4 int, default 1 + stride : int or tuple/list of 2 int, default 1 Strides of the convolution. padding : int or tuple/list of 2 int, default 3 Padding value for convolution layer. @@ -680,6 +955,11 @@ def conv7x7_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 7x7 Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -730,6 +1010,11 @@ def dwconv_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + Depthwise Convolutional Block. """ return ConvBlock( in_channels=in_channels, @@ -774,6 +1059,11 @@ def dwconv3x3_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 3x3 Depthwise Convolutional Block. """ return dwconv_block( in_channels=in_channels, @@ -816,6 +1106,11 @@ def dwconv5x5_block(in_channels, Small float added to variance in Batch norm. activation : function or str or None, default nn.ReLU(inplace=True) Activation function or name of activation function. + + Returns: + ------- + nn.Module + 5x5 Depthwise Convolutional Block. """ return dwconv_block( in_channels=in_channels, @@ -829,6 +1124,7 @@ def dwconv5x5_block(in_channels, activation=activation) + class DwsConvBlock(nn.Module): """ Depthwise separable convolution block with BatchNorms and activations at each convolution layers. @@ -859,6 +1155,11 @@ class DwsConvBlock(nn.Module): Activation function after the depthwise convolution block. pw_activation : function or str or None, default nn.ReLU(inplace=True) Activation function after the pointwise convolution block. + + Returns: + ---------- + torch.Tensor + The output tensor after applying depthwise separable convolution block. """ def __init__(self, @@ -895,6 +1196,19 @@ def __init__(self, activation=pw_activation) def forward(self, x): + """ + Forward pass of the depthwise separable convolution block. + + Parameters: + ---------- + x : torch.Tensor + Input tensor. + + Returns: + ---------- + torch.Tensor + The output tensor after applying depthwise separable convolution block. + """ x = self.dw_conv(x) x = self.pw_conv(x) return x diff --git a/python/fedml/model/cv/resnet.py b/python/fedml/model/cv/resnet.py index d833e3b762..9a4106c6e3 100644 --- a/python/fedml/model/cv/resnet.py +++ b/python/fedml/model/cv/resnet.py @@ -17,7 +17,8 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding. + """ + 3x3 convolution with padding. Args: in_planes (int): Number of input channels. @@ -42,7 +43,8 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution. + """ + 1x1 convolution. Args: in_planes (int): Number of input channels. @@ -56,7 +58,9 @@ def conv1x1(in_planes, out_planes, stride=1): class BasicBlock(nn.Module): - """Basic residual block used in ResNet architectures.""" + """ + Basic residual block used in ResNet architectures. + """ expansion = 1 @@ -71,7 +75,8 @@ def __init__( dilation=1, norm_layer=None, ): - """Initialize a BasicBlock instance. + """ + Initialize a BasicBlock instance. Args: inplanes (int): Number of input channels. @@ -103,7 +108,8 @@ def __init__( self.stride = stride def forward(self, x): - """Forward pass through the BasicBlock. + """ + Forward pass through the BasicBlock. Args: x (torch.Tensor): Input tensor. diff --git a/python/fedml/model/cv/resnet_torch.py b/python/fedml/model/cv/resnet_torch.py index 5008edf8c4..bdf50e468b 100644 --- a/python/fedml/model/cv/resnet_torch.py +++ b/python/fedml/model/cv/resnet_torch.py @@ -9,7 +9,16 @@ # from .._internally_replaced_utils import load_state_dict_from_url # from ..utils import _log_api_usage_once +""" +This module provides pre-trained ResNet models and their URLs for download. +- `ResNet`: The main ResNet model class. +- `resnet18`, `resnet34`, `resnet50`, `resnet101`, `resnet152`: Pre-trained ResNet models. +- `resnext50_32x4d`, `resnext101_32x8d`: Pre-trained ResNeXt models. +- `wide_resnet50_2`, `wide_resnet101_2`: Pre-trained Wide ResNet models. + +You can use these models for various computer vision tasks. +""" __all__ = [ "ResNet", "resnet18", @@ -38,7 +47,19 @@ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: - """3x3 convolution with padding""" + """ + 3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation rate for convolution. Default is 1. + + Returns: + nn.Conv2d: Convolutional layer. + """ return nn.Conv2d( in_planes, out_planes, @@ -52,11 +73,38 @@ def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, d def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: - """1x1 convolution""" + """ + 1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: Convolutional layer. + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic ResNet block. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for convolution. Default is 1. + norm_layer (Callable[..., nn.Module], optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor. + """ + expansion: int = 1 def __init__( @@ -87,6 +135,15 @@ def __init__( self.stride = stride def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the BasicBlock. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -112,6 +169,23 @@ class Bottleneck(nn.Module): # This variant is also known as ResNet V1.5 and improves accuracy according to # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + """ + Bottleneck block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for convolution. Default is 1. + norm_layer (Callable[..., nn.Module], optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor. + """ + expansion: int = 4 def __init__( @@ -141,6 +215,15 @@ def __init__( self.stride = stride def forward(self, x: Tensor) -> Tensor: + """ + Forward pass through the Bottleneck block. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -177,6 +260,26 @@ def __init__( args = None, in_channels = 3, ) -> None: + """ + Residual Neural Network (ResNet) model. + + Args: + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + layers (List[int]): List specifying the number of blocks in each layer. + num_classes (int, optional): Number of output classes. Default is 1000. + zero_init_residual (bool, optional): If True, zero-initializes the last BN in each residual branch. + Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Base width for grouped convolution. Default is 64. + replace_stride_with_dilation (Optional[List[bool]], optional): List specifying if strides should be replaced + with dilations in each layer. Default is None. + norm_layer (Optional[Callable[..., nn.Module]], optional): Normalization layer. Default is None. + args: Additional arguments (not used in the model). + in_channels: Number of input channels. Default is 3. + + Attributes: + expansion (int): Expansion factor for bottleneck blocks. + """ super().__init__() # _log_api_usage_once(self) if norm_layer is None: @@ -240,6 +343,19 @@ def _make_layer( stride: int = 1, dilate: bool = False, ) -> nn.Sequential: + """ + Create a layer consisting of multiple blocks. + + Args: + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + planes (int): Number of output channels for the layer. + blocks (int): Number of blocks in the layer. + stride (int, optional): Stride for the first block. Default is 1. + dilate (bool, optional): If True, use dilation in the layer. Default is False. + + Returns: + nn.Sequential: A sequential module containing the blocks. + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -274,6 +390,15 @@ def _make_layer( return nn.Sequential(*layers) def _forward_impl(self, x: Tensor) -> Tensor: + """ + Forward pass of the model. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ # See note [TorchScript super()] if len(x.shape) < 4: x = torch.unsqueeze(x, 1) @@ -295,6 +420,15 @@ def _forward_impl(self, x: Tensor) -> Tensor: return x def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the model. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor. + """ return self._forward_impl(x) @@ -308,6 +442,20 @@ def _resnet( progress: bool, **kwargs: Any, ) -> ResNet: + """ + Constructs a ResNet model. + + Args: + arch (str): Architecture name. + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + layers (List[int]): List specifying the number of blocks in each layer. + pretrained (bool): If True, loads pre-trained weights. + progress (bool): If True, displays download progress for pre-trained weights. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + ResNet: ResNet model. + """ model = ResNet(block, layers, **kwargs) # if pretrained: # state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) @@ -318,9 +466,18 @@ def _resnet( def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_. + Constructs a ResNet model. + Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr + arch (str): Architecture name. + block (Type[Union[BasicBlock, Bottleneck]]): Type of residual block to use (BasicBlock or Bottleneck). + layers (List[int]): List specifying the number of blocks in each layer. + pretrained (bool): If True, loads pre-trained weights. + progress (bool): If True, displays download progress for pre-trained weights. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + ResNet: ResNet model. """ return _resnet("resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) From 1714f22a12dbe3741a6416676e90c34784c96093 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 13 Sep 2023 13:30:47 +0530 Subject: [PATCH 57/70] model done --- python/fedml/model/cv/batchnorm_utils.py | 406 ++++++++++++++++-- python/fedml/model/cv/common.py | 3 +- python/fedml/model/cv/darts/architect.py | 185 ++++++++ python/fedml/model/cv/darts/model.py | 123 +++++- python/fedml/model/cv/darts/model_search.py | 315 +++++++++++++- .../fedml/model/cv/darts/model_search_gdas.py | 72 +++- python/fedml/model/cv/darts/operations.py | 100 +++++ python/fedml/model/cv/darts/train.py | 25 ++ python/fedml/model/cv/darts/train_search.py | 34 ++ python/fedml/model/cv/darts/utils.py | 128 +++++- python/fedml/model/cv/darts/visualize.py | 13 + python/fedml/model/cv/efficientnet_utils.py | 265 +++++++++--- python/fedml/model/cv/group_normalization.py | 173 +++++++- .../fedml/model/cv/resnet56/resnet_client.py | 143 +++++- .../model/cv/resnet56/resnet_pretrained.py | 153 ++++++- .../fedml/model/cv/resnet56/resnet_server.py | 178 +++++++- python/fedml/model/linear/lr.py | 39 +- python/fedml/model/linear/lr_cifar10.py | 23 +- python/fedml/model/mobile/mnn_lenet.py | 23 + python/fedml/model/mobile/mnn_resnet.py | 22 + python/fedml/model/mobile/torch_lenet.py | 12 +- 21 files changed, 2259 insertions(+), 176 deletions(-) diff --git a/python/fedml/model/cv/batchnorm_utils.py b/python/fedml/model/cv/batchnorm_utils.py index 876454b55a..7ecb309ebb 100644 --- a/python/fedml/model/cv/batchnorm_utils.py +++ b/python/fedml/model/cv/batchnorm_utils.py @@ -24,20 +24,84 @@ class FutureResult(object): - """A thread-safe future implementation. Used only as one-to-one pipe.""" + """A thread-safe future implementation used for one-to-one communication. + This class provides a thread-safe mechanism for transferring results between threads, + typically in a producer-consumer pattern. It is designed for one-to-one communication + and ensures that the result is safely passed from one thread to another. + + Args: + None + + Attributes: + _result: The result value stored in the future. + _lock: A lock to ensure thread safety. + _cond: A condition variable associated with the lock for waiting and notifying. + + Methods: + put(result): + Puts a result value into the future. If a result already exists, it raises an + assertion error. + + get(): + Retrieves the result value from the future. If the result is not available yet, + it blocks until the result is put into the future. + + Example: + Here's an example of using `FutureResult` for communication between two threads: + + ```python + import threading + + def producer(future): + result = 42 # Some computation or value to produce + future.put(result) + + def consumer(future): + result = future.get() + print(f"Received result: {result}") + + future = FutureResult() + + # Start the producer and consumer threads + producer_thread = threading.Thread(target=producer, args=(future,)) + consumer_thread = threading.Thread(target=consumer, args=(future,)) + + producer_thread.start() + consumer_thread.start() + + producer_thread.join() + consumer_thread.join() + ``` + + Note: + This class is intended for one-to-one communication between threads. + """ def __init__(self): self._result = None self._lock = threading.Lock() self._cond = threading.Condition(self._lock) def put(self, result): + """Put a result into the future. + + Args: + result: The result value to be stored in the future. + + Raises: + AssertionError: If a result is already present in the future. + """ with self._lock: assert self._result is None, "Previous result has't been fetched." self._result = result self._cond.notify() def get(self): + """Get the result from the future, blocking if necessary. + + Returns: + The result value stored in the future. + """ with self._lock: if self._result is None: self._cond.wait() @@ -54,9 +118,69 @@ def get(self): class SlavePipe(_SlavePipeBase): - """Pipe for master-slave communication.""" + """Pipe for master-slave communication in a multi-threaded environment. + + This class represents a pipe used for communication between a master thread and one + or more slave threads. It is designed for multi-threaded applications where the + master thread delegates tasks to the slave threads and waits for their results. + + Args: + queue (Queue): A queue for sending messages from the slave thread to the master. + result (FutureResult): A FutureResult object for receiving results from the slave. + identifier (int): An identifier for the slave thread. + + Attributes: + queue (Queue): A queue for sending messages from the slave thread to the master. + result (FutureResult): A FutureResult object for receiving results from the slave. + identifier (int): An identifier for the slave thread. + + Methods: + run_slave(msg): + Executes a task in the slave thread and sends a message to the master thread. + It waits for the master to acknowledge the completion of the task and returns + the result. + + Example: + Here's an example of using `SlavePipe` for master-slave communication: + + ```python + import threading + + def slave_function(pipe): + # Perform some computation and send the result to the master + result = 42 # Placeholder for the result + pipe.run_slave(result) + + # Create a SlavePipe for communication + slave_pipe = SlavePipe(queue, result, 1) + + # Start the slave thread + slave_thread = threading.Thread(target=slave_function, args=(slave_pipe,)) + slave_thread.start() + + # Master thread can send tasks and receive results using the slave_pipe + task_result = slave_pipe.run_slave(task_data) + # Wait for the slave thread to finish + slave_thread.join() + + # Use the task_result received from the slave + print(f"Received result from slave: {task_result}") + ``` + + Note: + This class is intended for use in multi-threaded applications where a master + thread communicates with one or more slave threads. + """ def run_slave(self, msg): + """Execute a task in the slave thread and communicate with the master. + + Args: + msg: The message or task to be sent to the master. + + Returns: + The result of the task received from the master. + """ self.queue.put((self.identifier, msg)) ret = self.result.get() self.queue.put(True) @@ -64,13 +188,67 @@ def run_slave(self, msg): class SyncMaster(object): - """An abstract `SyncMaster` object. + """An abstract `SyncMaster` object for coordinating communication between master and slave devices. + + In a data parallel setting, the `SyncMaster` object manages the communication between the master device + and multiple slave devices. It provides a mechanism for slave devices to register and communicate with + the master during forward and backward passes. + - During the replication, as the data parallel will trigger an callback of each module, all slave devices should call `register(id)` and obtain an `SlavePipe` to communicate with the master. - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, and passed to a registered callback. - After receiving the messages, the master device should gather the information and determine to message passed back to each slave devices. + + Args: + master_callback (callable): A callback function to be invoked after collecting messages from slave devices. + + Attributes: + _master_callback (callable): A callback function to be invoked after collecting messages from slave devices. + _queue (queue.Queue): A queue for exchanging messages between master and slave devices. + _registry (collections.OrderedDict): A registry of slave devices and their associated communication pipes. + _activated (bool): A flag indicating whether the SyncMaster is activated for communication. + + Methods: + register_slave(identifier): + Register a slave device and obtain a `SlavePipe` object for communication with the master device. + + run_master(master_msg): + Main entry for the master device during each forward pass. Collects messages from all devices, + invokes the master callback to compute a response message, and sends messages back to each device. + + nr_slaves: + Property that returns the number of registered slave devices. + + Example: + Here's an example of using `SyncMaster` for coordinating communication in a data parallel setting: + + ```python + def master_callback(messages): + # Compute the master message based on received messages + master_msg = messages[0][1] + return [(0, master_msg)] # Send the same message back to the master + + sync_master = SyncMaster(master_callback) + + # Register slave devices and obtain communication pipes + slave_pipe1 = sync_master.register_slave(1) + slave_pipe2 = sync_master.register_slave(2) + + # During the forward pass, master device runs run_master to coordinate communication + master_msg = "Hello from master" + response_msg = sync_master.run_master(master_msg) + + # Use the response_msg and coordinate further actions + + # Get the number of registered slave devices + num_slaves = sync_master.nr_slaves + ``` + + Note: + This class is intended for use in multi-device data parallel applications where a master device + coordinates communication with multiple slave devices. """ def __init__(self, master_callback): @@ -90,11 +268,27 @@ def __setstate__(self, state): self.__init__(state["master_callback"]) def register_slave(self, identifier): - """ - Register an slave device. + """Register a slave device with the SyncMaster. + Args: - identifier: an identifier, usually is the device id. - Returns: a `SlavePipe` object which can be used to communicate with the master device. + identifier (int): An identifier, usually the device ID. + + Returns: + SlavePipe: A `SlavePipe` object for communicating with the master device. + + Raises: + AssertionError: If the SyncMaster is already activated and the queue is not empty. + + Notes: + This method should be called by slave devices to register themselves with the SyncMaster. + The returned `SlavePipe` object can be used for communication with the master device. + + Example: + ```python + sync_master = SyncMaster(master_callback) + slave_pipe = sync_master.register_slave(1) + ``` + """ if self._activated: assert self._queue.empty(), "Queue is not clean before next initialization." @@ -105,15 +299,30 @@ def register_slave(self, identifier): return SlavePipe(identifier, self._queue, future) def run_master(self, master_msg): - """ - Main entry for the master device in each forward pass. + """Run the master device during each forward pass. + The messages were first collected from each devices (including the master device), and then an callback will be invoked to compute the message to be sent back to each devices (including the master device). + Args: - master_msg: the message that the master want to send to itself. This will be placed as the first - message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. - Returns: the message to be sent back to the master device. + master_msg: The message that the master wants to send to itself. + This message will be placed as the first message when calling `master_callback`. + + Returns: + Any: The message to be sent back to the master device. + + Notes: + This method is the main entry for the master device during each forward pass. + It collects messages from all devices, invokes the master callback to compute a response message, + and sends messages back to each device. + + Example: + ```python + master_msg = "Hello from master" + response_msg = sync_master.run_master(master_msg) + ``` + """ self._activated = True @@ -136,16 +345,57 @@ def run_master(self, master_msg): @property def nr_slaves(self): + """Get the number of registered slave devices. + + Returns: + int: The number of registered slave devices. + + Example: + ```python + num_slaves = sync_master.nr_slaves + ``` + + """ return len(self._registry) def _sum_ft(tensor): - """sum over the first and last dimention""" + """Sum over the first and last dimensions of a tensor. + + Args: + tensor (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: A tensor with the sum of values over the first and last dimensions. + + Example: + ```python + input_tensor = torch.tensor([[1, 2], [3, 4]]) + result = _sum_ft(input_tensor) + # Result: tensor([10]) + ``` + + """ return tensor.sum(dim=0).sum(dim=-1) def _unsqueeze_ft(tensor): - """add new dementions at the front and the tail""" + """Add new dimensions at the front and the tail of a tensor. + + Args: + tensor (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: A tensor with new dimensions added at the front and the tail. + + Example: + ```python + input_tensor = torch.tensor([1, 2, 3]) + result = _unsqueeze_ft(input_tensor) + # Result: tensor([[[1]], [[2]], [[3]]]) + ``` + + """ return tensor.unsqueeze(0).unsqueeze(-1) @@ -154,6 +404,21 @@ def _unsqueeze_ft(tensor): class _SynchronizedBatchNorm(_BatchNorm): + """Synchronized Batch Normalization for parallel computation. + + This class extends PyTorch's BatchNorm2d to support synchronization for data parallelism. + It uses a master-slave communication pattern to compute batch statistics efficiently. + + Args: + num_features (int): Number of features in the input tensor. + eps (float): Small constant added to the denominator for numerical stability. Default: 1e-5 + momentum (float): Momentum factor for the running statistics. Default: 0.1 + affine (bool): If True, apply learned affine transformation. Default: True + + Note: + This class is typically used in a data parallel setup where multiple GPUs work together. + + """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): super(_SynchronizedBatchNorm, self).__init__( num_features, eps=eps, momentum=momentum, affine=affine @@ -166,6 +431,15 @@ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): self._slave_pipe = None def forward(self, input): + """Forward pass through the synchronized batch normalization layer. + + Args: + input (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized and optionally affine-transformed tensor. + + """ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. if not (self._is_parallel and self.training): return F.batch_norm( @@ -221,7 +495,15 @@ def __data_parallel_replicate__(self, ctx, copy_id): self._slave_pipe = ctx.sync_master.register_slave(copy_id) def _data_parallel_master(self, intermediates): - """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" + """Replicate the synchronized batch normalization layer for data parallelism. + + This method is called during data parallel replication to prepare the layer for parallel computation. + + Args: + ctx: The context object. + copy_id (int): Identifier for the replica. + + """ # Always using same "device order" makes the ReduceAdd operation faster. # Thanks to:: Tete Xiao (http://tetexiao.com/) @@ -244,8 +526,17 @@ def _data_parallel_master(self, intermediates): return outputs def _compute_mean_std(self, sum_, ssum, size): - """Compute the mean and standard-deviation with sum and square-sum. This method - also maintains the moving average on the master device.""" + """Compute the mean and standard-deviation with sum and square-sum. + + Args: + sum_ (torch.Tensor): Sum of values. + ssum (torch.Tensor): Sum of squared values. + size (int): Size of the input batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Mean and standard-deviation. + + """ assert ( size > 1 ), "BatchNorm computes unbiased standard-deviation, which requires size > 1." @@ -288,25 +579,30 @@ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): During evaluation, this running mean/variance is used for normalization. Because the BatchNorm is done over the `C` dimension, computing statistics on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm + or Instance Norm. + + Note: + This layer behaves like the built-in PyTorch BatchNorm1d when used on a single GPU or CPU. + Args: - num_features: num_features from an expected input of size - `batch_size x num_features [x width]` - eps: a value added to the denominator for numerical stability. - Default: 1e-5 - momentum: the value used for the running_mean and running_var - computation. Default: 0.1 - affine: a boolean value that when set to ``True``, gives the layer learnable - affine parameters. Default: ``True`` + num_features (int): Number of features in the input tensor. `batch_size x num_features [x width]` + eps (float): A small constant added to the denominator for numerical stability. Default: 1e-5 + momentum (float): The momentum factor used for computing running statistics. Default: 0.1 + affine (bool): If True, learnable affine parameters (gamma and beta) are applied. Default: True + Shape: - - Input: :math:`(N, C)` or :math:`(N, C, L)` - - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) + - Input: (N, C) or (N, C, L) + - Output: (N, C) or (N, C, L) (same shape as input) + Examples: >>> # With Learnable Parameters >>> m = SynchronizedBatchNorm1d(100) >>> # Without Learnable Parameters >>> m = SynchronizedBatchNorm1d(100, affine=False) - >>> input = torch.autograd.Variable(torch.randn(20, 100)) + >>> input = torch.randn(20, 100) # 2D input >>> output = m(input) + >>> input_3d = torch.randn(20, 100, 30) # 3D input + >>> output_3d = m(input_3d) """ def _check_input_dim(self, input): @@ -426,14 +722,24 @@ class CallbackContext(object): def execute_replication_callbacks(modules): """ - Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` - Note that, as all modules are isomorphism, we assign each sub-module with a context + Execute a replication callback `__data_parallel_replicate__` on each module created by original replication. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`. + Note that, as all modules are isomorphic, we assign each sub-module with a context (shared among multiple copies of this module on different devices). Through this context, different copies can share some information. We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback of any slave copies. + + Args: + modules (list): List of replicated modules. + + Examples: + >>> # Replicate a module and execute replication callbacks + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> replicated_sync_bn = DataParallelWithCallback(replicate(sync_bn, device_ids=[0, 1])) + >>> # sync_bn.__data_parallel_replicate__ will be invoked. """ + master_copy = modules[0] nr_modules = len(list(master_copy.modules())) ctxs = [CallbackContext() for _ in range(nr_modules)] @@ -447,13 +753,19 @@ def execute_replication_callbacks(modules): class DataParallelWithCallback(DataParallel): """ Data Parallel with a replication callback. - An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by - original `replicate` function. - The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + A replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + the original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`. + + Args: + module (Module): The module to be parallelized. + device_ids (list): List of device IDs to use for parallelization. + Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) - # sync_bn.__data_parallel_replicate__ will be invoked. + >>> # Parallelize a module with a replication callback + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> replicated_sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + >>> # sync_bn.__data_parallel_replicate__ will be invoked. """ def replicate(self, module, device_ids): @@ -466,13 +778,21 @@ def patch_replication_callback(data_parallel): """ Monkey-patch an existing `DataParallel` object. Add the replication callback. Useful when you have customized `DataParallel` implementation. + + Args: + data_parallel (DataParallel): The existing DataParallel object to be patched. + Examples: - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) - > patch_replication_callback(sync_bn) - # this is equivalent to - > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) - > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + >>> patch_replication_callback(sync_bn) + # This is equivalent to: + >>> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + >>> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + + Note: + This function monkey-patches the `DataParallel` object to add the replication callback + without the need to create a new `DataParallelWithCallback` object. """ assert isinstance(data_parallel, DataParallel) diff --git a/python/fedml/model/cv/common.py b/python/fedml/model/cv/common.py index 267bb4494d..1f01c89022 100644 --- a/python/fedml/model/cv/common.py +++ b/python/fedml/model/cv/common.py @@ -1811,8 +1811,7 @@ def __repr__(self): groups=self.groups) -def channel_shuffle2(x, - groups): +def channel_shuffle2(x, groups): """ Channel shuffle operation from 'ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices,' https://arxiv.org/abs/1707.01083. The alternative version. diff --git a/python/fedml/model/cv/darts/architect.py b/python/fedml/model/cv/darts/architect.py index 27cca11c19..25fcf2883d 100644 --- a/python/fedml/model/cv/darts/architect.py +++ b/python/fedml/model/cv/darts/architect.py @@ -11,6 +11,64 @@ def _concat(xs): class Architect(object): + """ + The Architect class is responsible for architecture optimization in neural architecture search (NAS). + It adapts the architecture of a neural network to improve its performance on a specific task using gradient-based methods. + + Attributes: + network_momentum (float): The momentum term for the network weights. + network_weight_decay (float): The weight decay term for the network weights. + model (nn.Module): The neural network model for which the architecture is optimized. + criterion (nn.Module): The loss criterion used for training. + optimizer (torch.optim.Optimizer): The optimizer for architecture parameters. + device (torch.device): The device on which the operations are performed. + is_multi_gpu (bool): Flag indicating if the model is trained on multiple GPUs. + + Args: + model (nn.Module): The neural network model being optimized. + criterion (nn.Module): The loss criterion for training. + args (object): A configuration object containing hyperparameters. + device (torch.device): The device (e.g., 'cuda' or 'cpu') on which to perform computations. + + Methods: + step(input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled): + Perform a single step of architecture optimization. + + step_v2(input_train, target_train, input_valid, target_valid, lambda_train_regularizer, lambda_valid_regularizer): + Perform a single step of architecture optimization with custom regularization terms. + + step_single_level(input_train, target_train): + Perform a single step of architecture optimization for a single level. + + step_wa(input_train, target_train, input_valid, target_valid, lambda_regularizer): + Perform a single step of architecture optimization with weight adaptation. + + step_AOS(input_train, target_train, input_valid, target_valid): + Perform a single step of architecture optimization using the AOS method. + + _backward_step(input_valid, target_valid): + Perform the backward step during optimization. + + _backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer): + Perform the unrolled backward step during optimization. + + _construct_model_from_theta(theta): + Construct a new model using architecture parameters. + + _hessian_vector_product(vector, input, target, r=1e-2): + Compute the product of the Hessian matrix and a vector. + + _compute_unrolled_model(input, target, eta, network_optimizer): + Compute the unrolled model with updated weights. + + Example: + # Create an Architect instance + architect = Architect(model, criterion, args, device) + + # Perform architecture optimization + architect.step(input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled=True) + """ + def __init__(self, model, criterion, args, device): self.network_momentum = args.momentum self.network_weight_decay = args.weight_decay @@ -34,6 +92,18 @@ def __init__(self, model, criterion, args, device): # W_j = V_j + W_jx x # https://www.youtube.com/watch?v=k8fTYJPd3_I def _compute_unrolled_model(self, input, target, eta, network_optimizer): + """ + Compute the unrolled model with respect to the architecture parameters. + + Args: + input: Input data. + target: Target data. + eta (float): Learning rate. + network_optimizer: The network optimizer. + + Returns: + unrolled_model: The unrolled model. + """ logits = self.model(input) loss = self.criterion(logits, target) # pylint: disable=E1102 @@ -65,6 +135,18 @@ def step( network_optimizer, unrolled, ): + """ + Perform one optimization step for architecture search. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta (float): Learning rate. + network_optimizer: The network optimizer. + unrolled (bool): Whether to compute an unrolled model. + """ self.optimizer.zero_grad() if unrolled: # logging.info("first order") @@ -91,6 +173,17 @@ def step_v2( lambda_train_regularizer, lambda_valid_regularizer, ): + """ + Perform one optimization step for architecture search (variant 2). + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + lambda_train_regularizer (float): Regularization weight for training. + lambda_valid_regularizer (float): Regularization weight for validation. + """ self.optimizer.zero_grad() # grads_alpha_with_train_dataset @@ -143,6 +236,13 @@ def step_v2( # ours def step_single_level(self, input_train, target_train): + """ + Perform one optimization step for architecture search (single level). + + Args: + input_train: Training input data. + target_train: Training target data. + """ self.optimizer.zero_grad() # grads_alpha_with_train_dataset @@ -174,6 +274,16 @@ def step_single_level(self, input_train, target_train): def step_wa( self, input_train, target_train, input_valid, target_valid, lambda_regularizer ): + """ + Perform one optimization step for architecture search (weighted average). + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + lambda_regularizer (float): Regularization weight. + """ self.optimizer.zero_grad() # grads_alpha_with_train_dataset @@ -220,6 +330,15 @@ def step_wa( self.optimizer.step() def step_AOS(self, input_train, target_train, input_valid, target_valid): + """ + Perform one optimization step for architecture search (AOS). + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + """ self.optimizer.zero_grad() output_search = self.model(input_valid) arch_loss = self.criterion(output_search, target_valid) # pylint: disable=E1102 @@ -227,6 +346,13 @@ def step_AOS(self, input_train, target_train, input_valid, target_valid): self.optimizer.step() def _backward_step(self, input_valid, target_valid): + """ + Perform a backward step for the architecture optimization. + + Args: + input_valid: Validation input data. + target_valid: Validation target data. + """ logits = self.model(input_valid) loss = self.criterion(logits, target_valid) # pylint: disable=E1102 @@ -241,6 +367,17 @@ def _backward_step_unrolled( eta, network_optimizer, ): + """ + Perform a backward step for the architecture optimization with unrolled training. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta: Learning rate for unrolled training. + network_optimizer: The optimizer for the network weights. + """ # calculate w' in equation (7): # approximate w(*) by adapting w using only a single training step and enable momentum. unrolled_model = self._compute_unrolled_model( @@ -277,6 +414,15 @@ def _backward_step_unrolled( v.grad.data.copy_(g.data) def _construct_model_from_theta(self, theta): + """ + Construct a new model from the given theta. + + Args: + theta: A flattened parameter tensor. + + Returns: + model_new: A new model constructed using the provided theta. + """ model_new = self.model.new() model_dict = self.model.state_dict() @@ -311,6 +457,18 @@ def _construct_model_from_theta(self, theta): return model_new.to(self.device) def _hessian_vector_product(self, vector, input, target, r=1e-2): + """ + Calculate the Hessian-vector product. + + Args: + vector: A list of gradient vectors. + input: Input data. + target: Target data. + r: Regularization term. + + Returns: + List of Hessian-vector products. + """ # vector is (gradient of w' on validation dataset) R = r / _concat(vector).norm() parameters = ( @@ -374,6 +532,19 @@ def step_v2_2ndorder( lambda_train_regularizer, lambda_valid_regularizer, ): + """ + Perform a step for architecture optimization using the second-order method. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta: Learning rate for unrolled training. + network_optimizer: The optimizer for the network weights. + lambda_train_regularizer: Regularization term for training dataset. + lambda_valid_regularizer: Regularization term for validation dataset. + """ self.optimizer.zero_grad() # approximate w(*) by adapting w using only a single training step and enable momentum. @@ -465,6 +636,20 @@ def step_v2_2ndorder2( lambda_train_regularizer, lambda_valid_regularizer, ): + """ + Perform a step for architecture optimization using the second-order method with modifications. + + Args: + input_train: Training input data. + target_train: Training target data. + input_valid: Validation input data. + target_valid: Validation target data. + eta: Learning rate for unrolled training. + network_optimizer: The optimizer for the network weights. + lambda_train_regularizer: Regularization term for training dataset. + lambda_valid_regularizer: Regularization term for validation dataset. + """ + self.optimizer.zero_grad() # approximate w(*) by adapting w using only a single training step and enable momentum. diff --git a/python/fedml/model/cv/darts/model.py b/python/fedml/model/cv/darts/model.py index 62c11388b1..5f5f3badb8 100644 --- a/python/fedml/model/cv/darts/model.py +++ b/python/fedml/model/cv/darts/model.py @@ -6,9 +6,29 @@ class Cell(nn.Module): + """ + Cell in a neural architecture described by a genotype. + + Args: + genotype (Genotype): Genotype describing the cell's architecture. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels. + reduction (bool): Whether the cell is a reduction cell. + reduction_prev (bool): Whether the previous cell was a reduction cell. + + Input: + - s0 (Tensor): Input tensor from two steps back, shape (batch_size, C_prev_prev, H, W). + - s1 (Tensor): Input tensor from the previous step, shape (batch_size, C_prev, H, W). + - drop_prob (float): Dropout probability for drop path regularization during training. + + Output: + - Output tensor of the cell, shape (batch_size, C, H, W). + + """ + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): super(Cell, self).__init__() - print(C_prev_prev, C_prev, C) if reduction_prev: self.preprocess0 = FactorizedReduce(C_prev_prev, C) @@ -25,6 +45,17 @@ def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): self._compile(C, op_names, indices, concat, reduction) def _compile(self, C, op_names, indices, concat, reduction): + """ + Compiles the operations for the cell based on the given genotype. + + Args: + C (int): Number of output channels for the cell. + op_names (list of str): Names of the operations for each edge in the cell. + indices (list of int): Indices of the operations for each edge in the cell. + concat (list of int): Concatenation points for the cell. + reduction (bool): Whether the cell is a reduction cell. + + """ assert len(op_names) == len(indices) self._steps = len(op_names) // 2 self._concat = concat @@ -38,6 +69,18 @@ def _compile(self, C, op_names, indices, concat, reduction): self._indices = indices def forward(self, s0, s1, drop_prob): + """ + Forward pass through the cell. + + Args: + s0 (Tensor): Input tensor from two steps back, shape (batch_size, C_prev_prev, H, W). + s1 (Tensor): Input tensor from the previous step, shape (batch_size, C_prev, H, W). + drop_prob (float): Dropout probability for drop path regularization during training. + + Returns: + Tensor: Output tensor of the cell, shape (batch_size, C, H, W). + + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -59,15 +102,40 @@ def forward(self, s0, s1, drop_prob): return torch.cat([states[i] for i in self._concat], dim=1) + class AuxiliaryHeadCIFAR(nn.Module): + """ + Auxiliary head for CIFAR classification in the DARTS model. + + Args: + C (int): Number of input channels. + num_classes (int): Number of classes for classification. + + Input: + - Input tensor of shape (batch_size, C, 8, 8), assuming an input size of 8x8. + + Output: + - Output tensor of shape (batch_size, num_classes), representing class scores. + + Architecture: + - ReLU activation + - Average pooling with 5x5 kernel and stride 3 (resulting in an image size of 2x2) + - 1x1 convolution with 128 output channels + - Batch normalization + - ReLU activation + - 2x2 convolution with 768 output channels + - Batch normalization + - ReLU activation + - Linear layer with num_classes output units for classification. + + """ + def __init__(self, C, num_classes): - """assuming input size 8x8""" + super(AuxiliaryHeadCIFAR, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), - nn.AvgPool2d( - 5, stride=3, padding=0, count_include_pad=False - ), # image size = 2 x 2 + nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), @@ -108,6 +176,32 @@ def forward(self, x): class NetworkCIFAR(nn.Module): + """ + DARTS network architecture for CIFAR dataset. + + Args: + C (int): Initial number of channels. + num_classes (int): Number of classes for classification. + layers (int): Number of layers. + auxiliary (bool): Whether to use auxiliary heads. + genotype (Genotype): Genotype specifying the cell structure. + + Input: + - Input tensor of shape (batch_size, 3, 32, 32), where 3 is for RGB channels. + + Output: + - Main network output tensor of shape (batch_size, num_classes). + - Auxiliary head output tensor if auxiliary is True and during training. + + Architecture: + - Stem: Initial convolution layer followed by batch normalization. + - Cells: Stack of cells with specified genotype. + - Auxiliary Head: Optional auxiliary head for training stability. + - Global Pooling: Adaptive average pooling to 1x1 size. + - Classifier: Linear layer for classification. + + """ + def __init__(self, C, num_classes, layers, auxiliary, genotype): super(NetworkCIFAR, self).__init__() self._layers = layers @@ -158,6 +252,25 @@ def forward(self, input): class NetworkImageNet(nn.Module): + """ + Network architecture for ImageNet dataset. + + Args: + C (int): Initial number of channels. + num_classes (int): Number of classes for classification. + layers (int): Number of layers. + auxiliary (bool): Whether to include an auxiliary head. + genotype (Genotype): Genotype specifying the cell structure. + + Input: + - Input tensor of shape (batch_size, 3, height, width). + + Output: + - Main classifier logits tensor of shape (batch_size, num_classes). + - Auxiliary classifier logits tensor if auxiliary is True, otherwise None. + + """ + def __init__(self, C, num_classes, layers, auxiliary, genotype): super(NetworkImageNet, self).__init__() self._layers = layers diff --git a/python/fedml/model/cv/darts/model_search.py b/python/fedml/model/cv/darts/model_search.py index 75c5a504dd..dba356e6ea 100644 --- a/python/fedml/model/cv/darts/model_search.py +++ b/python/fedml/model/cv/darts/model_search.py @@ -7,7 +7,44 @@ from .utils import count_parameters_in_MB +import torch.nn as nn + class MixedOp(nn.Module): + """ + Mixed Operation Module for Neural Architecture Search (NAS). + + This module represents a mixture of different operations and allows for dynamic selection of one + of these operations based on a set of weights. + + Args: + C (int): Number of input channels. + stride (int): The stride for the operations. + + Input: + - Input tensor `x` of shape (batch_size, C, H, W), where `C` is the number of input channels, + and `H` and `W` are the spatial dimensions. + + Output: + - Output tensor of shape (batch_size, C, H', W'), where `C` is the number of output channels, + and `H'` and `W'` are the spatial dimensions after applying the selected operation. + + Attributes: + - _ops (nn.ModuleList): A list of operations to be mixed based on weights. + + Note: + - This module is typically used in Neural Architecture Search (NAS) to create a mixed operation + that combines different operations (e.g., convolution, pooling) and allows the architecture + search algorithm to learn which operations to use. + + Example: + To create an instance of the MixedOp module and use it in a NAS cell: + >>> mixed_op = MixedOp(C=64, stride=1) + >>> input_tensor = torch.randn(1, 64, 32, 32) # Example input tensor + >>> weights = torch.randn(5) # Example operation mixing weights + >>> output = mixed_op(input_tensor, weights) # Apply the mixed operation to the input + + """ + def __init__(self, C, stride): super(MixedOp, self).__init__() self._ops = nn.ModuleList() @@ -18,11 +55,67 @@ def __init__(self, C, stride): self._ops.append(op) def forward(self, x, weights): - # w is the operation mixing weights. see equation 2 in the original paper. + """ + Forward pass of the MixedOp module. + + Args: + x (Tensor): Input tensor of shape (batch_size, C, H, W). + weights (Tensor): Operation mixing weights of shape (num_operations,). + + Returns: + output (Tensor): Output tensor of shape (batch_size, C, H', W'). + + """ + # Apply the selected operation based on the given weights return sum(w * op(x) for w, op in zip(weights, self._ops)) class Cell(nn.Module): + """ + Cell Module for Neural Architecture Search (NAS). + + This module represents a cell in a neural network architecture designed for NAS. It contains a sequence + of mixed operations and is used to create the architecture search space. + + Args: + steps (int): The number of steps (operations) in the cell. + multiplier (int): The multiplier for the number of output channels. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels. + reduction (bool): Whether the cell performs reduction (downsampling). + reduction_prev (bool): Whether the previous cell performs reduction. + + Input: + - Two input tensors `s0` and `s1` of shape (batch_size, C_prev_prev, H, W) and (batch_size, C_prev, H, W), + where `C_prev_prev` is the number of input channels from two steps back, `C_prev` is the number of input + channels from the previous step, and `H` and `W` are the spatial dimensions. + + Output: + - Output tensor of shape (batch_size, C, H', W'), where `C` is the number of output channels, + and `H'` and `W'` are the spatial dimensions after applying the cell operations. + + Attributes: + - preprocess0 (nn.Module): Preprocessing layer for input `s0`. + - preprocess1 (nn.Module): Preprocessing layer for input `s1`. + - _steps (int): The number of steps (operations) in the cell. + - _multiplier (int): The multiplier for the number of output channels. + - _ops (nn.ModuleList): List of mixed operations to be applied in the cell. + + Note: + - This module is typically used in Neural Architecture Search (NAS) to create cells with different + combinations of operations, which are then combined to form a complete neural network architecture. + + Example: + To create an instance of the Cell module and use it in an NAS network: + >>> cell = Cell(steps=4, multiplier=4, C_prev_prev=48, C_prev=48, C=192, reduction=False, reduction_prev=True) + >>> input_s0 = torch.randn(1, 48, 32, 32) # Example input tensor s0 + >>> input_s1 = torch.randn(1, 48, 32, 32) # Example input tensor s1 + >>> weights = torch.randn(14) # Example operation mixing weights + >>> output = cell(input_s0, input_s1, weights) # Apply the cell operations to the inputs + + """ + def __init__( self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev ): @@ -38,7 +131,7 @@ def __init__( self._multiplier = multiplier self._ops = nn.ModuleList() - self._bns = nn.ModuleList() + for i in range(self._steps): for j in range(2 + i): stride = 2 if reduction and j < 2 else 1 @@ -46,6 +139,18 @@ def __init__( self._ops.append(op) def forward(self, s0, s1, weights): + """ + Forward pass of the Cell module. + + Args: + s0 (Tensor): Input tensor s0 of shape (batch_size, C_prev_prev, H, W). + s1 (Tensor): Input tensor s1 of shape (batch_size, C_prev, H, W). + weights (Tensor): Operation mixing weights of shape (num_operations,). + + Returns: + output (Tensor): Output tensor of shape (batch_size, C, H', W'). + + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -63,6 +168,52 @@ def forward(self, s0, s1, weights): class InnerCell(nn.Module): + """ + InnerCell Module for Neural Architecture Search (NAS). + + This module represents an inner cell in a neural network architecture designed for NAS. It contains a sequence + of mixed operations and is used to create the architecture search space. + + Args: + steps (int): The number of steps (operations) in the inner cell. + multiplier (int): The multiplier for the number of output channels. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels. + reduction (bool): Whether the inner cell performs reduction (downsampling). + reduction_prev (bool): Whether the previous cell performs reduction. + weights (Tensor): Operation mixing weights for the inner cell. + + Input: + - Two input tensors `s0` and `s1` of shape (batch_size, C_prev_prev, H, W) and (batch_size, C_prev, H, W), + where `C_prev_prev` is the number of input channels from two steps back, `C_prev` is the number of input + channels from the previous step, and `H` and `W` are the spatial dimensions. + + Output: + - Output tensor of shape (batch_size, C, H', W'), where `C` is the number of output channels, + and `H'` and `W'` are the spatial dimensions after applying the inner cell operations. + + Attributes: + - preprocess0 (nn.Module): Preprocessing layer for input `s0`. + - preprocess1 (nn.Module): Preprocessing layer for input `s1`. + - _steps (int): The number of steps (operations) in the inner cell. + - _multiplier (int): The multiplier for the number of output channels. + - _ops (nn.ModuleList): List of mixed operations to be applied in the inner cell. + + Note: + - This module is typically used in Neural Architecture Search (NAS) to create inner cells with different + combinations of operations, which are then combined to form a complete neural network architecture. + + Example: + To create an instance of the InnerCell module and use it in an NAS network: + >>> inner_cell = InnerCell(steps=4, multiplier=4, C_prev_prev=48, C_prev=48, C=192, reduction=False, + ... reduction_prev=True, weights=weights) + >>> input_s0 = torch.randn(1, 48, 32, 32) # Example input tensor s0 + >>> input_s1 = torch.randn(1, 48, 32, 32) # Example input tensor s1 + >>> output = inner_cell(input_s0, input_s1) # Apply the inner cell operations to the inputs + + """ + def __init__( self, steps, @@ -86,8 +237,7 @@ def __init__( self._multiplier = multiplier self._ops = nn.ModuleList() - self._bns = nn.ModuleList() - # len(self._ops)=2+3+4+5=14 + offset = 0 keys = list(OPS.keys()) for i in range(self._steps): @@ -102,6 +252,17 @@ def __init__( offset += i + 2 def forward(self, s0, s1): + """ + Forward pass of the InnerCell module. + + Args: + s0 (Tensor): Input tensor s0 of shape (batch_size, C_prev_prev, H, W). + s1 (Tensor): Input tensor s1 of shape (batch_size, C_prev, H, W). + + Returns: + output (Tensor): Output tensor of shape (batch_size, C, H', W'). + + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -117,16 +278,51 @@ def forward(self, s0, s1): class ModelForModelSizeMeasure(nn.Module): """ - This class is used only for calculating the size of the generated model. - The choices of opeartions are made using the current alpha value of the DARTS model. - The main difference between this model and DARTS model are the following: - 1. The __init__ takes one more parameter "alphas_normal" and "alphas_reduce" - 2. The new Cell module is rewriten to contain the functionality of both Cell and MixedOp - 3. To be more specific, MixedOp is replaced with a fixed choice of operation based on - the argmax(alpha_values) - 4. The new Cell class is redefined as an Inner Class. The name is the same, so please be - very careful when you change the code later - 5. + Model used solely for measuring the size of the generated model. + + This class is designed to calculate the size of a model based on specific choices of operations determined by + the alpha values of the DARTS model. It serves the purpose of estimating the model size without performing + actual training or inference. + + Differences from the DARTS model: + 1. Additional parameters "alphas_normal" and "alphas_reduce" are required in the constructor. + 2. The Cell module combines the functionality of both Cell and MixedOp. + 3. MixedOp is replaced with a fixed choice of operation based on the argmax(alpha_values). + 4. The Cell class is redefined as an inner class with the same name. + + Args: + C (int): The number of channels in the input data. + num_classes (int): The number of output classes. + layers (int): The number of layers in the model. + criterion: The loss criterion used for training. + alphas_normal (Tensor): Alpha values for normal cells. + alphas_reduce (Tensor): Alpha values for reduction cells. + steps (int, optional): The number of steps (operations) in each cell. Default is 4. + multiplier (int, optional): The multiplier for the number of output channels. Default is 4. + stem_multiplier (int, optional): The multiplier for the number of channels in the stem. Default is 3. + + Input: + - Input tensor of shape (batch_size, 3, H, W), where `batch_size` is the number of input samples, + `H` and `W` are the spatial dimensions, and `3` represents the RGB channels. + + Output: + - Output tensor of shape (batch_size, num_classes), representing class predictions. + + Attributes: + - stem (nn.Sequential): Stem layer consisting of a convolutional layer and batch normalization. + - cells (nn.ModuleList): List of inner cells that make up the model. + - global_pooling (nn.AdaptiveAvgPool2d): Global pooling layer for spatial aggregation. + - classifier (nn.Linear): Fully connected layer for class prediction. + + Note: + - This class is primarily used for measuring the size of a model and does not perform training or inference. + + Example: + To create an instance of the ModelForModelSizeMeasure and use it to measure the model size: + >>> model = ModelForModelSizeMeasure(C=16, num_classes=10, layers=8, criterion=nn.CrossEntropyLoss(), + ... alphas_normal=alphas_normal, alphas_reduce=alphas_reduce) + >>> input_data = torch.randn(1, 3, 32, 32) # Example input tensor + >>> model_size = get_model_size(model, input_data) # Get the estimated model size """ @@ -159,7 +355,7 @@ def __init__( self.cells = nn.ModuleList() reduction_prev = False - # for layers = 8, when layer_i = 2, 5, the cell is reduction cell. + for i in range(layers): if i in [layers // 3, 2 * layers // 3]: C_curr *= 2 @@ -207,6 +403,52 @@ def forward(self, input_data): class Network(nn.Module): + """ + DARTS-based neural network model for image classification. + + Args: + C (int): The number of channels in the input data. + num_classes (int): The number of output classes. + layers (int): The number of layers in the model. + criterion: The loss criterion used for training. + steps (int, optional): The number of steps (operations) in each cell. Default is 4. + multiplier (int, optional): The multiplier for the number of output channels. Default is 4. + stem_multiplier (int, optional): The multiplier for the number of channels in the stem. Default is 3. + + Input: + - Input tensor of shape (batch_size, 3, H, W), where `batch_size` is the number of input samples, + `H` and `W` are the spatial dimensions, and `3` represents the RGB channels. + + Output: + - Output tensor of shape (batch_size, num_classes), representing class predictions. + + Attributes: + - stem (nn.Sequential): Stem layer consisting of a convolutional layer and batch normalization. + - cells (nn.ModuleList): List of inner cells that make up the model. + - global_pooling (nn.AdaptiveAvgPool2d): Global pooling layer for spatial aggregation. + - classifier (nn.Linear): Fully connected layer for class prediction. + - alphas_normal (nn.Parameter): Learnable alpha values for normal cells. + - alphas_reduce (nn.Parameter): Learnable alpha values for reduction cells. + + Methods: + - new(self): Create a new instance of the network with the same architecture and initialize alpha values. + - new_arch_parameters(self): Generate new architecture parameters (alphas) for the network. + - arch_parameters(self): Get the current architecture parameters (alphas) of the network. + - genotype(self): Get the genotype of the network, which describes the architecture. + - get_current_model_size(self): Estimate the current model size in megabytes. + + Note: + - This class is based on the DARTS (Differentiable Architecture Search) architecture and is used for + neural architecture search (NAS) experiments. + + Example: + To create an instance of the Network class and use it for architecture search: + >>> model = Network(C=16, num_classes=10, layers=8, criterion=nn.CrossEntropyLoss()) + >>> input_data = torch.randn(1, 3, 32, 32) # Example input tensor + >>> genotype, normal_count, reduce_count = model.genotype() # Get the architecture genotype + >>> model_size = model.get_current_model_size() # Get the estimated model size + """ + def __init__( self, C, @@ -263,6 +505,12 @@ def __init__( self._initialize_alphas() def new(self): + """ + Create a new instance of the network with the same architecture and initialize alpha values. + + Returns: + Network: A new instance of the Network class with the same architecture. + """ model_new = Network( self._C, self._num_classes, self._layers, self._criterion, self.device ).to(self.device) @@ -271,6 +519,16 @@ def new(self): return model_new def forward(self, input): + """ + Forward pass of the neural network. + + Args: + input (Tensor): Input tensor of shape (batch_size, 3, H, W), where `batch_size` is the number of + input samples, `H` and `W` are the spatial dimensions, and `3` represents the RGB channels. + + Returns: + Tensor: Output tensor of shape (batch_size, num_classes), representing class predictions. + """ s0 = s1 = self.stem(input) for i, cell in enumerate(self.cells): if cell.reduction: @@ -283,6 +541,9 @@ def forward(self, input): return logits def _initialize_alphas(self): + """ + Initialize alpha values for normal and reduction cells. + """ k = sum(1 for i in range(self._steps) for n in range(2 + i)) num_ops = len(PRIMITIVES) @@ -294,6 +555,12 @@ def _initialize_alphas(self): ] def new_arch_parameters(self): + """ + Generate new architecture parameters (alphas) for the network. + + Returns: + List[nn.Parameter]: List of architecture parameters (alphas). + """ k = sum(1 for i in range(self._steps) for n in range(2 + i)) num_ops = len(PRIMITIVES) @@ -306,9 +573,21 @@ def new_arch_parameters(self): return _arch_parameters def arch_parameters(self): + """ + Get the current architecture parameters (alphas) of the network. + + Returns: + List[nn.Parameter]: List of architecture parameters (alphas). + """ return self._arch_parameters def genotype(self): + """ + Get the genotype of the network, which describes the architecture. + + Returns: + Genotype: The genotype of the network. + """ def _isCNNStructure(k_best): return k_best >= 4 @@ -360,6 +639,12 @@ def _parse(weights): return genotype, cnn_structure_count_normal, cnn_structure_count_reduce def get_current_model_size(self): + """ + Estimate the current model size in megabytes. + + Returns: + float: The estimated model size in megabytes. + """ model = ModelForModelSizeMeasure( self._C, self._num_classes, diff --git a/python/fedml/model/cv/darts/model_search_gdas.py b/python/fedml/model/cv/darts/model_search_gdas.py index 144c4af567..756c477516 100644 --- a/python/fedml/model/cv/darts/model_search_gdas.py +++ b/python/fedml/model/cv/darts/model_search_gdas.py @@ -8,6 +8,17 @@ class MixedOp(nn.Module): def __init__(self, C, stride): + """ + Initialize a MixedOp module. + + Args: + C (int): The number of input channels. + stride (int): The stride for the operation. + + Note: + PRIMITIVES: a list of operation primitives. + OPS: a dictionary mapping operation primitives to corresponding operation classes. + """ super(MixedOp, self).__init__() self._ops = nn.ModuleList() for primitive in PRIMITIVES: @@ -17,6 +28,17 @@ def __init__(self, C, stride): self._ops.append(op) def forward(self, x, weights, cpu_weights): + """ + Perform a forward pass through the MixedOp module. + + Args: + x (Tensor): Input tensor. + weights (Tensor): Weights for the operations. + cpu_weights (list): Weights converted to CPU. + + Returns: + Tensor: Output tensor after applying the mixed operations. + """ clist = [] for j, cpu_weight in enumerate(cpu_weights): if abs(cpu_weight) > 1e-10: @@ -31,6 +53,18 @@ class Cell(nn.Module): def __init__( self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev ): + """ + Initialize a Cell module. + + Args: + steps (int): The number of steps in the cell. + multiplier (int): Multiplier for the number of output channels. + C_prev_prev (int): Number of input channels from two steps back. + C_prev (int): Number of input channels from the previous step. + C (int): Number of output channels for the cell. + reduction (bool): Whether it's a reduction cell. + reduction_prev (bool): Whether the previous cell was a reduction cell. + """ super(Cell, self).__init__() self.reduction = reduction @@ -51,6 +85,17 @@ def __init__( self._ops.append(op) def forward(self, s0, s1, weights): + """ + Perform a forward pass through the Cell module. + + Args: + s0 (Tensor): Input tensor from two steps back. + s1 (Tensor): Input tensor from the previous step. + weights (Tensor): Weights for the operations. + + Returns: + Tensor: Output tensor after applying the cell operations. + """ s0 = self.preprocess0(s0) s1 = self.preprocess1(s1) @@ -64,7 +109,7 @@ def forward(self, s0, s1, weights): ) offset += len(states) states.append(s) - # logging.info(states) + return torch.cat(states[-self._multiplier :], dim=1) @@ -80,6 +125,19 @@ def __init__( multiplier=4, stem_multiplier=3, ): + """ + Initialize a Network_GumbelSoftmax model. + + Args: + C (int): Number of initial channels. + num_classes (int): Number of output classes. + layers (int): Number of layers. + criterion: Loss criterion. + device: The device to run the model on. + steps (int): Number of steps in each cell. + multiplier (int): Multiplier for the number of output channels. + stem_multiplier (int): Multiplier for the number of initial channels in the stem. + """ super(Network_GumbelSoftmax, self).__init__() self._C = C self._num_classes = num_classes @@ -89,7 +147,7 @@ def __init__( self._multiplier = multiplier self.device = device - C_curr = stem_multiplier * C # 3*16 + C_curr = stem_multiplier * C self.stem = nn.Sequential( nn.Conv2d(3, C_curr, 3, padding=1, bias=False), nn.BatchNorm2d(C_curr) ) @@ -98,7 +156,7 @@ def __init__( self.cells = nn.ModuleList() reduction_prev = False - # for layers = 8, when layer_i = 2, 5, the cell is reduction cell. + for i in range(layers): if i in [layers // 3, 2 * layers // 3]: C_curr *= 2 @@ -166,6 +224,14 @@ def arch_parameters(self): return self._arch_parameters def genotype(self): + """ + Get the architecture genotype of the model. + + Returns: + Genotype: The architecture genotype. + cnn_structure_count_normal (int): Count of CNN structures in normal cells. + cnn_structure_count_reduce (int): Count of CNN structures in reduction cells. + """ def _isCNNStructure(k_best): return k_best >= 4 diff --git a/python/fedml/model/cv/darts/operations.py b/python/fedml/model/cv/darts/operations.py index 1827b2c7d1..5a8cd9ab49 100644 --- a/python/fedml/model/cv/darts/operations.py +++ b/python/fedml/model/cv/darts/operations.py @@ -35,6 +35,25 @@ class ReLUConvBN(nn.Module): + """ + A composite module that applies ReLU activation, followed by a 2D convolution, and then batch normalization. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution operation. + padding (int): Padding for the convolution operation. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super(ReLUConvBN, self).__init__() self.op = nn.Sequential( @@ -50,6 +69,26 @@ def forward(self, x): class DilConv(nn.Module): + """ + A composite module that applies dilated convolution followed by batch normalization. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution operation. + padding (int): Padding for the convolution operation. + dilation (int): Dilation factor for the convolution operation. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__( self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True ): @@ -75,6 +114,25 @@ def forward(self, x): class SepConv(nn.Module): + """ + A composite module that applies separable convolution followed by batch normalization. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + kernel_size (int): Size of the convolution kernel. + stride (int): Stride of the convolution operation. + padding (int): Padding for the convolution operation. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): super(SepConv, self).__init__() self.op = nn.Sequential( @@ -108,7 +166,19 @@ def forward(self, x): return self.op(x) + class Identity(nn.Module): + """ + A module that represents the identity operation (no change). + + Input: + - Input tensor of any shape. + + Output: + - Output tensor with the same shape as the input. + + """ + def __init__(self): super(Identity, self).__init__() @@ -117,6 +187,20 @@ def forward(self, x): class Zero(nn.Module): + """ + A module that represents the zero operation (sets the tensor to zero). + + Args: + stride (int): Stride for selecting elements in the tensor. + + Input: + - Input tensor of any shape. + + Output: + - Output tensor with the same shape as the input, but with selected elements set to zero. + + """ + def __init__(self, stride): super(Zero, self).__init__() self.stride = stride @@ -128,6 +212,22 @@ def forward(self, x): class FactorizedReduce(nn.Module): + """ + A module that applies factorized reduction to reduce spatial dimensions. + + Args: + C_in (int): Number of input channels. + C_out (int): Number of output channels. + affine (bool): Whether to apply affine transformation in batch normalization. + + Input: + - Input tensor of shape (batch_size, C_in, height, width). + + Output: + - Output tensor of shape (batch_size, C_out, new_height, new_width). + + """ + def __init__(self, C_in, C_out, affine=True): super(FactorizedReduce, self).__init__() assert C_out % 2 == 0 diff --git a/python/fedml/model/cv/darts/train.py b/python/fedml/model/cv/darts/train.py index 95ef17d38b..9d5acd7da2 100644 --- a/python/fedml/model/cv/darts/train.py +++ b/python/fedml/model/cv/darts/train.py @@ -184,6 +184,19 @@ def main(): def train(train_queue, model, criterion, optimizer): + """ + Perform training on the training dataset. + + Args: + train_queue (DataLoader): DataLoader for the training dataset. + model (nn.Module): The neural network model. + criterion (nn.Module): The loss function. + optimizer (Optimizer): The optimizer for updating model parameters. + + Returns: + float: Top-1 accuracy on the training dataset. + float: Average loss on the training dataset. + """ global is_multi_gpu objs = utils.AvgrageMeter() top1 = utils.AvgrageMeter() @@ -221,6 +234,18 @@ def train(train_queue, model, criterion, optimizer): def infer(valid_queue, model, criterion): + """ + Perform inference on the validation dataset using the trained model. + + Args: + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The trained neural network model. + criterion (nn.Module): The loss function used for validation. + + Returns: + float: Top-1 accuracy on the validation dataset. + float: Average loss on the validation dataset. + """ global is_multi_gpu objs = utils.AvgrageMeter() diff --git a/python/fedml/model/cv/darts/train_search.py b/python/fedml/model/cv/darts/train_search.py index 1bdc7e8d90..52f43dabf6 100644 --- a/python/fedml/model/cv/darts/train_search.py +++ b/python/fedml/model/cv/darts/train_search.py @@ -352,6 +352,26 @@ def main(): def train(epoch, train_queue, valid_queue, model, architect, criterion, optimizer, lr): + """ + Train the neural network for one epoch. + + Args: + epoch (int): Current epoch number. + train_queue (DataLoader): DataLoader for the training dataset. + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The neural network model to be trained. + architect (Architect): The architect responsible for updating architecture weights. + criterion (nn.Module): The loss function used for training. + optimizer (torch.optim.Optimizer): The optimizer for updating model weights. + lr (float): Learning rate. + + Returns: + float: Top-1 accuracy on the training dataset. + float: Average loss on the training dataset. + float: Loss value. + + """ + global is_multi_gpu objs = utils.AvgrageMeter() @@ -407,6 +427,20 @@ def train(epoch, train_queue, valid_queue, model, architect, criterion, optimize def infer(valid_queue, model, criterion): + """ + Perform inference on the validation dataset using the trained model. + + Args: + valid_queue (DataLoader): DataLoader for the validation dataset. + model (nn.Module): The trained neural network model. + criterion (nn.Module): The loss function used for validation. + + Returns: + float: Top-1 accuracy on the validation dataset. + float: Average loss on the validation dataset. + float: Loss value. + + """ global is_multi_gpu objs = utils.AvgrageMeter() diff --git a/python/fedml/model/cv/darts/utils.py b/python/fedml/model/cv/darts/utils.py index 0f024b4614..696f121d30 100644 --- a/python/fedml/model/cv/darts/utils.py +++ b/python/fedml/model/cv/darts/utils.py @@ -7,23 +7,60 @@ from torch.autograd import Variable -class AvgrageMeter(object): +class AverageMeter(object): + """ + Computes and stores the average and sum of values over time. + + Attributes: + avg (float): The current average value. + sum (float): The current sum of values. + cnt (int): The current count of values. + + Methods: + reset(): Reset the average, sum, and count to zero. + update(val, n=1): Update the meter with a new value and count. + + """ def __init__(self): + """ + Initializes an AverageMeter object with initial values of zero. + """ self.reset() def reset(self): + """ + Reset the average, sum, and count to zero. + """ self.avg = 0 self.sum = 0 self.cnt = 0 def update(self, val, n=1): + """ + Update the meter with a new value and count. + + Args: + val (float): The new value to update the meter with. + n (int): The count associated with the new value. Default is 1. + """ self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt def accuracy(output, target, topk=(1,)): + """ + Computes the accuracy of model predictions given the output and target labels. + + Args: + output (Tensor): The model's output predictions. + target (Tensor): The ground truth labels. + topk (tuple of int): The top-k accuracy values to compute. Default is (1,). + + Returns: + list of float: A list of top-k accuracy values. + """ maxk = max(topk) batch_size = target.size(0) @@ -39,10 +76,35 @@ def accuracy(output, target, topk=(1,)): class Cutout(object): + """ + Apply cutout augmentation to an image. + + Args: + length (int): The size of the cutout square region. + + """ + def __init__(self, length): + """ + Initializes the Cutout object with a specified cutout length. + + Args: + length (int): The size of the cutout square region. + + """ self.length = length def __call__(self, img): + """ + Apply cutout augmentation to an image. + + Args: + img (PIL.Image): The input image. + + Returns: + PIL.Image: The augmented image with cutout applied. + + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -61,6 +123,16 @@ def __call__(self, img): def _data_transforms_cifar10(args): + """ + Define data transformations for CIFAR-10 dataset. + + Args: + args (argparse.Namespace): Command line arguments. + + Returns: + tuple: A tuple of train and validation data transforms. + + """ CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] @@ -81,10 +153,29 @@ def _data_transforms_cifar10(args): def count_parameters_in_MB(model): + """ + Count the number of parameters in a model in megabytes (MB). + + Args: + model (nn.Module): The model for which to count parameters. + + Returns: + float: The number of parameters in megabytes (MB). + + """ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 def save_checkpoint(state, is_best, save): + """ + Save a checkpoint of the model's state. + + Args: + state (dict): The model's state dictionary. + is_best (bool): True if this is the best checkpoint, False otherwise. + save (str): The directory where the checkpoint will be saved. + + """ filename = os.path.join(save, 'checkpoint.pth.tar') torch.save(state, filename) if is_best: @@ -93,14 +184,41 @@ def save_checkpoint(state, is_best, save): def save(model, model_path): + """ + Save the model's state dictionary to a file. + + Args: + model (nn.Module): The PyTorch model to be saved. + model_path (str): The path to the file where the model state will be saved. + + """ torch.save(model.state_dict(), model_path) def load(model, model_path): + """ + Load a model's state dictionary from a file into the model. + + Args: + model (nn.Module): The PyTorch model to which the state will be loaded. + model_path (str): The path to the file containing the model state. + + """ model.load_state_dict(torch.load(model_path)) def drop_path(x, drop_prob): + """ + Apply dropout to a tensor. + + Args: + x (Tensor): The input tensor to which dropout will be applied. + drop_prob (float): The probability of dropping out a value. + + Returns: + Tensor: The tensor after dropout has been applied. + + """ if drop_prob > 0.: keep_prob = 1. - drop_prob mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) @@ -110,6 +228,14 @@ def drop_path(x, drop_prob): def create_exp_dir(path, scripts_to_save=None): + """ + Create an experiment directory and optionally save scripts. + + Args: + path (str): The directory path for the experiment. + scripts_to_save (list of str, optional): List of script file paths to save in the directory. + + """ if not os.path.exists(path): os.mkdir(path) print('Experiment dir : {}'.format(path)) diff --git a/python/fedml/model/cv/darts/visualize.py b/python/fedml/model/cv/darts/visualize.py index df539289e2..79ffbf5970 100644 --- a/python/fedml/model/cv/darts/visualize.py +++ b/python/fedml/model/cv/darts/visualize.py @@ -4,6 +4,19 @@ def plot(genotype, filename): + """ + Generate a visualization of a given genotype and save it as a PDF file. + + Args: + genotype (list of tuples): The genotype to visualize, specifying operations and connections. + filename (str): The name of the PDF file to save the visualization. + + Example usage: + ```python + >>> genotype = [("conv3x3", 0), ("conv3x3", 1), ("maxpool3x3", 0), ("conv1x1", 2), ...] + >>> plot(genotype, "genotype_visualization.pdf") + ``` + """ g = Digraph( format="pdf", edge_attr=dict(fontsize="20", fontname="times"), diff --git a/python/fedml/model/cv/efficientnet_utils.py b/python/fedml/model/cv/efficientnet_utils.py index c95de26259..da000d7fa1 100644 --- a/python/fedml/model/cv/efficientnet_utils.py +++ b/python/fedml/model/cv/efficientnet_utils.py @@ -76,6 +76,15 @@ # An ordinary implementation of Swish function class Swish(nn.Module): def forward(self, x): + """ + Applies the Swish activation function to the input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the Swish activation. + """ return x * torch.sigmoid(x) @@ -83,12 +92,32 @@ def forward(self, x): class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): + """ + Forward pass for the memory-efficient Swish function. + + Args: + ctx: Context object to save tensors for backward pass. + i (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the memory-efficient Swish activation. + """ result = i * torch.sigmoid(i) ctx.save_for_backward(i) return result @staticmethod def backward(ctx, grad_output): + """ + Backward pass for the memory-efficient Swish function. + + Args: + ctx: Context object containing saved tensors from forward pass. + grad_output (torch.Tensor): Gradient of the loss with respect to the output. + + Returns: + torch.Tensor: Gradient of the loss with respect to the input tensor. + """ i = ctx.saved_tensors[0] sigmoid_i = torch.sigmoid(i) return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) @@ -96,17 +125,33 @@ def backward(ctx, grad_output): class MemoryEfficientSwish(nn.Module): def forward(self, x): + """ + Applies the memory-efficient Swish activation function to the input tensor. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the memory-efficient Swish activation. + """ + return SwishImplementation.apply(x) def round_filters(filters, global_params): - """Calculate and round number of filters based on width multiplier. - Use width_coefficient, depth_divisor and min_depth of global_params. + """ + Calculate and round the number of filters based on the width multiplier. + Args: - filters (int): Filters number to be calculated. - global_params (namedtuple): Global params of the model. + filters (int): Number of filters to be calculated. + global_params (namedtuple): Global parameters of the model. + Returns: - new_filters: New filters number after calculating. + int: New number of filters after rounding. + + Example: + # Calculate and round filters based on width multiplier and global parameters. + new_filters = round_filters(64, global_params) """ multiplier = global_params.width_coefficient if not multiplier: @@ -126,13 +171,19 @@ def round_filters(filters, global_params): def round_repeats(repeats, global_params): - """Calculate module's repeat number of a block based on depth multiplier. - Use depth_coefficient of global_params. + """ + Calculate module's repeat number of a block based on the depth multiplier. + Args: - repeats (int): num_repeat to be calculated. - global_params (namedtuple): Global params of the model. + repeats (int): Number of repeats to be calculated. + global_params (namedtuple): Global parameters of the model. + Returns: - new repeat: New repeat number after calculating. + int: New number of repeats after calculation. + + Example: + # Calculate repeats based on depth multiplier and global parameters. + new_repeats = round_repeats(5, global_params) """ multiplier = global_params.depth_coefficient if not multiplier: @@ -142,13 +193,20 @@ def round_repeats(repeats, global_params): def drop_connect(inputs, p, training): - """Drop connect. + """ + Apply drop connect to the input tensor. + Args: - input (tensor: BCWH): Input of this structure. - p (float: 0.0~1.0): Probability of drop connection. - training (bool): The running mode. + inputs (torch.Tensor): Input tensor to which drop connect will be applied. + p (float): Probability of drop connection (0.0 <= p <= 1.0). + training (bool): The running mode (True for training, False for inference). + Returns: - output: Output after drop connection. + torch.Tensor: Output tensor after applying drop connect. + + Example: + # Apply drop connect with a probability of 0.5 during training. + output = drop_connect(inputs, 0.5, training=True) """ assert 0 <= p <= 1, "p must be in range of [0,1]" @@ -170,11 +228,22 @@ def drop_connect(inputs, p, training): def get_width_and_height_from_size(x): - """Obtain height and width from x. + """ + Obtain height and width from a size value. + Args: - x (int, tuple or list): Data size. + x (int, tuple, or list): Data size. + Returns: - size: A tuple or list (H,W). + tuple: A tuple (height, width). + + Raises: + TypeError: If the input is not an int, tuple, or list. + + Example: + # Get height and width from an integer size. + size = get_width_and_height_from_size(32) + # Result: (32, 32) """ if isinstance(x, int): return x, x @@ -185,13 +254,20 @@ def get_width_and_height_from_size(x): def calculate_output_image_size(input_image_size, stride): - """Calculates the output image size when using Conv2dSamePadding with a stride. - Necessary for static padding. Thanks to mannatsingh for pointing this out. + """ + Calculate the output image size when using Conv2dSamePadding with a given stride. + Args: - input_image_size (int, tuple or list): Size of input image. - stride (int, tuple or list): Conv2d operation's stride. + input_image_size (int, tuple, or list): Size of the input image. + stride (int, tuple, or list): Conv2d operation's stride. + Returns: - output_image_size: A list [H,W]. + list: A list [height, width] representing the output image size. + + Example: + # Calculate the output size for an input image of size 128x128 with a stride of 2. + output_size = calculate_output_image_size((128, 128), 2) + # Result: [64, 64] """ if input_image_size is None: return None @@ -209,12 +285,18 @@ def calculate_output_image_size(input_image_size, stride): def get_same_padding_conv2d(image_size=None): - """Chooses static padding if you have specified an image size, and dynamic padding otherwise. - Static padding is necessary for ONNX exporting of models. + """ + Choose dynamic padding if no image size is specified, otherwise choose static padding. + Args: image_size (int or tuple): Size of the image. + Returns: - Conv2dDynamicSamePadding or Conv2dStaticSamePadding. + Conv2dDynamicSamePadding or Conv2dStaticSamePadding: The appropriate Conv2d class. + + Example: + # Get the Conv2d class with dynamic padding based on image size. + conv2d_class = get_same_padding_conv2d((128, 128)) """ if image_size is None: return Conv2dDynamicSamePadding @@ -223,8 +305,9 @@ def get_same_padding_conv2d(image_size=None): class Conv2dDynamicSamePadding(nn.Conv2d): - """2D Convolutions like TensorFlow, for a dynamic image size. - The padding is operated in forward function by calculating dynamically. + """ + 2D Convolution with dynamic padding based on the input image size. + The padding is calculated dynamically during the forward pass. """ # Tips for 'SAME' mode padding. @@ -279,8 +362,23 @@ def forward(self, x): class Conv2dStaticSamePadding(nn.Conv2d): - """2D Convolutions like TensorFlow's 'SAME' mode, with the given input image size. - The padding mudule is calculated in construction function, then used in forward. + """ + 2D Convolutions with static padding similar to TensorFlow's 'SAME' mode, + using the provided input image size for padding calculation. + + This module calculates the padding during construction and applies it during the forward pass. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int or tuple): Size of the convolutional kernel. + stride (int or tuple, optional): Stride of the convolution. Default is 1. + image_size (int or tuple, optional): Size of the input image. Must be provided for padding calculation. + **kwargs: Additional arguments for nn.Conv2d. + + Example: + # Create a Conv2dStaticSamePadding layer with an input image size of 128x128. + conv_layer = Conv2dStaticSamePadding(in_channels=3, out_channels=64, kernel_size=3, image_size=(128, 128)) """ # With the same calculation as Conv2dDynamicSamePadding @@ -327,12 +425,15 @@ def forward(self, x): def get_same_padding_maxPool2d(image_size=None): - """Chooses static padding if you have specified an image size, and dynamic padding otherwise. - Static padding is necessary for ONNX exporting of models. + """ + Chooses static padding if you have specified an image size, and dynamic padding otherwise. + Static padding is necessary for ONNX exporting of models. + Args: - image_size (int or tuple): Size of the image. + image_size (int or tuple, optional): Size of the image. If provided, static padding will be used. + Returns: - MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding. + MaxPool2dDynamicSamePadding or MaxPool2dStaticSamePadding: A MaxPooling layer with the chosen padding. """ if image_size is None: return MaxPool2dDynamicSamePadding @@ -341,8 +442,21 @@ def get_same_padding_maxPool2d(image_size=None): class MaxPool2dDynamicSamePadding(nn.MaxPool2d): - """2D MaxPooling like TensorFlow's 'SAME' mode, with a dynamic image size. - The padding is operated in forward function by calculating dynamically. + """ + 2D MaxPooling with dynamic padding, similar to TensorFlow's 'SAME' mode, for a dynamic image size. + The padding is calculated dynamically during the forward pass. + + Args: + kernel_size (int or tuple): Size of the max-pooling kernel. + stride (int or tuple): Stride of the max-pooling operation. + padding (int or tuple, optional): Padding to be added. Default is 0. + dilation (int or tuple, optional): Dilation rate. Default is 1. + return_indices (bool, optional): Whether to return the indices. Default is False. + ceil_mode (bool, optional): Whether to use 'ceil' mode for output size. Default is False. + + Example: + # Create a MaxPool2dDynamicSamePadding layer. + maxpool_layer = MaxPool2dDynamicSamePadding(kernel_size=3, stride=2) """ def __init__( @@ -390,8 +504,19 @@ def forward(self, x): class MaxPool2dStaticSamePadding(nn.MaxPool2d): - """2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. - The padding mudule is calculated in construction function, then used in forward. + """ + 2D MaxPooling like TensorFlow's 'SAME' mode, with the given input image size. + The padding module is calculated during construction and then applied in the forward pass. + + Args: + kernel_size (int or tuple): Size of the max-pooling kernel. + stride (int or tuple): Stride of the max-pooling operation. + image_size (int or tuple): Size of the input image. Required to calculate static padding. + **kwargs: Additional keyword arguments for MaxPool2d. + + Example: + # Create a MaxPool2dStaticSamePadding layer with a specified image size. + maxpool_layer = MaxPool2dStaticSamePadding(kernel_size=3, stride=2, image_size=(224, 224)) """ def __init__(self, kernel_size, stride, image_size=None, **kwargs): @@ -448,16 +573,22 @@ def forward(self, x): class BlockDecoder(object): - """Block Decoder for readability, - straight from the official TensorFlow repository. + """ + Block Decoder for readability, straight from the official TensorFlow repository. + + This class provides methods to decode and encode block configurations represented as strings. + These strings define the arguments of each block in a neural network architecture. """ @staticmethod def _decode_block_string(block_string): - """Get a block through a string notation of arguments. + """ + Get a block through a string notation of arguments. + Args: block_string (str): A string notation of arguments. Examples: 'r1_k3_s11_e1_i32_o16_se0.25_noskip'. + Returns: BlockArgs: The namedtuple defined at the top of this file. """ @@ -489,9 +620,12 @@ def _decode_block_string(block_string): @staticmethod def _encode_block_string(block): - """Encode a block to a string. + """ + Encode a block to a string. + Args: block (namedtuple): A BlockArgs type argument. + Returns: block_string: A String form of BlockArgs. """ @@ -511,9 +645,12 @@ def _encode_block_string(block): @staticmethod def decode(string_list): - """Decode a list of string notations to specify blocks inside the network. + """ + Decode a list of string notations to specify blocks inside the network. + Args: string_list (list[str]): A list of strings, each string is a notation of block. + Returns: blocks_args: A list of BlockArgs namedtuples of block args. """ @@ -525,12 +662,16 @@ def decode(string_list): @staticmethod def encode(blocks_args): - """Encode a list of BlockArgs to a list of strings. + """ + Encode a list of BlockArgs to a list of strings. + Args: blocks_args (list[namedtuples]): A list of BlockArgs namedtuples of block args. + Returns: block_strings: A list of strings, each string is a notation of block. """ + block_strings = [] for block in blocks_args: block_strings.append(BlockDecoder._encode_block_string(block)) @@ -538,9 +679,12 @@ def encode(blocks_args): def efficientnet_params(model_name): - """Map EfficientNet model name to parameter coefficients. + """ + Map EfficientNet model name to parameter coefficients. + Args: model_name (str): Model name to be queried. + Returns: params_dict[model_name]: A (width,depth,res,dropout) tuple. """ @@ -569,7 +713,9 @@ def efficientnet( num_classes=1000, include_top=True, ): - """Create BlockArgs and GlobalParams for efficientnet model. + """ + Create BlockArgs and GlobalParams for the EfficientNet model. + Args: width_coefficient (float) depth_coefficient (float) @@ -577,7 +723,8 @@ def efficientnet( dropout_rate (float) drop_connect_rate (float) num_classes (int) - Meaning as the name suggests. + include_top (bool) + Returns: blocks_args, global_params. """ @@ -613,10 +760,13 @@ def efficientnet( def get_model_params(model_name, override_params): - """Get the block args and global params for a given model name. + """ + Get the block args and global params for a given model name. + Args: model_name (str): Model's name. override_params (dict): A dict to modify global_params. + Returns: blocks_args, global_params """ @@ -669,16 +819,17 @@ def get_model_params(model_name, override_params): def load_pretrained_weights( model, model_name, weights_path=None, load_fc=True, advprop=False ): - """Loads pretrained weights from weights path or download using url. + """ + Loads pretrained weights from weights path or download using URL. + Args: - model (Module): The whole model of efficientnet. - model_name (str): Model name of efficientnet. - weights_path (None or str): - str: path to pretrained weights file on the local disk. - None: use pretrained weights downloaded from the Internet. - load_fc (bool): Whether to load pretrained weights for fc layer at the end of the model. - advprop (bool): Whether to load pretrained weights - trained with advprop (valid when weights_path is None). + model (Module): The whole model of EfficientNet. + model_name (str): Model name of EfficientNet. + weights_path (str or None): + - str: Path to pretrained weights file on the local disk. + - None: Use pretrained weights downloaded from the Internet. + load_fc (bool): Whether to load pretrained weights for the fully connected (fc) layer at the end of the model. + advprop (bool): Whether to load pretrained weights trained with advprop (valid when weights_path is None). """ if isinstance(weights_path, str): state_dict = torch.load(weights_path) diff --git a/python/fedml/model/cv/group_normalization.py b/python/fedml/model/cv/group_normalization.py index 0081e37e3f..5444fcdafe 100644 --- a/python/fedml/model/cv/group_normalization.py +++ b/python/fedml/model/cv/group_normalization.py @@ -15,11 +15,39 @@ def group_norm( momentum=0.1, eps=1e-5, ): - """Applies Group Normalization for channels in the same group in each data sample in a - batch. - See :class:`~torch.nn.GroupNorm1d`, :class:`~torch.nn.GroupNorm2d`, - :class:`~torch.nn.GroupNorm3d` for details. """ + Applies Group Normalization for channels in the same group in each data sample in a batch. + + Args: + input (Tensor): The input tensor of shape (N, C, *), where N is the batch size, + C is the number of channels, and * represents any number of additional dimensions. + group (int): The number of groups to divide the channels into. + running_mean (Tensor or None): A tensor of running means for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + running_var (Tensor or None): A tensor of running variances for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + weight (Tensor or None): A tensor to scale the normalized values for each channel. + bias (Tensor or None): A tensor to add an offset to the normalized values for each channel. + use_input_stats (bool): If True, batch statistics (mean and variance) are computed + from the input tensor for normalization. If False, `running_mean` and `running_var` + are used for normalization. + momentum (float): The momentum factor for updating running statistics. + eps (float): A small value added to the denominator for numerical stability. + + Returns: + Tensor: The normalized output tensor with the same shape as the input. + + Note: + Group Normalization is applied to the channels of the input tensor separately within each group. + If `use_input_stats` is True, running statistics (mean and variance) will not be used for + normalization, and batch statistics will be computed from the input tensor. + + See Also: + - :class:`~torch.nn.GroupNorm1d` for 1D input (sequence data). + - :class:`~torch.nn.GroupNorm2d` for 2D input (image data). + - :class:`~torch.nn.GroupNorm3d` for 3D input (volumetric data). + """ + if not use_input_stats and (running_mean is None or running_var is None): raise ValueError( "Expected running_mean and running_var to be not None when use_input_stats=False" @@ -42,6 +70,38 @@ def _instance_norm( momentum=None, eps=None, ): + """ + Applies Instance Normalization for channels within each group in the input tensor. + + Args: + input (Tensor): The input tensor of shape (N, C, *), where N is the batch size, + C is the number of channels, and * represents any number of additional dimensions. + group (int): The number of groups to divide the channels into. + running_mean (Tensor or None): A tensor of running means for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + running_var (Tensor or None): A tensor of running variances for each group, typically + from previous batches. Set to None if `use_input_stats` is True. + weight (Tensor or None): A tensor to scale the normalized values for each channel. + bias (Tensor or None): A tensor to add an offset to the normalized values for each channel. + use_input_stats (bool or None): If True, batch statistics (mean and variance) are computed + from the input tensor for normalization. If False, `running_mean` and `running_var` + are used for normalization. If None, it defaults to True during training and False during inference. + momentum (float): The momentum factor for updating running statistics. + eps (float): A small value added to the denominator for numerical stability. + + Returns: + Tensor: The normalized output tensor with the same shape as the input. + + Note: + Instance Normalization is applied to the channels of the input tensor separately within each group. + If `use_input_stats` is True, running statistics (mean and variance) will not be used for + normalization, and batch statistics will be computed from the input tensor. + + See Also: + - :class:`~torch.nn.InstanceNorm1d` for 1D input (sequence data). + - :class:`~torch.nn.InstanceNorm2d` for 2D input (image data). + - :class:`~torch.nn.InstanceNorm3d` for 3D input (volumetric data). + """ # Repeat stored stats and affine transform params if necessary if running_mean is not None: running_mean_orig = running_mean @@ -94,6 +154,36 @@ def _instance_norm( class _GroupNorm(_BatchNorm): + """ + Applies Group Normalization over a mini-batch of inputs. + + Group Normalization divides the channels into groups and computes statistics + (mean and variance) separately for each group, normalizing each group independently. + It can be used as a normalization layer in various neural network architectures. + + Args: + num_features (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float): A small value added to the denominator for numerical stability. + momentum (float): The momentum factor for updating running statistics. + affine (bool): If True, learnable affine parameters (weight and bias) are applied to + the normalized output. Default is False. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. Default is False. + + Attributes: + num_groups (int): Number of groups the channels are divided into. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. + + Note: + The input tensor should have shape (N, C, *), where N is the batch size, C is the + number of channels, and * represents any number of additional dimensions. + + See Also: + - :class:`~torch.nn.GroupNorm` for a user-friendly interface. + - :class:`~torch.nn.BatchNorm2d` for standard Batch Normalization. + """ def __init__( self, num_features, @@ -129,25 +219,27 @@ def forward(self, input): class GroupNorm2d(_GroupNorm): - r"""Applies Group Normalization over a 4D input (a mini-batch of 2D inputs - with additional channel dimension) as described in the paper - https://arxiv.org/pdf/1803.08494.pdf - `Group Normalization`_ . + """Applies Group Normalization over a 4D input (a mini-batch of 2D inputs + with an additional channel dimension) as described in the paper + "Group Normalization" (https://arxiv.org/pdf/1803.08494.pdf). + Args: - num_features: :math:`C` from an expected input of size - :math:`(N, C, H, W)` - num_groups: - eps: a value added to the denominator for numerical stability. Default: 1e-5 - momentum: the value used for the running_mean and running_var computation. Default: 0.1 - affine: a boolean value that when set to ``True``, this module has - learnable affine parameters. Default: ``True`` - track_running_stats: a boolean value that when set to ``True``, this - module tracks the running mean and variance, and when set to ``False``, - this module does not track such statistics and always uses batch - statistics in both training and eval modes. Default: ``False`` + num_features (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float): A small value added to the denominator for numerical stability. + Default: 1e-5. + momentum (float): The value used for computing running statistics (mean and variance). + Default: 0.1. + affine (bool): If True, learnable affine parameters (weight and bias) are applied to + the normalized output. Default: True. + track_running_stats (bool): If True, this module tracks running statistics + (mean and variance) during training. If False, it uses batch statistics in both + training and evaluation modes. Default: False. + Shape: - - Input: :math:`(N, C, H, W)` - - Output: :math:`(N, C, H, W)` (same shape as input) + - Input: (N, C, H, W) + - Output: (N, C, H, W) (same shape as input) + Examples: >>> # Without Learnable Parameters >>> m = GroupNorm2d(100, 4) @@ -155,8 +247,17 @@ class GroupNorm2d(_GroupNorm): >>> m = GroupNorm2d(100, 4, affine=True) >>> input = torch.randn(20, 100, 35, 45) >>> output = m(input) + + Note: + The input tensor should have shape (N, C, H, W), where N is the batch size, + C is the number of channels, H is the height, and W is the width. + + See Also: + - :class:`~torch.nn.GroupNorm` for a user-friendly interface. + - :class:`~torch.nn.BatchNorm2d` for standard Batch Normalization for 2D data. """ + def _check_input_dim(self, input): if input.dim() != 4: raise ValueError("expected 4D input (got {}D input)".format(input.dim())) @@ -164,7 +265,35 @@ def _check_input_dim(self, input): class GroupNorm3d(_GroupNorm): """ - Assume the data format is (B, C, D, H, W) + Applies 3D Group Normalization over a mini-batch of 3D inputs. + + Group Normalization divides the channels into groups and computes statistics + (mean and variance) separately for each group, normalizing each group independently. + It is designed for 3D data with the format (B, C, D, H, W), where B is the batch size, + C is the number of channels, D is the depth, H is the height, and W is the width. + + Args: + num_features (int): Number of channels in the input tensor. + num_groups (int): Number of groups to divide the channels into. + eps (float): A small value added to the denominator for numerical stability. + momentum (float): The momentum factor for updating running statistics. + affine (bool): If True, learnable affine parameters (weight and bias) are applied to + the normalized output. Default is False. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. Default is False. + + Attributes: + num_groups (int): Number of groups the channels are divided into. + track_running_stats (bool): If True, running statistics (mean and variance) are tracked + during training. + + Note: + The input tensor should have shape (N, C, D, H, W), where N is the batch size, C is the + number of channels, D is the depth, H is the height, and W is the width. + + See Also: + - :class:`~torch.nn.GroupNorm` for a user-friendly interface. + - :class:`~torch.nn.BatchNorm3d` for standard Batch Normalization for 3D data. """ def _check_input_dim(self, input): diff --git a/python/fedml/model/cv/resnet56/resnet_client.py b/python/fedml/model/cv/resnet56/resnet_client.py index 7e26488005..37d5dbe311 100644 --- a/python/fedml/model/cv/resnet56/resnet_client.py +++ b/python/fedml/model/cv/resnet56/resnet_client.py @@ -16,7 +16,19 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """ + 3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int): Stride for the convolution operation. + groups (int): Number of groups for grouped convolution. + dilation (int): Dilation factor for the convolution operation. + + Returns: + nn.Conv2d: A 3x3 convolutional layer. + """ return nn.Conv2d( in_planes, out_planes, @@ -30,11 +42,41 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """ + 1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int): Stride for the convolution operation. + + Returns: + nn.Conv2d: A 1x1 convolutional layer. + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Width of each group. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the block. + + Example: + block = BasicBlock(64, 128, stride=2) + """ + expansion = 1 def __init__( @@ -65,6 +107,15 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass through the BasicBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -84,6 +135,26 @@ def forward(self, x): class Bottleneck(nn.Module): + """ + Bottleneck building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolutional layers. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Width of each group. Default is 64. + dilation (int, optional): Dilation factor for convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the block. + + Example: + block = Bottleneck(256, 512, stride=2) + """ + expansion = 4 def __init__( @@ -113,6 +184,15 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass through the Bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ identity = x out = self.conv1(x) @@ -135,6 +215,7 @@ def forward(self, x): return out + class ResNet(nn.Module): def __init__( self, @@ -148,6 +229,29 @@ def __init__( norm_layer=None, KD=False, ): + """ + ResNet model architecture. + + Args: + block (nn.Module): The block type to use for constructing layers (e.g., BasicBlock or Bottleneck). + layers (list of int): List specifying the number of blocks in each layer. + num_classes (int, optional): Number of output classes. Default is 10. + zero_init_residual (bool, optional): Whether to initialize the last BN in each residual branch to zero. Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Width of each group. Default is 64. + replace_stride_with_dilation (list of bool, optional): List indicating if stride should be replaced with dilation. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + KD (bool, optional): Knowledge distillation flag. Default is False. + + Attributes: + expansion (int): Expansion factor for the blocks. + + Example: + # Example architecture for a ResNet-18 model with 2 blocks in each layer. + model = ResNet(BasicBlock, [2, 2, 2, 2]) + # Alternatively, for a ResNet-50 model with 3 blocks in each layer. + model = ResNet(Bottleneck, [3, 4, 6, 3]) + """ super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -242,6 +346,15 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """ + Forward pass through the ResNet model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output logits and extracted features. + """ x = self.conv1(x) x = self.bn1(x) x = self.relu(x) # B x 16 x 32 x 32 @@ -260,12 +373,22 @@ def forward(self, x): def resnet5_56(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-32 model. + Constructs a ResNet-5-56 model. Args: + c (int): Number of output classes. pretrained (bool): If True, returns a model pre-trained. - """ + path (str, optional): Path to a pre-trained checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + nn.Module: A ResNet-5-56 model. + Example: + # Create a ResNet-5-56 model with 10 output classes. + model = resnet5_56(10) + """ + model = ResNet(BasicBlock, [1, 2, 2], num_classes=c, **kwargs) if pretrained: checkpoint = torch.load(path) @@ -285,10 +408,20 @@ def resnet5_56(c, pretrained=False, path=None, **kwargs): def resnet8_56(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-32 model. + Constructs a ResNet-8-56 model. Args: + c (int): Number of output classes. pretrained (bool): If True, returns a model pre-trained. + path (str, optional): Path to a pre-trained checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet constructor. + + Returns: + nn.Module: A ResNet-8-56 model. + + Example: + # Create a ResNet-8-56 model with 10 output classes. + model = resnet8_56(10) """ model = ResNet(Bottleneck, [2, 2, 2], num_classes=c, **kwargs) diff --git a/python/fedml/model/cv/resnet56/resnet_pretrained.py b/python/fedml/model/cv/resnet56/resnet_pretrained.py index b1c6d93666..356db9fa5b 100644 --- a/python/fedml/model/cv/resnet56/resnet_pretrained.py +++ b/python/fedml/model/cv/resnet56/resnet_pretrained.py @@ -15,7 +15,23 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """ + Create a 3x3 convolution layer with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 3x3 convolution layer. + + Example: + # Create a 3x3 convolution layer with 64 input channels and 128 output channels. + conv_layer = conv3x3(64, 128) + """ return nn.Conv2d( in_planes, out_planes, @@ -29,11 +45,45 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """ + Create a 1x1 convolution layer. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 1x1 convolution layer. + + Example: + # Create a 1x1 convolution layer with 64 input channels and 128 output channels. + conv_layer = conv1x1(64, 128) + """ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): The expansion factor of the block. + + Example: + # Create a BasicBlock with 64 input channels and 128 output channels. + block = BasicBlock(64, 128) + """ expansion = 1 def __init__( @@ -83,6 +133,26 @@ def forward(self, x): class Bottleneck(nn.Module): + """ + Bottleneck building block for ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for grouped convolution. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): The expansion factor of the block (default is 4). + + Example: + # Create a Bottleneck block with 64 input channels and 128 output channels. + block = Bottleneck(64, 128) + """ expansion = 4 def __init__( @@ -135,6 +205,29 @@ def forward(self, x): class ResNet(nn.Module): + """ + ResNet model architecture for image classification. + + Args: + block (nn.Module): The building block for the network (e.g., BasicBlock or Bottleneck). + layers (list): List of integers specifying the number of blocks in each layer. + num_classes (int, optional): Number of classes for classification. Default is 10. + zero_init_residual (bool, optional): If True, zero-initialize the last BN in each residual branch. + Default is False. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + width_per_group (int, optional): Base width for grouped convolution. Default is 64. + replace_stride_with_dilation (list or None, optional): List of booleans specifying if the 2x2 stride + should be replaced with dilated convolution in each layer. Default is None. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + KD (bool, optional): Knowledge distillation flag. Default is False. + + Attributes: + expansion (int): The expansion factor of the building block (default is 4). + + Example: + # Create a ResNet-56 model with 10 output classes. + model = ResNet(Bottleneck, [6, 6, 6], num_classes=10) + """ def __init__( self, block, @@ -200,6 +293,23 @@ def __init__( nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + """ + Create a layer of blocks for the ResNet model. + + Args: + block (nn.Module): The building block for the layer (e.g., BasicBlock or Bottleneck). + planes (int): The number of output channels for the layer. + blocks (int): The number of blocks to stack in the layer. + stride (int, optional): The stride for the layer's convolutional operations. Default is 1. + dilate (bool, optional): If True, apply dilated convolutions in the layer. Default is False. + + Returns: + nn.Sequential: A sequential container of blocks representing the layer. + + Example: + # Create a layer of 2 Bottleneck blocks with 64 output channels and stride 1. + layer = self._make_layer(Bottleneck, 64, 2, stride=1) + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -241,6 +351,17 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """ + Forward pass of the ResNet model. + + Args: + x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is the batch size, C is the number of + channels, H is the height, and W is the width. + + Returns: + torch.Tensor: The output tensor of shape (B, num_classes) representing class logits. + torch.Tensor: Extracted features before the classification layer, of shape (B, C, H, W). + """ x = self.conv1(x) x = self.bn1(x) @@ -260,10 +381,20 @@ def forward(self, x): def resnet32_pretrained(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-32 model. + Constructs a pre-trained ResNet-32 model. Args: - pretrained (bool): If True, returns a model pre-trained. + c (int): The number of output classes. + pretrained (bool): If True, returns a model pre-trained on a given path. + path (str, optional): The path to the pre-trained model checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet model. + + Returns: + nn.Module: A pre-trained ResNet-32 model. + + Example: + # Create a pre-trained ResNet-32 model with 10 output classes. + model = resnet32_pretrained(10, pretrained=True, path='pretrained_resnet32.pth') """ model = ResNet(BasicBlock, [5, 5, 5], num_classes=c, **kwargs) @@ -285,10 +416,20 @@ def resnet32_pretrained(c, pretrained=False, path=None, **kwargs): def resnet56_pretrained(c, pretrained=False, path=None, **kwargs): """ - Constructs a ResNet-110 model. + Constructs a pre-trained ResNet-56 model. Args: - pretrained (bool): If True, returns a model pre-trained. + c (int): The number of output classes. + pretrained (bool): If True, returns a model pre-trained on a given path. + path (str, optional): The path to the pre-trained model checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet model. + + Returns: + nn.Module: A pre-trained ResNet-56 model. + + Example: + # Create a pre-trained ResNet-56 model with 10 output classes. + model = resnet56_pretrained(10, pretrained=True, path='pretrained_resnet56.pth') """ logging.info("path = " + str(path)) model = ResNet(Bottleneck, [6, 6, 6], num_classes=c, **kwargs) diff --git a/python/fedml/model/cv/resnet56/resnet_server.py b/python/fedml/model/cv/resnet56/resnet_server.py index a481461b1a..7ca1bf738c 100644 --- a/python/fedml/model/cv/resnet56/resnet_server.py +++ b/python/fedml/model/cv/resnet56/resnet_server.py @@ -17,7 +17,24 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" + """ + 3x3 convolution with padding. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 3x3 convolutional layer. + + Example: + # Create a 3x3 convolution with 64 input channels, 128 output channels, and a stride of 2. + conv_layer = conv3x3(64, 128, stride=2) + """ + return nn.Conv2d( in_planes, out_planes, @@ -31,11 +48,47 @@ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv1x1(in_planes, out_planes, stride=1): - """1x1 convolution""" + """ + 1x1 convolution. + + Args: + in_planes (int): Number of input channels. + out_planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + + Returns: + nn.Conv2d: A 1x1 convolutional layer. + + Example: + # Create a 1x1 convolution with 64 input channels and 128 output channels. + conv_layer = conv1x1(64, 128) + """ + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) class BasicBlock(nn.Module): + """ + Basic building block for a ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for width calculation. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the number of output channels. + + Example: + # Create a BasicBlock with 64 input channels and 128 output channels. + block = BasicBlock(64, 128) + """ + expansion = 1 def __init__( @@ -58,6 +111,19 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass of the BasicBlock. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + Example: + # Forward pass through a BasicBlock. + output = block(input_tensor) + """ identity = x out = self.conv1(x) @@ -77,6 +143,27 @@ def forward(self, x): class Bottleneck(nn.Module): + """ + Bottleneck building block for a ResNet. + + Args: + inplanes (int): Number of input channels. + planes (int): Number of output channels. + stride (int, optional): Stride for the convolution. Default is 1. + downsample (nn.Module, optional): Downsample layer for shortcut connection. Default is None. + groups (int, optional): Number of groups for grouped convolution. Default is 1. + base_width (int, optional): Base width for width calculation. Default is 64. + dilation (int, optional): Dilation rate for the convolution. Default is 1. + norm_layer (nn.Module, optional): Normalization layer. Default is None. + + Attributes: + expansion (int): Expansion factor for the number of output channels. + + Example: + # Create a Bottleneck with 64 input channels and 128 output channels. + bottleneck = Bottleneck(64, 128) + """ + expansion = 4 def __init__( @@ -98,6 +185,19 @@ def __init__( self.stride = stride def forward(self, x): + """ + Forward pass of the Bottleneck. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + Example: + # Forward pass through a Bottleneck block. + output = bottleneck(input_tensor) + """ identity = x out = self.conv1(x) @@ -133,6 +233,42 @@ def __init__( norm_layer=None, KD=False, ): + """ + ResNet model implementation. + + Args: + block (nn.Module): The type of block to use in the network (e.g., BasicBlock or Bottleneck). + layers (list of int): Number of blocks in each layer of the network. + num_classes (int): Number of output classes. Default is 10. + zero_init_residual (bool): Whether to zero-init the last BN in each residual branch. Default is False. + groups (int): Number of groups for grouped convolution. Default is 1. + width_per_group (int): Number of channels per group for grouped convolution. Default is 64. + replace_stride_with_dilation (list of bool): List indicating whether to replace 2x2 stride with dilation. + norm_layer (nn.Module): Normalization layer. Default is None. + KD (bool): Whether to enable knowledge distillation. Default is False. + + Attributes: + block (nn.Module): The type of block used in the network. + layers (list of int): Number of blocks in each layer of the network. + num_classes (int): Number of output classes. + zero_init_residual (bool): Whether to zero-init the last BN in each residual branch. + groups (int): Number of groups for grouped convolution. + base_width (int): Base width for width calculation. + dilation (int): Dilation rate for the convolution. + conv1 (nn.Conv2d): The initial convolutional layer. + bn1 (nn.BatchNorm2d): Batch normalization layer after the initial convolution. + relu (nn.ReLU): ReLU activation function. + layer1 (nn.Sequential): The first layer of the network. + layer2 (nn.Sequential): The second layer of the network. + layer3 (nn.Sequential): The third layer of the network. + avgpool (nn.AdaptiveAvgPool2d): Adaptive average pooling layer. + fc (nn.Linear): Fully connected layer for classification. + KD (bool): Whether knowledge distillation is enabled. + + Example: + # Create a ResNet-18 model with 10 output classes. + resnet = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=10) + """ super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d @@ -179,6 +315,19 @@ def __init__( nn.init.constant_(m.bn2.weight, 0) def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + """ + Helper function to create a layer of blocks. + + Args: + block (nn.Module): The type of block to use. + planes (int): Number of output channels for the layer. + blocks (int): Number of blocks in the layer. + stride (int, optional): Stride for the convolution. Default is 1. + dilate (bool, optional): Whether to use dilation. Default is False. + + Returns: + nn.Sequential: A sequential container of blocks. + """ norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation @@ -212,6 +361,19 @@ def _make_layer(self, block, planes, blocks, stride=1, dilate=False): return nn.Sequential(*layers) def forward(self, x): + """ + Forward pass of the ResNet model. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + + Example: + # Forward pass through a ResNet model. + output = resnet(input_tensor) + """ # x = self.conv1(x) # x = self.bn1(x) # x = self.relu(x) # B x 16 x 32 x 32 @@ -230,8 +392,20 @@ def resnet56_server(c, pretrained=False, path=None, **kwargs): """ Constructs a ResNet-110 model. + This function creates a ResNet-110 model for server-side applications with the specified number of output classes. + Args: + c (int): Number of output classes. pretrained (bool): If True, returns a model pre-trained. + path (str, optional): Path to a pre-trained model checkpoint. Default is None. + **kwargs: Additional keyword arguments to pass to the ResNet model constructor. + + Returns: + nn.Module: A ResNet-110 model. + + Example: + # Create a ResNet-110 model with 10 output classes. + model = resnet56_server(10) """ logging.info("path = " + str(path)) model = ResNet(Bottleneck, [6, 6, 6], num_classes=c, **kwargs) diff --git a/python/fedml/model/linear/lr.py b/python/fedml/model/linear/lr.py index d5bca7fde2..d22e3dc4af 100644 --- a/python/fedml/model/linear/lr.py +++ b/python/fedml/model/linear/lr.py @@ -5,25 +5,42 @@ class LogisticRegression(torch.nn.Module): """ Logistic Regression Model. + This class implements a simple logistic regression model for binary or multi-class classification tasks. + Args: - input_dim (int): The input dimension, typically the number of features in each input sample. - output_dim (int): The output dimension, representing the number of classes or a single output. + input_dim (int): The input dimension, typically representing the number of features in each input sample. + output_dim (int): The output dimension, representing the number of classes (for multi-class) or 1 (for binary). Input: - Input tensor of shape (batch_size, input_dim), where batch_size is the number of input samples. Output: - - Output tensor of shape (batch_size, output_dim), representing class probabilities or a single output. + - Output tensor of shape (batch_size, output_dim), representing class probabilities (for multi-class) + or a single output (for binary). Architecture: - Linear Layer: - Input: input_dim neurons - Output: output_dim neurons - Activation: Sigmoid (for binary classification) or Softmax (for multi-class classification) - + Note: - - For binary classification, output_dim is typically set to 1. - - For multi-class classification, output_dim is the number of classes. + - For binary classification, set output_dim to 1. + - For multi-class classification, output_dim should be set to the number of classes. + + Example: + To create a binary logistic regression model with 10 input features: + >>> model = LogisticRegression(input_dim=10, output_dim=1) + + Forward Method: + The forward method computes the forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities (for multi-class) + or a single output (for binary). """ def __init__(self, input_dim, output_dim): @@ -38,13 +55,11 @@ def forward(self, x): x (Tensor): Input tensor of shape (batch_size, input_dim). Returns: - outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities or a single output. + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities (for multi-class) + or a single output (for binary). """ - # try: + outputs = torch.sigmoid(self.linear(x)) - # except: - # print(x.size()) - # import pdb - # pdb.set_trace() + return outputs diff --git a/python/fedml/model/linear/lr_cifar10.py b/python/fedml/model/linear/lr_cifar10.py index 87d593a547..3c99818edb 100644 --- a/python/fedml/model/linear/lr_cifar10.py +++ b/python/fedml/model/linear/lr_cifar10.py @@ -5,8 +5,11 @@ class LogisticRegression_Cifar10(torch.nn.Module): """ Logistic Regression Model for CIFAR-10 Image Classification. + This class implements a logistic regression model for classifying images in the CIFAR-10 dataset. + Args: - input_dim (int): The input dimension, typically the number of features in each input sample. + input_dim (int): The input dimension, typically representing the number of features in each input sample + (flattened image vectors). output_dim (int): The output dimension, representing the number of classes in CIFAR-10. Input: @@ -21,6 +24,19 @@ class LogisticRegression_Cifar10(torch.nn.Module): - Output: output_dim neurons (class probabilities) - Activation: Sigmoid (to produce class probabilities) + Example: + To create a CIFAR-10 logistic regression model with 3072 input features (32x32x3 images): + >>> model = LogisticRegression_Cifar10(input_dim=3072, output_dim=10) + + Forward Method: + The forward method computes the forward pass of the Logistic Regression model. + + Args: + x (Tensor): Input tensor of shape (batch_size, input_dim). + + Returns: + outputs (Tensor): Output tensor of shape (batch_size, output_dim) with class probabilities. + """ def __init__(self, input_dim, output_dim): super(LogisticRegression_Cifar10, self).__init__() @@ -39,10 +55,13 @@ def forward(self, x): """ # Flatten images into vectors # print(f"size = {x.size()}") + x = x.view(x.size(0), -1) - outputs = torch.sigmoid(self.linear(x)) + # except: # print(x.size()) # import pdb # pdb.set_trace() + + outputs = torch.sigmoid(self.linear(x)) return outputs diff --git a/python/fedml/model/mobile/mnn_lenet.py b/python/fedml/model/mobile/mnn_lenet.py index 2378fc9695..e8527a67f5 100644 --- a/python/fedml/model/mobile/mnn_lenet.py +++ b/python/fedml/model/mobile/mnn_lenet.py @@ -48,7 +48,30 @@ def forward(self, x): def create_mnn_lenet5_model(mnn_file_path): + """ + Create and save a LeNet-5 model in the MNN format. + + Args: + mnn_file_path (str): The path to save the MNN model file. + + Note: + This function assumes you have a LeNet-5 model class defined in a 'lenet5' module. + The LeNet-5 model class should have a 'forward' method that takes an input tensor and returns predictions. + + Example: + To create and save a LeNet-5 model to 'lenet5.mnn': + >>> create_mnn_lenet5_model('lenet5.mnn') + + """ + # Create an instance of the LeNet-5 model net = Lenet5() + + # Define an input tensor with the desired shape (1 batch, 1 channel, 28x28) input_var = MNN.expr.placeholder([1, 1, 28, 28], MNN.expr.NCHW) + + # Perform a forward pass to generate predictions predicts = net.forward(input_var) + + # Save the model to the specified file path F.save([predicts], mnn_file_path) + \ No newline at end of file diff --git a/python/fedml/model/mobile/mnn_resnet.py b/python/fedml/model/mobile/mnn_resnet.py index 4f9cf53744..d265ddf94c 100644 --- a/python/fedml/model/mobile/mnn_resnet.py +++ b/python/fedml/model/mobile/mnn_resnet.py @@ -173,7 +173,29 @@ def forward(self, x): def create_mnn_resnet20_model(mnn_file_path): + """ + Create and save a ResNet-20 model in the MNN format. + + Args: + mnn_file_path (str): The path to save the MNN model file. + + Note: + This function assumes you have a ResNet-20 model class defined in a 'resnet20' module. + The ResNet-20 model class should have a 'forward' method that takes an input tensor and returns predictions. + + Example: + To create and save a ResNet-20 model to 'resnet20.mnn': + >>> create_mnn_resnet20_model('resnet20.mnn') + + """ + # Create an instance of the ResNet-20 model net = Resnet20() + + # Define an input tensor with the desired shape (1 batch, 3 channels, 32x32) input_var = MNN.expr.placeholder([1, 3, 32, 32], MNN.expr.NCHW) + + # Perform a forward pass to generate predictions predicts = net.forward(input_var) + + # Save the model to the specified file path F.save([predicts], mnn_file_path) diff --git a/python/fedml/model/mobile/torch_lenet.py b/python/fedml/model/mobile/torch_lenet.py index ee3f30241f..f72f8bee65 100644 --- a/python/fedml/model/mobile/torch_lenet.py +++ b/python/fedml/model/mobile/torch_lenet.py @@ -29,7 +29,7 @@ class LeNet(nn.Module): - Activation: ReLU - Max Pooling: 2x2 - Fully Connected Layer 1: - - Input: 800 neurons (flattened 50x4x4 from previous layer) + - Input: 800 neurons (flattened 50x4x4 from the previous layer) - Output: 500 neurons - Activation: ReLU - Dropout: 50% dropout rate @@ -38,6 +38,16 @@ class LeNet(nn.Module): - Output: 10 neurons (class probabilities) - Activation: Softmax + Note: + - LeNet-5 is a classic convolutional neural network architecture designed for image classification tasks. + - This implementation follows the original LeNet-5 architecture. + + Example: + To create an instance of the LeNet model: + >>> model = LeNet() + >>> input_tensor = torch.randn(1, 1, 32, 32) # Example input tensor + >>> output = model(input_tensor) # Forward pass to obtain class probabilities + """ def __init__(self): From 10c251c162adb59a4bc1e6a9ee906e3d6bd6b587 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 13 Sep 2023 14:29:35 +0530 Subject: [PATCH 58/70] 23 --- python/fedml/ml/engine/ml_engine_adapter.py | 252 +++++++++++++++++- .../ml/engine/torch_process_group_manager.py | 23 ++ .../fedml/ml/trainer/feddyn_trainer copy.py | 78 +++++- python/fedml/ml/trainer/mime_trainer.py | 64 +++++ python/fedml/ml/trainer/my_model_trainer.py | 75 +++++- .../my_model_trainer_classification.py | 58 ++++ .../fedml/ml/trainer/my_model_trainer_nwp.py | 53 +++- .../my_model_trainer_tag_prediction.py | 52 +++- python/fedml/ml/trainer/scaffold_trainer.py | 57 +++- 9 files changed, 687 insertions(+), 25 deletions(-) diff --git a/python/fedml/ml/engine/ml_engine_adapter.py b/python/fedml/ml/engine/ml_engine_adapter.py index dbae852142..4ec919964c 100644 --- a/python/fedml/ml/engine/ml_engine_adapter.py +++ b/python/fedml/ml/engine/ml_engine_adapter.py @@ -5,11 +5,23 @@ from .torch_process_group_manager import TorchProcessGroupManager from ...core.common.ml_engine_backend import MLEngineBackend +import tensorflow as tf +import numpy as np +from mxnet import np as mx_np def convert_numpy_to_torch_data_format(args, batched_x, batched_y): - import torch - import numpy as np - + """ + Convert batched data from NumPy format to PyTorch format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + torch.Tensor: Batched input data in PyTorch format. + torch.Tensor: Batched output data in PyTorch format. + """ if args.model == "cnn": batched_x = torch.from_numpy(np.asarray(batched_x)).float().reshape(-1, 28, 28) # CNN_MINST else: @@ -20,10 +32,18 @@ def convert_numpy_to_torch_data_format(args, batched_x, batched_y): def convert_numpy_to_tf_data_format(args, batched_x, batched_y): - # https://www.tensorflow.org/api_docs/python/tf/convert_to_tensor - import tensorflow as tf - import numpy as np - + """ + Convert batched data from NumPy format to TensorFlow format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + tf.Tensor: Batched input data in TensorFlow format. + tf.Tensor: Batched output data in TensorFlow format. + """ if args.model == "cnn": batched_x = tf.convert_to_tensor(np.asarray(batched_x), dtype=tf.float32) # CNN_MINST batched_x = tf.reshape(batched_x, [-1, 28, 28]) @@ -35,8 +55,18 @@ def convert_numpy_to_tf_data_format(args, batched_x, batched_y): def convert_numpy_to_jax_data_format(args, batched_x, batched_y): - import numpy as np - + """ + Convert batched data from NumPy format to JAX format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + numpy.ndarray: Batched input data in JAX format. + numpy.ndarray: Batched output data in JAX format. + """ if args.model == "cnn": batched_x = np.asarray(batched_x, dtype=np.float32) # CNN_MINST batched_x = np.reshape(batched_x, [-1, 28, 28]) @@ -48,8 +78,18 @@ def convert_numpy_to_jax_data_format(args, batched_x, batched_y): def convert_numpy_to_mxnet_data_format(args, batched_x, batched_y): - from mxnet import np as mx_np - + """ + Convert batched data from NumPy format to MXNet format. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + mxnet.numpy.ndarray: Batched input data in MXNet format. + mxnet.numpy.ndarray: Batched output data in MXNet format. + """ if args.model == "cnn": batched_x = mx_np.array(batched_x) batched_x = mx_np.reshape(batched_x, [-1, 28, 28]) # pylint: disable=E1101 @@ -61,6 +101,17 @@ def convert_numpy_to_mxnet_data_format(args, batched_x, batched_y): def convert_numpy_to_ml_engine_data_format(args, batched_x, batched_y): + """ + Convert batched data from NumPy format to the format required by a specified machine learning engine. + + Args: + args: Model-specific arguments or configuration. + batched_x (numpy.ndarray): Batched input data. + batched_y (numpy.ndarray): Batched output data. + + Returns: + Data in the format required by the specified machine learning engine. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: return convert_numpy_to_tf_data_format(args, batched_x, batched_y) @@ -75,6 +126,16 @@ def convert_numpy_to_ml_engine_data_format(args, batched_x, batched_y): def is_torch_device_available(args, device_type): + """ + Check if a Torch device of the specified type is available. + + Args: + args: Model-specific arguments or configuration. + device_type (str): The type of Torch device to check (e.g., "gpu", "mps", "cpu"). + + Returns: + bool: True if the Torch device is available, False otherwise. + """ if device_type == MLEngineBackend.ml_device_type_gpu: if torch.cuda.is_available(): return True @@ -99,6 +160,16 @@ def is_torch_device_available(args, device_type): def is_mxnet_device_available(args, device_type): + """ + Check if a MXNet device of the specified type is available. + + Args: + args: Model-specific arguments or configuration. + device_type (str): The type of MXNet device to check (e.g., "cpu", "gpu"). + + Returns: + bool: True if the MXNet device is available, False otherwise. + """ if device_type == MLEngineBackend.ml_device_type_cpu: return True elif device_type == MLEngineBackend.ml_device_type_gpu: @@ -116,6 +187,16 @@ def is_mxnet_device_available(args, device_type): def is_device_available(args, device_type=MLEngineBackend.ml_device_type_gpu): + """ + Check if a specified device type is available based on the provided arguments and ML engine. + + Args: + args: Model-specific arguments or configuration. + device_type (str): The type of device to check (e.g., "gpu", "mps", "cpu"). + + Returns: + bool: True if the device is available, False otherwise. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: import tensorflow as tf @@ -144,6 +225,19 @@ def is_device_available(args, device_type=MLEngineBackend.ml_device_type_gpu): def get_torch_device(args, using_gpu, device_id, device_type): + """ + Get a Torch device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + torch.device: The Torch device. + """ + logging.info( "args = {}, using_gpu = {}, device_id = {}, device_type = {}".format(args, using_gpu, device_id, device_type) ) @@ -165,6 +259,18 @@ def get_torch_device(args, using_gpu, device_id, device_type): def get_tf_device(args, using_gpu, device_id, device_type): + """ + Get a TensorFlow device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + tf.device: The TensorFlow device. + """ import tensorflow as tf if using_gpu: @@ -174,6 +280,18 @@ def get_tf_device(args, using_gpu, device_id, device_type): def get_jax_device(args, using_gpu, device_id, device_type): + """ + Get a JAX device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + jax.devices.Device: The JAX device. + """ import jax devices = jax.devices(None) @@ -187,6 +305,18 @@ def get_jax_device(args, using_gpu, device_id, device_type): def get_mxnet_device(args, using_gpu, device_id, device_type): + """ + Get an MXNet device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + using_gpu (bool): Indicates whether a GPU should be used. + device_id (int): The ID of the GPU device. + device_type (str): The type of device (e.g., "gpu", "mps", "cpu"). + + Returns: + mxnet.context.Context: The MXNet device. + """ import mxnet as mx if using_gpu: @@ -196,6 +326,17 @@ def get_mxnet_device(args, using_gpu, device_id, device_type): def get_device(args, device_id=None, device_type="cpu"): + """ + Get the appropriate device based on the provided arguments and configuration. + + Args: + args: Model-specific arguments or configuration. + device_id (int, optional): The ID of the GPU device. Defaults to None. + device_type (str, optional): The type of device (e.g., "cpu"). Defaults to "cpu". + + Returns: + torch.device, tf.device, jax.devices.Device, mxnet.context.Context: The selected device. + """ using_gpu = True if (hasattr(args, "using_gpu") and args.using_gpu is True) else False if hasattr(args, MLEngineBackend.ml_engine_args_flag): @@ -212,6 +353,17 @@ def get_device(args, device_id=None, device_type="cpu"): def dict_to_device(args, dict_obj, device): + """ + Move a dictionary of objects to the specified device. + + Args: + args: Model-specific arguments or configuration. + dict_obj (dict): A dictionary of objects. + device (torch.device, tf.device, jax.devices.Device, mxnet.context.Context): The target device. + + Returns: + dict: The dictionary with objects on the target device. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: with device: @@ -232,6 +384,17 @@ def dict_to_device(args, dict_obj, device): def model_params_to_device(args, params_obj, device): + """ + Move model parameters to the specified device. + + Args: + args: Model-specific arguments or configuration. + params_obj (dict): A dictionary of model parameters. + device (torch.device, tf.device, jax.devices.Device, mxnet.context.Context): The target device. + + Returns: + dict: The dictionary of model parameters on the target device. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: with device: @@ -255,6 +418,17 @@ def model_params_to_device(args, params_obj, device): def model_to_device(args, model_obj, device): + """ + Move a model to the specified device. + + Args: + args: Model-specific arguments or configuration. + model_obj: The model to be moved to the device. + device: The target device (e.g., torch.device, tf.device, jax.devices.Device, mxnet.context.Context). + + Returns: + The model on the target device. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: with device: @@ -271,6 +445,17 @@ def model_to_device(args, model_obj, device): def torch_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for PyTorch. + + Args: + args: Model-specific arguments or configuration. + model_obj: The PyTorch model. + device: The target device (e.g., torch.device). + + Returns: + TorchProcessGroupManager, torch.nn.parallel.DistributedDataParallel: The process group manager and DDP model. + """ from torch.nn.parallel import DistributedDataParallel as DDP only_gpu = args.using_gpu @@ -283,23 +468,68 @@ def torch_model_ddp(args, model_obj, device): # Todo: add tf ddp def tf_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for TensorFlow. + + Args: + args: Model-specific arguments or configuration. + model_obj: The TensorFlow model. + device: The target device (e.g., tf.device). + + Returns: + None, Model: The process group manager (None for TensorFlow) and DDP model. + """ process_group_manager, model = None, model_obj return process_group_manager, model # Todo: add jax ddp def jax_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for JAX. + + Args: + args: Model-specific arguments or configuration. + model_obj: The JAX model. + device: The target device (e.g., jax.devices.Device). + + Returns: + None, Model: The process group manager (None for JAX) and DDP model. + """ process_group_manager, model = None, model_obj return process_group_manager, model # Todo: add mxnet ddp def mxnet_model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model for MXNet. + + Args: + args: Model-specific arguments or configuration. + model_obj: The MXNet model. + device: The target device (e.g., mxnet.context.Context). + + Returns: + None, Model: The process group manager (None for MXNet) and DDP model. + """ process_group_manager, model = None, model_obj return process_group_manager, model def model_ddp(args, model_obj, device): + """ + Create a Distributed Data Parallel (DDP) model based on the selected ML engine. + + Args: + args: Model-specific arguments or configuration. + model_obj: The model to be wrapped with DDP. + device: The target device (e.g., torch.device, tf.device, jax.devices.Device, mxnet.context.Context). + + Returns: + TorchProcessGroupManager, torch.nn.parallel.DistributedDataParallel or + None, Model: The process group manager and DDP model (or None for non-Torch engines). + """ process_group_manager, model = None, model_obj if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: diff --git a/python/fedml/ml/engine/torch_process_group_manager.py b/python/fedml/ml/engine/torch_process_group_manager.py index 931ff386c6..fb8fa534ef 100644 --- a/python/fedml/ml/engine/torch_process_group_manager.py +++ b/python/fedml/ml/engine/torch_process_group_manager.py @@ -7,6 +7,18 @@ class TorchProcessGroupManager: def __init__(self, rank, world_size, master_address, master_port, only_gpu): + """ + Initialize the TorchProcessGroupManager. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the distributed training. + master_address (str): The address of the master process for communication. + master_port (int): The port for communication with the master process. + only_gpu (bool): Flag indicating whether only GPUs are used for communication. + + Initializes the process group and creates a messaging process group for communication. + """ logging.info("Start process group") logging.info( "rank: %d, world_size: %d, master_address: %s, master_port: %s" @@ -38,7 +50,18 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Clean up the process group. + + Destroys the process group and performs cleanup. + """ dist.destroy_process_group() def get_process_group(self): + """ + Get the messaging process group. + + Returns: + torch.distributed.ProcessGroup: The messaging process group for communication. + """ return self.messaging_pg diff --git a/python/fedml/ml/trainer/feddyn_trainer copy.py b/python/fedml/ml/trainer/feddyn_trainer copy.py index bfabace0bd..f3f63fce06 100644 --- a/python/fedml/ml/trainer/feddyn_trainer copy.py +++ b/python/fedml/ml/trainer/feddyn_trainer copy.py @@ -7,24 +7,84 @@ +import torch + def model_parameter_vector(model): - param = [p.view(-1) for p in model.parameters()] - return torch.concat(param, dim=0) + """ + Flatten the parameters of a PyTorch model into a single 1D tensor. + Args: + model (torch.nn.Module): The PyTorch model. + + Returns: + torch.Tensor: A 1D tensor containing all the flattened model parameters. + """ + param = [p.view(-1) for p in model.parameters()] + return torch.cat(param, dim=0) # Use torch.cat to concatenate tensors def parameter_vector(parameters): + """ + Flatten a dictionary of PyTorch parameters into a single 1D tensor. + + Args: + parameters (dict): A dictionary of PyTorch parameters. + + Returns: + torch.Tensor: A 1D tensor containing all the flattened parameters. + """ param = [p.view(-1) for p in parameters.values()] - return torch.concat(param, dim=0) + return torch.cat(param, dim=0) # Use torch.cat to concatenate tensors + class FedDynModelTrainer(ClientTrainer): + """ + A federated dynamic model trainer that implements training and testing methods. + + Args: + ClientTrainer: The base class for client trainers. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + id (int): The identifier of the client. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args, old_grad): Train the model with federated dynamic regularization. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, old_grad): + """ + Train the model with federated dynamic regularization. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + old_grad (torch.Tensor): The previous gradient. + + Returns: + torch.Tensor: The updated gradient. + """ model = self.model for params in model.parameters(): params.requires_grad = True @@ -137,6 +197,18 @@ def train(self, train_data, device, args, old_grad): def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy and test loss. + """ + model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/mime_trainer.py b/python/fedml/ml/trainer/mime_trainer.py index 8f5f9d4078..a4fe013be2 100644 --- a/python/fedml/ml/trainer/mime_trainer.py +++ b/python/fedml/ml/trainer/mime_trainer.py @@ -13,6 +13,19 @@ def clip_norm(tensors, device, max_norm=1.0, norm_type=2.): + """ + Clip the gradients of a list of tensors to have a maximum norm. + + Args: + tensors (list of torch.Tensor): The list of tensors whose gradients need to be clipped. + device (torch.device): The device (CPU or GPU) on which the tensors are located. + max_norm (float): The maximum norm value for gradient clipping. + norm_type (float): The type of norm to use for computing the gradient norm. + + Returns: + float: The total gradient norm after clipping. + + """ total_norm = torch.norm(torch.stack( [torch.norm(p.detach(), norm_type).to(device) for p in tensors]), norm_type) clip_coef = max_norm / (total_norm + 1e-6) @@ -23,14 +36,55 @@ def clip_norm(tensors, device, max_norm=1.0, norm_type=2.): class MimeModelTrainer(ClientTrainer): + """ + A custom model trainer for Mime-based federated learning. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + accumulate_data_grad(train_data, device, args): Accumulate the gradients of the local data. + train(train_data, device, args, grad_global, global_named_states): Train the model with Mime-based federated learning. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def accumulate_data_grad(self, train_data, device, args): + """ + Accumulate the gradients of the local data. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for gradient computation. + args: Training arguments. + + Returns: + dict: A dictionary containing the accumulated gradients for each parameter. + """ model = self.model model.to(device) @@ -58,6 +112,16 @@ def accumulate_data_grad(self, train_data, device, args): def train(self, train_data, device, args, grad_global, global_named_states): + """ + Train the model with Mime-based federated learning. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + grad_global: Global gradients. + global_named_states: Global model states. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/my_model_trainer.py b/python/fedml/ml/trainer/my_model_trainer.py index f746a72379..e353018db6 100644 --- a/python/fedml/ml/trainer/my_model_trainer.py +++ b/python/fedml/ml/trainer/my_model_trainer.py @@ -6,23 +6,70 @@ class MyModelTrainer(ClientTrainer): + """ + A custom model trainer that implements training and testing methods. + + Args: + ClientTrainer: The base class for client trainers. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + __init__(self, model, args): Initialize the trainer. + get_model_params(self): Get the model parameters as a state dictionary. + set_model_params(self, model_parameters): Set the model parameters from a state dictionary. + on_before_local_training(self, train_data, device, args): Perform actions before local training (optional). + train(self, train_data, device, args): Train the model. + test(self, test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def __init__(self, model, args): super().__init__(model, args) - self.cpu_transfer = False if not hasattr(self.args, "cpu_transfer") else self.args.cpu_transfer + self.cpu_transfer = False if not hasattr( + self.args, "cpu_transfer") else self.args.cpu_transfer def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ if self.cpu_transfer: return self.model.cpu().state_dict() return self.model.state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def on_before_local_training(self, train_data, device, args): + """ + Execute code before local training (optional). + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ pass def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) @@ -31,7 +78,8 @@ def train(self, train_data, device, args): # train and update criterion = nn.CrossEntropyLoss().to(device) # pylint: disable=E1102 if args.client_optimizer == "sgd": - optimizer = torch.optim.SGD(self.model.parameters(), lr=args.learning_rate) + optimizer = torch.optim.SGD( + self.model.parameters(), lr=args.learning_rate) else: optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, self.model.parameters()), @@ -71,6 +119,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, precision, and recall (if applicable). + """ model = self.model model.to(device) @@ -91,7 +150,8 @@ def test(self, test_data, device, args): https://github.com/google-research/federated/blob/49a43456aa5eaee3e1749855eed89c0087983541/optimization/stackoverflow_lr/federated_stackoverflow_lr.py#L131 """ if args.dataset == "stackoverflow_lr": - criterion = nn.BCELoss(reduction="sum").to(device) # pylint: disable=E1102 + criterion = nn.BCELoss(reduction="sum").to( + device) # pylint: disable=E1102 else: criterion = nn.CrossEntropyLoss().to(device) # pylint: disable=E1102 @@ -104,9 +164,12 @@ def test(self, test_data, device, args): if args.dataset == "stackoverflow_lr": predicted = (pred > 0.5).int() - correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum() - true_positive = ((target * predicted) > 0.1).int().sum(axis=-1) - precision = true_positive / (predicted.sum(axis=-1) + 1e-13) + correct = predicted.eq(target).sum( + axis=-1).eq(target.size(1)).sum() + true_positive = ((target * predicted) > + 0.1).int().sum(axis=-1) + precision = true_positive / \ + (predicted.sum(axis=-1) + 1e-13) recall = true_positive / (target.sum(axis=-1) + 1e-13) metrics["test_precision"] += precision.sum().item() metrics["test_recall"] += recall.sum().item() diff --git a/python/fedml/ml/trainer/my_model_trainer_classification.py b/python/fedml/ml/trainer/my_model_trainer_classification.py index a4251b4c3c..0aa9beea25 100644 --- a/python/fedml/ml/trainer/my_model_trainer_classification.py +++ b/python/fedml/ml/trainer/my_model_trainer_classification.py @@ -12,13 +12,51 @@ class ModelTrainerCLS(ClientTrainer): + """ + A custom model trainer for classification tasks. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args): Train the model. + train_iterations(train_data, device, args): Train the model for a specified number of iterations. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) @@ -77,6 +115,15 @@ def train(self, train_data, device, args): ) def train_iterations(self, train_data, device, args): + """ + Train the model for a specified number of iterations. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ + model = self.model model.to(device) @@ -137,6 +184,17 @@ def train_iterations(self, train_data, device, args): ) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, and + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/my_model_trainer_nwp.py b/python/fedml/ml/trainer/my_model_trainer_nwp.py index a613e1d987..1077ff5e7b 100644 --- a/python/fedml/ml/trainer/my_model_trainer_nwp.py +++ b/python/fedml/ml/trainer/my_model_trainer_nwp.py @@ -8,20 +8,58 @@ class ModelTrainerNWP(ClientTrainer): + """ + A custom model trainer for Next Word Prediction (NWP) tasks. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args): Train the model. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) model.train() # train and update - criterion = nn.CrossEntropyLoss(ignore_index=0).to(device) # pylint: disable=E1102 + criterion = nn.CrossEntropyLoss(ignore_index=0).to( + device) # pylint: disable=E1102 if args.client_optimizer == "sgd": optimizer = torch.optim.SGD( filter(lambda p: p.requires_grad, self.model.parameters()), @@ -34,7 +72,7 @@ def train(self, train_data, device, args): weight_decay=args.weight_decay, amsgrad=True, ) - + epoch_loss = [] for epoch in range(args.epochs): # begin_time = time.time() @@ -66,6 +104,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, and the total number of test samples. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py b/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py index 6cb07a7274..4a367cf679 100644 --- a/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py +++ b/python/fedml/ml/trainer/my_model_trainer_tag_prediction.py @@ -5,13 +5,51 @@ class ModelTrainerTAGPred(ClientTrainer): + """ + A custom model trainer for TAG prediction tasks. + + Args: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + args: Training arguments. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args): Train the model. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ + def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + """ model = self.model model.to(device) @@ -56,6 +94,17 @@ def train(self, train_data, device, args): # self.client_idx, epoch, sum(epoch_loss) / len(epoch_loss))) def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy, test loss, precision, and recall (if applicable). + """ model = self.model model.to(device) @@ -85,7 +134,8 @@ def test(self, test_data, device, args): loss = criterion(pred, target) # pylint: disable=E1102 predicted = (pred > 0.5).int() - correct = predicted.eq(target).sum(axis=-1).eq(target.size(1)).sum() + correct = predicted.eq(target).sum( + axis=-1).eq(target.size(1)).sum() true_positive = ((target * predicted) > 0.1).int().sum(axis=-1) precision = true_positive / (predicted.sum(axis=-1) + 1e-13) recall = true_positive / (target.sum(axis=-1) + 1e-13) diff --git a/python/fedml/ml/trainer/scaffold_trainer.py b/python/fedml/ml/trainer/scaffold_trainer.py index ea1f592064..e552946725 100644 --- a/python/fedml/ml/trainer/scaffold_trainer.py +++ b/python/fedml/ml/trainer/scaffold_trainer.py @@ -7,13 +7,55 @@ class ScaffoldModelTrainer(ClientTrainer): + """ + A scaffold model trainer that implements training and testing methods. + + Args: + ClientTrainer: The base class for client trainers. + + Attributes: + model (torch.nn.Module): The PyTorch model to be trained. + id (int): The identifier of the client. + + Methods: + get_model_params(): Get the model parameters as a state dictionary. + set_model_params(model_parameters): Set the model parameters from a state dictionary. + train(train_data, device, args, c_model_global_params, c_model_local_params): Train the model. + test(test_data, device, args): Evaluate the model on test data and return evaluation metrics. + """ + def get_model_params(self): + """ + Get the model parameters as a state dictionary. + + Returns: + dict: The model parameters as a state dictionary. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters from a state dictionary. + + Args: + model_parameters (dict): The model parameters as a state dictionary. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, c_model_global_params, c_model_local_params): + """ + Train the model. + + Args: + train_data: The training data. + device (torch.device): The device (CPU or GPU) to use for training. + args: Training arguments. + c_model_global_params (dict): Global model parameters. + c_model_local_params (dict): Local model parameters. + + Returns: + int: The number of training iterations. + """ model = self.model model.to(device) @@ -63,7 +105,8 @@ def train(self, train_data, device, args, c_model_global_params, c_model_local_p # logging.debug(f"c_model_global[name].device : {c_model_global[name].device}, \ # c_model_global_params[name].device : {c_model_local_params[name].device}") param.data = param.data - current_lr * \ - check_device((c_model_global_params[name] - c_model_local_params[name]), param.data.device) + check_device( + (c_model_global_params[name] - c_model_local_params[name]), param.data.device) iteration_cnt += 1 batch_loss.append(loss.item()) if len(batch_loss) == 0: @@ -77,8 +120,18 @@ def train(self, train_data, device, args, c_model_global_params, c_model_local_p ) return iteration_cnt - def test(self, test_data, device, args): + """ + Evaluate the model on test data and return evaluation metrics. + + Args: + test_data: The test data. + device (torch.device): The device (CPU or GPU) to use for evaluation. + args: Training arguments. + + Returns: + dict: Evaluation metrics including test accuracy and test loss. + """ model = self.model model.to(device) From 1e67bd2f438612d69fc185424a95c86befbf0234 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 15 Sep 2023 22:04:25 +0530 Subject: [PATCH 59/70] n --- python/fedml/device/device.py | 30 ++- python/fedml/device/gpu_mapping_cross_silo.py | 27 ++- python/fedml/device/gpu_mapping_mpi.py | 45 +++- python/fedml/device/ip_config_utils.py | 9 + python/fedml/fa/__init__.py | 48 +++- python/fedml/fa/aggregator/avg_aggregator.py | 34 +++ .../frequency_estimation_aggregator.py | 46 +++- .../fa/aggregator/global_analyzer_creator.py | 10 + .../heavy_hitter_triehh_aggregator.py | 49 +++- .../fa/aggregator/intersection_aggregator.py | 71 ++++-- .../k_percentile_element_aggregator.py | 43 +++- .../fedml/fa/aggregator/union_aggregator.py | 48 +++- python/fedml/fa/base_frame/client_analyzer.py | 83 +++++++ .../fedml/fa/base_frame/server_aggregator.py | 60 ++++- .../cross_silo/client/client_initializer.py | 54 +++++ .../fa/cross_silo/client/client_launcher.py | 37 ++- .../fa/cross_silo/client/fa_local_analyzer.py | 110 ++++++++- .../client/fedml_client_master_manager.py | 177 ++++++++++++++- .../client/fedml_client_slave_manager.py | 63 +++++- .../client/fedml_trainer_dist_adapter.py | 108 ++++++++- .../client/process_group_manager.py | 50 ++++- python/fedml/fa/cross_silo/fa_client.py | 38 ++++ python/fedml/fa/cross_silo/fa_server.py | 38 ++++ .../fa/cross_silo/server/fedml_aggregator.py | 134 +++++++++-- .../cross_silo/server/fedml_server_manager.py | 172 ++++++++++++++ .../cross_silo/server/server_initializer.py | 16 ++ python/fedml/fa/data/data_loader.py | 24 +- .../fa/data/fake_numeric_data/data_loader.py | 32 ++- .../fa/data/self_defined_data/data_loader.py | 30 ++- .../data/twitter_Sentiment140/data_loader.py | 34 ++- .../twitter_data_processing.py | 57 ++++- python/fedml/fa/data/utils.py | 50 ++++- python/fedml/fa/local_analyzer/avg.py | 25 ++- .../local_analyzer/client_analyzer_creator.py | 11 +- .../fa/local_analyzer/frequency_estimation.py | 32 ++- .../fa/local_analyzer/heavy_hitter_triehh.py | 81 ++++++- .../fedml/fa/local_analyzer/intersection.py | 24 +- .../fa/local_analyzer/k_percentage_element.py | 25 ++- python/fedml/fa/local_analyzer/union.py | 23 +- python/fedml/fa/runner.py | 63 +++++- python/fedml/fa/simulation/sp/client.py | 65 +++++- python/fedml/fa/simulation/sp/simulator.py | 45 ++++ python/fedml/fa/simulation/utils.py | 13 +- python/fedml/fa/utils/trie.py | 212 +++++++++++++++++- python/fedml/ml/aggregator/agg_operator.py | 73 ++++++ .../fedml/ml/aggregator/aggregator_creator.py | 10 + .../fedml/ml/aggregator/default_aggregator.py | 41 ++++ .../ml/aggregator/my_server_aggregator.py | 53 +++++ .../my_server_aggregator_classification.py | 34 +++ .../ml/aggregator/my_server_aggregator_nwp.py | 34 +++ .../my_server_aggregator_prediction.py | 34 +++ python/fedml/ml/trainer/feddyn_trainer.py | 80 +++++++ python/fedml/ml/trainer/fednova_trainer.py | 90 ++++++++ python/fedml/ml/trainer/fedprox_trainer.py | 69 ++++++ python/fedml/ml/trainer/trainer_creator.py | 12 +- 55 files changed, 2831 insertions(+), 145 deletions(-) diff --git a/python/fedml/device/device.py b/python/fedml/device/device.py index 1085a19412..7892dbde35 100644 --- a/python/fedml/device/device.py +++ b/python/fedml/device/device.py @@ -10,6 +10,18 @@ def get_device_type(args): + """ + Determine the type of device (CPU, GPU, or MPS) based on the provided arguments. + + Args: + args (object): An object containing arguments, including 'device_type', 'using_gpu', 'gpu_id', and 'training_type'. + + Returns: + str: The type of device to use (e.g., 'cpu', 'gpu', or 'mps'). + + Raises: + Exception: If the provided 'device_type' is not supported. + """ if hasattr(args, "device_type"): if args.device_type == "cpu": device_type = "cpu" @@ -40,6 +52,18 @@ def get_device_type(args): def get_device(args): + """ + Get the device for training based on the provided arguments. + + Args: + args (object): An object containing arguments, including 'training_type', 'backend', 'gpu_id', 'using_gpu', 'process_id', and others. + + Returns: + str: The device (CPU or GPU) assigned to the current process. + + Raises: + Exception: If the 'training_type' is not defined. + """ if args.training_type == "simulation" and args.backend == "sp": if not hasattr(args, "gpu_id"): args.gpu_id = 0 @@ -104,14 +128,14 @@ def get_device(args): gpu_mapping_key = ( args.gpu_mapping_key if hasattr(args, "gpu_mapping_key") else None ) - gpu_id = args.gpu_id if hasattr(args, "gpu_id") else None # no no need to set gpu_id + gpu_id = args.gpu_id if hasattr(args, "gpu_id") else None # no need to set gpu_id else: gpu_mapping_file = None gpu_mapping_key = None gpu_id = None logging.info( - "devide_type = {}, gpu_mapping_file = {}, " + "device_type = {}, gpu_mapping_file = {}, " "gpu_mapping_key = {}, gpu_id = {}".format( device_type, gpu_mapping_file, gpu_mapping_key, gpu_id ) @@ -138,7 +162,7 @@ def get_device(args): if args.enable_cuda_rpc and is_master_process: assert ( device.index == args.cuda_rpc_gpu_mapping[args.rank] - ), f"GPU assignemnt inconsistent with cuda_rpc_gpu_mapping. Assigned to GPU {device.index} while expecting {args.cuda_rpc_gpu_mapping[args.rank]}" + ), f"GPU assignment inconsistent with cuda_rpc_gpu_mapping. Assigned to GPU {device.index} while expecting {args.cuda_rpc_gpu_mapping[args.rank]}" return device elif args.training_type == "cross_device": diff --git a/python/fedml/device/gpu_mapping_cross_silo.py b/python/fedml/device/gpu_mapping_cross_silo.py index f1fd8f9948..4c9c46c6ad 100644 --- a/python/fedml/device/gpu_mapping_cross_silo.py +++ b/python/fedml/device/gpu_mapping_cross_silo.py @@ -10,6 +10,27 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( process_id, worker_number, gpu_util_file, gpu_util_key, device_type, scenario, gpu_id=None, args=None ): + """ + Map processes to GPU devices based on GPU utilization information from a YAML file in a cross-silo setting. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of worker processes. + gpu_util_file (str): The path to the GPU utilization YAML file. + gpu_util_key (str): The key to retrieve GPU utilization information from the YAML file. + device_type (str): The type of device to use (e.g., "gpu" or "cpu"). + scenario (str): The cross-silo training scenario (e.g., hierarchical or non-hierarchical). + gpu_id (int, optional): The GPU ID to use for the current process. Defaults to None. + args (object, optional): An object containing additional arguments (e.g., device settings). + + Returns: + str: The GPU or CPU device assigned to the current process. + + Raises: + Exception: If there is an issue with GPU device mapping, such as exceeding PyTorch DDP limits. + AssertionError: If the number of mapped processes does not match the worker number. + + """ if device_type != "gpu": args.using_gpu = False device = ml_engine_adapter.get_device(args, device_id=gpu_id, device_type=device_type) @@ -27,8 +48,7 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( with open(gpu_util_file, "r") as f: gpu_util_yaml = yaml.load(f, Loader=yaml.FullLoader) - # gpu_util_num_process = 'gpu_util_' + str(worker_number) - # gpu_util = gpu_util_yaml[gpu_util_num_process] + gpu_util = gpu_util_yaml[gpu_util_key] logging.info("gpu_util = {}".format(gpu_util)) gpu_util_map = {} @@ -38,7 +58,7 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( # validate DDP gpu mapping if unique_gpu and num_process_on_gpu > 1: raise Exception( - "Cannot put {num_process_on_gpu} processes on GPU {gpu_j} of {host}." + f"Cannot put {num_process_on_gpu} processes on GPU {gpu_j} of {host}. " "PyTorch DDP supports up to one process on each GPU." ) for _ in range(num_process_on_gpu): @@ -57,3 +77,4 @@ def mapping_processes_to_gpu_device_from_yaml_file_cross_silo( logging.info("process_id = {}, GPU device = {}".format(process_id, device)) return device + \ No newline at end of file diff --git a/python/fedml/device/gpu_mapping_mpi.py b/python/fedml/device/gpu_mapping_mpi.py index 5bff4c3e26..568f790666 100644 --- a/python/fedml/device/gpu_mapping_mpi.py +++ b/python/fedml/device/gpu_mapping_mpi.py @@ -9,6 +9,23 @@ def mapping_processes_to_gpu_device_from_yaml_file_mpi( process_id, worker_number, gpu_util_file, gpu_util_key, args=None ): + """ + Map processes to GPU devices based on GPU utilization information from a YAML file. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of worker processes. + gpu_util_file (str): The path to the GPU utilization YAML file. + gpu_util_key (str): The key to retrieve GPU utilization information from the YAML file. + args (object, optional): An object containing additional arguments (e.g., device settings). + + Returns: + str: The GPU device assigned to the current process. + + Raises: + AssertionError: If the number of mapped processes does not match the worker number. + + """ if gpu_util_file is None: logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") logging.info(" ################## You do not indicate gpu_util_file, will use CPU training #################") @@ -16,10 +33,10 @@ def mapping_processes_to_gpu_device_from_yaml_file_mpi( logging.info(device) return device else: + # Load GPU utilization information from the YAML file with open(gpu_util_file, "r") as f: gpu_util_yaml = yaml.load(f, Loader=yaml.FullLoader) - # gpu_util_num_process = 'gpu_util_' + str(worker_number) - # gpu_util = gpu_util_yaml[gpu_util_num_process] + gpu_util = gpu_util_yaml[gpu_util_key] logging.info("gpu_util = {}".format(gpu_util)) gpu_util_map = {} @@ -43,15 +60,31 @@ def mapping_processes_to_gpu_device_from_yaml_file_mpi( def mapping_processes_to_gpu_device_from_gpu_util_parse(process_id, worker_number, gpu_util_parse, args=None): - if gpu_util_parse == None: + """ + Map processes to GPU devices based on parsed GPU utilization information. + + Args: + process_id (int): The ID of the current process. + worker_number (int): The total number of worker processes. + gpu_util_parse (str): The parsed GPU utilization information in string format. + args (object, optional): An object containing additional arguments (e.g., device settings). + + Returns: + str: The GPU device assigned to the current process. + + Raises: + AssertionError: If the number of mapped processes does not match the worker number. + + """ + if gpu_util_parse is None: device = ml_engine_adapter.get_device(args, device_type="cpu") logging.info(" !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") logging.info(" ################## Not Indicate gpu_util_file, using cpu #################") logging.info(device) - # return gpu_util_map[process_id][1] + return device else: - # example parse str `gpu_util_parse`: + # Example parse str `gpu_util_parse`: # "gpu1:0,1,1,2;gpu2:3,3,3;gpu3:0,0,0,1,2,4,4,0" gpu_util_parse_temp = gpu_util_parse.split(";") gpu_util_parse_temp = [(item.split(":")[0], item.split(":")[1]) for item in gpu_util_parse_temp] @@ -68,7 +101,7 @@ def mapping_processes_to_gpu_device_from_gpu_util_parse(process_id, worker_numbe gpu_util_map[i] = (host, gpu_j) i += 1 logging.info( - "Process %d running on host: %s,gethostname: %s, gpu: %d ..." + "Process %d running on host: %s, gethostname: %s, gpu: %d ..." % (process_id, gpu_util_map[process_id][0], socket.gethostname(), gpu_util_map[process_id][1]) ) assert i == worker_number diff --git a/python/fedml/device/ip_config_utils.py b/python/fedml/device/ip_config_utils.py index 1ebedfd73a..fad2cf126a 100644 --- a/python/fedml/device/ip_config_utils.py +++ b/python/fedml/device/ip_config_utils.py @@ -2,6 +2,15 @@ def build_ip_table(path): + """ + Build an IP table from a CSV file containing receiver IDs and their corresponding IP addresses. + + Args: + path (str): The path to the CSV file. + + Returns: + dict: A dictionary mapping receiver IDs to their respective IP addresses. + """ ip_config = dict() with open(path, newline="") as csv_file: csv_reader = csv.reader(csv_file) diff --git a/python/fedml/fa/__init__.py b/python/fedml/fa/__init__.py index e216fe2d20..56fd61b7cc 100644 --- a/python/fedml/fa/__init__.py +++ b/python/fedml/fa/__init__.py @@ -4,8 +4,25 @@ from .. import load_arguments, run_simulation, FEDML_TRAINING_PLATFORM_SIMULATION, FEDML_TRAINING_PLATFORM_CROSS_SILO, \ collect_env, mlops +from .runner import FARunner + +__all__ = [ + "FARunner", + "run_simulation", + "init" +] + def init(args=None): + """ + Initialize FedML Engine. + + Args: + args (object, optional): Arguments for initialization. If None, load default arguments. + + Returns: + object: Initialized arguments. + """ print(f"args={args}") if args is None: args = load_arguments(training_type=None, comm_backend=None) @@ -31,6 +48,12 @@ def init(args=None): return args def manage_mpi_args(args): + """ + Manage MPI-related arguments. + + Args: + args (object): Initialized arguments. + """ if hasattr(args, "backend") and args.backend == "MPI": from mpi4py import MPI @@ -48,6 +71,15 @@ def manage_mpi_args(args): args.comm = None def init_cross_silo(args): + """ + Initialize arguments for cross-silo training. + + Args: + args (object): Initialized arguments. + + Returns: + object: Updated arguments. + """ manage_mpi_args(args) # Set intra-silo arguments @@ -82,13 +114,13 @@ def init_cross_silo(args): def init_simulation_sp(args): - return args + """ + Initialize arguments for simulation with SP backend. + Args: + args (object): Initialized arguments. -from .runner import FARunner - -__all__ = [ - "FARunner", - "run_simulation", - "init" -] + Returns: + object: Updated arguments. + """ + return args diff --git a/python/fedml/fa/aggregator/avg_aggregator.py b/python/fedml/fa/aggregator/avg_aggregator.py index a7493a42d1..db829d0ac0 100644 --- a/python/fedml/fa/aggregator/avg_aggregator.py +++ b/python/fedml/fa/aggregator/avg_aggregator.py @@ -3,12 +3,45 @@ class AVGAggregatorFA(FAServerAggregator): + """ + Aggregator for Federated Learning with Averaging. + + Args: + args (object): An object containing aggregator configuration parameters. + + Attributes: + total_sample_num (int): The total number of training samples aggregated. + server_data (float): The aggregated server data. + + Methods: + aggregate(local_submission_list): + Aggregate local submissions from clients and compute the weighted average. + + """ def __init__(self, args): + """ + Initialize the AVGAggregatorFA. + + Args: + args (object): An object containing aggregator configuration parameters. + + Returns: + None + """ super().__init__(args) self.total_sample_num = 0 self.set_server_data(server_data=0) def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients and compute the weighted average. + + Args: + local_submission_list (list): A list of tuples containing local sample number and local submissions. + + Returns: + float: The computed weighted average. + """ print(f"local_submission_list={local_submission_list}") training_num = 0 for idx in range(len(local_submission_list)): @@ -30,6 +63,7 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): return avg + """ todo: Mode 1: (online mode) each client stores its AVG result and the total number of data being sampled so far; later computation will use this result. diff --git a/python/fedml/fa/aggregator/frequency_estimation_aggregator.py b/python/fedml/fa/aggregator/frequency_estimation_aggregator.py index 1e7c1bbeed..1325c87c52 100644 --- a/python/fedml/fa/aggregator/frequency_estimation_aggregator.py +++ b/python/fedml/fa/aggregator/frequency_estimation_aggregator.py @@ -4,14 +4,51 @@ class FrequencyEstimationAggregatorFA(FAServerAggregator): + """ + Aggregator for Federated Learning with Frequency Estimation. + + Args: + args (object): An object containing aggregator configuration parameters. + + Attributes: + total_sample_num (int): The total number of training samples aggregated. + server_data (dict): Dictionary to store aggregated data. + round_idx (int): The current training round index. + total_round (int): The total number of training rounds. + + Methods: + aggregate(local_submission_list): + Aggregate local submissions from clients. + print_frequency_estimation_results(): + Print and display frequency estimation results as a histogram. + + """ def __init__(self, args): + """ + Initialize the FrequencyEstimationAggregatorFA. + + Args: + args (object): An object containing aggregator configuration parameters. + + Returns: + None + """ super().__init__(args) self.total_sample_num = 0 - self.set_server_data(server_data=[]) + self.set_server_data(server_data={}) self.round_idx = 0 self.total_round = args.comm_round def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submission_list (list): A list of tuples containing local sample number and local submissions. + + Returns: + dict: The aggregated server data. + """ training_num = 0 (sample_num, averaged_params) = local_submission_list[0] for i in range(0, len(local_submission_list)): @@ -33,6 +70,12 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): return self.server_data def print_frequency_estimation_results(self): + """ + Print and display frequency estimation results as a histogram. + + Returns: + None + """ print("frequency estimation: ") for key in self.server_data: print(f"key = {key}, freq = {self.server_data[key] / self.total_sample_num}") @@ -41,3 +84,4 @@ def print_frequency_estimation_results(self): plt.ylabel('Occurrence # ') plt.title('Histogram') plt.show() + \ No newline at end of file diff --git a/python/fedml/fa/aggregator/global_analyzer_creator.py b/python/fedml/fa/aggregator/global_analyzer_creator.py index 142e8b14b6..55a6cc9898 100644 --- a/python/fedml/fa/aggregator/global_analyzer_creator.py +++ b/python/fedml/fa/aggregator/global_analyzer_creator.py @@ -9,6 +9,16 @@ def create_global_analyzer(args, train_data_num): + """ + Create a global analyzer based on the specified federated aggregation task. + + Args: + args: Additional arguments for creating the global analyzer. + train_data_num (int): The number of training data samples. + + Returns: + FAServerAggregator: An instance of a global analyzer based on the specified task. + """ task_type = args.fa_task if task_type == FA_TASK_AVG: return AVGAggregatorFA(args) diff --git a/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py b/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py index 6cfbdc1764..d228075390 100644 --- a/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py +++ b/python/fedml/fa/aggregator/heavy_hitter_triehh_aggregator.py @@ -12,6 +12,16 @@ class HeavyHitterTriehhAggregatorFA(FAServerAggregator): def __init__(self, args, train_data_num): + """ + Initialize the HeavyHitterTriehhAggregatorFA. + + Args: + args: Additional arguments for initialization. + train_data_num (int): The number of training data samples. + + Returns: + None + """ super().__init__(args) if hasattr(args, "max_word_len"): self.MAX_L = args.max_word_len @@ -43,7 +53,7 @@ def __init__(self, args, train_data_num): self.batch_size = int(train_data_num * (np.e ** (self.epsilon / self.MAX_L) - 1) / ( self.theta * np.e ** (self.epsilon / self.MAX_L))) self.init_msg = int(math.ceil(self.batch_size * 1.0 / args.client_num_per_round)) - self.w_global = {} # self.trie = {} + self.w_global = {} def get_init_msg(self): return self.init_msg @@ -52,6 +62,15 @@ def set_init_msg(self, init_msg): self.init_msg = init_msg def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of local submissions. + + Returns: + Dict: The aggregated data. + """ votes = {} for (num, local_vote_dict) in local_submission_list: for key in local_vote_dict.keys(): @@ -70,6 +89,12 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): return self.w_global def _set_theta(self): + """ + Calculate and set the value of theta. + + Returns: + int: The calculated theta value. + """ theta = 5 # initial guess delta_inverse = 1 / self.delta while ((theta - 3) / (theta - 2)) * math.factorial(theta) < delta_inverse: @@ -80,11 +105,15 @@ def _set_theta(self): return theta def server_update(self, votes): - # It might make more sense to define a small class called server_state - # server_state can track 2 things: 1) updated trie, and 2) quit_sign - # server_state can be initialized in the constructor of SimulateTrieHH - # and server_update would just update server_state - # (i.e, it would update self.server_state.trie & self.server_state.quit_sign) + """ + Update the server based on received votes. + + Args: + votes (Dict): A dictionary of votes. + + Returns: + None + """ self.quit_sign = True for prefix in votes: if votes[prefix] >= self.theta: @@ -92,11 +121,17 @@ def server_update(self, votes): self.quit_sign = False def print_heavy_hitters(self): + """ + Print the discovered heavy hitters. + + Returns: + None + """ heavy_hitters = [] print(f"self.w_global = {self.w_global}") raw_result = self.w_global.keys() for word in raw_result: if word[-1:] == '$': heavy_hitters.append(word.rstrip('$')) - # print(f'Discovered {len(heavy_hitters)} heavy hitters in run #{self.round_counter + 1}: {heavy_hitters}') + print(f'Discovered {len(heavy_hitters)} heavy hitters: {heavy_hitters}') diff --git a/python/fedml/fa/aggregator/intersection_aggregator.py b/python/fedml/fa/aggregator/intersection_aggregator.py index c7f0b1e559..c85a24b191 100644 --- a/python/fedml/fa/aggregator/intersection_aggregator.py +++ b/python/fedml/fa/aggregator/intersection_aggregator.py @@ -3,48 +3,77 @@ from fedml.fa.base_frame.server_aggregator import FAServerAggregator -def get_intersection_of_two_lists_keep_duplicates(list1, list2): +def get_intersection_of_two_lists_keep_duplicates(list1: List[Any], list2: List[Any]) -> List[Any]: """ - Keep duplicates in the intersection, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [2,3,2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Return the intersection of two lists while keeping duplicates. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The intersection of the two lists, keeping duplicates. """ intersection = [] - for i in range(len(list1)): - for j in range(len(list2) - 1, -1, -1): - if list1[i] == list2[j]: - intersection.append(list2[j]) - list2.remove(j) + for item in list1: + if item in list2: + intersection.append(item) + list2.remove(item) return intersection -def get_intersection_of_two_lists_remove_duplicates(list1, list2): +def get_intersection_of_two_lists_remove_duplicates(list1: List[Any], list2: List[Any]) -> List[Any]: """ - Remove duplicates in the intersection, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Return the intersection of two lists and remove duplicate values. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The intersection of the two lists with duplicates removed. """ return list(set(list1) & set(list2)) class IntersectionAggregatorFA(FAServerAggregator): def __init__(self, args): + """ + Initialize the IntersectionAggregatorFA. + + Args: + args: Additional arguments for initialization. + + Returns: + None + """ super().__init__(args) self.set_server_data(server_data=[]) - def aggregate(self, local_submission_list: List[Tuple[float, Any]]): - for i in range(0, len(local_submission_list)): - _, local_submission = local_submission_list[i] + def aggregate(self, local_submission_list: List[Tuple[float, Any]]) -> List[Any]: + """ + Aggregate local submissions while maintaining intersection. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of local submissions. + + Returns: + List: The intersection of local submissions. + """ + for _, local_submission in local_submission_list: if len(self.server_data) == 0: - # no need to remove duplicates even in ``remove duplicate'' mode, - # as the duplicates will be removed in later computation + self.server_data = local_submission else: self.server_data = get_intersection_of_two_lists_remove_duplicates(self.server_data, local_submission) print(f"cardinality = {self.get_cardinality()}") return self.server_data - def get_cardinality(self): + def get_cardinality(self) -> int: + """ + Get the cardinality (number of elements) of the aggregated data. + + Returns: + int: The cardinality of the aggregated data. + """ return len(self.server_data) diff --git a/python/fedml/fa/aggregator/k_percentile_element_aggregator.py b/python/fedml/fa/aggregator/k_percentile_element_aggregator.py index 5b0eae456f..d0415716a6 100644 --- a/python/fedml/fa/aggregator/k_percentile_element_aggregator.py +++ b/python/fedml/fa/aggregator/k_percentile_element_aggregator.py @@ -17,6 +17,16 @@ class KPercentileElementAggregatorFA(FAServerAggregator): def __init__(self, args, train_data_num): + """ + Initialize the KPercentileElementAggregatorFA. + + Args: + args: Configuration arguments. + train_data_num (int): The total number of training data samples. + + Returns: + None + """ super().__init__(args) self.total_sample_num = 0 self.set_server_data(server_data=[]) @@ -24,46 +34,63 @@ def __init__(self, args, train_data_num): self.total_sample_num = 0 self.train_data_num_in_total = train_data_num self.percentage = args.k / 100 + + # Initialize server_data and previous_server_data if hasattr(args, "flag"): self.server_data = args.flag self.previous_server_data = args.flag else: self.server_data = 100 self.previous_server_data = 100 + + # Check if use_all_data attribute is specified in args if hasattr(args, "use_all_data") and args.use_all_data in [False]: - self.use_all_data = False # in each iteration, each client randomly sample some data to compute + self.use_all_data = False # In each iteration, each client randomly samples some data to compute else: - self.use_all_data = True # in each iteration, each client uses its all local data to compute + self.use_all_data = True # In each iteration, each client uses all its local data to compute + + # Initialize max_val and min_val self.max_val = self.previous_server_data self.min_val = self.previous_server_data def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of tuples containing local submissions and weights. + + Returns: + float: The aggregated result. + """ if self.quit: return self.server_data + total_sample_num_this_round = 0 local_satisfied_data_num_current_round = 0 - logging.info(f"flag={self.server_data}, local_submission_list={local_submission_list}") + for (sample_num, satisfied_counter) in local_submission_list: total_sample_num_this_round += sample_num local_satisfied_data_num_current_round += satisfied_counter + if total_sample_num_this_round * self.percentage == local_satisfied_data_num_current_round: self.quit = True self.previous_server_data = self.server_data elif total_sample_num_this_round * self.percentage > local_satisfied_data_num_current_round: - # decrease server_data + # Decrease server_data self.max_val = self.server_data if self.previous_server_data >= self.server_data: self.previous_server_data = self.server_data if self.server_data / 2 < self.min_val < self.max_val: - self.server_data = (self.server_data + self.min_val)/2 + self.server_data = (self.server_data + self.min_val) / 2 else: self.server_data = self.server_data / 2 - self.min_val = self.server_data # set lower bound for flag + self.min_val = self.server_data # Set lower bound for flag else: new_server_data = (self.previous_server_data + self.server_data) / 2 self.previous_server_data = self.server_data self.server_data = new_server_data - else: # increase server_data + else: # Increase server_data self.min_val = self.server_data if self.previous_server_data <= self.server_data: self.previous_server_data = self.server_data @@ -76,4 +103,6 @@ def aggregate(self, local_submission_list: List[Tuple[float, Any]]): new_server_data = (self.previous_server_data + self.server_data) / 2 self.previous_server_data = self.server_data self.server_data = new_server_data + return self.server_data + \ No newline at end of file diff --git a/python/fedml/fa/aggregator/union_aggregator.py b/python/fedml/fa/aggregator/union_aggregator.py index f3e3baf730..4a15b2f86a 100644 --- a/python/fedml/fa/aggregator/union_aggregator.py +++ b/python/fedml/fa/aggregator/union_aggregator.py @@ -4,39 +4,65 @@ def get_union_of_two_lists_keep_duplicates(list1, list2): """ - Keep duplicates in the union, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [1,2,3,2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Compute the union of two lists while keeping duplicates. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The union of the two lists with duplicates. """ union = [] for item in list1: union.append(item) if item in list2: - list2.remove(list2.index(item)) + list2.remove(item) union.extend(list2) return union def get_union_of_two_lists_remove_duplicates(list1, list2): """ - Remove duplicates in the union, e.g., list1=[1,2,3,2,3], list2=[2,3,2,3]. intersect(list1, list2) = [1,2,3] - :param list1: first list - :param list2: second list - :return: intersection of the 2 lists + Compute the union of two lists and remove duplicates. + + Args: + list1 (List): The first list. + list2 (List): The second list. + + Returns: + List: The union of the two lists without duplicates. """ return list(set(list1 + list2)) class UnionAggregatorFA(FAServerAggregator): def __init__(self, args): + """ + Initialize the UnionAggregatorFA. + + Args: + args: Configuration arguments. + + Returns: + None + """ super().__init__(args) self.set_server_data(server_data=[]) - self.union_function = get_union_of_two_lists_remove_duplicates # select the way to compute union + self.union_function = get_union_of_two_lists_remove_duplicates # Select the way to compute union def aggregate(self, local_submission_list: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submission_list (List[Tuple[float, Any]]): A list of tuples containing local submissions and weights. + + Returns: + List: The aggregated result. + """ for i in range(0, len(local_submission_list)): _, local_submission = local_submission_list[i] - # when server_data is [], i.e., the first round, will only process local_submission + # When server_data is [], i.e., the first round, will only process local_submission self.server_data = self.union_function(self.server_data, local_submission) return self.server_data diff --git a/python/fedml/fa/base_frame/client_analyzer.py b/python/fedml/fa/base_frame/client_analyzer.py index 624f7c290a..9fec99924e 100644 --- a/python/fedml/fa/base_frame/client_analyzer.py +++ b/python/fedml/fa/base_frame/client_analyzer.py @@ -4,6 +4,15 @@ class FAClientAnalyzer(ABC): def __init__(self, args): + """ + Initialize the client analyzer. + + Args: + args: Configuration arguments. + + Returns: + None + """ self.client_submission = 0 self.id = 0 self.args = args @@ -12,30 +21,104 @@ def __init__(self, args): self.init_msg = None def set_init_msg(self, init_msg): + """ + Set the initialization message. + + Args: + init_msg: The initialization message. + + Returns: + None + """ pass def get_init_msg(self): + """ + Get the initialization message. + + Returns: + Any: The initialization message. + """ pass def set_id(self, analyzer_id): + """ + Set the ID of the client analyzer. + + Args: + analyzer_id: The ID of the analyzer. + + Returns: + None + """ self.id = analyzer_id def get_client_submission(self): + """ + Get the client submission. + + Returns: + Any: The client submission. + """ return self.client_submission def set_client_submission(self, client_submission): + """ + Set the client submission. + + Args: + client_submission: The client submission. + + Returns: + None + """ self.client_submission = client_submission def get_server_data(self): + """ + Get the server data. + + Returns: + Any: The server data. + """ return self.server_data def set_server_data(self, server_data): + """ + Set the server data. + + Args: + server_data: The server data. + + Returns: + None + """ self.server_data = server_data @abstractmethod def local_analyze(self, train_data, args): + """ + Perform local analysis. + + Args: + train_data: The local training data. + args: Configuration arguments. + + Returns: + None + """ pass def update_dataset(self, local_train_dataset, local_sample_number): + """ + Update the local dataset. + + Args: + local_train_dataset: The local training dataset. + local_sample_number: The number of local samples. + + Returns: + None + """ self.local_train_dataset = local_train_dataset self.local_sample_number = local_sample_number diff --git a/python/fedml/fa/base_frame/server_aggregator.py b/python/fedml/fa/base_frame/server_aggregator.py index 76fc1a73bc..5ad46dd6a8 100644 --- a/python/fedml/fa/base_frame/server_aggregator.py +++ b/python/fedml/fa/base_frame/server_aggregator.py @@ -1,9 +1,17 @@ from abc import ABC from typing import List, Tuple, Any - class FAServerAggregator(ABC): def __init__(self, args): + """ + Initialize the server aggregator. + + Args: + args: Configuration arguments. + + Returns: + None + """ self.id = 0 self.args = args self.eval_data = None @@ -11,21 +19,67 @@ def __init__(self, args): self.init_msg = None def get_init_msg(self): - # return self.init_msg + """ + Get the initialization message. + + Returns: + Any: The initialization message. + """ pass def set_init_msg(self, init_msg): - # self.init_msg = init_msg + """ + Set the initialization message. + + Args: + init_msg: The initialization message. + + Returns: + None + """ pass def set_id(self, aggregator_id): + """ + Set the ID of the server aggregator. + + Args: + aggregator_id: The ID of the aggregator. + + Returns: + None + """ self.id = aggregator_id def get_server_data(self): + """ + Get the server data. + + Returns: + Any: The server data. + """ return self.server_data def set_server_data(self, server_data): + """ + Set the server data. + + Args: + server_data: The server data. + + Returns: + None + """ self.server_data = server_data def aggregate(self, local_submissions: List[Tuple[float, Any]]): + """ + Aggregate local submissions from clients. + + Args: + local_submissions (List[Tuple[float, Any]]): A list of tuples containing local submissions and weights. + + Returns: + None + """ pass diff --git a/python/fedml/fa/cross_silo/client/client_initializer.py b/python/fedml/fa/cross_silo/client/client_initializer.py index 167d5e0a5a..9d4b21fd8f 100644 --- a/python/fedml/fa/cross_silo/client/client_initializer.py +++ b/python/fedml/fa/cross_silo/client/client_initializer.py @@ -13,6 +13,22 @@ def init_client( train_data_local_dict, local_analyzer=None, ): + """ + Initialize the federated learning client. + + Args: + args: Configuration arguments. + comm: Communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indices to their local training data. + local_analyzer: Local analyzer for the client (optional). + + Returns: + None + """ backend = args.backend trainer_dist_adapter = get_trainer_dist_adapter( @@ -38,6 +54,20 @@ def get_trainer_dist_adapter( train_data_local_dict, local_analyzer, ): + """ + Get the trainer distribution adapter. + + Args: + args: Configuration arguments. + client_rank (int): The rank of the client. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): A dictionary mapping client indices to their local training data. + local_analyzer: Local analyzer for the client. + + Returns: + TrainerDistAdapter: The trainer distribution adapter. + """ return TrainerDistAdapter( args, client_rank, @@ -49,10 +79,34 @@ def get_trainer_dist_adapter( def get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend): + """ + Get the client master manager. + + Args: + args: Configuration arguments. + trainer_dist_adapter: Trainer distribution adapter. + comm: Communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend: Backend for distributed training. + + Returns: + ClientMasterManager: The client master manager. + """ return ClientMasterManager(args, trainer_dist_adapter, comm, client_rank, client_num, backend) def get_client_manager_salve(args, trainer_dist_adapter): + """ + Get the client slave manager. + + Args: + args: Configuration arguments. + trainer_dist_adapter: Trainer distribution adapter. + + Returns: + ClientSlaveManager: The client slave manager. + """ from .fedml_client_slave_manager import ClientSlaveManager return ClientSlaveManager(args, trainer_dist_adapter) diff --git a/python/fedml/fa/cross_silo/client/client_launcher.py b/python/fedml/fa/cross_silo/client/client_launcher.py index 0f3ea1f278..66c1045124 100644 --- a/python/fedml/fa/cross_silo/client/client_launcher.py +++ b/python/fedml/fa/cross_silo/client/client_launcher.py @@ -2,16 +2,47 @@ from fedml.arguments import load_arguments from fedml.constants import FEDML_TRAINING_PLATFORM_CROSS_SILO - class CrossSiloLauncher: + """ + A class for launching distributed trainers in a cross-silo federated learning setup. + + Attributes: + None + + Methods: + launch_dist_trainers(torch_client_filename, inputs): + Launch distributed trainers using the provided arguments. + + """ @staticmethod def launch_dist_trainers(torch_client_filename, inputs): - # this is only used by the client (DDP or single process), so there is no need to specify the backend. + """ + Launch distributed trainers using the provided arguments. + + Args: + torch_client_filename (str): The filename of the PyTorch client script. + inputs (list): A list of input arguments to be passed to the client script. + + Returns: + None + """ + # This is only used by the client (DDP or single process), so there is no need to specify the backend. args = load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO) CrossSiloLauncher._run_cross_silo_horizontal(args, torch_client_filename, inputs) @staticmethod def _run_cross_silo_horizontal(args, torch_client_filename, inputs): + """ + Run the cross-silo horizontal federated learning process. + + Args: + args: Configuration arguments. + torch_client_filename (str): The filename of the PyTorch client script. + inputs (list): A list of input arguments to be passed to the client script. + + Returns: + None + """ python_path = subprocess.run(["which", "python"], capture_output=True, text=True).stdout.strip() process_arguments = [python_path, torch_client_filename] + inputs - subprocess.run(process_arguments) \ No newline at end of file + subprocess.run(process_arguments) diff --git a/python/fedml/fa/cross_silo/client/fa_local_analyzer.py b/python/fedml/fa/cross_silo/client/fa_local_analyzer.py index 0ec0c89dc9..9aeddc8365 100755 --- a/python/fedml/fa/cross_silo/client/fa_local_analyzer.py +++ b/python/fedml/fa/cross_silo/client/fa_local_analyzer.py @@ -3,6 +3,49 @@ class FALocalAnalyzer(object): + """ + A class representing a local analyzer for federated learning. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples. + train_data_num (int): The total number of training samples. + args: Configuration arguments. + local_analyzer: An instance of the local analyzer. + + Attributes: + local_analyzer: An instance of the local analyzer. + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples. + all_train_data_num (int): The total number of training samples. + train_local: Local training data for the client. + local_sample_number: The number of local training samples for the client. + test_local: Local testing data for the client. + args: Configuration arguments. + init_msg: Initialization message for the client. + + Methods: + set_init_msg(init_msg): + Set the initialization message for the client. + + get_init_msg(): + Get the initialization message for the client. + + set_server_data(server_data): + Set the server data for the client. + + set_client_submission(client_submission): + Set the client's submission. + + update_dataset(client_index): + Update the client's dataset based on the provided client index. + + local_analyze(round_idx=None): + Perform local analysis for federated learning. + + """ def __init__( self, client_index, @@ -12,6 +55,20 @@ def __init__( args, local_analyzer, ): + """ + Initialize the FALocalAnalyzer. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary containing local training data. + train_data_local_num_dict (dict): A dictionary containing the number of local training samples. + train_data_num (int): The total number of training samples. + args: Configuration arguments. + local_analyzer: An instance of the local analyzer. + + Returns: + None + """ self.local_analyzer = local_analyzer self.client_index = client_index self.train_data_local_dict = train_data_local_dict @@ -24,18 +81,60 @@ def __init__( self.init_msg = None def set_init_msg(self, init_msg): + """ + Set the initialization message for the client. + + Args: + init_msg: Initialization message for the client. + + Returns: + None + """ self.local_analyzer.set_init_msg(init_msg) def get_init_msg(self): + """ + Get the initialization message for the client. + + Returns: + Initialization message for the client. + """ return self.local_analyzer.get_init_msg() def set_server_data(self, server_data): + """ + Set the server data for the client. + + Args: + server_data: Server data for the client. + + Returns: + None + """ self.local_analyzer.set_server_data(server_data) def set_client_submission(self, client_submission): + """ + Set the client's submission. + + Args: + client_submission: Client's submission data. + + Returns: + None + """ self.local_analyzer.set_client_submission(client_submission) def update_dataset(self, client_index): + """ + Update the client's dataset based on the provided client index. + + Args: + client_index (int): The index of the client. + + Returns: + None + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -51,10 +150,19 @@ def update_dataset(self, client_index): self.local_analyzer.update_dataset(self.train_local, self.local_sample_number) def local_analyze(self, round_idx=None): + """ + Perform local analysis for federated learning. + + Args: + round_idx (int): The current round index (default is None). + + Returns: + Tuple containing client submission data and the number of local samples. + """ self.args.round_idx = round_idx tick = time.time() self.local_analyzer.local_analyze(self.train_local, self.args) MLOpsProfilerEvent.log_to_wandb({"Train/Time": time.time() - tick, "round": round_idx}) client_submission = self.local_analyzer.get_client_submission() - return client_submission, self.local_sample_number \ No newline at end of file + return client_submission, self.local_sample_number diff --git a/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py b/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py index 1623114f36..0f56970226 100644 --- a/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py +++ b/python/fedml/fa/cross_silo/client/fedml_client_master_manager.py @@ -11,7 +11,71 @@ class ClientMasterManager(FedMLCommManager): + """ + Manages the communication and training process for a federated learning client master. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + comm: A communication backend (default is None). + rank (int): The rank of the client (default is 0). + size (int): The size of the communication group (default is 0). + backend (str): The communication backend (default is "MPI"). + + Attributes: + trainer_dist_adapter: An instance of the trainer distribution adapter. + args (object): An object containing client configuration parameters. + num_rounds (int): The total number of communication rounds. + round_idx (int): The current communication round index. + rank (int): The rank of the client. + client_real_ids (list): A list of client real IDs. + client_real_id (str): The client's real ID. + has_sent_online_msg (bool): A flag indicating if the online message has been sent. + + Methods: + register_message_receive_handlers(): + Register message receive handlers for various message types. + handle_message_connection_ready(msg_params): + Handle the connection-ready message. + handle_message_check_status(msg_params): + Handle the check-client-status message. + handle_message_init(msg_params): + Handle the initialization message. + handle_message_receive_model_from_server(msg_params): + Handle the message to receive a model from the server. + handle_message_finish(msg_params): + Handle the message indicating the completion of training. + cleanup(): + Perform cleanup after training finishes. + send_model_to_server(receive_id, weights, local_sample_num): + Send the model and related information to the server. + send_client_status(receive_id, status="ONLINE"): + Send the client's status to the server. + report_training_status(status): + Report the training status to MLOps. + sync_process_group(round_idx, model_params=None, client_index=None, src=0): + Synchronize the process group with round information. + __train(): + Perform the training for the current round. + run(): + Start the client master manager's communication and training process. + + """ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the ClientMasterManager. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + comm: A communication backend (default is None). + rank (int): The rank of the client (default is 0). + size (int): The size of the communication group (default is 0). + backend (str): The communication backend (default is "MPI"). + + Returns: + None + """ super().__init__(args, comm, rank, size, backend) self.trainer_dist_adapter = trainer_dist_adapter self.args = args @@ -20,11 +84,17 @@ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backen self.rank = rank self.client_real_ids = json.loads(args.client_id_list) logging.info("self.client_real_ids = {}".format(self.client_real_ids)) - # for the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others. + # For the client, len(self.client_real_ids)==1: we only specify its client id in the list, not including others. self.client_real_id = self.client_real_ids[0] self.has_sent_online_msg = False def register_message_receive_handlers(self): + """ + Register message receive handlers for various message types. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -43,6 +113,15 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection-ready message. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -50,9 +129,27 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the check-client-status message. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) init_msg = msg_params.get(MyMessage.MSG_INIT_MSG_TO_CLIENTS) @@ -70,6 +167,15 @@ def handle_message_init(self, msg_params): self.round_idx += 1 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the message to receive a model from the server. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -82,14 +188,40 @@ def handle_message_receive_model_from_server(self, msg_params): self.round_idx += 1 def handle_message_finish(self, msg_params): + """ + Handle the message indicating the completion of training. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ logging.info(" ====================cleanup ====================") self.cleanup() def cleanup(self): + """ + Perform cleanup after training finishes. + + Returns: + None + """ self.finish() mlops.log_training_finished_status() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model and related information to the server. + + Args: + receive_id (int): The ID of the receiver. + weights (object): Model weights or parameters. + local_sample_num (int): The number of local samples. + + Returns: + None + """ tick = time.time() mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id,) @@ -103,6 +235,16 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def send_client_status(self, receive_id, status="ONLINE"): + """ + Send the client's status to the server. + + Args: + receive_id (int): The ID of the receiver. + status (str): The client's status (default is "ONLINE"). + + Returns: + None + """ logging.info("send_client_status") logging.info("self.client_real_id = {}".format(self.client_real_id)) message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) @@ -117,9 +259,30 @@ def send_client_status(self, receive_id, status="ONLINE"): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status (str): The training status to report. + + Returns: + None + """ mlops.log_training_status(status) def sync_process_group(self, round_idx, model_params=None, client_index=None, src=0): + """ + Synchronize the process group with round information. + + Args: + round_idx (int): The current round index. + model_params (object): Model weights or parameters (default is None). + client_index (int): The index of the client (default is None). + src (int): The source process rank (default is 0). + + Returns: + None + """ logging.info("sending round number to pg") round_number = [round_idx, model_params, client_index] dist.broadcast_object_list( @@ -128,6 +291,12 @@ def sync_process_group(self, round_idx, model_params=None, client_index=None, sr logging.info("round number %d broadcast to process group" % round_number[0]) def __train(self): + """ + Perform the training for the current round. + + Returns: + None + """ logging.info("#######training########### round_id = %d" % self.round_idx) mlops.event("train", event_started=True, event_value=str(self.round_idx)) @@ -139,4 +308,10 @@ def __train(self): self.send_model_to_server(0, client_submission, local_sample_num) def run(self): + """ + Start the client master manager's communication and training process. + + Returns: + None + """ super().run() diff --git a/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py b/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py index 48f30d8263..f6cc3af6e2 100644 --- a/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py +++ b/python/fedml/fa/cross_silo/client/fedml_client_slave_manager.py @@ -4,7 +4,42 @@ class ClientSlaveManager: + """ + Manages the training process for a federated learning client slave. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + + Attributes: + trainer_dist_adapter: An instance of the trainer distribution adapter. + args (object): An object containing client configuration parameters. + round_idx (int): The current training round index. + num_rounds (int): The total number of training rounds. + finished (bool): A flag indicating if training has finished. + + Methods: + train(): + Perform training for the current round. + finish(): + Finish the client slave's training. + await_sync_process_group(src=0): + Await synchronization with the process group and receive round information. + run(): + Start the client slave's training process. + + """ def __init__(self, args, trainer_dist_adapter): + """ + Initialize the ClientSlaveManager. + + Args: + args (object): An object containing client configuration parameters. + trainer_dist_adapter: An instance of the trainer distribution adapter. + + Returns: + None + """ self.trainer_dist_adapter = trainer_dist_adapter self.args = args self.round_idx = 0 @@ -12,6 +47,12 @@ def __init__(self, args, trainer_dist_adapter): self.finished = False def train(self): + """ + Perform training for the current round. + + Returns: + None + """ [round_idx, model_params, client_index] = self.await_sync_process_group() if round_idx: self.round_idx = round_idx @@ -28,7 +69,12 @@ def train(self): self.trainer_dist_adapter.train(self.round_idx) def finish(self): - # pass + """ + Finish the client slave's training. + + Returns: + None + """ self.trainer_dist_adapter.cleanup_pg() logging.info( "Training finished for slave client rank %s in silo %s" @@ -37,6 +83,15 @@ def finish(self): self.finished = True def await_sync_process_group(self, src=0): + """ + Await synchronization with the process group and receive round information. + + Args: + src (int): The source process rank to receive data from (default is 0). + + Returns: + list: A list containing round index, model parameters, and client index. + """ logging.info("process %d waiting for round number" % dist.get_rank()) objects = [None, None, None] dist.broadcast_object_list( @@ -46,5 +101,11 @@ def await_sync_process_group(self, src=0): return objects def run(self): + """ + Start the client slave's training process. + + Returns: + None + """ while not self.finished: self.train() diff --git a/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py b/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py index d758b11fec..00d8f9692f 100644 --- a/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py +++ b/python/fedml/fa/cross_silo/client/fedml_trainer_dist_adapter.py @@ -4,6 +4,36 @@ class TrainerDistAdapter: + """ + Adapter for a Federated Learning Trainer with Distributed Training. + + Args: + args (object): An object containing trainer configuration parameters. + client_rank (int): The rank of the client. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + train_data_local_dict (dict): A dictionary of client-specific training data. + local_analyzer: An instance of the local analyzer (optional). + + Attributes: + client_index (int): The index of the client. + client_rank (int): The rank of the client. + local_analyzer: An instance of the local analyzer. + args (object): An object containing trainer configuration parameters. + + Methods: + local_analyze(round_idx): + Perform local analysis for a given training round. + set_server_data(server_data): + Set server data for the local analyzer. + set_init_msg(init_msg): + Set initialization message for the local analyzer. + set_client_submission(client_submission): + Set client submission for the local analyzer. + update_dataset(client_index=None): + Update the dataset for the local analyzer. + + """ def __init__( self, args, @@ -13,6 +43,23 @@ def __init__( train_data_local_dict, local_analyzer, ): + """ + Initialize the TrainerDistAdapter. + + Args: + args (object): An object containing trainer configuration parameters. + client_rank (int): The rank of the client. + train_data_num (int): The total number of training data samples. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + train_data_local_dict (dict): A dictionary of client-specific training data. + local_analyzer: An instance of the local analyzer (optional). + + Note: + This constructor sets up the adapter and initializes it with the provided dataset and configuration. + + Returns: + None + """ if local_analyzer is None: local_analyzer = create_local_analyzer(args=args) @@ -42,6 +89,20 @@ def get_local_analyzer( args, local_analyzer, ): + """ + Get an instance of the local analyzer. + + Args: + client_index (int): The index of the client. + train_data_local_dict (dict): A dictionary of client-specific training data. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + train_data_num (int): The total number of training data samples. + args (object): An object containing trainer configuration parameters. + local_analyzer: An instance of the local analyzer (optional). + + Returns: + FALocalAnalyzer: An instance of the local analyzer. + """ return FALocalAnalyzer( client_index, train_data_local_dict, @@ -52,18 +113,63 @@ def get_local_analyzer( ) def local_analyze(self, round_idx): + """ + Perform local analysis for a given training round. + + Args: + round_idx (int): The index of the training round. + + Returns: + tuple: A tuple containing client submission and local sample count. + """ client_submission, local_sample_num = self.local_analyzer.local_analyze(round_idx) return client_submission, local_sample_num def set_server_data(self, server_data): + """ + Set server data for the local analyzer. + + Args: + server_data: Data received from the server. + + Returns: + None + """ self.local_analyzer.set_server_data(server_data) def set_init_msg(self, init_msg): + """ + Set initialization message for the local analyzer. + + Args: + init_msg: Initialization message received from the server. + + Returns: + None + """ self.local_analyzer.set_init_msg(init_msg) def set_client_submission(self, client_submission): + """ + Set client submission for the local analyzer. + + Args: + client_submission: Client's training submission. + + Returns: + None + """ self.local_analyzer.set_client_submission(client_submission) def update_dataset(self, client_index=None): + """ + Update the dataset for the local analyzer. + + Args: + client_index (int): The index of the client (optional). + + Returns: + None + """ _client_index = client_index or self.client_index - self.local_analyzer.update_dataset(int(_client_index)) \ No newline at end of file + self.local_analyzer.update_dataset(int(_client_index)) diff --git a/python/fedml/fa/cross_silo/client/process_group_manager.py b/python/fedml/fa/cross_silo/client/process_group_manager.py index 92519c6cc4..ff5970b89f 100644 --- a/python/fedml/fa/cross_silo/client/process_group_manager.py +++ b/python/fedml/fa/cross_silo/client/process_group_manager.py @@ -1,12 +1,46 @@ import logging import os - import torch import torch.distributed as dist - class ProcessGroupManager: + """ + Manages the initialization and cleanup of process groups for distributed training. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master process. + master_port (int): The port for communication with the master process. + only_gpu (bool): Whether to use NCCL backend if GPUs are available, otherwise use GLOO. + + Attributes: + messaging_pg (dist.ProcessGroup): The initialized process group for messaging. + + Methods: + cleanup(): + Cleanup and destroy the process group. + get_process_group(): + Get the initialized process group. + + """ def __init__(self, rank, world_size, master_address, master_port, only_gpu): + """ + Initialize the ProcessGroupManager. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master process. + master_port (int): The port for communication with the master process. + only_gpu (bool): Whether to use NCCL backend if GPUs are available, otherwise use GLOO. + + Note: + This constructor sets up the process group and environment variables. + + Returns: + None + """ logging.info("Start process group") logging.info( "rank: %d, world_size: %d, master_address: %s, master_port: %s" @@ -31,7 +65,19 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Cleanup and destroy the process group. + + Returns: + None + """ dist.destroy_process_group() def get_process_group(self): + """ + Get the initialized process group. + + Returns: + dist.ProcessGroup: The initialized process group. + """ return self.messaging_pg diff --git a/python/fedml/fa/cross_silo/fa_client.py b/python/fedml/fa/cross_silo/fa_client.py index 2971c08193..2f43acb0c3 100644 --- a/python/fedml/fa/cross_silo/fa_client.py +++ b/python/fedml/fa/cross_silo/fa_client.py @@ -3,7 +3,39 @@ class FACrossSiloClient: + """ + Federated Learning Client for Cross-Silo Federated Learning. + + Args: + args (object): An object containing client configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + client_analyzer (FAClientAnalyzer): An instance of the client analyzer (optional). + + Attributes: + args (object): An object containing client configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + client_analyzer (FAClientAnalyzer): An instance of the client analyzer. + + Methods: + run(): + Start the Cross-Silo Federated Learning client. + + """ def __init__(self, args, dataset, client_analyzer: FAClientAnalyzer = None): + """ + Initialize the Cross-Silo Federated Learning client. + + Args: + args (object): An object containing client configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + client_analyzer (FAClientAnalyzer): An instance of the client analyzer (optional). + + Note: + This constructor sets up the client and initializes it with the provided dataset and configuration. + + Returns: + None + """ [ train_data_num, train_data_local_num_dict, @@ -21,4 +53,10 @@ def __init__(self, args, dataset, client_analyzer: FAClientAnalyzer = None): ) def run(self): + """ + Start the Cross-Silo Federated Learning client. + + Returns: + None + """ pass diff --git a/python/fedml/fa/cross_silo/fa_server.py b/python/fedml/fa/cross_silo/fa_server.py index a3242dde26..ab1e8f119b 100644 --- a/python/fedml/fa/cross_silo/fa_server.py +++ b/python/fedml/fa/cross_silo/fa_server.py @@ -3,7 +3,39 @@ class FACrossSiloServer: + """ + Federated Learning Server for Cross-Silo Federated Learning. + + Args: + args (object): An object containing server configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + server_aggregator (FAServerAggregator): An instance of the server aggregator (optional). + + Attributes: + args (object): An object containing server configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + server_aggregator (FAServerAggregator): An instance of the server aggregator. + + Methods: + run(): + Start the Cross-Silo Federated Learning server. + + """ def __init__(self, args, dataset, server_aggregator: FAServerAggregator = None): + """ + Initialize the Cross-Silo Federated Learning server. + + Args: + args (object): An object containing server configuration parameters. + dataset (tuple): A tuple containing dataset information, including size and partitions. + server_aggregator (FAServerAggregator): An instance of the server aggregator (optional). + + Note: + This constructor sets up the server and initializes it with the provided dataset and configuration. + + Returns: + None + """ [ train_data_num, train_data_local_num_dict, @@ -21,4 +53,10 @@ def __init__(self, args, dataset, server_aggregator: FAServerAggregator = None): ) def run(self): + """ + Start the Cross-Silo Federated Learning server. + + Returns: + None + """ pass diff --git a/python/fedml/fa/cross_silo/server/fedml_aggregator.py b/python/fedml/fa/cross_silo/server/fedml_aggregator.py index 63d7331214..169f40608e 100644 --- a/python/fedml/fa/cross_silo/server/fedml_aggregator.py +++ b/python/fedml/fa/cross_silo/server/fedml_aggregator.py @@ -5,6 +5,41 @@ class FAAggregator(object): + """ + The FAAggregator class handles the aggregation of local models and sample numbers from clients. + + Args: + all_train_data_num (int): The total number of training data samples. + train_data_local_dict (dict): A dictionary containing the local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for each client. + client_num (int): The number of clients. + args: Additional arguments. + server_aggregator: The server aggregator responsible for aggregation. + + Attributes: + aggregator: The server aggregator responsible for aggregation. + args: Additional arguments. + all_train_data_num (int): The total number of training data samples. + train_data_local_dict (dict): A dictionary containing the local training data for each client. + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for each client. + client_num (int): The number of clients. + model_dict (dict): A dictionary containing the model parameters from each client. + sample_num_dict (dict): A dictionary containing the number of samples from each client. + flag_client_model_uploaded_dict (dict): A dictionary tracking whether each client has uploaded its model. + + Methods: + get_init_msg(): Get the initialization message from the server aggregator. + set_init_msg(init_msg): Set the initialization message in the server aggregator. + get_server_data(): Get the server data from the server aggregator. + set_server_data(server_data): Set the server data in the server aggregator. + add_local_trained_result(index, model_params, sample_num): Add local model parameters and sample numbers from a client. + check_whether_all_receive(): Check if all clients have uploaded their models. + aggregate(): Aggregate local models and calculate the global result. + data_silo_selection(round_idx, client_num_in_total, client_num_per_round): Select data silos for clients in a round. + client_selection(round_idx, client_id_list_in_total, client_num_per_round): Select clients for a round. + client_sampling(round_idx, client_num_in_total, client_num_per_round): Sample clients for a round. + """ + def __init__( self, all_train_data_num, @@ -27,24 +62,71 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_init_msg(self): + """ + Get the initialization message from the server aggregator. + + Returns: + Any: The initialization message. + """ return self.aggregator.get_init_msg() def set_init_msg(self, init_msg): + """ + Set the initialization message in the server aggregator. + + Args: + init_msg: The initialization message to set. + + Returns: + None + """ self.aggregator.set_init_msg(init_msg) def get_server_data(self): + """ + Get the server data from the server aggregator. + + Returns: + Any: The server data. + """ return self.aggregator.get_server_data() def set_server_data(self, server_data): + """ + Set the server data in the server aggregator. + + Args: + server_data: The server data to set. + + Returns: + None + """ self.aggregator.set_server_data(server_data) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add local model parameters and sample numbers from a client. + + Args: + index (int): The index of the client. + model_params: The local model parameters. + sample_num (int): The number of samples used for training. + + Returns: + None + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.debug("client_num = {}".format(self.client_num)) for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -54,6 +136,12 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate local models and calculate the global result. + + Returns: + tuple: A tuple containing the global result and a list of local results. + """ start_time = time.time() local_result_list = [] @@ -70,16 +158,15 @@ def aggregate(self): def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_round): """ + Select data silos for clients in a round. + Args: - round_idx: round index, starting from 0 - client_num_in_total: this is equal to the users in a synthetic data, - e.g., in synthetic_1_1, this value is 30 - client_num_per_round: the number of edge devices that can train + round_idx (int): The round index, starting from 0. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients that can train in a round. Returns: - data_silo_index_list: e.g., when client_num_in_total = 30, client_num_in_total = 3, - this value is the form of [0, 11, 20] - + list: A list of data silo indexes. """ logging.info( "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) @@ -89,37 +176,46 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): """ + Select clients for a round. + Args: - round_idx: round index, starting from 0 - client_id_list_in_total: this is the real edge IDs. - In MLOps, its element is real edge ID, e.g., [64, 65, 66, 67]; - in simulated mode, its element is client index starting from 1, e.g., [1, 2, 3, 4] - client_num_per_round: + round_idx (int): The round index, starting from 0. + client_id_list_in_total (list): A list of real edge IDs or client indices. + client_num_per_round (int): The number of clients to select. Returns: - client_id_list_in_this_round: sampled real edge ID list, e.g., [64, 66] + list: A list of selected client IDs. """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample clients for a round. + + Args: + round_idx (int): The round index, starting from 0. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to sample. + + Returns: + list: A list of sampled client indices. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we are selecting the same clients each round client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes - - - + \ No newline at end of file diff --git a/python/fedml/fa/cross_silo/server/fedml_server_manager.py b/python/fedml/fa/cross_silo/server/fedml_server_manager.py index e0182bdab9..89baf3ed35 100644 --- a/python/fedml/fa/cross_silo/server/fedml_server_manager.py +++ b/python/fedml/fa/cross_silo/server/fedml_server_manager.py @@ -9,9 +9,84 @@ class FedMLServerManager(FedMLCommManager): + """ + Federated Learning Server Manager for Cross-Silo Federated Learning. + + Args: + args (object): An object containing server configuration parameters. + aggregator (FAAggregator): An instance of the server aggregator. + comm: The communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend (str): The backend for communication (e.g., "MQTT_S3"). + + Attributes: + args (object): An object containing server configuration parameters. + aggregator (FAAggregator): An instance of the server aggregator. + round_num (int): The number of communication rounds. + client_online_mapping (dict): A dictionary mapping client IDs to their online status. + client_real_ids (list): A list of real client IDs. + is_initialized (bool): A flag indicating whether the server is initialized. + client_id_list_in_this_round (list): A list of client IDs for the current round. + data_silo_index_list (list): A list of data silo indices for clients in the current round. + + Methods: + run(): + Start the Federated Learning server. + + send_init_msg(): + Send initialization messages to clients. + + register_message_receive_handlers(): + Register handlers for receiving messages. + + handle_message_connection_ready(msg_params): + Handle the connection ready message from clients. + + handle_message_client_status_update(msg_params): + Handle client status updates. + + handle_message_receive_model_from_client(msg_params): + Handle received models from clients. + + cleanup(): + Perform cleanup operations after completing a round of communication. + + send_message_init_config(receive_id, global_model_params, datasilo_index, + global_model_url=None, global_model_key=None): + Send initialization configuration messages to clients. + + send_message_check_client_status(receive_id, datasilo_index): + Send client status check messages to clients. + + send_message_finish(receive_id, datasilo_index): + Send finish messages to clients. + + send_message_sync_model_to_client(receive_id, global_model_params, client_index, + global_model_url=None, global_model_key=None): + Send synchronized model messages to clients. + + """ def __init__( self, args, aggregator, comm=None, client_rank=0, client_num=0, backend="MQTT_S3", ): + """ + Initialize the Federated Learning Server Manager. + + Args: + args (object): An object containing server configuration parameters. + aggregator (FAAggregator): An instance of the server aggregator. + comm: The communication object. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend (str): The backend for communication (e.g., "MQTT_S3"). + + Note: + This constructor sets up the server manager with the provided configuration and aggregator. + + Returns: + None + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.aggregator = aggregator @@ -24,9 +99,21 @@ def __init__( self.data_silo_index_list = None def run(self): + """ + Start the Federated Learning server. + + Returns: + None + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + Returns: + None + """ global_result = self.aggregator.get_server_data() global_result_url = None @@ -43,6 +130,12 @@ def send_init_msg(self): mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) def register_message_receive_handlers(self): + """ + Register handlers for receiving messages. + + Returns: + None + """ logging.info("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready @@ -57,6 +150,15 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection ready message from clients. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( self.args.round_idx, self.client_real_ids, self.args.client_num_per_round @@ -78,6 +180,15 @@ def handle_message_connection_ready(self, msg_params): client_idx_this_round += 1 def handle_message_client_status_update(self, msg_params): + """ + Handle client status updates. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == "ONLINE": self.client_online_mapping[str(msg_params.get_sender_id())] = True @@ -100,6 +211,15 @@ def handle_message_client_status_update(self, msg_params): self.is_initialized = True def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received models from clients. + + Args: + msg_params (dict): Message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) local_results = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) @@ -157,6 +277,12 @@ def handle_message_receive_model_from_client(self, msg_params): self.aggregator.set_init_msg(init_msg=None) def cleanup(self): + """ + Perform cleanup operations after completing a round of communication. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -169,6 +295,19 @@ def cleanup(self): def send_message_init_config(self, receive_id, global_model_params, datasilo_index, global_model_url=None, global_model_key=None): + """ + Send initialization configuration messages to clients. + + Args: + receive_id: The ID of the receiving client. + global_model_params: The global model parameters. + datasilo_index: The data silo index. + global_model_url (str, optional): The URL of global model parameters. Defaults to None. + global_model_key (str, optional): The key of global model parameters. Defaults to None. + + Returns: + tuple: A tuple containing the updated global_model_url and global_model_key. + """ tick = time.time() message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) if global_model_url is not None: @@ -187,11 +326,31 @@ def send_message_init_config(self, receive_id, global_model_params, datasilo_ind return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): + """ + Send client status check messages to clients. + + Args: + receive_id: The ID of the receiving client. + datasilo_index: The data silo index. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): + """ + Send finish messages to clients. + + Args: + receive_id: The ID of the receiving client. + datasilo_index: The data silo index. + + Returns: + None + """ message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) @@ -201,6 +360,19 @@ def send_message_finish(self, receive_id, datasilo_index): def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index, global_model_url=None, global_model_key=None): + """ + Send synchronized model messages to clients. + + Args: + receive_id: The ID of the receiving client. + global_model_params: The global model parameters. + client_index: The client index. + global_model_url (str, optional): The URL of global model parameters. Defaults to None. + global_model_key (str, optional): The key of global model parameters. Defaults to None. + + Returns: + tuple: A tuple containing the updated global_model_url and global_model_key. + """ tick = time.time() logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) diff --git a/python/fedml/fa/cross_silo/server/server_initializer.py b/python/fedml/fa/cross_silo/server/server_initializer.py index cff55ecc26..c296bf1744 100644 --- a/python/fedml/fa/cross_silo/server/server_initializer.py +++ b/python/fedml/fa/cross_silo/server/server_initializer.py @@ -13,6 +13,22 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the Federated Learning server for Cross-Silo Federated Learning. + + Args: + args (object): An object containing server configuration parameters. + comm: The communication object. + rank (int): The rank of the server. + worker_num (int): The total number of workers. + train_data_num (int): The total number of training data samples. + train_data_local_dict (dict): A dictionary of client-specific training data. + train_data_local_num_dict (dict): A dictionary of client-specific training data sizes. + server_aggregator: An instance of the server aggregator (optional). + + Returns: + None + """ if server_aggregator is None: server_aggregator = create_global_analyzer(args, train_data_num=train_data_num) server_aggregator.set_id(0) diff --git a/python/fedml/fa/data/data_loader.py b/python/fedml/fa/data/data_loader.py index a29877cc98..2a6e820da0 100644 --- a/python/fedml/fa/data/data_loader.py +++ b/python/fedml/fa/data/data_loader.py @@ -11,12 +11,31 @@ def fa_load_data(args): + """ + Load synthetic data based on the specified dataset. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + list: A list containing dataset information. + """ return load_synthetic_data(args) def load_synthetic_data(args): + """ + Load synthetic data based on the specified dataset name. + + Args: + args (argparse.Namespace): Command-line arguments. + + Returns: + list: A list containing dataset information. + """ dataset_name = args.dataset if dataset_name == "fake": + # Load fake numeric data data_cache_dir = os.path.join(args.data_cache_dir, "fake_numeric_data") if not os.path.exists(data_cache_dir): os.makedirs(data_cache_dir, exist_ok=True) @@ -33,7 +52,7 @@ def load_synthetic_data(args): train_data_local_num_dict, local_data_dict, ] - # print(f"datasize, train_data_local_num_dict, local_data_dict,{dataset}") + elif dataset_name == "twitter": path = os.path.join(args.data_cache_dir, "twitter_Sentiment140") download_twitter_Sentiment140(data_cache_dir=path) @@ -70,7 +89,7 @@ def load_synthetic_data(args): if hasattr(args, "seperator"): separator = args.seperator else: - separator = "," # default seperator = "," + separator = "," # default separator = "," ( datasize, train_data_local_num_dict, @@ -110,6 +129,7 @@ def load_synthetic_data_test(): load_synthetic_data(args=args) + if __name__ == '__main__': # read_data(train_data_dir="fake_data") # download_twitter_Sentiment140("data") diff --git a/python/fedml/fa/data/fake_numeric_data/data_loader.py b/python/fedml/fa/data/fake_numeric_data/data_loader.py index a2e461b097..f615948204 100644 --- a/python/fedml/fa/data/fake_numeric_data/data_loader.py +++ b/python/fedml/fa/data/fake_numeric_data/data_loader.py @@ -5,16 +5,36 @@ def generate_fake_data(data_cache_dir): - file_path = data_cache_dir + "/fake_numeric_data.txt" + """ + Generate fake numeric data and save it to a text file in the specified directory. + + Args: + data_cache_dir (str): The directory where the fake numeric data file should be saved. + + Note: + This function generates random integer data and writes it to a text file. + + Returns: + None + """ + file_path = os.path.join(data_cache_dir, "fake_numeric_data.txt") if not os.path.exists(file_path): - f = open(file_path, "a") - for i in range(10000): - f.write(f"{random.randint(1, 100)}\n") - f.close() + with open(file_path, "a") as f: + for i in range(10000): + f.write(f"{random.randint(1, 100)}\n") def load_partition_data_fake(data_dir, client_num): + """ + Load and partition fake data from a specified directory into client-specific partitions. + + Args: + data_dir (str): The directory path where the fake data is located. + client_num (int): The total number of clients to partition the data for. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ dataset = read_data(data_dir=data_dir) return equally_partition_a_dataset(client_num, dataset) - diff --git a/python/fedml/fa/data/self_defined_data/data_loader.py b/python/fedml/fa/data/self_defined_data/data_loader.py index 9a249ca0b3..333a53ad40 100644 --- a/python/fedml/fa/data/self_defined_data/data_loader.py +++ b/python/fedml/fa/data/self_defined_data/data_loader.py @@ -6,6 +6,18 @@ def generate_fake_data(data_cache_dir): + """ + Generate fake numeric data and save it to a text file in the specified directory. + + Args: + data_cache_dir (str): The directory where the fake numeric data file should be saved. + + Note: + This function generates random integer data and writes it to a text file. + + Returns: + None + """ file_path = data_cache_dir + "/fake_numeric_data.txt" if not os.path.exists(file_path): @@ -16,9 +28,23 @@ def generate_fake_data(data_cache_dir): def load_partition_self_defined_data(file_folder_path, client_num, data_col_idx, separator=","): + """ + Load and partition self-defined data from a text file into client-specific partitions. + + Args: + file_folder_path (str): The path to the text file containing the data. + client_num (int): The total number of clients to partition the data for. + data_col_idx (int): The column index of the data to be used. + separator (str): The separator used in the data file (default is comma ','). + + Raises: + Exception: If the specified data file does not exist. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ if not os.path.exists(file_folder_path): raise Exception(f"No data file: {file_folder_path}") logging.info(f"file_folder_path = {file_folder_path}") - dataset = read_data_with_column_idx(file_folder_path=file_folder_path, column_idx=data_col_idx, seperator=separator) + dataset = read_data_with_column_idx(file_folder_path=file_folder_path, column_idx=data_col_idx, separator=separator) return equally_partition_a_dataset(client_num, dataset) - diff --git a/python/fedml/fa/data/twitter_Sentiment140/data_loader.py b/python/fedml/fa/data/twitter_Sentiment140/data_loader.py index c99d9b8d77..007c79b524 100644 --- a/python/fedml/fa/data/twitter_Sentiment140/data_loader.py +++ b/python/fedml/fa/data/twitter_Sentiment140/data_loader.py @@ -9,6 +9,18 @@ def download_twitter_Sentiment140(data_cache_dir): + """ + Download the Sentiment140 Twitter dataset if it doesn't exist in the specified directory. + + Args: + data_cache_dir (str): The directory where the dataset should be downloaded. + + Note: + This function downloads the dataset from a URL and extracts it to the specified directory. + + Returns: + None + """ if not os.path.exists(data_cache_dir): os.makedirs(data_cache_dir, exist_ok=True) file_path = os.path.join(data_cache_dir, "trainingandtestdata.zip") @@ -21,10 +33,30 @@ def download_twitter_Sentiment140(data_cache_dir): def load_partition_data_twitter_sentiment140(dataset, client_num_in_total): + """ + Load and partition the Sentiment140 Twitter dataset into client-specific partitions. + + Args: + dataset (dict): A dictionary containing client usernames as keys and their data as values. + client_num_in_total (int): The total number of clients to partition the data for. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ return equally_partition_a_dataset_according_to_users(client_num_in_total, dataset) def load_partition_data_twitter_sentiment140_heavy_hitter(dataset, client_num_in_total): + """ + Load and partition the Sentiment140 Twitter dataset for heavy hitters into client-specific partitions. + + Args: + dataset (dict): A dictionary containing client usernames as keys and their data as values. + client_num_in_total (int): The total number of clients to partition the data for. + + Returns: + tuple: A tuple containing the dataset size, a dictionary of client data sizes, and a dictionary of client data. + """ local_data_dict = dict() train_data_local_num_dict = dict() heavy_hitters = list(dataset.values()) @@ -41,4 +73,4 @@ def load_partition_data_twitter_sentiment140_heavy_hitter(dataset, client_num_in datasize, train_data_local_num_dict, local_data_dict, - ) \ No newline at end of file + ) diff --git a/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py b/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py index 8c048b555d..4fc1bb6566 100644 --- a/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py +++ b/python/fedml/fa/data/twitter_Sentiment140/twitter_data_processing.py @@ -10,9 +10,16 @@ def is_valid(word): - if len(word) < 3 or (word[-1] in [ - '?', '!', '.', ';', ',' - ]) or word.startswith('http') or word.startswith('www'): + """ + Check if a word is valid for processing. + + Args: + word (str): The word to check. + + Returns: + bool: True if the word is valid, False otherwise. + """ + if len(word) < 3 or (word[-1] in ['?', '!', '.', ';', ',']) or word.startswith('http') or word.startswith('www'): return False if re.match(r'^[a-z_\@\#\-\;\(\)\*\:\.\'\/]+$', word): return True @@ -20,6 +27,16 @@ def is_valid(word): def truncate_or_extend(word, max_word_len): + """ + Truncate or extend a word to a specified length. + + Args: + word (str): The word to modify. + max_word_len (int): The desired maximum length of the word. + + Returns: + str: The modified word. + """ if len(word) > max_word_len: word = word[:max_word_len] else: @@ -28,10 +45,26 @@ def truncate_or_extend(word, max_word_len): def add_end_symbol(word): + """ + Add an end symbol ('$') to a word. + + Args: + word (str): The word to modify. + + Returns: + str: The modified word with an end symbol. + """ return word + '$' def generate_triehh_clients(clients, path): + """ + Generate TrieHH clients from a list of clients and save them to a file. + + Args: + clients (list): List of client names. + path (str): The directory path to save the file. + """ clients_num = len(clients) triehh_clients = [add_end_symbol(clients[i]) for i in range(clients_num)] word_freq = collections.defaultdict(lambda: 0) @@ -44,6 +77,15 @@ def generate_triehh_clients(clients, path): def preprocess_twitter_data(path): + """ + Preprocess Twitter data from a CSV file. + + Args: + path (str): The directory path where the CSV file is located. + + Returns: + dict: A dictionary containing client usernames as keys and lists of preprocessed words as values. + """ filename = os.path.join(path, 'training.1600000.processed.noemoticon.csv') dataset = {} with open(filename, encoding='ISO-8859-1') as csv_file: @@ -66,6 +108,15 @@ def preprocess_twitter_data(path): def preprocess_twitter_data_heavy_hitter(path): + """ + Preprocess Twitter data and identify heavy hitters (most frequent words) for each client. + + Args: + path (str): The directory path where the CSV file is located. + + Returns: + dict: A dictionary containing client usernames as keys and their identified heavy hitter words as values. + """ # load dataset from csv file filename = os.path.join(path, 'training.1600000.processed.noemoticon.csv') clients = {} diff --git a/python/fedml/fa/data/utils.py b/python/fedml/fa/data/utils.py index ada768dca4..a41b2b6867 100644 --- a/python/fedml/fa/data/utils.py +++ b/python/fedml/fa/data/utils.py @@ -3,6 +3,17 @@ def equally_partition_a_dataset(client_num_in_total, dataset): + """ + Equally partition a dataset among clients. + + Args: + client_num_in_total (int): The total number of clients. + dataset (list): The dataset to partition. + + Returns: + tuple: A tuple containing the total dataset size, a dictionary of local data counts per client, + and a dictionary of local data for each client. + """ client_data_num = int(len(dataset) / client_num_in_total) local_data_dict = dict() train_data_local_num_dict = dict() @@ -20,6 +31,17 @@ def equally_partition_a_dataset(client_num_in_total, dataset): def equally_partition_a_dataset_according_to_users(client_num_in_total, dataset): + """ + Equally partition a dataset among clients based on the number of users. + + Args: + client_num_in_total (int): The total number of clients. + dataset (dict): The dataset organized by user IDs. + + Returns: + tuple: A tuple containing the total dataset size, a dictionary of local data counts per client, + and a dictionary of local data for each client. + """ user_num_for_one_client = int(math.ceil(len(dataset) / client_num_in_total)) local_data_dict = dict() train_data_local_num_dict = dict() @@ -45,6 +67,15 @@ def equally_partition_a_dataset_according_to_users(client_num_in_total, dataset) def read_data(data_dir): + """ + Read data from text files in a directory. + + Args: + data_dir (str): The path to the directory containing text data files. + + Returns: + list: A list of integers representing the dataset. + """ train_files = os.listdir(data_dir) train_files = [f for f in train_files if f.endswith(".txt")] dataset = [] @@ -56,7 +87,18 @@ def read_data(data_dir): return dataset -def read_data_with_column_idx(file_folder_path, column_idx, seperator=","): +def read_data_with_column_idx(file_folder_path, column_idx, separator=","): + """ + Read data from text files in a directory, selecting a specific column. + + Args: + file_folder_path (str): The path to the directory containing text data files. + column_idx (int): The index of the column to extract. + separator (str, optional): The separator used in the text files (default is comma). + + Returns: + list: A list of values from the selected column. + """ train_files = os.listdir(file_folder_path) train_files = [f for f in train_files if not f.startswith(".")] dataset = [] @@ -64,6 +106,6 @@ def read_data_with_column_idx(file_folder_path, column_idx, seperator=","): file_path = os.path.join(file_folder_path, f) f2 = open(file_path, "r") for line in f2: - if len(line.split(seperator)[column_idx].strip()) > 0: - dataset.append(line.split(seperator)[column_idx].strip()) - return dataset \ No newline at end of file + if len(line.split(separator)[column_idx].strip()) > 0: + dataset.append(line.split(separator)[column_idx].strip()) + return dataset diff --git a/python/fedml/fa/local_analyzer/avg.py b/python/fedml/fa/local_analyzer/avg.py index 0e4761c66a..0866988d2e 100644 --- a/python/fedml/fa/local_analyzer/avg.py +++ b/python/fedml/fa/local_analyzer/avg.py @@ -1,10 +1,31 @@ from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class AverageClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for calculating the average of values in the training data. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to calculate the average of values and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to calculate the average of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ sample_num = len(train_data) average = 0.0 for value in train_data: - average = average + float(value) / float(sample_num) + average += float(value) / float(sample_num) self.set_client_submission(average) diff --git a/python/fedml/fa/local_analyzer/client_analyzer_creator.py b/python/fedml/fa/local_analyzer/client_analyzer_creator.py index 64c5a154f9..b5694ef361 100644 --- a/python/fedml/fa/local_analyzer/client_analyzer_creator.py +++ b/python/fedml/fa/local_analyzer/client_analyzer_creator.py @@ -7,8 +7,16 @@ from fedml.fa.local_analyzer.k_percentage_element import KPercentileElementClientAnalyzer from fedml.fa.local_analyzer.union import UnionClientAnalyzer - def create_local_analyzer(args): + """ + Create a specific type of local analyzer based on the task type. + + Args: + args (object): Arguments for the local analyzer creation. + + Returns: + object: A local analyzer instance based on the specified task type. + """ task_type = args.fa_task if task_type == FA_TASK_AVG: return AverageClientAnalyzer(args) @@ -24,4 +32,3 @@ def create_local_analyzer(args): return FrequencyEstimationClientAnalyzer(args) if task_type == FA_TASK_HEAVY_HITTER_TRIEHH: return TrieHHClientAnalyzer(args) - diff --git a/python/fedml/fa/local_analyzer/frequency_estimation.py b/python/fedml/fa/local_analyzer/frequency_estimation.py index 4477f068dd..98ca1154ed 100644 --- a/python/fedml/fa/local_analyzer/frequency_estimation.py +++ b/python/fedml/fa/local_analyzer/frequency_estimation.py @@ -2,12 +2,40 @@ class FrequencyEstimationClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for estimating the frequency of values in the training data. + + Args: + client_id: The unique identifier of the client. + server: The federated learning server. + + Attributes: + client_id: The unique identifier of the client. + server: The federated learning server. + + Methods: + local_analyze(train_data, args): + Analyze the training data to estimate the frequency of values and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to estimate the frequency of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ counter_dict = dict() for value in train_data: if counter_dict.get(value) is None: counter_dict[value] = 1 else: - counter_dict[value] = counter_dict[value] + 1 - self.set_client_submission(counter_dict) \ No newline at end of file + counter_dict[value] += 1 + + self.set_client_submission(counter_dict) diff --git a/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py b/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py index 3933d6e9bb..fac7197baa 100644 --- a/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py +++ b/python/fedml/fa/local_analyzer/heavy_hitter_triehh.py @@ -3,8 +3,39 @@ from collections import defaultdict from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class TrieHHClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for Trie-HH federated learning. + + Args: + args: Additional arguments for configuration. + + Attributes: + round_counter (int): Counter to keep track of rounds. + batch_size (int): Size of the sample batch for analysis. + client_num_per_round (int): Number of clients per round. + + Methods: + __init__(self, args): + Initialize the TrieHHClientAnalyzer with provided arguments. + + set_init_msg(self, init_msg): + Set the initial message containing batch size. + + get_init_msg(self): + Get the initial message. + + local_analyze(self, train_data, args): + Analyze the local training data and set the client submission. + + client_vote(self, sample_local_dataset): + Perform voting based on local data and return the votes. + + one_word_vote(self, word): + Perform voting for a single word in the dataset. + + """ + def __init__(self, args): super().__init__(args=args) self.round_counter = 0 @@ -12,19 +43,53 @@ def __init__(self, args): self.client_num_per_round = args.client_num_per_round def set_init_msg(self, init_msg): + """ + Set the initial message containing batch size. + + Args: + init_msg: The initial message containing batch size. + + Returns: + None + """ self.init_msg = init_msg self.batch_size = self.init_msg def get_init_msg(self): + """ + Get the initial message. + + Returns: + int: The initial message containing batch size. + """ return self.init_msg def local_analyze(self, train_data, args): + """ + Analyze the training data and set the client submission. + + Args: + train_data (list): The training data for analysis. + args: Additional arguments (not used in this method). + + Returns: + None + """ idxs = np.random.choice(range(len(train_data)), self.batch_size, replace=False) sample_local_dataset = [train_data[i] for i in idxs] votes = self.client_vote(sample_local_dataset) self.set_client_submission(votes) def client_vote(self, sample_local_dataset): + """ + Perform voting based on local data and return the votes. + + Args: + sample_local_dataset (list): Sampled local dataset for voting. + + Returns: + dict: Dictionary containing votes. + """ votes = defaultdict(int) self.round_counter += 1 self.w_global = self.get_server_data() @@ -35,13 +100,21 @@ def client_vote(self, sample_local_dataset): return votes def one_word_vote(self, word): + """ + Perform voting for a single word in the dataset. + + Args: + word (str): A word from the dataset. + + Returns: + int: Voting result (1 if valid, 0 otherwise). + """ if len(word) < self.round_counter: return 0 pre = word[0:self.round_counter - 1] - # print(f"self.w_global={self.w_global}") - # print(f"pre = {pre}, type={type(self.w_global)}") + if self.w_global is None: return 1 if pre and (pre not in self.w_global): return 0 - return 1 \ No newline at end of file + return 1 diff --git a/python/fedml/fa/local_analyzer/intersection.py b/python/fedml/fa/local_analyzer/intersection.py index 76f5fe1b92..9102f31bd4 100644 --- a/python/fedml/fa/local_analyzer/intersection.py +++ b/python/fedml/fa/local_analyzer/intersection.py @@ -2,5 +2,27 @@ class IntersectionClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for finding the intersection of values in the training data. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to find the intersection of values and set the client submission. + + """ + def local_analyze(self, train_data, args): - self.set_client_submission(list(set(train_data))) \ No newline at end of file + """ + Analyze the training data to find the intersection of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ + self.set_client_submission(list(set(train_data))) diff --git a/python/fedml/fa/local_analyzer/k_percentage_element.py b/python/fedml/fa/local_analyzer/k_percentage_element.py index 1c075842a6..4ea7819580 100644 --- a/python/fedml/fa/local_analyzer/k_percentage_element.py +++ b/python/fedml/fa/local_analyzer/k_percentage_element.py @@ -1,10 +1,31 @@ from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class KPercentileElementClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for counting values larger than a given percentile. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to count values larger than a given percentile and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to count values larger than a given percentile and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ counter = 0 for data in train_data: if data >= self.server_data: # flag counter += 1 - self.set_client_submission(counter) # number of values that are larger than flag + self.set_client_submission(counter) # number of values that are larger than the flag diff --git a/python/fedml/fa/local_analyzer/union.py b/python/fedml/fa/local_analyzer/union.py index 4ce99a39ad..2b0ad16b63 100644 --- a/python/fedml/fa/local_analyzer/union.py +++ b/python/fedml/fa/local_analyzer/union.py @@ -1,6 +1,27 @@ from fedml.fa.base_frame.client_analyzer import FAClientAnalyzer - class UnionClientAnalyzer(FAClientAnalyzer): + """ + A client analyzer for finding the union of values in the training data. + + Args: + None + + Methods: + local_analyze(train_data, args): + Analyze the training data to find the union of values and set the client submission. + + """ + def local_analyze(self, train_data, args): + """ + Analyze the training data to find the union of values and set the client submission. + + Args: + train_data (list): The training data containing values to analyze. + args: Additional arguments (not used in this method). + + Returns: + None + """ self.set_client_submission(list(set(train_data))) diff --git a/python/fedml/fa/runner.py b/python/fedml/fa/runner.py index f69dfc38ee..c446b79077 100644 --- a/python/fedml/fa/runner.py +++ b/python/fedml/fa/runner.py @@ -1,8 +1,21 @@ from fedml import FEDML_SIMULATION_TYPE_SP, FEDML_TRAINING_PLATFORM_SIMULATION, FEDML_TRAINING_PLATFORM_CROSS_SILO from fedml.fa.simulation.sp.simulator import FASimulatorSingleProcess - class FARunner: + """ + A class for running Federated Learning simulations. + + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_trainer: The client trainer for training clients (optional). + server_aggregator: The server aggregator for aggregating client updates (optional). + + Methods: + run(): + Run the Federated Learning simulation. + + """ def __init__( self, args, @@ -10,7 +23,19 @@ def __init__( client_trainer=None, server_aggregator=None, ): + """ + Initialize the FARunner with the provided arguments and components. + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_trainer: The client trainer for training clients (optional). + server_aggregator: The server aggregator for aggregating client updates (optional). + + Raises: + Exception: If an invalid training type is specified in the arguments. + + """ if args.training_type == FEDML_TRAINING_PLATFORM_SIMULATION: init_runner_func = self._init_simulation_runner elif args.training_type == FEDML_TRAINING_PLATFORM_CROSS_SILO: @@ -25,6 +50,22 @@ def __init__( def _init_simulation_runner( self, args, dataset, client_analyzer=None, server_analyzer=None ): + """ + Initialize a simulation runner based on the provided arguments. + + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_analyzer: The client analyzer for analyzing client behavior (optional). + server_analyzer: The server analyzer for analyzing server behavior (optional). + + Returns: + FASimulatorSingleProcess: A simulation runner for single-process simulation. + + Raises: + Exception: If an unsupported simulation backend is specified in the arguments. + + """ if hasattr(args, "backend") and args.backend == FEDML_SIMULATION_TYPE_SP: runner = FASimulatorSingleProcess(args, dataset) else: @@ -33,6 +74,22 @@ def _init_simulation_runner( return runner def _init_cross_silo_runner(self, args, dataset, client_analyzer=None, server_analyzer=None): + """ + Initialize a cross-silo runner based on the provided arguments. + + Args: + args: The arguments for configuring the simulation. + dataset: The dataset used for the simulation. + client_analyzer: The client analyzer for analyzing client behavior (optional). + server_analyzer: The server analyzer for analyzing server behavior (optional). + + Returns: + FACrossSiloClient or FACrossSiloServer: A cross-silo client or server runner. + + Raises: + Exception: If an invalid role is specified in the arguments. + + """ if args.role == "client": from fedml.fa.cross_silo.fa_client import FACrossSiloClient as Client runner = Client(args, dataset, client_analyzer) @@ -45,4 +102,8 @@ def _init_cross_silo_runner(self, args, dataset, client_analyzer=None, server_an return runner def run(self): + """ + Run the Federated Learning simulation. + + """ self.runner.run() diff --git a/python/fedml/fa/simulation/sp/client.py b/python/fedml/fa/simulation/sp/client.py index 1902d165b3..7de16c8ed9 100644 --- a/python/fedml/fa/simulation/sp/client.py +++ b/python/fedml/fa/simulation/sp/client.py @@ -1,10 +1,48 @@ import numpy as np - class Client: + """ + Client class for Federated Analytics simulation. + + Args: + client_idx (int): Index of the client. + local_training_data (list): Local training data for the client. + local_datasize (int): Size of the local training data. + args (object): Arguments for the simulation. + local_analyzer (object): Local analyzer instance. + + Attributes: + client_idx (int): Index of the client. + local_training_data (list): Local training data for the client. + local_datasize (int): Size of the local training data. + local_sample_number (int): Number of local samples. + args (object): Arguments for the simulation. + local_analyzer (object): Local analyzer instance. + + Methods: + update_local_dataset(client_idx, local_training_data, local_sample_number): + Update the client's local dataset and sample number. + + get_sample_number(): + Get the number of local samples. + + local_analyze(w_global): + Perform local analysis and return client's submission. + """ + def __init__( self, client_idx, local_training_data, local_datasize, args, local_analyzer, ): + """ + Initialize the Client class. + + Args: + client_idx (int): Index of the client. + local_training_data (list): Local training data for the client. + local_datasize (int): Size of the local training data. + args (object): Arguments for the simulation. + local_analyzer (object): Local analyzer instance. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_datasize = local_datasize @@ -13,18 +51,41 @@ def __init__( self.local_analyzer = local_analyzer def update_local_dataset(self, client_idx, local_training_data, local_sample_number): + """ + Update the client's local dataset and sample number. + + Args: + client_idx (int): Index of the client. + local_training_data (list): Updated local training data. + local_sample_number (int): Updated number of local samples. + """ self.client_idx = client_idx self.local_training_data = local_training_data self.local_sample_number = local_sample_number self.local_analyzer.set_id(client_idx) def get_sample_number(self): + """ + Get the number of local samples. + + Returns: + int: Number of local samples. + """ return self.local_sample_number def local_analyze(self, w_global): + """ + Perform local analysis and return client's submission. + + Args: + w_global (object): Global data from the server. + + Returns: + object: Client's submission after local analysis. + """ self.local_analyzer.set_server_data(w_global) idxs = np.random.choice(range(len(self.local_training_data)), self.local_sample_number, replace=False) train_data = [self.local_training_data[i] for i in idxs] - # print(f"train data = {train_data}") + self.local_analyzer.local_analyze(train_data, self.args) return self.local_analyzer.get_client_submission() diff --git a/python/fedml/fa/simulation/sp/simulator.py b/python/fedml/fa/simulation/sp/simulator.py index 265d257dc9..627e5d9f94 100644 --- a/python/fedml/fa/simulation/sp/simulator.py +++ b/python/fedml/fa/simulation/sp/simulator.py @@ -7,7 +7,38 @@ class FASimulatorSingleProcess: + """ + Simulator for Federated Analytics with a Single Process. + + Args: + args (object): Arguments for the simulation. + dataset (list): Dataset information including train data count, local datasize, and train data for each client. + + Attributes: + args (object): Arguments for the simulation. + train_data_num_in_total (int): Total number of training data points. + client_list (list): List of client instances. + local_datasize_dict (dict): Dictionary of local datasizes for each client. + train_data_local_dict (dict): Dictionary of local training data for each client. + local_analyzer (object): Local analyzer instance. + aggregator (object): Global aggregator instance. + + Methods: + analyze(): + Run the Federated Analytics simulation. + + run(): + Run the simulation. + """ + def __init__(self, args, dataset): + """ + Initialize the FASimulatorSingleProcess class. + + Args: + args (object): Arguments for the simulation. + dataset (list): Dataset information including train data count, local datasize, and train data for each client. + """ self.args = args [ train_data_num, @@ -30,6 +61,14 @@ def __init__(self, args, dataset): def _setup_clients( self, local_datasize_dict, train_data_local_dict, local_analyzer, ): + """ + Set up client instances for the simulation. + + Args: + local_datasize_dict (dict): Dictionary of local datasizes for each client. + train_data_local_dict (dict): Dictionary of local training data for each client. + local_analyzer (object): Local analyzer instance. + """ logging.info("############setup_clients (START)#############") for client_idx in range(self.args.client_num_per_round): c = Client( @@ -43,6 +82,9 @@ def _setup_clients( logging.info("############setup_clients (END)#############") def analyze(self): + """ + Run the Federated Analytics simulation. + """ logging.info("self.local_analyzer = {}".format(self.local_analyzer)) local_sample_num = dict() for round_idx in range(self.args.comm_round): @@ -76,4 +118,7 @@ def analyze(self): print(f"round_idx={round_idx}, aggregation result = {result}") def run(self): + """ + Run the simulation. + """ self.analyze() diff --git a/python/fedml/fa/simulation/utils.py b/python/fedml/fa/simulation/utils.py index 3ae180397d..0bd40de181 100644 --- a/python/fedml/fa/simulation/utils.py +++ b/python/fedml/fa/simulation/utils.py @@ -2,11 +2,22 @@ def client_sampling(round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for federated learning. + + Args: + round_idx (int): The index of the current federated learning round. + client_num_in_total (int): The total number of available clients. + client_num_per_round (int): The number of clients to select for the current round. + + Returns: + list: A list of selected client indexes for the current round. + """ if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) # Make sure for each comparison, we select the same clients each round. client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) print("client_indexes = %s" % str(client_indexes)) return client_indexes diff --git a/python/fedml/fa/utils/trie.py b/python/fedml/fa/utils/trie.py index f7b8166692..261f6df2cf 100644 --- a/python/fedml/fa/utils/trie.py +++ b/python/fedml/fa/utils/trie.py @@ -195,10 +195,70 @@ def _levenshtein(path, node, word, distance, cigar): class Trie(object): + """ + A Trie data structure for efficiently storing and searching words. + + Args: + words (list): List of words to initialize the Trie. + + Attributes: + root (dict): The root of the Trie. + + Methods: + __contains__(word): + Check if a word is present in the Trie. + + __iter__(): + Get an iterator for the words in the Trie. + + list(unique=True): + Get a list of words in the Trie. + + add(word, count=1): + Add a word to the Trie. + + get(word): + Get the count of a word in the Trie. + + remove(word, count=1): + Remove a word from the Trie. + + has_prefix(word): + Check if any word in the Trie has a given prefix. + + fill(alphabet, length): + Fill the Trie with words of a given length using characters from the alphabet. + + all_hamming_(word, distance): + Find all words in the Trie within a given Hamming distance. + + all_hamming(word, distance): + Find all words in the Trie within a given Hamming distance (returns words only). + + hamming(word, distance): + Find the first word in the Trie within a given Hamming distance. + + best_hamming(word, distance): + Find the best match for a word in the Trie within a given Hamming distance. + + all_levenshtein_(word, distance): + Find all words in the Trie within a given Levenshtein distance. + + all_levenshtein(word, distance): + Find all words in the Trie within a given Levenshtein distance (returns words only). + + levenshtein(word, distance): + Find the first word in the Trie within a given Levenshtein distance. + + best_levenshtein(word, distance): + Find the best match for a word in the Trie within a given Levenshtein distance. + """ def __init__(self, words=None): - """Initialise the class. + """ + Initialize the Trie class. - :arg list words: List of words. + Args: + words (list): List of words to initialize the Trie. """ self.root = {} @@ -207,54 +267,153 @@ def __init__(self, words=None): self.add(word) def __contains__(self, word): + """ + Check if a word is present in the Trie. + + Args: + word (str): The word to check. + + Returns: + bool: True if the word is in the Trie, False otherwise. + """ return '' in _find(self.root, word) def __iter__(self): + """ + Get an iterator for the words in the Trie. + + Returns: + Iterator: An iterator object for iterating through words in the Trie. + """ return _iterate('', self.root, True) def list(self, unique=True): + """ + Get a list of words in the Trie. + + Args: + unique (bool): Whether to return unique words only (default is True). + + Returns: + list: A list of words in the Trie. + """ return _iterate('', self.root, unique) def add(self, word, count=1): + """ + Add a word to the Trie. + + Args: + word (str): The word to add. + count (int): The count to associate with the word (default is 1). + """ _add(self.root, word, count) def get(self, word): + """ + Get the count of a word in the Trie. + + Args: + word (str): The word to get the count for. + + Returns: + int: The count of the word in the Trie or None if not found. + """ node = _find(self.root, word) if '' in node: return node[''] return None def remove(self, word, count=1): + """ + Remove a word from the Trie. + + Args: + word (str): The word to remove. + count (int): The count to decrement (default is 1). + + Returns: + int: The remaining count of the word in the Trie or None if not found. + """ return _remove(self.root, word, count) def has_prefix(self, word): + """ + Check if any word in the Trie has a given prefix. + + Args: + word (str): The prefix to check. + + Returns: + bool: True if any word has the given prefix, False otherwise. + """ return _find(self.root, word) != {} def fill(self, alphabet, length): + """ + Fill the Trie with words of a given length using characters from the alphabet. + + Args: + alphabet (str): The characters to use for filling. + length (int): The length of words to generate and add to the Trie. + """ _fill(self.root, alphabet, length) def all_hamming_(self, word, distance): + """ + Find all words in the Trie within a given Hamming distance and return detailed results. + + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. + + Returns: + map: A map containing tuples with (word, remaining distance, count). + """ return map( lambda x: (x[0], distance - x[1], x[2]), _hamming('', self.root, word, distance, '')) def all_hamming(self, word, distance): + """ + Find all words in the Trie within a given Hamming distance and return words only. + + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. + + Returns: + map: A map containing words within the specified Hamming distance. + """ return map( lambda x: x[0], _hamming('', self.root, word, distance, '')) def hamming(self, word, distance): + """ + Find the first word in the Trie within a given Hamming distance. + + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. + + Returns: + str: The first word within the specified Hamming distance or None if not found. + """ try: return next(self.all_hamming(word, distance)) except StopIteration: return None def best_hamming(self, word, distance): - """Find the best match with {word} in a trie. + """ + Find the best match with {word} in a trie using Hamming distance. - :arg str word: Query word. - :arg int distance: Maximum allowed distance. + Args: + word (str): Query word. + distance (int): Maximum allowed Hamming distance. - :returns str: Best match with {word}. + Returns: + str: Best match with {word}. """ if self.get(word): return word @@ -267,27 +426,60 @@ def best_hamming(self, word, distance): return None def all_levenshtein_(self, word, distance): + """ + Find all words in the Trie within a given Levenshtein distance and return detailed results. + + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. + + Returns: + map: A map containing tuples with (word, remaining distance, count). + """ return map( lambda x: (x[0], distance - x[1], x[2]), _levenshtein('', self.root, word, distance, '')) def all_levenshtein(self, word, distance): + """ + Find all words in the Trie within a given Levenshtein distance and return words only. + + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. + + Returns: + map: A map containing words within the specified Levenshtein distance. + """ return map( lambda x: x[0], _levenshtein('', self.root, word, distance, '')) def levenshtein(self, word, distance): + """ + Find the first word in the Trie within a given Levenshtein distance. + + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. + + Returns: + str: The first word within the specified Levenshtein distance or None if not found. + """ try: return next(self.all_levenshtein(word, distance)) except StopIteration: return None def best_levenshtein(self, word, distance): - """Find the best match with {word} in a trie. + """ + Find the best match with {word} in a trie using Levenshtein distance. - :arg str word: Query word. - :arg int distance: Maximum allowed distance. + Args: + word (str): Query word. + distance (int): Maximum allowed Levenshtein distance. - :returns str: Best match with {word}. + Returns: + str: Best match with {word}. """ if self.get(word): return word diff --git a/python/fedml/ml/aggregator/agg_operator.py b/python/fedml/ml/aggregator/agg_operator.py index ebcc939541..4f2a123e0e 100644 --- a/python/fedml/ml/aggregator/agg_operator.py +++ b/python/fedml/ml/aggregator/agg_operator.py @@ -8,6 +8,17 @@ class FedMLAggOperator: @staticmethod def agg(args, raw_grad_list: List[Tuple[float, OrderedDict]]) -> OrderedDict: + """ + Aggregate gradients from multiple clients using a federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing + local sample counts and gradient updates from client models. + + Returns: + OrderedDict: The aggregated model parameters. + """ training_num = 0 if args.federated_optimizer == "SCAFFOLD": for i in range(len(raw_grad_list)): @@ -31,6 +42,20 @@ def agg(args, raw_grad_list: List[Tuple[float, OrderedDict]]) -> OrderedDict: def torch_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Union[Tuple[float, OrderedDict], Tuple[float, OrderedDict, OrderedDict]]]): + A list of tuples containing local sample counts and gradient updates from client models. + For some optimizers, it also includes an additional tuple element with local gradients. + training_num (int): The total number of training samples used for aggregation. + + Returns: + Union[OrderedDict, Tuple[OrderedDict, OrderedDict]]: The aggregated model parameters or a tuple + containing aggregated model parameters and aggregated local gradients, depending on the optimizer. + """ if args.federated_optimizer == "FedAvg": (num0, avg_params) = raw_grad_list[0] @@ -135,6 +160,18 @@ def torch_aggregator(args, raw_grad_list, training_num): def tf_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a TensorFlow-based federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, List[float]]]): A list of tuples containing local sample counts and + gradient updates from client models. + training_num (int): The total number of training samples used for aggregation. + + Returns: + List[float]: The aggregated model parameters. + """ (num0, avg_params) = raw_grad_list[0] if args.federated_optimizer == "FedAvg": @@ -161,6 +198,17 @@ def tf_aggregator(args, raw_grad_list, training_num): def jax_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a JAX-based federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, Dict[str, Dict[str, float]]]]): A list of tuples containing local sample counts + and gradient updates from client models. Each update is a dictionary containing 'w' and 'b' keys. + + Returns: + Dict[str, Dict[str, float]]: The aggregated model parameters containing 'w' and 'b' keys. + """ (num0, avg_params) = raw_grad_list[0] if args.federated_optimizer == "FedAvg": @@ -191,6 +239,17 @@ def jax_aggregator(args, raw_grad_list, training_num): def mxnet_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a MXNet-based federated learning aggregator. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Tuple[float, Dict[str, List[float]]]]): A list of tuples containing local sample counts + and gradient updates from client models. Each update is a dictionary containing lists of parameters. + + Returns: + Dict[str, List[float]]: The aggregated model parameters. + """ (num0, avg_params) = raw_grad_list[0] if args.federated_optimizer == "FedAvg": @@ -221,6 +280,20 @@ def mxnet_aggregator(args, raw_grad_list, training_num): def model_aggregator(args, raw_grad_list, training_num): + """ + Aggregate gradients or parameters from multiple clients using a federated learning aggregator based on the + specified machine learning engine. + + Args: + args: A dictionary containing training configuration parameters. + raw_grad_list (List[Union[Tuple[float, Dict[str, Dict[str, float]]], Tuple[float, Dict[str, List[float]]]]]): + A list of tuples containing local sample counts and gradient updates from client models. The format of + updates varies based on the machine learning engine. + + Returns: + Union[Dict[str, Dict[str, float]], Dict[str, List[float]]]: The aggregated model parameters or gradients + based on the selected machine learning engine. + """ if hasattr(args, MLEngineBackend.ml_engine_args_flag): if args.ml_engine == MLEngineBackend.ml_engine_backend_tf: return tf_aggregator(args, raw_grad_list, training_num) diff --git a/python/fedml/ml/aggregator/aggregator_creator.py b/python/fedml/ml/aggregator/aggregator_creator.py index 0ea2f506ee..e00475fe1e 100644 --- a/python/fedml/ml/aggregator/aggregator_creator.py +++ b/python/fedml/ml/aggregator/aggregator_creator.py @@ -4,6 +4,16 @@ def create_server_aggregator(model, args): + """ + Create a server aggregator instance based on the selected dataset and configuration parameters. + + Args: + model: The machine learning model to be used for aggregation. + args: A dictionary containing training configuration parameters, including the dataset. + + Returns: + ServerAggregator: An instance of a server aggregator class suitable for the specified dataset. + """ if args.dataset == "stackoverflow_lr": aggregator = MyServerAggregatorTAGPred(model, args) elif args.dataset in ["fed_shakespeare", "stackoverflow_nwp"]: diff --git a/python/fedml/ml/aggregator/default_aggregator.py b/python/fedml/ml/aggregator/default_aggregator.py index d81507d09a..a1a0f44162 100644 --- a/python/fedml/ml/aggregator/default_aggregator.py +++ b/python/fedml/ml/aggregator/default_aggregator.py @@ -11,18 +11,48 @@ class DefaultServerAggregator(ServerAggregator): def __init__(self, model, args): + """ + Initialize the DefaultServerAggregator. + + Args: + model: The machine learning model. + args: A dictionary containing configuration parameters. + """ super().__init__(model, args) self.cpu_transfer = False if not hasattr(self.args, "cpu_transfer") else self.args.cpu_transfer def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ if self.cpu_transfer: return self.model.cpu().state_dict() return self.model.state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -75,6 +105,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset and log the results. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator.py b/python/fedml/ml/aggregator/my_server_aggregator.py index 6f4125e6fa..4e7ca9b33d 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator.py +++ b/python/fedml/ml/aggregator/my_server_aggregator.py @@ -11,18 +11,48 @@ class MyServerAggregator(ServerAggregator): def __init__(self, model, args): + """ + Initialize the MyServerAggregator. + + Args: + model: The model used for aggregation. + args: A dictionary containing configuration parameters. + """ super().__init__(model, args) self.cpu_transfer = False if not hasattr(self.args, "cpu_transfer") else self.args.cpu_transfer def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ if self.cpu_transfer: return self.model.cpu().state_dict() return self.model.state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -75,6 +105,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -107,6 +148,18 @@ def test(self, test_data, device, args): return (test_acc, test_loss, None, None) def test_all(self, train_data_local_dict, test_data_local_dict, device, args) -> bool: + """ + Test the model on all client datasets, log the results, and return True. + + Args: + train_data_local_dict: A dictionary of training datasets for each client. + test_data_local_dict: A dictionary of test datasets for each client. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + bool: Always returns True. + """ train_num_samples = [] train_tot_corrects = [] train_losses = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator_classification.py b/python/fedml/ml/aggregator/my_server_aggregator_classification.py index 7f93417641..e265beb01d 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator_classification.py +++ b/python/fedml/ml/aggregator/my_server_aggregator_classification.py @@ -11,12 +11,35 @@ class MyServerAggregatorCLS(ServerAggregator): def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -42,6 +65,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator_nwp.py b/python/fedml/ml/aggregator/my_server_aggregator_nwp.py index 9306d42c8e..ae4ec10b7f 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator_nwp.py +++ b/python/fedml/ml/aggregator/my_server_aggregator_nwp.py @@ -11,12 +11,35 @@ class MyServerAggregatorNWP(ServerAggregator): def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -42,6 +65,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/aggregator/my_server_aggregator_prediction.py b/python/fedml/ml/aggregator/my_server_aggregator_prediction.py index def0647e80..6d913bd864 100644 --- a/python/fedml/ml/aggregator/my_server_aggregator_prediction.py +++ b/python/fedml/ml/aggregator/my_server_aggregator_prediction.py @@ -11,12 +11,35 @@ class MyServerAggregatorTAGPred(ServerAggregator): def get_model_params(self): + """ + Get the model parameters. + + Returns: + OrderedDict: The model parameters. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model parameters. + + Args: + model_parameters (OrderedDict): The model parameters to set. + """ self.model.load_state_dict(model_parameters) def _test(self, test_data, device, args): + """ + Internal method for testing the model on a given dataset. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + dict: A dictionary containing test metrics. + """ model = self.model model.to(device) @@ -59,6 +82,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Test the model on a given dataset, log the results, and return test accuracy and loss. + + Args: + test_data: The test dataset. + device: The device to run the test on. + args: A dictionary containing configuration parameters. + + Returns: + tuple: A tuple containing test accuracy and loss. + """ # test data test_num_samples = [] test_tot_corrects = [] diff --git a/python/fedml/ml/trainer/feddyn_trainer.py b/python/fedml/ml/trainer/feddyn_trainer.py index 1f24f27e6a..2fe646c3e0 100644 --- a/python/fedml/ml/trainer/feddyn_trainer.py +++ b/python/fedml/ml/trainer/feddyn_trainer.py @@ -8,23 +8,92 @@ def model_parameter_vector(model): + """ + Flatten and concatenate the parameters of a PyTorch model. + + Args: + model (torch.nn.Module): The PyTorch model whose parameters need to be flattened. + + Returns: + torch.Tensor: A 1D tensor containing the concatenated flattened parameters. + """ param = [p.view(-1) for p in model.parameters()] return torch.concat(param, dim=0) def parameter_vector(parameters): + """ + Flatten and concatenate a dictionary of PyTorch parameters. + + Args: + parameters (dict): A dictionary of PyTorch parameters. + + Returns: + torch.Tensor: A 1D tensor containing the concatenated flattened parameters. + """ param = [p.view(-1) for p in parameters.values()] return torch.concat(param, dim=0) class FedDynModelTrainer(ClientTrainer): + """ + A class for training and testing federated dynamic models. + + Args: + model: The neural network model to train. + id: The client's unique identifier. + args: A dictionary containing training configuration parameters. + + Attributes: + model: The neural network model for training. + id: The unique identifier of the client. + args: A dictionary containing training configuration parameters. + + Methods: + get_model_params(): + Get the current state dictionary of the model. + + set_model_params(model_parameters): + Set the model's parameters using the provided state dictionary. + + train(train_data, device, args, old_grad): + Train the model on the given training data. + + test(test_data, device, args): + Test the model's performance on the provided test data. + + """ def get_model_params(self): + """ + Get the current state dictionary of the model. + + Returns: + dict: The state dictionary of the model. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model's parameters using the provided state dictionary. + + Args: + model_parameters (dict): The state dictionary containing model parameters. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args, old_grad): + """ + Train the model on the given training data. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + old_grad (dict): Dictionary of old gradients for dynamic regularization. + + Returns: + dict: Updated old gradients after training. + """ model = self.model for params in model.parameters(): params.requires_grad = True @@ -117,6 +186,17 @@ def train(self, train_data, device, args, old_grad): def test(self, test_data, device, args): + """ + Test the model's performance on the provided test data. + + Args: + test_data (torch.utils.data.DataLoader): The DataLoader containing test data. + device (str): The device to perform testing (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing testing configuration parameters. + + Returns: + dict: Metrics including test accuracy and test loss. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/fednova_trainer.py b/python/fedml/ml/trainer/fednova_trainer.py index c0fd4cc5c9..84122ac9c5 100644 --- a/python/fedml/ml/trainer/fednova_trainer.py +++ b/python/fedml/ml/trainer/fednova_trainer.py @@ -175,13 +175,71 @@ def step(self, closure=None): class FedNovaModelTrainer(ClientTrainer): + """ + A class for training and testing federated Nova (FedNova) models. + + Args: + model: The neural network model to train. + id: The client's unique identifier. + args: A dictionary containing training configuration parameters. + + Attributes: + model: The neural network model for training. + id: The unique identifier of the client. + args: A dictionary containing training configuration parameters. + + Methods: + get_model_params(): + Get the current state dictionary of the model. + + set_model_params(model_parameters): + Set the model's parameters using the provided state dictionary. + + get_local_norm_grad(opt, cur_params, init_params, weight=0): + Calculate the local normalized gradients. + + get_local_tau_eff(opt): + Calculate the effective tau for FedNova. + + train(train_data, device, args, **kwargs): + Train the model on the given training data using FedNova optimizer. + + test(test_data, device, args): + Test the model's performance on the provided test data. + + """ + def get_model_params(self): + """ + Get the current state dictionary of the model. + + Returns: + dict: The state dictionary of the model. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model's parameters using the provided state dictionary. + + Args: + model_parameters (dict): The state dictionary containing model parameters. + """ self.model.load_state_dict(model_parameters) def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): + """ + Calculate the local normalized gradients. + + Args: + opt: The FedNova optimizer instance. + cur_params (dict): The current model's parameters. + init_params (dict): The initial model's parameters. + weight (float): The weight for gradient scaling (default is 0). + + Returns: + dict: Dictionary of local normalized gradients. + """ if weight == 0: weight = opt.ratio grad_dict = {} @@ -193,12 +251,33 @@ def get_local_norm_grad(self, opt, cur_params, init_params, weight=0): return grad_dict def get_local_tau_eff(self, opt): + """ + Calculate the effective tau for FedNova. + + Args: + opt: The FedNova optimizer instance. + + Returns: + float: The effective tau for FedNova. + """ if opt.mu != 0: return opt.local_steps * opt.ratio else: return opt.local_normalizing_vec * opt.ratio def train(self, train_data, device, args, **kwargs): + """ + Train the model on the given training data using the FedNova optimizer. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + **kwargs: Additional keyword arguments. + + Returns: + Tuple[float, dict, float]: Tuple containing the average loss, local normalized gradients, and effective tau. + """ model = self.model model.to(device) @@ -248,6 +327,17 @@ def train(self, train_data, device, args, **kwargs): def test(self, test_data, device, args): + """ + Test the model's performance on the provided test data. + + Args: + test_data (torch.utils.data.DataLoader): The DataLoader containing test data. + device (str): The device to perform testing (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing testing configuration parameters. + + Returns: + dict: Metrics including test accuracy and test loss. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/fedprox_trainer.py b/python/fedml/ml/trainer/fedprox_trainer.py index 06ebb4feab..e4741a1fa6 100644 --- a/python/fedml/ml/trainer/fedprox_trainer.py +++ b/python/fedml/ml/trainer/fedprox_trainer.py @@ -7,13 +7,63 @@ class FedProxModelTrainer(ClientTrainer): + """ + A class for training and testing federated Proximal (FedProx) models. + + Args: + model: The neural network model to train. + id: The client's unique identifier. + args: A dictionary containing training configuration parameters. + + Attributes: + model: The neural network model for training. + id: The unique identifier of the client. + args: A dictionary containing training configuration parameters. + + Methods: + get_model_params(): + Get the current state dictionary of the model. + + set_model_params(model_parameters): + Set the model's parameters using the provided state dictionary. + + train(train_data, device, args): + Train the model on the given training data with optional FedProx regularization. + + train_iterations(train_data, device, args): + Train the model for a specified number of local iterations. + + test(test_data, device, args): + Test the model's performance on the provided test data. + + """ def get_model_params(self): + """ + Get the current state dictionary of the model. + + Returns: + dict: The state dictionary of the model. + """ return self.model.cpu().state_dict() def set_model_params(self, model_parameters): + """ + Set the model's parameters using the provided state dictionary. + + Args: + model_parameters (dict): The state dictionary containing model parameters. + """ self.model.load_state_dict(model_parameters) def train(self, train_data, device, args): + """ + Train the model on the given training data with optional FedProx regularization. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + """ model = self.model model.to(device) @@ -79,6 +129,14 @@ def train(self, train_data, device, args): def train_iterations(self, train_data, device, args): + """ + Train the model for a specified number of local iterations. + + Args: + train_data (torch.utils.data.DataLoader): The DataLoader containing training data. + device (str): The device to perform training (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing training configuration parameters. + """ model = self.model model.to(device) @@ -145,6 +203,17 @@ def train_iterations(self, train_data, device, args): def test(self, test_data, device, args): + """ + Test the model's performance on the provided test data. + + Args: + test_data (torch.utils.data.DataLoader): The DataLoader containing test data. + device (str): The device to perform testing (e.g., 'cuda' or 'cpu'). + args (dict): A dictionary containing testing configuration parameters. + + Returns: + dict: Metrics including test accuracy and test loss. + """ model = self.model model.to(device) diff --git a/python/fedml/ml/trainer/trainer_creator.py b/python/fedml/ml/trainer/trainer_creator.py index 67f629d33e..0441159dd3 100644 --- a/python/fedml/ml/trainer/trainer_creator.py +++ b/python/fedml/ml/trainer/trainer_creator.py @@ -4,10 +4,20 @@ def create_model_trainer(model, args): + """ + Create and return an appropriate model trainer based on the dataset type. + + Args: + model: The neural network model to be trained. + args: A dictionary containing training configuration parameters, including the dataset type. + + Returns: + ModelTrainer: An instance of a model trainer tailored to the dataset type. + """ if args.dataset == "stackoverflow_lr": model_trainer = ModelTrainerTAGPred(model, args) elif args.dataset in ["fed_shakespeare", "stackoverflow_nwp"]: model_trainer = ModelTrainerNWP(model, args) - else: # default model trainer is for classification problem + else: # Default model trainer is for classification problem model_trainer = ModelTrainerCLS(model, args) return model_trainer From c12a139a3fd09304d43a7612fe91bc5eab00a6a0 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 16 Sep 2023 12:17:34 +0530 Subject: [PATCH 60/70] ed --- python/fedml/data/ImageNet/data_loader.py | 190 +++++++++++++--- python/fedml/data/ImageNet/datasets.py | 109 +++++++-- python/fedml/data/ImageNet/datasets_hdf5.py | 50 +++- python/fedml/data/Landmarks/data_loader.py | 110 +++++++++ python/fedml/data/Landmarks/datasets.py | 38 ++++ .../data/Landmarks/download_without_tf.py | 55 +++-- .../data/Landmarks/download_without_tff.py | 21 ++ python/fedml/data/Landmarks/utils.py | 5 + python/fedml/data/MNIST/data_loader.py | 74 ++++-- .../data/MNIST/mnist_mobile_preprocessor.py | 42 +++- python/fedml/data/MNIST/stats.py | 15 ++ .../fedml/data/NUS_WIDE/nus_wide_dataset.py | 94 ++++++++ .../data/UCI/data_loader_for_susy_and_ro.py | 104 +++++++++ .../lending_club_loan/lending_club_dataset.py | 126 ++++++++++- python/fedml/data/reddit/data_loader.py | 16 ++ python/fedml/data/reddit/datasets.py | 103 ++++++++- python/fedml/data/reddit/divide_data.py | 213 ++++++++++++++++-- python/fedml/data/reddit/nlp.py | 99 +++++++- .../data/stackoverflow_lr/data_loader.py | 36 +++ python/fedml/data/stackoverflow_lr/dataset.py | 16 +- python/fedml/data/stackoverflow_lr/utils.py | 86 +++++++ .../data/stackoverflow_nwp/data_loader.py | 26 +++ .../fedml/data/stackoverflow_nwp/dataset.py | 46 ++-- .../synthetic_0.5_0.5/generate_synthetic.py | 23 ++ .../data/synthetic_0_0/generate_synthetic.py | 23 ++ .../fedml/data/synthetic_1_1/data_loader.py | 37 ++- .../data/synthetic_1_1/generate_synthetic.py | 30 ++- python/fedml/data/synthetic_1_1/stats.py | 19 ++ 28 files changed, 1647 insertions(+), 159 deletions(-) diff --git a/python/fedml/data/ImageNet/data_loader.py b/python/fedml/data/ImageNet/data_loader.py index 22ab9a54f6..84150d1dd9 100644 --- a/python/fedml/data/ImageNet/data_loader.py +++ b/python/fedml/data/ImageNet/data_loader.py @@ -13,11 +13,51 @@ from .datasets_hdf5 import ImageNet_truncated_hdf5 +import numpy as np +import torch + class Cutout(object): + """ + Apply the Cutout data augmentation technique to an image. + + Cutout is a technique used for regularization during training deep neural networks. + It randomly masks out a rectangular region of the input image. + + Args: + length (int): The length of the square mask to apply. + + Usage: + transform = Cutout(length=16) # Create an instance of the Cutout transform. + transformed_image = transform(input_image) # Apply the Cutout transform to an image. + + Note: + The Cutout transform is typically applied as part of a data augmentation pipeline. + + References: + - Original paper: https://arxiv.org/abs/1708.04552 + + """ + def __init__(self, length): + """ + Initialize the Cutout transform with the specified length. + + Args: + length (int): The length of the square mask to apply. + """ self.length = length def __call__(self, img): + """ + Apply Cutout transformation to an input image. + + Args: + img (torch.Tensor): The input image tensor to which Cutout will be applied. + + Returns: + torch.Tensor: The input image with a randomly masked region. + + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -36,6 +76,13 @@ def __call__(self, img): def _data_transforms_ImageNet(): + """ + Define data transforms for the ImageNet dataset. + + Returns: + transforms.Compose: A composition of data augmentation transforms for training + and validation data. + """ # IMAGENET_MEAN = [0.5071, 0.4865, 0.4409] # IMAGENET_STD = [0.2673, 0.2564, 0.2762] @@ -43,41 +90,55 @@ def _data_transforms_ImageNet(): IMAGENET_STD = [0.229, 0.224, 0.225] image_size = 224 - train_transform = transforms.Compose( - [ - # transforms.ToPILImage(), - transforms.RandomResizedCrop(image_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), - ] - ) + train_transform = transforms.Compose([ + transforms.RandomResizedCrop(image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) train_transform.transforms.append(Cutout(16)) - valid_transform = transforms.Compose( - [ - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), - ] - ) + valid_transform = transforms.Compose([ + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD), + ]) return train_transform, valid_transform - -# for centralized training def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): - return get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs) + """ + Get data loaders for centralized training. + Args: + dataset (str): The dataset name. + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training (default: None). -# for local devices -def get_dataloader_test( - dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test -): - return get_dataloader_test_ImageNet( - datadir, train_bs, test_bs, dataidxs_train, dataidxs_test - ) + Returns: + DataLoader: Training and testing data loaders. + """ + return get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs) + +def get_dataloader_test(dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test): + """ + Get data loaders for local devices. + + Args: + dataset (str): The dataset name. + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to use for training. + dataidxs_test (list): List of data indices to use for testing. + + Returns: + DataLoader: Training and testing data loaders. + """ + return get_dataloader_test_ImageNet(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test) def get_dataloader_ImageNet_truncated( @@ -89,7 +150,25 @@ def get_dataloader_ImageNet_truncated( net_dataidx_map=None, ): """ - imagenet_dataset_train, imagenet_dataset_test should be ImageNet or ImageNet_hdf5 + Get data loaders for a truncated version of the ImageNet dataset. + + Args: + imagenet_dataset_train: The training dataset (ImageNet or ImageNet_hdf5). + imagenet_dataset_test: The testing dataset (ImageNet or ImageNet_hdf5). + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training (default: None). + net_dataidx_map (dict, optional): Mapping of data indices to network indices (default: None). + + Returns: + tuple: A tuple containing training and testing data loaders. + + Raises: + NotImplementedError: If the dataset type is not supported. + + Note: + - The `imagenet_dataset_train` and `imagenet_dataset_test` should be instances of `ImageNet` or `ImageNet_hdf5`. + """ if type(imagenet_dataset_train) == ImageNet: dl_obj = ImageNet_truncated @@ -138,6 +217,19 @@ def get_dataloader_ImageNet_truncated( def get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for the ImageNet dataset. + + Args: + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training (default: None). + + Returns: + tuple: A tuple containing training and testing data loaders. + + """ dl_obj = ImageNet transform_train, transform_test = _data_transforms_ImageNet() @@ -176,6 +268,20 @@ def get_dataloader_ImageNet(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_ImageNet( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loaders for the ImageNet dataset for testing. + + Args: + datadir (str): The path to the dataset directory. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list, optional): List of data indices to use for training (default: None). + dataidxs_test (list, optional): List of data indices to use for testing (default: None). + + Returns: + tuple: A tuple containing training and testing data loaders. + + """ dl_obj = ImageNet transform_train, transform_test = _data_transforms_ImageNet() @@ -219,14 +325,25 @@ def distributed_centralized_ImageNet_loader( dataset, data_dir, world_size, rank, batch_size ): """ - Used for generating distributed dataloader for - accelerating centralized training + Generate a distributed dataloader for accelerating centralized training. + + Args: + dataset (str): The dataset name ("ILSVRC2012" or "ILSVRC2012_hdf5"). + data_dir (str): The path to the dataset directory. + world_size (int): The total number of processes in the distributed training. + rank (int): The rank of the current process in the distributed training. + batch_size (int): Batch size for training and testing data. + + Returns: + tuple: A tuple containing various training and testing data related information. + """ train_bs = batch_size test_bs = batch_size transform_train, transform_test = _data_transforms_ImageNet() + if dataset == "ILSVRC2012": train_dataset = ImageNet( data_dir=data_dir, dataidxs=None, train=True, transform=transform_train @@ -278,6 +395,21 @@ def load_partition_data_ImageNet( client_number=100, batch_size=10, ): + """ + Load and partition data for the ImageNet dataset. + + Args: + dataset (str): The dataset name ("ILSVRC2012" or "ILSVRC2012_hdf5"). + data_dir (str): The path to the dataset directory. + partition_method (str, optional): The partitioning method (default: None). + partition_alpha (float, optional): The partitioning alpha value (default: None). + client_number (int, optional): The number of clients (default: 100). + batch_size (int, optional): Batch size for training and testing data (default: 10). + + Returns: + tuple: A tuple containing various data-related information. + + """ if dataset == "ILSVRC2012": train_dataset = ImageNet(data_dir=data_dir, dataidxs=None, train=True) diff --git a/python/fedml/data/ImageNet/datasets.py b/python/fedml/data/ImageNet/datasets.py index 5b47b65184..f4103a18a2 100644 --- a/python/fedml/data/ImageNet/datasets.py +++ b/python/fedml/data/ImageNet/datasets.py @@ -19,6 +19,15 @@ def has_file_allowed_extension(filename, extensions): def find_classes(dir): + """Find class names from subdirectories in a given directory. + + Args: + dir (str): The root directory containing subdirectories, each representing a class. + + Returns: + list: A sorted list of class names. + dict: A dictionary mapping class names to their respective indices. + """ classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} @@ -26,6 +35,18 @@ def find_classes(dir): def make_dataset(dir, class_to_idx, extensions): + """Create a dataset of image file paths and their corresponding class indices. + + Args: + dir (str): The root directory containing subdirectories, each representing a class. + class_to_idx (dict): A dictionary mapping class names to their respective indices. + extensions (tuple): A tuple of allowed file extensions. + + Returns: + list: A list of tuples, each containing the file path and class index. + dict: A dictionary mapping class indices to the number of samples per class. + dict: A dictionary mapping class indices to data index ranges. + """ images = [] data_local_num_dict = dict() @@ -55,14 +76,29 @@ def make_dataset(dir, class_to_idx, extensions): def pil_loader(path): - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + """Load an image using PIL (Python Imaging Library). + + Args: + path (str): The path to the image file. + + Returns: + PIL.Image.Image: The loaded image in RGB format. + """ with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") def accimage_loader(path): - import accimage # pylint: disable=E0401 + """Load an image using AccImage (optimized for CUDA). + + Args: + path (str): The path to the image file. + + Returns: + accimage.Image: The loaded image using AccImage. + """ + import accimage try: return accimage.Image(path) @@ -72,6 +108,14 @@ def accimage_loader(path): def default_loader(path): + """Load an image using the default loader (PIL or AccImage). + + Args: + path (str): The path to the image file. + + Returns: + PIL.Image.Image or accimage.Image: The loaded image. + """ from torchvision import get_image_backend if get_image_backend() == "accimage": @@ -91,8 +135,20 @@ def __init__( download=False, ): """ - Generating this class too many times will be time-consuming. - So it will be better calling this once and put it into ImageNet_truncated. + Initialize the ImageNet dataset. + + Args: + data_dir (str): Root directory of the dataset. + dataidxs (int or list, optional): List of indices to select specific data subsets. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + Note: + Generating this class too many times will be time-consuming. + It's better to call this once and use ImageNet_truncated. + """ self.dataidxs = dataidxs self.train = train @@ -110,9 +166,10 @@ def __init__( self.data_local_num_dict, self.net_dataidx_map, ) = self.__getdatasets__() - if dataidxs == None: + + if dataidxs is None: self.local_data = self.all_data - elif type(dataidxs) == int: + elif isinstance(dataidxs, int): (begin, end) = self.net_dataidx_map[dataidxs] self.local_data = self.all_data[begin:end] else: @@ -130,20 +187,18 @@ def get_net_dataidx_map(self): def get_data_local_num_dict(self): return self.data_local_num_dict + def __getdatasets__(self): - # all_data = datasets.ImageFolder(data_dir, self.transform, self.target_transform) classes, class_to_idx = find_classes(self.data_dir) - IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif"] + all_data, data_local_num_dict, net_dataidx_map = make_dataset( self.data_dir, class_to_idx, IMG_EXTENSIONS ) if len(all_data) == 0: - raise ( - RuntimeError( - "Found 0 files in subfolders of: " + self.data_dir + "\n" - "Supported extensions are: " + ",".join(IMG_EXTENSIONS) - ) + raise RuntimeError( + f"Found 0 files in subfolders of: {self.data_dir}\n" + f"Supported extensions are: {','.join(IMG_EXTENSIONS)}" ) return all_data, data_local_num_dict, net_dataidx_map @@ -153,9 +208,8 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ - # img, target = self.data[index], self.target[index] path, target = self.local_data[index] img = self.loader(path) @@ -174,7 +228,7 @@ def __len__(self): class ImageNet_truncated(data.Dataset): def __init__( self, - imagenet_dataset: ImageNet, + imagenet_dataset, dataidxs, net_dataidx_map, train=True, @@ -182,7 +236,19 @@ def __init__( target_transform=None, download=False, ): + """ + Initialize a truncated version of the ImageNet dataset. + Args: + imagenet_dataset (ImageNet): The original ImageNet dataset. + dataidxs (int or list): List of indices to select specific data subsets. + net_dataidx_map (dict): Mapping of data indices in the original dataset. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + """ self.dataidxs = dataidxs self.train = train self.transform = transform @@ -191,9 +257,10 @@ def __init__( self.net_dataidx_map = net_dataidx_map self.loader = default_loader self.all_data = imagenet_dataset.get_local_data() - if dataidxs == None: + + if dataidxs is None: self.local_data = self.all_data - elif type(dataidxs) == int: + elif isinstance(dataidxs, int): (begin, end) = self.net_dataidx_map[dataidxs] self.local_data = self.all_data[begin:end] else: @@ -208,10 +275,9 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ - # img, target = self.data[index], self.target[index] - + path, target = self.local_data[index] img = self.loader(path) if self.transform is not None: @@ -224,3 +290,4 @@ def __getitem__(self, index): def __len__(self): return len(self.local_data) + \ No newline at end of file diff --git a/python/fedml/data/ImageNet/datasets_hdf5.py b/python/fedml/data/ImageNet/datasets_hdf5.py index 042016fee8..c20f29da83 100644 --- a/python/fedml/data/ImageNet/datasets_hdf5.py +++ b/python/fedml/data/ImageNet/datasets_hdf5.py @@ -13,7 +13,14 @@ class DatasetHDF5(data.Dataset): def __init__(self, hdf5fn, t, transform=None, target_transform=None): """ - t: 'train' or 'val' + Initialize a custom dataset from an HDF5 file. + + Args: + hdf5fn (str): Filepath to the HDF5 file. + t (str): 'train' or 'val' to specify the dataset split. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + """ super(DatasetHDF5, self).__init__() self.hf = h5py.File(hdf5fn, "r", libver="latest", swmr=True) @@ -21,8 +28,7 @@ def __init__(self, hdf5fn, t, transform=None, target_transform=None): self.n_images = self.hf["%s_img" % self.t].shape[0] self.dlabel = self.hf["%s_labels" % self.t][...] self.d = self.hf["%s_img" % self.t] - # self.transform = transform - # self.target_transform = target_transform + def _get_dataset_x_and_target(self, index): img = self.d[index, ...] @@ -30,6 +36,13 @@ def _get_dataset_x_and_target(self, index): return img, np.int64(target) def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is the label of the image. + """ img, target = self._get_dataset_x_and_target(index) # if self.transform is not None: # img = self.transform(img) @@ -52,8 +65,20 @@ def __init__( download=False, ): """ - Generating this class too many times will be time-consuming. - So it will be better calling this once and put it into ImageNet_truncated. + Initialize the ImageNet dataset using HDF5 files. + + Args: + data_dir (str): Directory containing the HDF5 file. + dataidxs (int or list, optional): List of indices to select specific data subsets. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + Note: + Generating this class too many times will be time-consuming. + It's better to call this once and use ImageNet_truncated. + """ self.dataidxs = dataidxs self.train = train @@ -117,7 +142,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the label of the image. """ img, target = self.all_data_hdf5[self.local_data_idx[index]] @@ -146,6 +171,19 @@ def __init__( target_transform=None, download=False, ): + """ + Initialize a truncated version of the ImageNet dataset using HDF5 files. + + Args: + imagenet_dataset (ImageNet_hdf5): The original ImageNet HDF5 dataset. + dataidxs (int or list): List of indices to select specific data subsets. + net_dataidx_map (dict): Mapping of data indices in the original dataset. + train (bool, optional): If True, loads the training dataset; otherwise, loads the validation dataset. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the labels. + download (bool, optional): Whether to download the dataset if it's not found locally. + + """ self.dataidxs = dataidxs self.train = train diff --git a/python/fedml/data/Landmarks/data_loader.py b/python/fedml/data/Landmarks/data_loader.py index 63514bb088..923df16293 100644 --- a/python/fedml/data/Landmarks/data_loader.py +++ b/python/fedml/data/Landmarks/data_loader.py @@ -67,9 +67,24 @@ def _read_csv(path: str): class Cutout(object): def __init__(self, length): + """ + Initialize the Cutout transformation. + + Args: + length (int): The size of the square patch to cut out from the image. + """ self.length = length def __call__(self, img): + """ + Apply the Cutout transformation to the input image. + + Args: + img (PIL.Image): The input image. + + Returns: + PIL.Image: The transformed image with a square patch cut out. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -126,6 +141,17 @@ def get_mapping_per_user(fn): [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {'user_id': xxx, 'image_id': xxx, 'class': xxx} ... ] } + + Load mapping information per user from a CSV file. + + Args: + fn (str): The filename of the CSV file containing user-image mapping. + + Returns: + tuple: A tuple containing: + - data_files (list): A list of dictionaries containing mapping information. + - data_local_num_dict (dict): A dictionary mapping user IDs to the number of data entries they have. + - net_dataidx_map (dict): A dictionary mapping user IDs to data index ranges. """ mapping_table = _read_csv(fn) expected_cols = ["user_id", "image_id", "class"] @@ -163,6 +189,21 @@ def get_mapping_per_user(fn): def get_dataloader( dataset, datadir, train_files, test_files, train_bs, test_bs, dataidxs=None ): + """ + Get data loaders for centralized training. + + Args: + dataset (str): The name of the dataset. + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + + Returns: + DataLoader: Data loaders for training and testing. + """ return get_dataloader_Landmarks( datadir, train_files, test_files, train_bs, test_bs, dataidxs ) @@ -179,6 +220,22 @@ def get_dataloader_test( dataidxs_train, dataidxs_test, ): + """ + Get data loaders for testing with specified data indices. + + Args: + dataset (str): The name of the dataset. + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs_train (list): List of data indices to select specific training data entries. + dataidxs_test (list): List of data indices to select specific testing data entries. + + Returns: + DataLoader: Data loaders for training and testing. + """ return get_dataloader_test_Landmarks( datadir, train_files, @@ -193,6 +250,20 @@ def get_dataloader_test( def get_dataloader_Landmarks( datadir, train_files, test_files, train_bs, test_bs, dataidxs=None ): + """ + Get data loaders for Landmarks dataset. + + Args: + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + + Returns: + DataLoader: Data loaders for training and testing. + """ dl_obj = Landmarks transform_train, transform_test = _data_transforms_landmarks() @@ -233,6 +304,21 @@ def get_dataloader_test_Landmarks( dataidxs_train=None, dataidxs_test=None, ): + """ + Get data loaders for testing Landmarks dataset with specified data indices. + + Args: + datadir (str): The directory containing the data files. + train_files (list): A list of training data files. + test_files (list): A list of testing data files. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + dataidxs_train (list, optional): List of data indices to select specific training data entries. Defaults to None. + dataidxs_test (list, optional): List of data indices to select specific testing data entries. Defaults to None. + + Returns: + DataLoader: Data loaders for training and testing. + """ dl_obj = Landmarks transform_train, transform_test = _data_transforms_landmarks() @@ -274,6 +360,30 @@ def load_partition_data_landmarks( client_number=233, batch_size=10, ): + """ + Load partitioned data for the Landmarks dataset. + + Args: + dataset (str): The name of the dataset. + data_dir (str): The directory containing the data files. + fed_train_map_file (str): The path to the federated train data mapping file. + fed_test_map_file (str): The path to the federated test data mapping file. + partition_method (str, optional): The partitioning method for data. Defaults to None. + partition_alpha (float, optional): The alpha value for partitioning. Defaults to None. + client_number (int): The number of clients/participants. Defaults to 233. + batch_size (int): The batch size for data loaders. Defaults to 10. + + Returns: + Tuple: A tuple containing the following elements: + - train_data_num (int): The number of training data samples. + - test_data_num (int): The number of testing data samples. + - train_data_global (DataLoader): Global training data loader. + - test_data_global (DataLoader): Global testing data loader. + - data_local_num_dict (dict): Dictionary mapping client IDs to the number of local data samples. + - train_data_local_dict (dict): Dictionary mapping client IDs to their local training data loaders. + - test_data_local_dict (dict): Dictionary mapping client IDs to their local testing data loaders. + - class_num (int): The number of unique classes in the dataset. + """ train_files, data_local_num_dict, net_dataidx_map = get_mapping_per_user( fed_train_map_file diff --git a/python/fedml/data/Landmarks/datasets.py b/python/fedml/data/Landmarks/datasets.py index ac8364c75a..020e855001 100644 --- a/python/fedml/data/Landmarks/datasets.py +++ b/python/fedml/data/Landmarks/datasets.py @@ -6,6 +6,18 @@ class Landmarks(data.Dataset): + """ + Custom dataset class for the Landmarks dataset. + + Args: + data_dir (str): The directory containing the data files. + allfiles (list): A list of data entries in the form of dictionaries with 'user_id', 'image_id', and 'class'. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + train (bool, optional): Indicates whether the dataset is for training. Defaults to True. + transform (callable, optional): A function/transform to apply to the data. Defaults to None. + target_transform (callable, optional): A function/transform to apply to the target. Defaults to None. + download (bool, optional): Whether to download the data. Defaults to False. + """ def __init__( self, data_dir, @@ -19,6 +31,16 @@ def __init__( """ allfiles is [{'user_id': xxx, 'image_id': xxx, 'class': xxx} ... {'user_id': xxx, 'image_id': xxx, 'class': xxx} ... ] + Initialize the Landmarks dataset. + + Args: + data_dir (str): The directory containing the data files. + allfiles (list): A list of data entries in the form of dictionaries with 'user_id', 'image_id', and 'class'. + dataidxs (list, optional): List of data indices to select specific data entries. Defaults to None. + train (bool, optional): Indicates whether the dataset is for training. Defaults to True. + transform (callable, optional): A function/transform to apply to the data. Defaults to None. + target_transform (callable, optional): A function/transform to apply to the target. Defaults to None. + download (bool, optional): Whether to download the data. Defaults to False. """ self.allfiles = allfiles if dataidxs == None: @@ -32,6 +54,13 @@ def __init__( self.target_transform = target_transform def __len__(self): + """ + Get the number of data entries in the dataset. + + Returns: + int: The number of data entries. + """ + # if self.user_id != None: # return sum([len(local_data) for local_data in self.mapping_per_user.values()]) # else: @@ -39,6 +68,15 @@ def __len__(self): return len(self.local_files) def __getitem__(self, idx): + """ + Get a data sample and its corresponding label by index. + + Args: + idx (int): Index of the data sample to retrieve. + + Returns: + tuple: A tuple containing the data sample and its corresponding label. + """ # if self.user_id != None: # img_name = self.mapping_per_user[self.user_id][idx]['image_id'] # label = self.mapping_per_user[self.user_id][idx]['class'] diff --git a/python/fedml/data/Landmarks/download_without_tf.py b/python/fedml/data/Landmarks/download_without_tf.py index f38f3205b2..66c7bad004 100644 --- a/python/fedml/data/Landmarks/download_without_tf.py +++ b/python/fedml/data/Landmarks/download_without_tf.py @@ -49,12 +49,15 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): - """Sets up a separate process for handling logging messages. + """ + Sets up a separate process for handling logging messages. + This setup is required because without it, the logging messages will be duplicated when multiple processes are created for downloading GLD dataset. + Args: - queue: The queue to receive logging messages. - log_file: The file which the messages will be written to. + queue (multiprocessing.Queue): The queue to receive logging messages. + log_file (str): The file to which the messages will be written. """ root = logging.getLogger() h = logging.FileHandler(log_file) @@ -77,27 +80,32 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): def _read_csv(path: str) -> List[Dict[str, str]]: - """Reads a csv file, and returns the content inside a list of dictionaries. + """ + Reads a CSV file and returns the content inside a list of dictionaries. + Args: - path: The path to the csv file. + path (str): The path to the CSV file. + Returns: - A list of dictionaries. Each row in the csv file will be a list entry. The - dictionary is keyed by the column names. + List[Dict[str, str]]: A list of dictionaries. Each row in the CSV file will be a list entry. + The dictionary is keyed by the column names. """ with open(path, "r") as f: return list(csv.DictReader(f)) def _filter_images(shard: int, all_images: Set[str], image_dir: str, base_url: str): - """Download full GLDv2 dataset, only keep images that are included in the federated gld v2 dataset. + """ + Download full GLDv2 dataset, only keep images that are included in the federated GLD v2 dataset. + Args: - shard: The shard of the GLDv2 dataset. - all_images: A set which contains all images included in the federated GLD - dataset. - image_dir: The directory to keep all filtered images. - base_url: The base url for downloading GLD v2 dataset images. + shard (int): The shard of the GLDv2 dataset. + all_images (Set[str]): A set that contains all images included in the federated GLD dataset. + image_dir (str): The directory to keep all filtered images. + base_url (str): The base URL for downloading GLD v2 dataset images. + Raises: - IOError: when failed to download checksum. + IOError: When failed to download checksum. """ shard_str = "%03d" % shard images_tar_url = "%s/train/images_%s.tar" % (base_url, shard_str) @@ -135,10 +143,14 @@ def _download_data(num_worker: int, cache_dir: str, base_url: str): Download the entire GLD v2 dataset, subset the dataset to only include the images in the federated GLD v2 dataset, and create both gld23k and gld160k datasets. + Args: - num_worker: The number of threads for downloading the GLD v2 dataset. - cache_dir: The directory for caching temporary results. - base_url: The base url for downloading GLD images. + num_worker (int): The number of threads for downloading the GLD v2 dataset. + cache_dir (str): The directory for caching temporary results. + base_url (str): The base URL for downloading GLD images. + + Raises: + IOError: When failed to download checksum. """ logger = logging.getLogger(LOGGER) logging.info("Start to download fed gldv2 mapping files") @@ -194,6 +206,15 @@ def load_data( gld23k: bool = False, base_url: str = GLD_SHARD_BASE_URL, ): + """ + Load the GLD v2 dataset. + + Args: + num_worker (int): The number of threads for downloading the GLD v2 dataset. + cache_dir (str): The directory for caching temporary results. + gld23k (bool): Whether to load the gld23k dataset. + base_url (str): The base URL for downloading GLD images. + """ if not os.path.exists(cache_dir): os.mkdir(cache_dir) diff --git a/python/fedml/data/Landmarks/download_without_tff.py b/python/fedml/data/Landmarks/download_without_tff.py index eb351d433f..7e0d028fd7 100644 --- a/python/fedml/data/Landmarks/download_without_tff.py +++ b/python/fedml/data/Landmarks/download_without_tff.py @@ -48,6 +48,7 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): """Sets up a separate process for handling logging messages. This setup is required because without it, the logging messages will be duplicated when multiple processes are created for downloading GLD dataset. + Args: queue: The queue to receive logging messages. log_file: The file which the messages will be written to. @@ -74,8 +75,10 @@ def _listener_process(queue: multiprocessing.Queue, log_file: str): def _read_csv(path: str) -> List[Dict[str, str]]: """Reads a csv file, and returns the content inside a list of dictionaries. + Args: path: The path to the csv file. + Returns: A list of dictionaries. Each row in the csv file will be a list entry. The dictionary is keyed by the column names. @@ -88,10 +91,12 @@ def _create_dataset_with_mapping( image_dir: str, mapping: List[Dict[str, str]] ) -> List[tf.train.Example]: """Builds a dataset based on the mapping file and the images in the image dir. + Args: image_dir: The directory contains the image files. mapping: A list of dictionaries. Each dictionary contains 'image_id' and 'class' columns. + Returns: A list of `tf.train.Example`. """ @@ -126,6 +131,7 @@ def _create_dataset_with_mapping( def _create_train_data_files(cache_dir: str, image_dir: str, mapping_file: str): """Create the train data and persist it into a separate file per user. + Args: cache_dir: The directory caching the intermediate results. image_dir: The directory containing all the downloaded images. @@ -165,6 +171,7 @@ def _create_train_data_files(cache_dir: str, image_dir: str, mapping_file: str): def _create_test_data_file(cache_dir: str, image_dir: str, mapping_file: str): """Create the test data and persist it into a file. + Args: cache_dir: The directory caching the intermediate results. image_dir: The directory containing all the downloaded images. @@ -195,6 +202,7 @@ def _create_federated_gld_dataset( cache_dir: str, image_dir: str, train_mapping_file: str, test_mapping_file: str ): """Generate fedreated GLDv2 dataset with the downloaded images. + Args: cache_dir: The directory for caching the intermediate results. image_dir: The directory that contains the filtered images. @@ -217,6 +225,7 @@ def _create_federated_gld_dataset( def _create_mini_gld_dataset(cache_dir: str, image_dir: str): """Generate mini federated GLDv2 dataset with the downloaded images. + Args: cache_dir: The directory for caching the intermediate results. image_dir: The directory that contains the filtered images. @@ -249,12 +258,14 @@ def _create_mini_gld_dataset(cache_dir: str, image_dir: str): def _filter_images(shard: int, all_images: Set[str], image_dir: str, base_url: str): """Download full GLDv2 dataset, only keep images that are included in the federated gld v2 dataset. + Args: shard: The shard of the GLDv2 dataset. all_images: A set which contains all images included in the federated GLD dataset. image_dir: The directory to keep all filtered images. base_url: The base url for downloading GLD v2 dataset images. + Raises: IOError: when failed to download checksum. """ @@ -301,6 +312,7 @@ def _download_data(num_worker: int, cache_dir: str, base_url: str): Download the entire GLD v2 dataset, subset the dataset to only include the images in the federated GLD v2 dataset, and create both gld23k and gld160k datasets. + Args: num_worker: The number of threads for downloading the GLD v2 dataset. cache_dir: The directory for caching temporary results. @@ -362,6 +374,15 @@ def load_data( gld23k: bool = False, base_url: str = GLD_SHARD_BASE_URL, ): + """ + Load the GLD v2 dataset. + + Args: + num_worker (int, optional): The number of threads for downloading the GLD v2 dataset. + cache_dir (str, optional): The directory for caching temporary results. + gld23k (bool, optional): Whether to load the gld23k dataset. + base_url (str, optional): The base URL for downloading GLD images. + """ if not os.path.exists(cache_dir): os.mkdir(cache_dir) diff --git a/python/fedml/data/Landmarks/utils.py b/python/fedml/data/Landmarks/utils.py index aa75034d08..68e1b0e7d5 100644 --- a/python/fedml/data/Landmarks/utils.py +++ b/python/fedml/data/Landmarks/utils.py @@ -17,6 +17,7 @@ class Progbar(object): """Displays a progress bar. + Arguments: target: Total number of steps expected, None if unknown. width: Progress bar width on screen. @@ -242,6 +243,7 @@ def chunk_read(response, chunk_size=8192, reporthook=None): def _extract_archive(file_path, path=".", archive_format="auto"): """Extracts an archive if it matches tar, tar.gz, tar.bz, or zip formats. + Arguments: file_path: path to the archive file path: path to extract the archive file @@ -251,6 +253,7 @@ def _extract_archive(file_path, path=".", archive_format="auto"): The default 'auto' is ['tar', 'zip']. None or an empty list will return no matches found. Returns: + True if a match was found and an archive extraction was completed, False otherwise. """ @@ -301,6 +304,7 @@ def get_file( Files in tar, tar.gz, tar.bz, and zip formats can also be extracted. Passing a hash will verify the file after download. The command line programs `shasum` and `sha256sum` can compute the hash. + Arguments: fname: Name of the file. If an absolute path `/path/to/file.txt` is specified the file will be saved at that location. @@ -320,6 +324,7 @@ def get_file( defaults to the [Keras Directory](/faq/#where-is-the-keras-configuration-filed-stored). Returns: + Path to the downloaded file """ if cache_dir is None: diff --git a/python/fedml/data/MNIST/data_loader.py b/python/fedml/data/MNIST/data_loader.py index 18e0c29bcf..cbf587c708 100755 --- a/python/fedml/data/MNIST/data_loader.py +++ b/python/fedml/data/MNIST/data_loader.py @@ -14,6 +14,15 @@ def download_mnist(data_cache_dir): + """ + Download the MNIST dataset if it's not already downloaded. + + Args: + data_cache_dir (str): Directory where the dataset should be stored. + + Returns: + None + """ if not os.path.exists(data_cache_dir): os.makedirs(data_cache_dir, exist_ok=True) @@ -30,18 +39,18 @@ def download_mnist(data_cache_dir): zip_ref.extractall(data_cache_dir) def read_data(train_data_dir, test_data_dir): - """parses data in given train and test data directories - - assumes: - - the data in the input directories are .json files with - keys 'users' and 'user_data' - - the set of train set users is the same as the set of test set users - - Return: - clients: list of non-unique client ids - groups: list of group ids; empty list if none found - train_data: dictionary of train data - test_data: dictionary of test data + """ + Parses data in the given train and test data directories. + + Args: + train_data_dir (str): Path to the directory containing train data. + test_data_dir (str): Path to the directory containing test data. + + Returns: + clients (list): List of non-unique client ids. + groups (list): List of group ids; empty list if none found. + train_data (dict): Dictionary of train data. + test_data (dict): Dictionary of test data. """ clients = [] groups = [] @@ -71,24 +80,29 @@ def read_data(train_data_dir, test_data_dir): return clients, groups, train_data, test_data - def batch_data(args, data, batch_size): - """ - data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client) - returns x, y, which are both numpy array of length: batch_size + Prepare data batches. + + Args: + args: Additional arguments (not specified). + data (dict): Data dictionary containing 'x' and 'y'. + batch_size (int): Size of each batch. + + Returns: + batch_data (list): List of data batches. """ data_x = data["x"] data_y = data["y"] - # randomly shuffle data + # Randomly shuffle data np.random.seed(100) rng_state = np.random.get_state() np.random.shuffle(data_x) np.random.set_state(rng_state) np.random.shuffle(data_y) - # loop through mini-batches + # Loop through mini-batches batch_data = list() for i in range(0, len(data_x), batch_size): batched_x = data_x[i : i + batch_size] @@ -99,6 +113,18 @@ def batch_data(args, data, batch_size): def load_partition_data_mnist_by_device_id(batch_size, device_id, train_path="MNIST_mobile", test_path="MNIST_mobile"): + """ + Load partitioned MNIST data by device ID. + + Args: + batch_size (int): Size of each batch. + device_id (str): ID of the device. + train_path (str): Path to the train data directory. + test_path (str): Path to the test data directory. + + Returns: + Tuple containing data information. + """ train_path += os.path.join("/", device_id, "train") test_path += os.path.join("/", device_id, "test") return load_partition_data_mnist(batch_size, train_path, test_path) @@ -108,6 +134,18 @@ def load_partition_data_mnist( args, batch_size, train_path=os.path.join(os.getcwd(), "MNIST", "train"), test_path=os.path.join(os.getcwd(), "MNIST", "test") ): + """ + Load partitioned MNIST data. + + Args: + args: Additional arguments (not specified). + batch_size (int): Size of each batch. + train_path (str): Path to the train data directory. + test_path (str): Path to the test data directory. + + Returns: + Tuple containing data information. + """ users, groups, train_data, test_data = read_data(train_path, test_path) if len(groups) == 0: diff --git a/python/fedml/data/MNIST/mnist_mobile_preprocessor.py b/python/fedml/data/MNIST/mnist_mobile_preprocessor.py index 0d65c0e95b..b058e3c38a 100644 --- a/python/fedml/data/MNIST/mnist_mobile_preprocessor.py +++ b/python/fedml/data/MNIST/mnist_mobile_preprocessor.py @@ -28,18 +28,24 @@ def add_args(parser): def read_data(train_data_dir, test_data_dir): - """parses data in given train and test data directories - - assumes: - - the data in the input directories are .json files with - keys 'users' and 'user_data' - - the set of train set users is the same as the set of test set users - - Return: - clients: list of client ids - groups: list of group ids; empty list if none found - train_data: dictionary of train data - test_data: dictionary of test data + """ + Parse data from train and test data directories. + + Assumes: + - Data in the input directories are .json files with keys 'users' and 'user_data'. + - The set of train set users is the same as the set of test set users. + + Args: + train_data_dir (str): Path to the directory containing train data. + test_data_dir (str): Path to the directory containing test data. + + Returns: + clients (list): List of client ids. + train_num_samples (list): List of the number of samples for each client in the training data. + test_num_samples (list): List of the number of samples for each client in the test data. + train_data (dict): Dictionary of training data. + test_data (dict): Dictionary of test data. + client_list (list): List of client arguments. """ clients = [] train_num_samples = [] @@ -94,6 +100,18 @@ def __init__(self, client_id, client_num_per_round, comm_round): def client_sampling(round_idx, client_num_in_total, client_num_per_round): + """ + Randomly select clients for federated learning. + + Args: + round_idx (int): Index of the current federated learning round. + client_num_in_total (int): Total number of clients available. + client_num_per_round (int): Number of clients to select for the current round. + + Returns: + client_indexes (list): List of selected client indexes for the current round. + """ + if client_num_in_total == client_num_per_round: client_indexes = [client_index for client_index in range(client_num_in_total)] else: diff --git a/python/fedml/data/MNIST/stats.py b/python/fedml/data/MNIST/stats.py index 761e1cd563..cf499a16bf 100755 --- a/python/fedml/data/MNIST/stats.py +++ b/python/fedml/data/MNIST/stats.py @@ -24,6 +24,15 @@ def load_data(name): + """ + Load user and sample data from JSON files in a specified directory. + + Args: + name (str): The name of the dataset. + + Returns: + tuple: A tuple containing lists of users and their corresponding number of samples. + """ users = [] num_samples = [] @@ -47,6 +56,12 @@ def load_data(name): def print_dataset_stats(name): + """ + Print statistics about the dataset, including user count, total samples, mean, std, skewness, and histogram. + + Args: + name (str): The name of the dataset. + """ users, num_samples = load_data(name) num_users = len(users) diff --git a/python/fedml/data/NUS_WIDE/nus_wide_dataset.py b/python/fedml/data/NUS_WIDE/nus_wide_dataset.py index e6931d60b2..15e00fa14b 100644 --- a/python/fedml/data/NUS_WIDE/nus_wide_dataset.py +++ b/python/fedml/data/NUS_WIDE/nus_wide_dataset.py @@ -6,6 +6,16 @@ def get_top_k_labels(data_dir, top_k=5): + """ + Get the top k labels based on their frequency in the dataset. + + Args: + data_dir (str): The directory containing the dataset. + top_k (int): The number of top labels to retrieve. + + Returns: + list: A list of the top k labels. + """ data_path = "Groundtruth/AllLabels" label_counts = {} for filename in os.listdir(os.path.join(data_dir, data_path)): @@ -21,6 +31,18 @@ def get_top_k_labels(data_dir, top_k=5): def get_labeled_data_with_2_party(data_dir, selected_labels, n_samples, dtype="Train"): + """ + Load labeled data for a two-party scenario. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + n_samples (int): The number of samples to load. + dtype (str): The data type (e.g., 'Train' or 'Test'). + + Returns: + tuple: A tuple containing XA (image features), XB (tags), and Y (labels). + """ # get labels data_path = "Groundtruth/TrainTestLabels/" dfs = [] @@ -71,6 +93,18 @@ def get_labeled_data_with_2_party(data_dir, selected_labels, n_samples, dtype="T def get_labeled_data_with_3_party(data_dir, selected_labels, n_samples, dtype="Train"): + """ + Load labeled data for a three-party scenario. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + n_samples (int): The number of samples to load. + dtype (str): The data type (e.g., 'Train' or 'Test'). + + Returns: + tuple: A tuple containing XA (image features), XB1 (tags for party 1), XB2 (tags for party 2), and Y (labels). + """ Xa, Xb, Y = get_labeled_data_with_2_party( data_dir=data_dir, selected_labels=selected_labels, @@ -83,6 +117,18 @@ def get_labeled_data_with_3_party(data_dir, selected_labels, n_samples, dtype="T def NUS_WIDE_load_two_party_data(data_dir, selected_labels, neg_label=-1, n_samples=-1): + """ + Load two-party data for NUS-WIDE dataset. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + neg_label (int): The negative label value. + n_samples (int): The number of samples to load. + + Returns: + tuple: A tuple containing training data and testing data for two parties. + """ print("# load_two_party_data") Xa, Xb, y = get_labeled_data_with_2_party( @@ -134,6 +180,19 @@ def NUS_WIDE_load_two_party_data(data_dir, selected_labels, neg_label=-1, n_samp def NUS_WIDE_load_three_party_data( data_dir, selected_labels, neg_label=-1, n_samples=-1 ): + """ + Load three-party data for NUS-WIDE dataset. + + Args: + data_dir (str): The directory containing the dataset. + selected_labels (list): The selected labels for the data. + neg_label (int): The negative label value. + n_samples (int): The number of samples to load. + + Returns: + tuple: A tuple containing training data and testing data for three parties. + """ + print("# load_three_party_data") Xa, Xb, Xc, y = get_labeled_data_with_3_party( data_dir=data_dir, selected_labels=selected_labels, n_samples=n_samples @@ -185,6 +244,20 @@ def prepare_party_data( n_samples, is_three_party=False, ): + """ + Prepare data for a federated learning scenario. + + Args: + src_data_folder (str): The source data folder. + des_data_folder (str): The destination data folder. + selected_labels (list): The selected labels for the data. + neg_label (int): The negative label value. + n_samples (int): The number of samples to load. + is_three_party (bool): Whether it's a three-party scenario. + + Returns: + None + """ print("# preparing data ...") train_data_list, test_data_list = ( @@ -235,6 +308,16 @@ def prepare_party_data( def get_data_folder_name(sel_lbls, is_three_party): + """ + Generate a folder name based on selected labels and party type. + + Args: + sel_lbls (list): List of selected labels. + is_three_party (bool): Indicates whether it's a three-party scenario. + + Returns: + str: Generated folder name. + """ folder_name = sel_lbls[0] for idx, lbl in enumerate(sel_lbls): if idx == 0: @@ -246,6 +329,17 @@ def get_data_folder_name(sel_lbls, is_three_party): def load_prepared_parties_data(data_dir, sel_lbls, load_three_party): + """ + Load prepared party data from a specific directory. + + Args: + data_dir (str): The directory containing the prepared data. + sel_lbls (list): List of selected labels. + load_three_party (bool): Indicates whether to load three-party data. + + Returns: + tuple: A tuple containing training and testing data lists. + """ print( "# load prepared {0} party data".format("three" if load_three_party else "two") ) diff --git a/python/fedml/data/UCI/data_loader_for_susy_and_ro.py b/python/fedml/data/UCI/data_loader_for_susy_and_ro.py index 5936c49c07..79916526a8 100644 --- a/python/fedml/data/UCI/data_loader_for_susy_and_ro.py +++ b/python/fedml/data/UCI/data_loader_for_susy_and_ro.py @@ -5,7 +5,60 @@ class DataLoader(object): + """ + DataLoader class for managing data loading and preprocessing. + + Args: + data_name (str): The name of the dataset. + data_path (str): The path to the dataset CSV file. + client_list (list): A list of client IDs. + sample_num_in_total (int): The total number of data samples. + beta (float): A parameter for data loading. + + Attributes: + data_name (str): The name of the dataset. + data_path (str): The path to the dataset CSV file. + client_list (list): A list of client IDs. + sample_num_in_total (int): The total number of data samples. + beta (float): A parameter for data loading. + streaming_full_dataset_X (list): A list to store data samples. + streaming_full_dataset_Y (list): A list to store data labels. + StreamingDataDict (dict): A dictionary to store streaming data for clients. + + Methods: + load_datastream(): + Load and preprocess the data for streaming and return it as a dictionary. + load_adversarial_data(): + Load adversarial data based on the beta parameter. + load_stochastic_data(): + Load stochastic data based on the beta parameter. + read_csv_file(percent): + Read and return data samples and labels from a CSV file. + read_csv_file_for_cluster(percent): + Read and cluster data samples based on the beta parameter. + kMeans(X): + Perform K-means clustering on the data. + preprocessing(): + Perform preprocessing on the data. + + """ def __init__(self, data_name, data_path, client_list, sample_num_in_total, beta): + """ + Initialize the DataLoader with dataset information and parameters. + + Args: + data_name (str): The name of the dataset. + data_path (str): The path to the dataset CSV file. + client_list (list): A list of client IDs. + sample_num_in_total (int): The total number of data samples. + beta (float): A parameter for data loading. + + Note: + This constructor initializes the DataLoader with dataset details and parameters. + + Returns: + None + """ # SUSY, Room Occupancy; self.data_name = data_name self.data_path = data_path @@ -24,6 +77,12 @@ def __init__(self, data_name, data_path, client_list, sample_num_in_total, beta) """ def load_datastream(self): + """ + Load and preprocess the data for streaming and return it as a dictionary. + + Returns: + dict: A dictionary containing streaming data for clients. + """ self.preprocessing() self.load_adversarial_data() self.load_stochastic_data() @@ -37,14 +96,35 @@ def load_datastream(self): # beta (clustering, GMM) def load_adversarial_data(self): + """ + Load adversarial data based on the beta parameter. + + Returns: + dict: A dictionary containing adversarial streaming data for clients. + """ streaming_data = self.read_csv_file_for_cluster(self.beta) return streaming_data def load_stochastic_data(self): + """ + Load stochastic data based on the beta parameter. + + Returns: + dict: A dictionary containing stochastic streaming data for clients. + """ streaming_data = self.read_csv_file(self.beta) return streaming_data def read_csv_file(self, percent): + """ + Read and return data samples and labels from a CSV file. + + Args: + percent (float): The percentage of data to read. + + Returns: + dict: A dictionary containing streaming data for clients. + """ # print("start from:") iteration_number = int(self.sample_num_in_total / len(self.client_list)) @@ -105,6 +185,15 @@ def read_csv_file(self, percent): return self.StreamingDataDict def read_csv_file_for_cluster(self, percent): + """ + Read and cluster data samples based on the beta parameter. + + Args: + percent (float): The percentage of data to read and cluster. + + Returns: + dict: A dictionary containing clustered streaming data for clients. + """ data = [] label = [] for client_id in self.client_list: @@ -134,11 +223,26 @@ def read_csv_file_for_cluster(self, percent): return self.StreamingDataDict def kMeans(self, X): + """ + Perform K-means clustering on the data. + + Args: + X (list): List of data samples. + + Returns: + array: Cluster labels for data samples. + """ kmeans = KMeans(n_clusters=len(self.client_list)) kmeans.fit(X) return kmeans.labels_ def preprocessing(self): + """ + Perform preprocessing on the data. + + Returns: + None + """ # print("sample_num_in_total = " + str(self.sample_num_in_total)) data = [] with open(self.data_path) as csvfile: diff --git a/python/fedml/data/lending_club_loan/lending_club_dataset.py b/python/fedml/data/lending_club_loan/lending_club_dataset.py index 15812e28d4..f93ac6f132 100644 --- a/python/fedml/data/lending_club_loan/lending_club_dataset.py +++ b/python/fedml/data/lending_club_loan/lending_club_dataset.py @@ -105,20 +105,45 @@ def normalize(x): + """ + Normalize a numerical array using StandardScaler. + + Args: + x (array-like): The data to normalize. + + Returns: + array-like: Normalized data. + """ scaler = StandardScaler() x_scaled = scaler.fit_transform(x) return x_scaled - def normalize_df(df): + """ + Normalize a DataFrame using StandardScaler. + + Args: + df (pd.DataFrame): The DataFrame to normalize. + + Returns: + pd.DataFrame: Normalized DataFrame. + """ column_names = df.columns x = df.values x_scaled = normalize(x) scaled_df = pd.DataFrame(data=x_scaled, columns=column_names) return scaled_df - def loan_condition(status): + """ + Determine if a loan is a good or bad loan based on its status. + + Args: + status (str): Loan status. + + Returns: + str: "Good Loan" or "Bad Loan". + """ bad_loan = [ "Charged Off", "Default", @@ -132,22 +157,45 @@ def loan_condition(status): else: return "Good Loan" - def compute_annual_income(row): + """ + Compute the annual income for a loan applicant. + + Args: + row (pd.Series): A row of loan data. + + Returns: + float: Annual income. + """ if row["verification_status"] == row["verification_status_joint"]: return row["annual_inc_joint"] return row["annual_inc"] - def determine_good_bad_loan(df_loan): - print("[INFO] determine good or bad loan") + """ + Determine if a loan is a good or bad loan based on its status. + Args: + df_loan (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with "target" column indicating loan condition. + """ + print("[INFO] determine good or bad loan") df_loan["target"] = np.nan df_loan["target"] = df_loan["loan_status"].apply(loan_condition) return df_loan - def determine_annual_income(df_loan): + """ + Determine annual income for loan applicants. + + Args: + df_loan (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with "annual_inc_comp" column for annual income. + """ print("[INFO] determine annual income") df_loan["annual_inc_comp"] = np.nan @@ -156,15 +204,32 @@ def determine_annual_income(df_loan): def determine_issue_year(df_loan): - print("[INFO] determine issue year") + """ + Determine the issue year of loans. - # transform the issue dates by year + Args: + df_loan (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with "issue_year" column for issue years. + """ + print("[INFO] determine issue year") + # Transform the issue dates by year dt_series = pd.to_datetime(df_loan["issue_d"]) df_loan["issue_year"] = dt_series.dt.year return df_loan def digitize_columns(data_frame): + """ + Digitize categorical columns in the DataFrame. + + Args: + data_frame (pd.DataFrame): The DataFrame to digitize. + + Returns: + pd.DataFrame: DataFrame with categorical columns converted to numerical values. + """ print("[INFO] digitize columns") data_frame = data_frame.replace( @@ -185,6 +250,15 @@ def digitize_columns(data_frame): def prepare_data(file_path): + """ + Prepare loan data from a CSV file. + + Args: + file_path (str): Path to the CSV file containing loan data. + + Returns: + pd.DataFrame: DataFrame with processed loan data. + """ print("[INFO] prepare loan data.") df_loan = pd.read_csv(file_path, low_memory=False) @@ -200,6 +274,15 @@ def prepare_data(file_path): def process_data(loan_df): + """ + Process loan data. + + Args: + loan_df (pd.DataFrame): DataFrame containing loan data. + + Returns: + pd.DataFrame: DataFrame with processed loan features and target. + """ loan_feat_df = loan_df[all_feature_list] loan_feat_df = loan_feat_df.fillna(-99) assert loan_feat_df.isnull().sum().sum() == 0 @@ -211,6 +294,15 @@ def process_data(loan_df): def load_processed_data(data_dir): + """ + Load processed loan data from a CSV file, or preprocess and save it if not available. + + Args: + data_dir (str): Directory path for data files. + + Returns: + pd.DataFrame: DataFrame with processed loan data. + """ file_path = data_dir + "processed_loan.csv" if os.path.exists(file_path): print(f"[INFO] load processed loan data from {file_path}") @@ -226,6 +318,15 @@ def load_processed_data(data_dir): def loan_load_two_party_data(data_dir): + """ + Load two-party loan data. + + Args: + data_dir (str): Directory path for data files. + + Returns: + tuple: Training and testing data for two parties. + """ print("[INFO] load two party data") processed_loan_df = load_processed_data(data_dir) party_a_feat_list = qualification_feat + loan_feat @@ -253,6 +354,15 @@ def loan_load_two_party_data(data_dir): def loan_load_three_party_data(data_dir): + """ + Load three-party loan data. + + Args: + data_dir (str): Directory path for data files. + + Returns: + tuple: Training and testing data for three parties (Party A, Party B, Party C). + """ print("[INFO] load three party data") processed_loan_df = load_processed_data(data_dir) party_a_feat_list = qualification_feat + loan_feat diff --git a/python/fedml/data/reddit/data_loader.py b/python/fedml/data/reddit/data_loader.py index 65dff93415..939a3fee0d 100644 --- a/python/fedml/data/reddit/data_loader.py +++ b/python/fedml/data/reddit/data_loader.py @@ -35,6 +35,22 @@ def load_partition_data_reddit( batch_size, n_proc_in_silo=0, ): + """ + Load and partition Reddit dataset for Federated Learning. + + Args: + args: An object containing configuration parameters. + dataset: The Reddit dataset. + data_dir: The directory containing the dataset. + partition_method: The method used for data partitioning. + partition_alpha: A parameter for data partitioning. + client_number: The number of clients/partitions. + batch_size: The batch size for data loading. + n_proc_in_silo: The number of processes in the silo (default: 0). + + Returns: + tuple: A tuple containing various data components for Federated Learning. + """ from .nlp import load_and_cache_examples, mask_tokens from transformers import (AdamW, AlbertTokenizer, AutoConfig, diff --git a/python/fedml/data/reddit/datasets.py b/python/fedml/data/reddit/datasets.py index 24bf05c475..2199f67540 100644 --- a/python/fedml/data/reddit/datasets.py +++ b/python/fedml/data/reddit/datasets.py @@ -13,10 +13,53 @@ class Reddit_dataset(): + """ + Dataset class for Reddit data. + + Args: + root (str): The root directory where the data is stored. + train (bool): Whether to load the training or testing dataset. + + Attributes: + train_file (str): The file name for the training dataset. + test_file (str): The file name for the testing dataset. + vocab_tokens_size (int): The size of the token vocabulary. + vocab_tags_size (int): The size of the tag vocabulary. + raw_data (list): A list of tokenized text data. + dict (dict): A mapping dictionary from sample id to target tag. + + Methods: + __getitem__(self, index): + Get an item from the dataset by index. + __mapping_dict__(self): + Get the mapping dictionary. + __len__(self): + Get the length of the dataset. + raw_folder(self): + Get the raw data folder path. + processed_folder(self): + Get the processed data folder path. + class_to_idx(self): + Get a mapping from class names to class indices. + _check_exists(self): + Check if the dataset exists. + load_token_vocab(self, vocab_size, path): + Load token vocabulary from a file. + load_file(self, path, is_train): + Load the dataset from files. + + """ classes = [] MAX_SEQ_LEN = 20000 def __init__(self, root, train=True): + """ + Initialize the Reddit_dataset. + + Args: + root (str): The root directory where the data is stored. + train (bool): Whether to load the training or testing dataset. + """ self.train = train # training set or test set self.root = root @@ -61,34 +104,90 @@ def __getitem__(self, index): return tokens def __mapping_dict__(self): + """ + Get the mapping dictionary. + + Returns: + dict: A dictionary mapping sample IDs to target tags. + """ + return self.dict def __len__(self): + """ + Get the length of the dataset. + + Returns: + int: The number of samples in the dataset. + """ return len(self.raw_data) @property def raw_folder(self): + """ + Get the raw data folder path. + + Returns: + str: The path to the raw data folder. + """ return self.root @property def processed_folder(self): + """ + Get the processed data folder path. + + Returns: + str: The path to the processed data folder. + """ return self.root @property def class_to_idx(self): + """ + Get a mapping from class names to class indices. + + Returns: + dict: A dictionary mapping class names to class indices. + """ return {_class: i for i, _class in enumerate(self.classes)} def _check_exists(self): - return (os.path.exists(os.path.join(self.processed_folder, - self.data_file))) + """ + Check if the dataset exists. + + Returns: + bool: True if the dataset exists, False otherwise. + """ + return (os.path.exists(os.path.join(self.processed_folder, self.data_file))) def load_token_vocab(self, vocab_size, path): + """ + Load token vocabulary from a file. + + Args: + vocab_size (int): The size of the token vocabulary. + path (str): The path to the vocabulary file. + + Returns: + list: A list of tokens from the vocabulary. + """ tokens_file = "reddit_vocab.pkl" with open(os.path.join(path, tokens_file), 'rb') as f: tokens = pickle.load(f) return tokens[:vocab_size] def load_file(self, path, is_train): + """ + Load the dataset from files. + + Args: + path (str): The path to the dataset files. + is_train (bool): Whether to load the training or testing dataset. + + Returns: + tuple: A tuple containing text data and a mapping dictionary. + """ file_name = os.path.join( path, 'train') if self.train else os.path.join(path, 'test') diff --git a/python/fedml/data/reddit/divide_data.py b/python/fedml/data/reddit/divide_data.py index 96f562c422..e4be2ed988 100644 --- a/python/fedml/data/reddit/divide_data.py +++ b/python/fedml/data/reddit/divide_data.py @@ -12,24 +12,129 @@ class Partition(object): - """ Dataset partitioning helper """ + """ + Helper class for dataset partitioning. + + Args: + data (list): The dataset to be partitioned. + index (list): A list of indices specifying the partition. + + Attributes: + data (list): The dataset to be partitioned. + index (list): A list of indices specifying the partition. + + Methods: + __len__(): + Get the length of the partition. + __getitem__(index): + Get an item from the partition by index. + + """ def __init__(self, data, index): + """ + Initialize a dataset partition. + + Args: + data (list): The dataset to be partitioned. + index (list): A list of indices specifying the partition. + + Returns: + None + """ self.data = data self.index = index def __len__(self): + """ + Get the length of the partition. + + Returns: + int: The length of the partition. + """ return len(self.index) def __getitem__(self, index): + """ + Get an item from the partition by index. + + Args: + index (int): The index of the item to retrieve. + + Returns: + object: The item from the partition. + """ data_idx = self.index[index] return self.data[data_idx] -class DataPartitioner(object): - """Partition data by trace or random""" +import csv +import logging +import numpy as np +from collections import defaultdict +from random import Random +class DataPartitioner(object): + """ + Partition data by trace or random for federated learning. + + Args: + data: The dataset to be partitioned. + args: An object containing configuration parameters. + numOfClass (int): The number of classes in the dataset (default: 0). + seed (int): The seed for randomization (default: 10). + isTest (bool): Whether the partitioning is for a test dataset (default: False). + + Attributes: + partitions (list): A list of partitions, where each partition is a list of sample indices. + rng (Random): A random number generator. + data: The dataset to be partitioned. + labels: The labels of the dataset. + args: An object containing configuration parameters. + isTest (bool): Whether the partitioning is for a test dataset. + data_len (int): The length of the dataset. + task: The task type. + numOfLabels (int): The number of labels in the dataset. + client_label_cnt (defaultdict): A dictionary to count labels for each client. + + Methods: + getNumOfLabels(): + Get the number of unique labels in the dataset. + getDataLen(): + Get the length of the dataset. + getClientLen(): + Get the number of clients/partitions. + getClientLabel(): + Get the number of unique labels for each client. + trace_partition(data_map_file): + Partition data based on a trace file. + partition_data_helper(num_clients, data_map_file=None): + Helper function for partitioning data. + uniform_partition(num_clients): + Uniformly partition data randomly. + use(partition, istest): + Get a partition of the dataset for a specific client. + getSize(): + Get the size of each partition (number of samples). + + """ def __init__(self, data, args, numOfClass=0, seed=10, isTest=False): + """ + Initialize the DataPartitioner. + + Args: + data: The dataset to be partitioned. + args: An object containing configuration parameters. + numOfClass (int): The number of classes in the dataset (default: 0). + seed (int): The seed for randomization (default: 10). + isTest (bool): Whether the partitioning is for a test dataset (default: False). + + Note: + This constructor sets up the DataPartitioner with the provided dataset and configuration. + + Returns: + None + """ self.partitions = [] self.rng = Random() self.rng.seed(seed) @@ -46,24 +151,57 @@ def __init__(self, data, args, numOfClass=0, seed=10, isTest=False): self.client_label_cnt = defaultdict(set) def getNumOfLabels(self): + """ + Get the number of unique labels in the dataset. + + Returns: + int: The number of unique labels. + """ return self.numOfLabels def getDataLen(self): + """ + Get the length of the dataset. + + Returns: + int: The length of the dataset. + """ return self.data_len def getClientLen(self): + """ + Get the number of clients/partitions. + + Returns: + int: The number of clients/partitions. + """ return len(self.partitions) def getClientLabel(self): + """ + Get the number of unique labels for each client. + + Returns: + list: A list of the number of unique labels for each client. + """ return [len(self.client_label_cnt[i]) for i in range(self.getClientLen())] def trace_partition(self, data_map_file): - """Read data mapping from data_map_file. Format: """ - logging.info(f"Partitioning data by profile {data_map_file}...") + """ + Partition data based on a trace file. + Args: + data_map_file (str): The path to the data mapping file. + + Returns: + None + """ + logging.info(f"Partitioning data by profile {data_map_file}...") + clientId_maps = {} unique_clientIds = {} - # load meta data from the data_map_file + + # Load meta data from the data_map_file with open(data_map_file) as csv_file: csv_reader = csv.reader(csv_file, delimiter=',') read_first = True @@ -80,8 +218,7 @@ def trace_partition(self, data_map_file): unique_clientIds[client_id] = len(unique_clientIds) clientId_maps[sample_id] = unique_clientIds[client_id] - self.client_label_cnt[unique_clientIds[client_id]].add( - row[-1]) + self.client_label_cnt[unique_clientIds[client_id]].add(row[-1]) sample_id += 1 # Partition data given mapping @@ -91,15 +228,33 @@ def trace_partition(self, data_map_file): self.partitions[clientId_maps[idx]].append(idx) def partition_data_helper(self, num_clients, data_map_file=None): + """ + Helper function for partitioning data. + + Args: + num_clients (int): The number of clients/partitions. + data_map_file (str): The path to the data mapping file (default: None). - # read mapping file to partition trace + Returns: + None + """ + # Read mapping file to partition trace if data_map_file is not None: self.trace_partition(data_map_file) else: self.uniform_partition(num_clients=num_clients) def uniform_partition(self, num_clients): - # random partition + """ + Uniformly partition data randomly. + + Args: + num_clients (int): The number of clients/partitions. + + Returns: + None + """ + # Random partition numOfLabels = self.getNumOfLabels() data_len = self.getDataLen() logging.info(f"Randomly partitioning data, {data_len} samples...") @@ -108,11 +263,21 @@ def uniform_partition(self, num_clients): self.rng.shuffle(indexes) for _ in range(num_clients): - part_len = int(1./num_clients * data_len) + part_len = int(1. / num_clients * data_len) self.partitions.append(indexes[0:part_len]) indexes = indexes[part_len:] def use(self, partition, istest): + """ + Get a partition of the dataset for a specific client. + + Args: + partition (int): The index of the client/partition. + istest (bool): Whether the partition is for a test dataset. + + Returns: + Partition: A partition of the dataset for the specified client. + """ resultIndex = self.partitions[partition] exeuteLength = len(resultIndex) if not istest else int( @@ -123,12 +288,31 @@ def use(self, partition, istest): return Partition(self.data, resultIndex) def getSize(self): - # return the size of samples + """ + Get the size of each partition (number of samples). + + Returns: + dict: A dictionary containing the size of each partition. + """ + # Return the size of samples return {'size': [len(partition) for partition in self.partitions]} def select_dataset(rank, partition, batch_size, args, isTest=False, collate_fn=None): - """Load data given client Id""" + """ + Load data for a specific client based on client ID. + + Args: + rank (int): The client's rank or ID. + partition (Partition): A partition of the dataset for the client. + batch_size (int): The batch size for data loading. + args: An object containing configuration parameters. + isTest (bool): Whether the data loading is for a test dataset (default: False). + collate_fn (callable, optional): A function used to collate data samples into batches (default: None). + + Returns: + DataLoader: A DataLoader object for loading the client's data. + """ partition = partition.use(rank - 1, isTest) dropLast = False if isTest else True num_loaders = min(int(len(partition)/args.batch_size/2), args.num_loaders) @@ -145,6 +329,3 @@ def select_dataset(rank, partition, batch_size, args, isTest=False, collate_fn=N return DataLoader(partition, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_loaders, drop_last=dropLast, collate_fn=collate_fn) return DataLoader(partition, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_loaders, drop_last=dropLast) - - - diff --git a/python/fedml/data/reddit/nlp.py b/python/fedml/data/reddit/nlp.py index 4711a4846f..a62a681289 100644 --- a/python/fedml/data/reddit/nlp.py +++ b/python/fedml/data/reddit/nlp.py @@ -44,13 +44,34 @@ def chunks_idx(l, n): + """ + Split a list into 'n' roughly equal-sized chunks and yield the start and end indices of each chunk. + + Args: + l (list): The list to be split. + n (int): The number of chunks to split the list into. + + Yields: + tuple: A tuple containing the start and end indices of each chunk. + """ d, r = divmod(len(l), n) for i in range(n): si = (d+1)*(i if i < r else r) + d*(0 if i < r else i - r) yield si, si+(d+1 if i < r else d) - def feature_creation_worker(files, tokenizer, block_size, worker_idx): + """ + Worker function for creating features from a list of text files. + + Args: + files (list): A list of file paths containing text data. + tokenizer: The tokenizer to convert text to tokens. + block_size (int): The maximum block size for tokenized text. + worker_idx (int): The index of the worker. + + Returns: + tuple: A tuple containing examples (tokenized text), client mapping, and sample client IDs. + """ examples = [] sample_client = [] client_mapping = collections.defaultdict(list) @@ -83,8 +104,43 @@ def feature_creation_worker(files, tokenizer, block_size, worker_idx): class TextDataset(Dataset): + """ + Dataset for text data used in language modeling tasks. + + Args: + tokenizer: The tokenizer to convert text to tokens. + args: An object containing dataset configuration parameters. + file_path (str): The directory containing the dataset files. + block_size (int): The maximum block size for tokenized text (default: 512). + + Attributes: + examples (list): A list of tokenized text examples. + sample_client (list): A list of sample client IDs. + client_mapping (dict): A dictionary mapping client IDs to tokenized text examples. + + Methods: + __len__(): + Get the number of examples in the dataset. + __getitem__(item): + Get an example from the dataset. + + """ def __init__(self, tokenizer, args, file_path, block_size=512): + """ + Initialize the TextDataset. + + Args: + tokenizer: The tokenizer to convert text to tokens. + args: An object containing dataset configuration parameters. + file_path (str): The directory containing the dataset files. + block_size (int): The maximum block size for tokenized text (default: 512). + + Note: + This constructor processes and loads the dataset from files or creates features if not cached. + Returns: + None + """ block_size = block_size - \ (tokenizer.model_max_length - tokenizer.max_len_single_sentence) @@ -135,7 +191,7 @@ def __init__(self, tokenizer, args, file_path, block_size=512): self.client_mapping[true_user_id] = client_mapping[user_id] user_id_base = true_sample_client[-1] + 1 - # Note that we are loosing the last truncated example here for the sake of simplicity (no padding) + # Note that we are losing the last truncated example here for the sake of simplicity (no padding) # If your dataset is small, first you should look for a bigger one :-) and second you # can change this behavior by adding (model specific) padding. logger.info("Saving features into cached file %s", @@ -159,13 +215,39 @@ def __init__(self, tokenizer, args, file_path, block_size=512): self.targets = [0 for i in range(len(self.data))] def __len__(self): + """ + Get the number of examples in the dataset. + + Returns: + int: The number of examples in the dataset. + """ return len(self.examples) def __getitem__(self, item): + """ + Get an example from the dataset. + + Args: + item: The index of the example to retrieve. + + Returns: + torch.Tensor: The tokenized text example as a PyTorch tensor. + """ return torch.tensor(self.examples[item], dtype=torch.long) def load_and_cache_examples(args, tokenizer, evaluate=False): + """ + Load and cache examples from the dataset for training or evaluation. + + Args: + args: An object containing dataset configuration parameters. + tokenizer: The tokenizer to convert text to tokens. + evaluate (bool): Whether to load examples for evaluation (default: False). + + Returns: + TextDataset: A dataset containing tokenized text examples. + """ file_path = os.path.join(args.data_cache_dir, 'test') if evaluate else os.path.join( args.data_cache_dir, 'train') @@ -173,7 +255,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False): def mask_tokens(inputs, tokenizer, args, device='cpu') -> Tuple[torch.Tensor, torch.Tensor]: - """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """ + """ + Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. + + Args: + inputs (torch.Tensor): The input token IDs. + tokenizer: The tokenizer to convert text to tokens. + args: An object containing configuration parameters. + device (str): The device to use for computations (default: 'cpu'). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing masked input tokens and labels for masked language modeling. + """ labels = inputs.clone().to(device=device) # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) probability_matrix = torch.full( diff --git a/python/fedml/data/stackoverflow_lr/data_loader.py b/python/fedml/data/stackoverflow_lr/data_loader.py index 0aa5087017..6edd993bec 100644 --- a/python/fedml/data/stackoverflow_lr/data_loader.py +++ b/python/fedml/data/stackoverflow_lr/data_loader.py @@ -21,6 +21,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): + """ + Get DataLoader objects for training and testing data. + + Args: + dataset: The dataset to use. + data_dir (str): The directory containing the data. + train_bs (int): The batch size for training. + test_bs (int): The batch size for testing. + client_idx (int, optional): The client index (None for global data). + + Returns: + tuple: A tuple containing training and testing DataLoader objects. + """ if client_idx is None: train_dl = data.DataLoader( @@ -94,6 +107,18 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): def load_partition_data_distributed_federated_stackoverflow_lr( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for distributed federated stackoverflow_lr. + + Args: + process_id (int): The process ID. + dataset: The dataset to use. + data_dir (str): The directory containing the data. + batch_size (int, optional): The batch size (default is 64). + + Returns: + tuple: A tuple containing data for distributed federated stackoverflow_lr. + """ # get global dataset if process_id == 0: train_data_global, test_data_global = get_dataloader( @@ -131,6 +156,17 @@ def load_partition_data_distributed_federated_stackoverflow_lr( def load_partition_data_federated_stackoverflow_lr( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for federated stackoverflow_lr. + + Args: + dataset: The dataset to use. + data_dir (str): The directory containing the data. + batch_size (int, optional): The batch size (default is 64). + + Returns: + tuple: A tuple containing data for federated stackoverflow_lr. + """ logging.info("load_partition_data_federated_stackoverflow_lr START") global cache_data diff --git a/python/fedml/data/stackoverflow_lr/dataset.py b/python/fedml/data/stackoverflow_lr/dataset.py index 7d7dab6ecb..3b312d90c5 100644 --- a/python/fedml/data/stackoverflow_lr/dataset.py +++ b/python/fedml/data/stackoverflow_lr/dataset.py @@ -4,7 +4,15 @@ class StackOverflowDataset(data.Dataset): - """StackOverflow dataset""" + """ + StackOverflow dataset. + + Args: + h5_path (str): Path to the h5 file. + client_idx (int): Index of the train file. + datast (str): "train" or "test" denoting the train set or test set. + preprocess (dict of callable, optional): Optional preprocessing functions with keys "input" and "target". + """ __train_client_id_list = None __test_client_id_list = None @@ -33,6 +41,12 @@ def __init__(self, h5_path, client_idx, datast, preprocess=None): self.target_fn = preprocess["target"] def get_client_id_list(self): + """ + Get a list of client IDs based on the dataset type. + + Returns: + list: List of client IDs. + """ if self.datast == "train": if StackOverflowDataset.__train_client_id_list is None: with h5py.File(self.h5_path, "r") as h5_file: diff --git a/python/fedml/data/stackoverflow_lr/utils.py b/python/fedml/data/stackoverflow_lr/utils.py index 7bcb011106..5fbf007e60 100644 --- a/python/fedml/data/stackoverflow_lr/utils.py +++ b/python/fedml/data/stackoverflow_lr/utils.py @@ -17,6 +17,15 @@ def get_word_count_file(data_dir): + """ + Get the path to the word count file. + + Args: + data_dir (str): The directory where the file is located. + + Returns: + str: The full path to the word count file. + """ # word_count_file_path global word_count_file_path if word_count_file_path is None: @@ -25,6 +34,15 @@ def get_word_count_file(data_dir): def get_tag_count_file(data_dir): + """ + Get the path to the tag count file. + + Args: + data_dir (str): The directory where the file is located. + + Returns: + str: The full path to the tag count file. + """ # tag_count_file_path global tag_count_file_path if tag_count_file_path is None: @@ -33,6 +51,16 @@ def get_tag_count_file(data_dir): def get_most_frequent_words(data_dir=None, vocab_size=10000): + """ + Get a list of the most frequent words. + + Args: + data_dir (str, optional): The directory where the word count file is located. + vocab_size (int, optional): The number of most frequent words to retrieve. + + Returns: + list: A list of the most frequent words. + """ frequent_words = [] with open(get_word_count_file(data_dir), "r") as f: frequent_words = [next(f).split()[0] for i in range(vocab_size)] @@ -40,12 +68,31 @@ def get_most_frequent_words(data_dir=None, vocab_size=10000): def get_tags(data_dir=None, tag_size=500): + """ + Get a list of tags. + + Args: + data_dir (str, optional): The directory where the tag count file is located. + tag_size (int, optional): The number of tags to retrieve. + + Returns: + list: A list of tags. + """ f = open(get_tag_count_file(data_dir), "r") frequent_tags = json.load(f) return list(frequent_tags.keys())[:tag_size] def get_word_dict(data_dir): + """ + Get a dictionary that maps words to their IDs. + + Args: + data_dir (str): The directory where the word count file is located. + + Returns: + collections.OrderedDict: A dictionary mapping words to their IDs. + """ global word_dict if word_dict == None: words = get_most_frequent_words(data_dir) @@ -56,6 +103,15 @@ def get_word_dict(data_dir): def get_tag_dict(data_dir): + """ + Get a dictionary that maps tags to their IDs. + + Args: + data_dir (str): The directory where the tag count file is located. + + Returns: + collections.OrderedDict: A dictionary mapping tags to their IDs. + """ global tag_dict if tag_dict == None: tags = get_tags(data_dir) @@ -66,6 +122,16 @@ def get_tag_dict(data_dir): def preprocess_inputs(sentences, data_dir): + """ + Preprocess a list of sentences into a bag-of-words representation. + + Args: + sentences (list): List of sentences to preprocess. + data_dir (str): The directory where the word count file is located. + + Returns: + list: List of preprocessed bag-of-words representations. + """ sentences = [sentence.split(" ") for sentence in sentences] vocab_size = len(get_word_dict(data_dir)) @@ -87,6 +153,16 @@ def to_bag_of_words(sentence): def preprocess_targets(tags, data_dir): + """ + Preprocess a list of tags into a bag-of-words representation. + + Args: + tags (list): List of tags to preprocess. + data_dir (str): The directory where the tag count file is located. + + Returns: + list: List of preprocessed bag-of-words representations. + """ tags = [tag.split("|") for tag in tags] tag_size = len(get_tag_dict(data_dir)) @@ -129,6 +205,16 @@ def to_bag_of_words(sentence): def preprocess_target(tag, data_dir): + """ + Preprocess a single sentence into a bag-of-words representation. + + Args: + sentence (str): The sentence to preprocess. + data_dir (str): The directory where the word count file is located. + + Returns: + numpy.ndarray: Preprocessed bag-of-words representation. + """ tag = tag.split("|") tag_size = len(get_tag_dict(data_dir)) diff --git a/python/fedml/data/stackoverflow_nwp/data_loader.py b/python/fedml/data/stackoverflow_nwp/data_loader.py index c1a1b2d008..ccc8bf8250 100644 --- a/python/fedml/data/stackoverflow_nwp/data_loader.py +++ b/python/fedml/data/stackoverflow_nwp/data_loader.py @@ -21,6 +21,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): + """ + Get data loaders for training and testing. + + Args: + dataset: The dataset object. + data_dir (str): The directory containing the data. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + client_idx (int or None): Index of the client (None for global dataset). + + Returns: + tuple: A tuple containing train and test data loaders (train_dl, test_dl). + """ def _tokenizer(x): return utils.tokenizer(x, data_dir) @@ -79,6 +92,19 @@ def _tokenizer(x): def load_partition_data_distributed_federated_stackoverflow_nwp( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for distributed federated StackOverflow NWP. + + Args: + process_id (int): The process ID or rank. + dataset: The dataset object. + data_dir (str): The directory containing the data. + batch_size (int): Batch size. + + Returns: + tuple: A tuple containing client number, train data number, global train data, + global test data, local data number, local train data, local test data, and vocabulary length. + """ # get global dataset if process_id == 0: diff --git a/python/fedml/data/stackoverflow_nwp/dataset.py b/python/fedml/data/stackoverflow_nwp/dataset.py index 8ebf6f07a7..c9ae22fdbc 100644 --- a/python/fedml/data/stackoverflow_nwp/dataset.py +++ b/python/fedml/data/stackoverflow_nwp/dataset.py @@ -4,29 +4,44 @@ class StackOverflowDataset(data.Dataset): - """StackOverflow dataset""" + """ + StackOverflow dataset. + + Args: + h5_path (str): Path to the h5 file. + client_idx (int): Index of the train file. + datast (str): "train" or "test" denoting the train set or test set. + preprocess (callable, optional): Optional preprocessing function. + + Attributes: + _EXAMPLE (str): Name of the "examples" attribute in the h5 file. + _TOKENS (str): Name of the "tokens" attribute in the h5 file. + + """ __train_client_id_list = None __test_client_id_list = None - def __init__(self, h5_path, client_idx, datast, preprocess): - """ - Args: - h5_path (string) : path to the h5 file - client_idx (idx) : index of train file - datast (string) : "train" or "test" denoting on train set or test set - preprocess (callable, optional) : Optional preprocessing - """ - + def __init__(self, h5_path, client_idx, datast, preprocess=None): self._EXAMPLE = "examples" self._TOKENS = "tokens" self.h5_path = h5_path self.datast = datast - self.client_id = self.get_client_id_list()[client_idx] # pylint: disable=E1136 + self.client_id = self.get_client_id_list()[client_idx] + self.preprocess = preprocess def get_client_id_list(self): + """ + Get the list of client IDs for the specified dataset. + + Returns: + list: List of client IDs. + + Raises: + Exception: If an invalid dataset is specified. + """ if self.datast == "train": if StackOverflowDataset.__train_client_id_list is None: with h5py.File(self.h5_path, "r") as h5_file: @@ -42,7 +57,7 @@ def get_client_id_list(self): ) return StackOverflowDataset.__test_client_id_list else: - raise Exception("Please specify either train or test set!") + raise Exception("Please specify either 'train' or 'test' set!") def __len__(self): with h5py.File(self.h5_path, "r") as h5_file: @@ -50,8 +65,7 @@ def __len__(self): def __getitem__(self, idx): with h5py.File(self.h5_path, "r") as h5_file: - sample = h5_file[self._EXAMPLE][self.client_id][self._TOKENS][()][ - idx - ].decode("utf8") - sample = self.preprocess(sample) + sample = h5_file[self._EXAMPLE][self.client_id][self._TOKENS][()][idx].decode("utf8") + if self.preprocess is not None: + sample = self.preprocess(sample) return np.asarray(sample[:-1]), np.asarray(sample[1:]) diff --git a/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py b/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py index 014587f6d6..bde7ab7304 100644 --- a/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py +++ b/python/fedml/data/synthetic_0.5_0.5/generate_synthetic.py @@ -8,12 +8,35 @@ def softmax(x): + """ + Compute the softmax function for an array of values. + + Args: + x (numpy.ndarray): Input array. + + Returns: + numpy.ndarray: Softmax probabilities for the input array. + """ ex = np.exp(x) sum_ex = np.sum(np.exp(x)) return ex / sum_ex def generate_synthetic(alpha, beta, iid): + """ + Generate synthetic data for federated learning. + + Args: + NUM_USER (int): Number of users/clients. + alpha (float): Mean of the normal distribution for generating model weights. + beta (float): Mean of the normal distribution for generating model bias. + iid (int): Indicator for generating independent (1) or non-independent (0) data. + + Returns: + tuple: A tuple containing synthetic data for X (features) and y (labels). + - X_split (list): List of lists containing feature data for each user. + - y_split (list): List of lists containing label data for each user. + """ dimension = 60 NUM_CLASS = 10 np.random.seed(0) diff --git a/python/fedml/data/synthetic_0_0/generate_synthetic.py b/python/fedml/data/synthetic_0_0/generate_synthetic.py index 53c544944d..be1f67efc8 100644 --- a/python/fedml/data/synthetic_0_0/generate_synthetic.py +++ b/python/fedml/data/synthetic_0_0/generate_synthetic.py @@ -8,12 +8,35 @@ def softmax(x): + """ + Compute the softmax function for an array of values. + + Args: + x (numpy.ndarray): Input array. + + Returns: + numpy.ndarray: Softmax probabilities for the input array. + """ ex = np.exp(x) sum_ex = np.sum(np.exp(x)) return ex / sum_ex def generate_synthetic(alpha, beta, iid): + """ + Generate synthetic data for federated learning. + + Args: + NUM_USER (int): Number of users/clients. + alpha (float): Mean of the normal distribution for generating model weights. + beta (float): Mean of the normal distribution for generating model bias. + iid (int): Indicator for generating independent (1) or non-independent (0) data. + + Returns: + tuple: A tuple containing synthetic data for X (features) and y (labels). + - X_split (list): List of lists containing feature data for each user. + - y_split (list): List of lists containing label data for each user. + """ dimension = 60 NUM_CLASS = 10 np.random.seed(0) diff --git a/python/fedml/data/synthetic_1_1/data_loader.py b/python/fedml/data/synthetic_1_1/data_loader.py index 021963596c..5ad41fed63 100644 --- a/python/fedml/data/synthetic_1_1/data_loader.py +++ b/python/fedml/data/synthetic_1_1/data_loader.py @@ -14,8 +14,28 @@ def load_partition_data_federated_synthetic_1_1( - data_dir=None, batch_size=DEFAULT_BATCH_SIZE + train_file_path, test_file_path, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load federated synthetic data for training and testing. + + Args: + train_file_path (str): Path to the training data JSON file. + test_file_path (str): Path to the testing data JSON file. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing client and data-related information: + - client_num (int): Number of clients. + - train_data_num (int): Number of samples in the global training dataset. + - test_data_num (int): Number of samples in the global testing dataset. + - train_data_global (torch.utils.data.DataLoader): DataLoader for the global training dataset. + - test_data_global (torch.utils.data.DataLoader): DataLoader for the global testing dataset. + - data_local_num_dict (dict): Dictionary containing the number of samples for each client. + - train_data_local_dict (dict): Dictionary of DataLoader objects for local training data. + - test_data_local_dict (dict): Dictionary of DataLoader objects for local testing data. + - output_dim (int): Dimension of the output (e.g., number of classes). + """ logging.info("load_partition_data_federated_synthetic_1_1 START") with open(train_file_path, "r") as train_f, open(test_file_path, "r") as test_f: @@ -118,7 +138,13 @@ def load_partition_data_federated_synthetic_1_1( ) -def test_data_loader(): +def test_data_loader(train_file_path): + """ + Test the data loader function by comparing the number of samples with the original data. + + Args: + train_file_path (str): Path to the training data JSON file. + """ ( client_num, train_data_num, @@ -129,9 +155,10 @@ def test_data_loader(): train_data_local_dict, test_data_local_dict, output_dim, - ) = load_partition_data_federated_synthetic_1_1() - f = open(train_file_path, "r") - train_data = json.load(f) + ) = load_partition_data_federated_synthetic_1_1(train_file_path, train_file_path) + + with open(train_file_path, "r") as f: + train_data = json.load(f) assert train_data["num_samples"] == list(data_local_num_dict.values()) diff --git a/python/fedml/data/synthetic_1_1/generate_synthetic.py b/python/fedml/data/synthetic_1_1/generate_synthetic.py index 9a46ad43ca..edc2b15b65 100644 --- a/python/fedml/data/synthetic_1_1/generate_synthetic.py +++ b/python/fedml/data/synthetic_1_1/generate_synthetic.py @@ -8,23 +8,43 @@ def softmax(x): + """ + Compute the softmax values for a given array. + + Args: + x (numpy.ndarray): Input array. + + Returns: + numpy.ndarray: Softmax values. + """ ex = np.exp(x) sum_ex = np.sum(np.exp(x)) return ex / sum_ex def generate_synthetic(alpha, beta, iid): + """ + Generate synthetic data for federated learning. + + Args: + alpha (float): Mean of user weights. + beta (float): Mean of user biases. + iid (int): Unused parameter. + + Returns: + list: List of user data samples. + list: List of labels for user data samples. + """ dimension = 60 NUM_CLASS = 10 np.random.seed(0) samples_per_user = np.random.lognormal(4, 2, (NUM_USER)).astype(int) + 50 - print(samples_per_user) - # num_samples = np.sum(samples_per_user) + X_split = [[] for _ in range(NUM_USER)] y_split = [[] for _ in range(NUM_USER)] - #### define some eprior #### + mean_W = np.random.normal(0, alpha, NUM_USER) mean_b = mean_W B = np.random.normal(0, beta, NUM_USER) @@ -37,7 +57,7 @@ def generate_synthetic(alpha, beta, iid): for i in range(NUM_USER): mean_x[i] = np.random.normal(B[i], 1, dimension) - # print(mean_x[i]) + for i in range(NUM_USER): @@ -54,7 +74,7 @@ def generate_synthetic(alpha, beta, iid): X_split[i] = xx.tolist() y_split[i] = yy.tolist() - print("{}-th users has {} exampls".format(i, len(y_split[i]))) + return X_split, y_split diff --git a/python/fedml/data/synthetic_1_1/stats.py b/python/fedml/data/synthetic_1_1/stats.py index 5328562587..43f57bccd0 100755 --- a/python/fedml/data/synthetic_1_1/stats.py +++ b/python/fedml/data/synthetic_1_1/stats.py @@ -23,6 +23,16 @@ def load_data(name): + """ + Load user and sample data from JSON files in a dataset directory. + + Args: + name (str): The name of the dataset. + + Returns: + list: List of user names. + list: List of the number of samples per user. + """ users = [] num_samples = [] @@ -57,6 +67,15 @@ def load_data(name): def print_dataset_stats(name): + """ + Print statistics of a dataset, including the number of users and samples. + + Args: + name (str): The name of the dataset. + + Returns: + None + """ users, num_samples = load_data(name) num_users = len(users) From 843b0e8b9b15d4030634aa99912592e0159fb3b0 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Mon, 18 Sep 2023 11:41:09 +0530 Subject: [PATCH 61/70] add --- .../fedml/data/FederatedEMNIST/data_loader.py | 39 ++++ python/fedml/data/cinic10/data_loader.py | 157 +++++++++++++- python/fedml/data/cinic10/datasets.py | 20 +- .../data/edge_case_examples/data_loader.py | 100 ++++++++- .../fedml/data/edge_case_examples/datasets.py | 157 ++++++++++++-- python/fedml/data/fed_cifar100/data_loader.py | 48 ++++- python/fedml/data/fed_cifar100/utils.py | 23 ++- .../fedml/data/fed_shakespeare/data_loader.py | 65 ++++-- python/fedml/data/fed_shakespeare/utils.py | 65 +++++- .../base/data_manager/base_data_manager.py | 111 +++++++++- .../base/preprocess/base_preprocessor.py | 30 +++ .../base/raw_data/base_raw_data_loader.py | 192 ++++++++++++++++++ .../data/fednlp/base/raw_data/partition.py | 21 ++ python/fedml/data/fednlp/base/utils.py | 92 ++++++++- 14 files changed, 1057 insertions(+), 63 deletions(-) diff --git a/python/fedml/data/FederatedEMNIST/data_loader.py b/python/fedml/data/FederatedEMNIST/data_loader.py index c8160b2a7f..fb4c1414b5 100644 --- a/python/fedml/data/FederatedEMNIST/data_loader.py +++ b/python/fedml/data/FederatedEMNIST/data_loader.py @@ -21,6 +21,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): + """ + Create data loaders for training and testing data. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory where the dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + client_idx (int or None): Index of the client to load data for. If None, load data for all clients. + + Returns: + tuple: A tuple containing the training and testing data loaders. + """ train_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TEST_FILE), "r") @@ -76,6 +89,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): def load_partition_data_distributed_federated_emnist( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for a federated EMNIST dataset. + + Args: + process_id (int): The ID of the current process (0 for server, >0 for clients). + dataset (str): The dataset name. + data_dir (str): The directory where the dataset is located. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing information about the dataset, including the number of clients, + the number of samples in the global training data, global and local data loaders, and class number. + """ if process_id == 0: # get global dataset @@ -133,6 +159,19 @@ def load_partition_data_distributed_federated_emnist( def load_partition_data_federated_emnist( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load partitioned data for federated EMNIST dataset. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory where the dataset is located. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing information about the dataset, including the number of clients, + the number of samples in the global training and testing data, global and local data loaders, + the number of samples per client, and the class number. + """ # client ids train_file_path = os.path.join(data_dir, DEFAULT_TRAIN_FILE) diff --git a/python/fedml/data/cinic10/data_loader.py b/python/fedml/data/cinic10/data_loader.py index 10515a1569..13a28552c4 100644 --- a/python/fedml/data/cinic10/data_loader.py +++ b/python/fedml/data/cinic10/data_loader.py @@ -14,6 +14,15 @@ def read_data_distribution( filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt", ): + """ + Reads the data distribution from a text file. + + Args: + filename (str): The path to the distribution file. + + Returns: + dict: A dictionary representing the data distribution. + """ distribution = {} with open(filename, "r") as data: for x in data.readlines(): @@ -33,6 +42,15 @@ def read_data_distribution( def read_net_dataidx_map( filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt", ): + """ + Reads the network data index map from a text file. + + Args: + filename (str): The path to the network data index map file. + + Returns: + dict: A dictionary mapping network IDs to data indices. + """ net_dataidx_map = {} with open(filename, "r") as data: for x in data.readlines(): @@ -48,6 +66,16 @@ def read_net_dataidx_map( def record_net_data_stats(y_train, net_dataidx_map): + """ + Records network-specific data statistics. + + Args: + y_train (numpy.ndarray): Array of ground truth labels for the entire dataset. + net_dataidx_map (dict): A dictionary mapping network IDs to data indices. + + Returns: + dict: A dictionary containing network-specific class counts. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -63,6 +91,15 @@ def __init__(self, length): self.length = length def __call__(self, img): + """ + Applies the Cutout augmentation to an image. + + Args: + img (torch.Tensor): The input image. + + Returns: + torch.Tensor: The image with the Cutout augmentation applied. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -81,8 +118,15 @@ def __call__(self, img): def _data_transforms_cinic10(): + """ + Define data transformations for the CIFAR-10 dataset. + + Returns: + tuple: A tuple containing two transformation functions, one for training and one for validation/test. + """ cinic_mean = [0.47889522, 0.47227842, 0.43047404] cinic_std = [0.24205776, 0.23828046, 0.25874835] + # Transformer for train set: random crops and horizontal flip train_transform = transforms.Compose( [ @@ -120,6 +164,15 @@ def _data_transforms_cinic10(): def load_cinic10_data(datadir): + """ + Load CIFAR-10 data from the specified directory. + + Args: + datadir (str): The directory containing CIFAR-10 data. + + Returns: + tuple: A tuple containing training and testing data. + """ _train_dir = datadir + str("/train") logging.info("_train_dir = " + str(_train_dir)) _test_dir = datadir + str("/test") @@ -168,6 +221,19 @@ def load_cinic10_data(datadir): def partition_data(dataset, datadir, partition, n_nets, alpha): + """ + Partition the dataset into subsets for federated learning. + + Args: + dataset: The dataset to be partitioned. + datadir (str): The directory containing the dataset. + partition (str): The type of partitioning to be applied ("homo", "hetero", "hetero-fix"). + n_nets (int): The number of clients (networks) to partition the data for. + alpha (float): A hyperparameter controlling the heterogeneity of the data partition. + + Returns: + tuple: A tuple containing partitioned data and related information. + """ logging.info("*********partition data***************") X_train, y_train, X_test, y_test = load_cinic10_data(datadir) @@ -176,7 +242,7 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): y_train = np.array(y_train) y_test = np.array(y_test) n_train = len(X_train) - # n_test = len(X_test) + if partition == "homo": total_num = n_train @@ -193,12 +259,12 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): while min_size < 10: idx_batch = [[] for _ in range(n_nets)] - # for each class in the dataset + for k in range(K): idx_k = np.where(y_train == k)[0] np.random.shuffle(idx_k) proportions = np.random.dirichlet(np.repeat(alpha, n_nets)) - ## Balance + proportions = np.array( [ p * (len(idx_j) < N / n_nets) @@ -234,21 +300,60 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts -# for centralized training + def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for centralized training using the CIFAR-10 dataset. + + Args: + dataset (str): The dataset name. + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training. Default is None. + + Returns: + tuple: A tuple containing the training and testing data loaders. + """ return get_dataloader_cinic10(datadir, train_bs, test_bs, dataidxs) -# for local devices + def get_dataloader_test( dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ): + """ + Get data loaders for decentralized (local devices) testing using the CIFAR-10 dataset. + + Args: + dataset (str): The dataset name. + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to use for training. + dataidxs_test (list): List of data indices to use for testing. + + Returns: + tuple: A tuple containing the training and testing data loaders for local devices. + """ return get_dataloader_test_cinic10( datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ) def get_dataloader_cinic10(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for centralized training using the CIFAR-10 dataset. + + Args: + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use for training. Default is None. + + Returns: + tuple: A tuple containing the training and testing data loaders. + """ dl_obj = ImageFolderTruncated transform_train, transform_test = _data_transforms_cinic10() @@ -272,6 +377,19 @@ def get_dataloader_cinic10(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_cinic10( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loaders for decentralized (local devices) testing using the CIFAR-10 dataset. + + Args: + datadir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to use for training. + dataidxs_test (list): List of data indices to use for testing. + + Returns: + tuple: A tuple containing the training and testing data loaders for local devices. + """ dl_obj = ImageFolderTruncated transform_train, transform_test = _data_transforms_cinic10() @@ -301,6 +419,21 @@ def load_partition_data_distributed_cinic10( client_number, batch_size, ): + """ + Load partitioned data for distributed training using the CIFAR-10 dataset. + + Args: + process_id (int): The ID of the current process. + dataset (str): The dataset name. + data_dir (str): The directory containing the dataset. + partition_method (str): The data partitioning method (e.g., 'homo' or 'hetero'). + partition_alpha (float): The alpha parameter for data partitioning. + client_number (int): The number of clients. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing training and testing data information for distributed training. + """ ( X_train, y_train, @@ -360,6 +493,20 @@ def load_partition_data_distributed_cinic10( def load_partition_data_cinic10( dataset, data_dir, partition_method, partition_alpha, client_number, batch_size ): + """ + Load partitioned data for centralized training using the CIFAR-10 dataset. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory containing the dataset. + partition_method (str): The data partitioning method (e.g., 'homo' or 'hetero'). + partition_alpha (float): The alpha parameter for data partitioning. + client_number (int): The number of clients. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing training and testing data information for centralized training. + """ ( X_train, y_train, diff --git a/python/fedml/data/cinic10/datasets.py b/python/fedml/data/cinic10/datasets.py index a82d553ac0..515cf63de3 100644 --- a/python/fedml/data/cinic10/datasets.py +++ b/python/fedml/data/cinic10/datasets.py @@ -16,11 +16,29 @@ def default_loader(path): + """ + Default image loader function. + + Args: + path (str): The file path to the image. + + Returns: + PIL.Image.Image: An RGB image loaded from the specified path. + """ return pil_loader(path) def pil_loader(path): - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + """ + Image loader function using the PIL library. + + Args: + path (str): The file path to the image. + + Returns: + PIL.Image.Image: An RGB image loaded from the specified path. + """ + # Open the path as a file to avoid ResourceWarning with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") diff --git a/python/fedml/data/edge_case_examples/data_loader.py b/python/fedml/data/edge_case_examples/data_loader.py index 726d262919..8ecbd44d08 100644 --- a/python/fedml/data/edge_case_examples/data_loader.py +++ b/python/fedml/data/edge_case_examples/data_loader.py @@ -28,6 +28,12 @@ def download_edgecase_data(data_cache_dir): + """ + Download edge case attack data and extract it to the specified directory. + + Args: + data_cache_dir (str): The directory where the data should be downloaded and extracted. + """ file_path = data_cache_dir + "/edge_case_examples.zip" logging.info(file_path) URL = "http://pages.cs.wisc.edu/~hongyiwang/edge_case_attack/edge_case_examples.zip" @@ -38,8 +44,17 @@ def download_edgecase_data(data_cache_dir): with zipfile.ZipFile(file_path, "r") as zip_ref: zip_ref.extractall(data_cache_dir) - def record_net_data_stats(y_train, net_dataidx_map): + """ + Record data statistics for each network based on the provided data index mapping. + + Args: + y_train (numpy.ndarray): The labels of the training data. + net_dataidx_map (dict): A dictionary mapping network indices to data indices. + + Returns: + dict: A dictionary containing class counts for each network. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -51,6 +66,15 @@ def record_net_data_stats(y_train, net_dataidx_map): def load_mnist_data(datadir): + """ + Load the MNIST dataset from the specified directory. + + Args: + datadir (str): The directory where the dataset is stored. + + Returns: + tuple: A tuple containing training and testing data and labels for MNIST. + """ transform = transforms.Compose([transforms.ToTensor()]) mnist_train_ds = MNIST_truncated( @@ -72,6 +96,15 @@ def load_mnist_data(datadir): def load_emnist_data(datadir): + """ + Load the EMNIST dataset from the specified directory. + + Args: + datadir (str): The directory where the dataset is stored. + + Returns: + tuple: A tuple containing training and testing data and labels for EMNIST. + """ transform = transforms.Compose([transforms.ToTensor()]) emnist_train_ds = EMNIST_truncated( @@ -93,6 +126,15 @@ def load_emnist_data(datadir): def load_cifar10_data(datadir): + """ + Load the CIFAR-10 dataset from the specified directory. + + Args: + datadir (str): The directory where the dataset is stored. + + Returns: + tuple: A tuple containing training and testing data and labels for CIFAR-10. + """ transform = transforms.Compose([transforms.ToTensor()]) cifar10_train_ds = CIFAR10_truncated( @@ -109,6 +151,20 @@ def load_cifar10_data(datadir): def partition_data(dataset, datadir, partition, n_nets, alpha, args): + """ + Partition the dataset based on the specified method and parameters. + + Args: + dataset (str): The name of the dataset (e.g., "mnist", "emnist", "cifar10"). + datadir (str): The directory where the dataset is stored. + partition (str): The partitioning method ("homo" or "hetero-dir"). + n_nets (int): The number of clients/networks. + alpha (float): A parameter for data partitioning. + args: Additional arguments. + + Returns: + dict: A dictionary mapping network indices to data indices. + """ if dataset == "mnist": X_train, y_train, X_test, y_test = load_mnist_data(datadir) n_train = X_train.shape[0] @@ -244,6 +300,19 @@ def partition_data(dataset, datadir, partition, n_nets, alpha, args): def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for the specified dataset. + + Args: + dataset (str): The name of the dataset (e.g., "mnist", "emnist", "cifar10"). + datadir (str): The directory where the dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (dict): A dictionary mapping network indices to data indices. + + Returns: + tuple: A tuple containing training and testing data loaders. + """ if dataset in ("mnist", "emnist", "cifar10"): if dataset == "mnist": dl_obj = MNIST_truncated @@ -320,6 +389,24 @@ def get_dataloader_normal_case( ardis_dataset=None, attack_case="normal-case", ): + """ + Get data loaders for the specified dataset with support for poison attacks. + + Args: + dataset (str): The name of the dataset (e.g., "mnist", "emnist", "cifar10"). + datadir (str): The directory where the dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (dict): A dictionary mapping network indices to data indices. + user_id (int): The user ID for poison attack. + num_total_users (int): The total number of users. + poison_type (str): The type of poison attack (e.g., "southwest"). + ardis_dataset: ARDIS dataset for poison attack (if applicable). + attack_case (str): The type of attack case (e.g., "normal-case"). + + Returns: + tuple: A tuple containing training and testing data loaders. + """ if dataset in ("mnist", "emnist", "cifar10"): if dataset == "mnist": dl_obj = MNIST_truncated @@ -391,6 +478,17 @@ def get_dataloader_normal_case( def load_poisoned_dataset(args): + """ + Load a poisoned dataset based on the provided arguments. + + Args: + args (Namespace): Command-line arguments containing dataset details. + + Returns: + DataLoader: DataLoader for the poisoned dataset. + DataLoader: DataLoader for the clean test dataset. + DataLoader: DataLoader for the targetted task test dataset. + """ use_cuda = not args.using_gpu and torch.cuda.is_available() kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} if args.dataset in ("mnist", "emnist"): diff --git a/python/fedml/data/edge_case_examples/datasets.py b/python/fedml/data/edge_case_examples/datasets.py index e493a16750..2cf26a649f 100644 --- a/python/fedml/data/edge_case_examples/datasets.py +++ b/python/fedml/data/edge_case_examples/datasets.py @@ -37,6 +37,18 @@ class MNIST_truncated(data.Dataset): def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): + """ + Initialize the MNIST_truncated dataset. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of data indices to include in the dataset. + train (bool, optional): Whether to load the training or testing data. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the target. + download (bool, optional): Whether to download the dataset if it's not found. + """ + self.root = root self.dataidxs = dataidxs @@ -48,6 +60,13 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated dataset based on the provided data indices. + + Returns: + torch.Tensor: The truncated data. + torch.Tensor: The corresponding labels/targets. + """ mnist_dataobj = MNIST(self.root, self.train, self.transform, self.target_transform, self.download) @@ -94,6 +113,17 @@ class EMNIST_truncated(data.Dataset): def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): + """ + Initialize the EMNIST_truncated dataset. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of data indices to include in the dataset. + train (bool, optional): Whether to load the training or testing data. + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the target. + download (bool, optional): Whether to download the dataset if it's not found. + """ self.root = root self.dataidxs = dataidxs @@ -105,6 +135,13 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated dataset based on the provided data indices. + + Returns: + torch.Tensor: The truncated data. + torch.Tensor: The corresponding labels/targets. + """ emnist_dataobj = EMNIST( self.root, split="digits", @@ -154,20 +191,28 @@ def __len__(self): def get_ardis_dataset(): - # load the data from csv's + """Load the ARDIS dataset and prepare it for training. + + This function loads the ARDIS dataset from CSV files, reshapes the images, + and prepares the dataset for training. + + Returns: + torch.utils.data.Dataset: The ARDIS dataset prepared for training. + """ + # Load the data from CSV files ardis_images = np.loadtxt("./../../../data/edge_case_examples/ARDIS/ARDIS_train_2828.csv", dtype="float") ardis_labels = np.loadtxt("./../../../data/edge_case_examples/ARDIS/ARDIS_train_labels.csv", dtype="float") - #### reshape to be [samples][width][height] + # Reshape the images to [samples][width][height] ardis_images = ardis_images.reshape(ardis_images.shape[0], 28, 28).astype("float32") - # labels are one-hot encoded + # Labels are one-hot encoded; extract images and labels for digit 7 indices_seven = np.where(ardis_labels[:, 7] == 1)[0] images_seven = ardis_images[indices_seven, :] images_seven = torch.tensor(images_seven).type(torch.uint8) + labels_seven = torch.tensor([7 for _ in ardis_labels]) - labels_seven = torch.tensor([7 for y in ardis_labels]) - + # Create an EMNIST dataset for digit 7 ardis_dataset = EMNIST( "./../../../data", split="digits", @@ -176,13 +221,23 @@ def get_ardis_dataset(): transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), ) + # Set the data and targets to the extracted images and labels ardis_dataset.data = images_seven ardis_dataset.targets = labels_seven return ardis_dataset - def get_southwest_dataset(attack_case="normal-case"): + """Load the Southwest dataset for a specified attack case. + + This function loads the Southwest dataset for a given attack case. + + Args: + attack_case (str): The attack case to load. Options are "normal-case" and "almost-edge-case". + + Returns: + pickle.Unpickler: The loaded Southwest dataset for the specified attack case. + """ if attack_case == "normal-case": with open( "./../../../data/edge_case_examples/southwest_cifar10/southwest_images_honest_full_normal.pkl", "rb", @@ -200,8 +255,8 @@ def get_southwest_dataset(attack_case="normal-case"): class EMNIST_NormalCase_truncated(data.Dataset): """ - we use this class for normal case attack where normal - users also hold the poisoned data point with true label + Dataset class for normal case attack where normal + users also hold the poisoned data point with true label. """ def __init__( @@ -218,7 +273,22 @@ def __init__( ardis_dataset_train=None, attack_case="normal-case", ): + """ + Initializes the EMNIST_NormalCase_truncated dataset. + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of indices to select specific data points. Default is None. + train (bool): True for training dataset, False for testing dataset. + transform (callable, optional): A function/transform to apply to the data. Default is None. + target_transform (callable, optional): A function/transform to apply to the target. Default is None. + download (bool): Whether to download the dataset if it's not found in the root directory. Default is False. + user_id (int): ID of the user accessing the dataset. + num_total_users (int): Total number of users in the scenario. + poison_type (str): Type of poisoning data. Default is "ardis". + ardis_dataset_train (torch.utils.data.Dataset): ARDIS dataset used for poisoning. Default is None. + attack_case (str): The type of attack case. Options are "normal-case" and "almost-edge-case". Default is "normal-case". + """ self.root = root self.dataidxs = dataidxs self.train = train @@ -229,9 +299,9 @@ def __init__( if attack_case == "normal-case": self._num_users_hold_edge_data = int( 3383 / 20 - ) # we allow 1/20 of the users (other than the attacker) to hold the edge data. + ) # We allow 1/20 of the users (other than the attacker) to hold the edge data. else: - # almost edge case + # Almost edge case self._num_users_hold_edge_data = 66 # ~2% of users hold data if poison_type == "ardis": @@ -249,17 +319,18 @@ def __init__( self.saved_ardis_dataset_train = self.ardis_dataset_train.data[user_partition] self.saved_ardis_label_train = self.ardis_dataset_train.targets[user_partition] else: - NotImplementedError("Unsupported poison type for normal case attack ...") + raise NotImplementedError("Unsupported poison type for normal case attack ...") - # logging.info("USER: {} got {} points".format(user_id, len(self.saved_ardis_dataset_train.data))) self.data, self.target = self.__build_truncated_dataset__() - # if self.dataidxs is not None: - # print("$$$$$$$$ Inside data loader: user ID: {}, Combined data: {}, Ori data shape: {}".format( - # user_id, self.data.shape, len(dataidxs))) - def __build_truncated_dataset__(self): + """ + Builds the truncated dataset by combining the EMNIST dataset with the ARDIS dataset. + Returns: + np.ndarray: Combined data. + np.ndarray: Combined target labels. + """ emnist_dataobj = EMNIST( self.root, split="digits", @@ -290,7 +361,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] @@ -307,6 +378,21 @@ def __len__(self): class CIFAR10_truncated(data.Dataset): + """ + Dataset class for a truncated version of the CIFAR-10 dataset. + + This class allows you to create a truncated version of the CIFAR-10 dataset + by selecting specific data indices. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of indices to select specific data points. Default is None. + train (bool): True for training dataset, False for testing dataset. + transform (callable, optional): A function/transform to apply to the data. Default is None. + target_transform (callable, optional): A function/transform to apply to the target. Default is None. + download (bool): Whether to download the dataset if it's not found in the root directory. Default is False. + """ + def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): @@ -321,12 +407,16 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Builds the truncated dataset by selecting specific data indices. + Returns: + np.ndarray: Combined data. + np.ndarray: Combined target labels. + """ cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) if self.train: - # print("train member of the class: {}".format(self.train)) - # data = cifar_dataobj.train_data data = cifar_dataobj.data target = np.array(cifar_dataobj.targets) else: @@ -345,7 +435,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] @@ -363,8 +453,8 @@ def __len__(self): class CIFAR10NormalCase_truncated(data.Dataset): """ - we use this class for normal case attack where normal - users also hold the poisoned data point with true label + Dataset class for normal case attack where normal + users also hold the poisoned data point with true label. """ def __init__( @@ -381,6 +471,22 @@ def __init__( ardis_dataset_train=None, attack_case="normal-case", ): + """ + Initializes the CIFAR10NormalCase_truncated dataset. + + Args: + root (str): Root directory where the dataset is stored. + dataidxs (list, optional): List of indices to select specific data points. Default is None. + train (bool): True for training dataset, False for testing dataset. + transform (callable, optional): A function/transform to apply to the data. Default is None. + target_transform (callable, optional): A function/transform to apply to the target. Default is None. + download (bool): Whether to download the dataset if it's not found in the root directory. Default is False. + user_id (int): ID of the user accessing the dataset. + num_total_users (int): Total number of users in the scenario. + poison_type (str): Type of poisoning data. Default is "southwest". + ardis_dataset_train (np.ndarray): ARDIS dataset used for poisoning. Default is None. + attack_case (str): The type of attack case. Options are "normal-case" and "almost-edge-case". Default is "normal-case". + """ self.root = root self.dataidxs = dataidxs @@ -447,6 +553,13 @@ def __init__( # user_id, self.data.shape, len(dataidxs))) def __build_truncated_dataset__(self): + """ + Builds the truncated dataset by combining the CIFAR-10 dataset with the poisoned ARDIS dataset. + + Returns: + np.ndarray: Combined data. + np.ndarray: Combined target labels. + """ cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) diff --git a/python/fedml/data/fed_cifar100/data_loader.py b/python/fedml/data/fed_cifar100/data_loader.py index e54111e575..c05d909ca1 100644 --- a/python/fedml/data/fed_cifar100/data_loader.py +++ b/python/fedml/data/fed_cifar100/data_loader.py @@ -23,7 +23,19 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): - + """ + Get data loaders for training and testing. + + Args: + dataset (str): Dataset name. + data_dir (str): Directory containing the data. + train_bs (int): Batch size for training data loader. + test_bs (int): Batch size for testing data loader. + client_idx (int, optional): Index of the client to load data for. + + Returns: + tuple: A tuple containing the training data loader and testing data loader. + """ train_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TEST_FILE), "r") train_x = [] @@ -31,7 +43,7 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): test_x = [] test_y = [] - # load data in numpy format from h5 file + # Load data in numpy format from h5 file if client_idx is None: train_x = np.vstack( [ @@ -62,14 +74,14 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): [test_h5[_EXAMPLE][client_id_test][_LABEL][()]] ).squeeze() - # preprocess + # Preprocess train_x = utils.preprocess_cifar_img(torch.tensor(train_x), train=True) train_y = torch.tensor(train_y) if len(test_x) != 0: test_x = utils.preprocess_cifar_img(torch.tensor(test_x), train=False) test_y = torch.tensor(test_y) - # generate dataloader + # Generate data loader train_ds = data.TensorDataset(train_x, train_y) train_dl = data.DataLoader( dataset=train_ds, batch_size=train_bs, shuffle=True, drop_last=False @@ -91,11 +103,22 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): def load_partition_data_distributed_federated_cifar100( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): - + """ + Load distributed federated CIFAR-100 dataset for a specific client. + + Args: + process_id (int): Identifier of the client process. + dataset (str): Dataset name. + data_dir (str): Directory containing the data. + batch_size (int, optional): Batch size for data loader. + + Returns: + tuple: A tuple containing information about the dataset, including the number of classes. + """ class_num = 100 if process_id == 0: - # get global dataset + # Get global dataset train_data_global, test_data_global = get_dataloader( dataset, data_dir, batch_size, batch_size ) @@ -107,7 +130,7 @@ def load_partition_data_distributed_federated_cifar100( test_data_local = None local_data_num = 0 else: - # get local dataset + # Get local dataset train_data_local, test_data_local = get_dataloader( dataset, data_dir, batch_size, batch_size, process_id - 1 ) @@ -132,6 +155,17 @@ def load_partition_data_distributed_federated_cifar100( def load_partition_data_federated_cifar100( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): + """ + Load federated CIFAR-100 dataset for multiple clients. + + Args: + dataset (str): Dataset name. + data_dir (str): Directory containing the data. + batch_size (int, optional): Batch size for data loader. + + Returns: + tuple: A tuple containing information about the dataset, including the number of classes. + """ class_num = 100 diff --git a/python/fedml/data/fed_cifar100/utils.py b/python/fedml/data/fed_cifar100/utils.py index 323654200c..d2108ce4dc 100644 --- a/python/fedml/data/fed_cifar100/utils.py +++ b/python/fedml/data/fed_cifar100/utils.py @@ -9,7 +9,18 @@ # def cifar100_transform(img_mean, img_std, train=True, crop_size=(24, 24)): def cifar100_transform(img_mean, img_std, train=True, crop_size=32): - """cropping, flipping, and normalizing.""" + """ + Define data transformations for CIFAR-100 dataset. + + Args: + img_mean (tuple): Mean values for image normalization. + img_std (tuple): Standard deviation values for image normalization. + train (bool): Whether the transformations are for training or testing data. + crop_size (int): Size of the crop (default is 32). + + Returns: + torchvision.transforms.Compose: A composition of data transformations. + """ if train: return transforms.Compose( [ @@ -40,6 +51,16 @@ def cifar100_transform(img_mean, img_std, train=True, crop_size=32): def preprocess_cifar_img(img, train): + """ + Preprocess CIFAR-100 images for use in a PyTorch model. + + Args: + img (torch.Tensor): Input images. + train (bool): Whether the data is for training or testing. + + Returns: + torch.Tensor: Preprocessed images as a PyTorch tensor. + """ # scale img to range [0,1] to fit ToTensor api img = torch.div(img, 255.0) transoformed_img = torch.stack( diff --git a/python/fedml/data/fed_shakespeare/data_loader.py b/python/fedml/data/fed_shakespeare/data_loader.py index 5f21334dd6..aad409a6b3 100644 --- a/python/fedml/data/fed_shakespeare/data_loader.py +++ b/python/fedml/data/fed_shakespeare/data_loader.py @@ -21,19 +21,31 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): - + """ + Get data loaders for the specified dataset. + + Args: + dataset (str): The name of the dataset. + data_dir (str): The directory containing the dataset. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + client_idx (int): Index of the client (None for all clients). + + Returns: + tuple: A tuple of DataLoader objects for training and testing data. + """ train_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TRAIN_FILE), "r") test_h5 = h5py.File(os.path.join(data_dir, DEFAULT_TEST_FILE), "r") train_ds = [] test_ds = [] - # load data + # Load data if client_idx is None: - # get ids of all clients + # Get IDs of all clients train_ids = client_ids_train test_ids = client_ids_test else: - # get ids of single client + # Get IDs of a single client train_ids = [client_ids_train[client_idx]] test_ids = [client_ids_test[client_idx]] @@ -46,7 +58,7 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): raw_test = [x.decode("utf8") for x in raw_test] test_ds.extend(utils.preprocess(raw_test)) - # split data + # Split data train_x, train_y = utils.split(train_ds) test_x, test_y = utils.split(test_ds) train_ds = data.TensorDataset(torch.tensor(train_x[:, :]), torch.tensor(train_y[:])) @@ -62,26 +74,36 @@ def get_dataloader(dataset, data_dir, train_bs, test_bs, client_idx=None): test_h5.close() return train_dl, test_dl - def load_partition_data_distributed_federated_shakespeare( process_id, dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): - + """ + Load partitioned data for distributed federated learning with Shakespearean text data. + + Args: + process_id (int): The process ID of the current worker (0 for the server). + dataset (str): The name of the dataset. + data_dir (str): The directory containing the dataset. + batch_size (int): Batch size for data loaders. + + Returns: + tuple: A tuple containing information about the data partitions and vocabulary size. + """ if process_id == 0: - # get global dataset + # Get global dataset train_data_global, test_data_global = get_dataloader( dataset, data_dir, batch_size, batch_size, process_id - 1 ) - train_data_num = len(train_data_global) - test_data_num = len(test_data_global) + train_data_num = len(train_data_global.dataset) + test_data_num = len(test_data_global.dataset) logging.info("train_dl_global number = " + str(train_data_num)) logging.info("test_dl_global number = " + str(test_data_num)) train_data_local = None test_data_local = None local_data_num = 0 else: - # get local dataset - # client id list + # Get local dataset + # Client ID list train_file_path = os.path.join(data_dir, DEFAULT_TRAIN_FILE) test_file_path = os.path.join(data_dir, DEFAULT_TEST_FILE) with h5py.File(train_file_path, "r") as train_h5, h5py.File( @@ -117,8 +139,18 @@ def load_partition_data_distributed_federated_shakespeare( def load_partition_data_federated_shakespeare( dataset, data_dir, batch_size=DEFAULT_BATCH_SIZE ): - - # client id list + """ + Load partitioned data for federated learning with Shakespearean text data. + + Args: + dataset (str): The name of the dataset. + data_dir (str): The directory containing the dataset. + batch_size (int): Batch size for data loaders (default is DEFAULT_BATCH_SIZE). + + Returns: + tuple: A tuple containing information about the data partitions and vocabulary size. + """ + # Client ID list train_file_path = os.path.join(data_dir, DEFAULT_TRAIN_FILE) test_file_path = os.path.join(data_dir, DEFAULT_TEST_FILE) with h5py.File(train_file_path, "r") as train_h5, h5py.File( @@ -128,7 +160,7 @@ def load_partition_data_federated_shakespeare( client_ids_train = list(train_h5[_EXAMPLE].keys()) client_ids_test = list(test_h5[_EXAMPLE].keys()) - # get local dataset + # Get local dataset data_local_num_dict = dict() train_data_local_dict = dict() test_data_local_dict = dict() @@ -149,7 +181,7 @@ def load_partition_data_federated_shakespeare( train_data_local_dict[client_idx] = train_data_local test_data_local_dict[client_idx] = test_data_local - # global dataset + # Global dataset train_data_global = data.DataLoader( data.ConcatDataset( list(dl.dataset for dl in list(train_data_local_dict.values())) @@ -185,3 +217,4 @@ def load_partition_data_federated_shakespeare( VOCAB_LEN, ) + diff --git a/python/fedml/data/fed_shakespeare/utils.py b/python/fedml/data/fed_shakespeare/utils.py index 8393710249..13db267b2e 100644 --- a/python/fedml/data/fed_shakespeare/utils.py +++ b/python/fedml/data/fed_shakespeare/utils.py @@ -21,8 +21,14 @@ def get_word_dict(): + """ + Get a dictionary mapping words to their corresponding IDs. + + Returns: + collections.OrderedDict: A dictionary with words as keys and their IDs as values. + """ global word_dict - if word_dict == None: + if word_dict is None: words = [_pad] + CHAR_VOCAB + [_bos] + [_eos] word_dict = collections.OrderedDict() for i, w in enumerate(words): @@ -31,18 +37,42 @@ def get_word_dict(): def get_word_list(): + """ + Get a list of words in the vocabulary. + + Returns: + list: A list of words in the vocabulary. + """ global word_list - if word_list == None: + if word_list is None: word_dict = get_word_dict() word_list = list(word_dict.keys()) return word_list def id_to_word(idx): + """ + Convert a word ID to the corresponding word. + + Args: + idx (int): The word ID. + + Returns: + str: The corresponding word. + """ return get_word_list()[idx] def char_to_id(char): + """ + Convert a character to its corresponding ID using the word_dict. + + Args: + char (str): The character to convert. + + Returns: + int: The corresponding ID for the character. + """ word_dict = get_word_dict() if char in word_dict: return word_dict[char] @@ -51,15 +81,29 @@ def char_to_id(char): def preprocess(sentences, max_seq_len=SEQUENCE_LENGTH): + """ + Preprocess a list of sentences by converting characters to IDs and padding. + Args: + sentences (list): A list of sentences, where each sentence is a string. + max_seq_len (int): Maximum sequence length (including start and end tokens). + + Returns: + list: A list of sequences, where each sequence is a list of token IDs. + """ sequences = [] def to_ids(sentence, num_oov_buckets=1): """ - map list of sentence to list of [idx..] and pad to max_seq_len + 1 + Map a sentence to a list of token IDs and pad it to the specified length. + Args: - num_oov_buckets : The number of out of vocabulary buckets. - max_seq_len: Integer determining shape of padded batches. + sentence (str): The input sentence. + num_oov_buckets (int): The number of out-of-vocabulary (OOV) buckets. + max_seq_len (int): Maximum sequence length (including start and end tokens). + + Returns: + list: A list of token IDs, padded to max_seq_len. """ tokens = [char_to_id(c) for c in sentence] tokens = [char_to_id(_bos)] + tokens + [char_to_id(_eos)] @@ -77,12 +121,23 @@ def to_ids(sentence, num_oov_buckets=1): def split(dataset): + """ + Split a dataset into input sequences (x) and target sequences (y). + + Args: + dataset (list): A list of sequences, where each sequence is a list of token IDs. + + Returns: + tuple: A tuple containing two arrays, x and y, where x represents input sequences + and y represents target sequences. + """ ds = np.asarray(dataset) x = ds[:, :-1] y = ds[:, 1:] return x, y + if __name__ == "__main__": print( split( diff --git a/python/fedml/data/fednlp/base/data_manager/base_data_manager.py b/python/fedml/data/fednlp/base/data_manager/base_data_manager.py index 4afc77e375..b0addb59f7 100644 --- a/python/fedml/data/fednlp/base/data_manager/base_data_manager.py +++ b/python/fedml/data/fednlp/base/data_manager/base_data_manager.py @@ -12,8 +12,29 @@ class BaseDataManager(ABC): + """Abstract base class for managing data in federated learning scenarios. + + This class defines the common interface and functionality for managing data in federated learning, + including loading, partitioning, and distributing datasets to clients. + + Attributes: + args: The command-line arguments passed to the manager. + model_args: The model-specific arguments. + train_batch_size: The batch size for training data. + eval_batch_size: The batch size for evaluation data. + process_id: The identifier of the current process. + num_workers: The total number of workers (including the server). + """ @abstractmethod def __init__(self, args, model_args, process_id, num_workers): + """Initialize the BaseDataManager. + + Args: + args: Command-line arguments. + model_args: Model-specific arguments. + process_id: Identifier of the current process. + num_workers: Total number of workers (including the server). + """ self.model_args = model_args self.args = args self.train_batch_size = model_args.train_batch_size @@ -44,6 +65,14 @@ def __init__(self, args, model_args, process_id, num_workers): @staticmethod def load_attributes(data_path): + """Load data attributes from an HDF5 data file. + + Args: + data_path: Path to the HDF5 data file. + + Returns: + Dictionary containing data attributes. + """ data_file = h5py.File(data_path, "r", swmr=True) attributes = json.loads(data_file["attributes"][()]) data_file.close() @@ -51,6 +80,15 @@ def load_attributes(data_path): @staticmethod def load_num_clients(partition_file_path, partition_name): + """Load the number of clients from a partition file. + + Args: + partition_file_path: Path to the partition file. + partition_name: Name of the partition. + + Returns: + The number of clients. + """ data_file = h5py.File(partition_file_path, "r", swmr=True) num_clients = int(data_file[partition_name]["n_clients"][()]) data_file.close() @@ -58,11 +96,27 @@ def load_num_clients(partition_file_path, partition_name): @abstractmethod def read_instance_from_h5(self, data_file, index_list, desc): + """Read instances from an HDF5 data file. + + Args: + data_file: HDF5 data file object. + index_list: List of indices to read. + desc: Description of the read operation. + + Returns: + Data instances. + """ pass def sample_client_index(self, process_id, num_workers): - """ - Sample client indices according to process_id + """Sample client indices according to the process_id. + + Args: + process_id (int): The identifier of the current process. + num_workers (int): The total number of workers. + + Returns: + list or None: A list of client indices if process_id is not 0, else None. """ # process_id = 0 means this process is the server process if process_id == 0: @@ -71,6 +125,14 @@ def sample_client_index(self, process_id, num_workers): return self._simulated_sampling(process_id) def _simulated_sampling(self, process_id): + """Simulated client sampling for federated learning. + + Args: + process_id (int): The identifier of the current process. + + Returns: + list: A list of sampled client indices. + """ res_client_indexes = list() for round_idx in range(self.args.comm_round): if self.num_clients == self.num_workers: @@ -92,6 +154,14 @@ def get_all_clients(self): return list(range(0, self.num_clients)) def load_centralized_data(self, cut_off=None): + """Load centralized training and testing data. + + Args: + cut_off (int, optional): The maximum number of data points to load. + + Returns: + tuple: A tuple containing centralized training and testing data loaders. + """ state, res = self._load_data_loader_from_cache(-1) if state: ( @@ -169,6 +239,14 @@ def load_centralized_data(self, cut_off=None): return train_dl, test_dl def load_federated_data(self, test_cut_off=None): + """Load federated training and testing data. + + Args: + test_cut_off (int, optional): The maximum number of testing data points to load. + + Returns: + tuple: A tuple containing federated training and testing data and related information. + """ ( train_data_num, test_data_num, @@ -193,6 +271,16 @@ def load_federated_data(self, test_cut_off=None): ) def _load_federated_data_server(self, test_only=False, test_cut_off=None): + """Load federated training and testing data from the server. + + Args: + test_only (bool, optional): Whether to load only testing data. Defaults to False. + test_cut_off (int, optional): The maximum number of testing data points to load. + + Returns: + tuple: A tuple containing the number of training data points, the number of testing data points, + federated training data loader, and federated testing data loader. + """ # state, res = self._load_data_loader_from_cache(-1) state = False train_data_local_dict = None @@ -288,6 +376,12 @@ def _load_federated_data_server(self, test_only=False, test_cut_off=None): return (train_data_num, test_data_num, train_data_global, test_data_global) def _load_federated_data_local(self): + """Load federated training and testing data for local clients. + + Returns: + tuple: A tuple containing dictionaries with local client data loaders, the number of clients, + and the number of training data points and testing data points. + """ data_file = h5py.File(self.args.data_file_path, "r", swmr=True) partition_file = h5py.File(self.args.partition_file_path, "r", swmr=True) @@ -397,8 +491,17 @@ def _load_federated_data_local(self): ) def _load_data_loader_from_cache(self, client_id): - """ - Different clients has different cache file. client_id = -1 means loading the cached file on server end. + """Load cached data loader from cache file for a specific client. + + Different clients has different cache file. client_id = -1 means + loading the cached file on server end. + + Args: + client_id (int): The ID of the client for which to load the cached data loader. + + Returns: + tuple: A tuple containing a boolean indicating whether the data loader was loaded from cache, + and the cached data loader if available. """ args = self.args model_args = self.model_args diff --git a/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py b/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py index 352b4d38fc..db1b6bd8b3 100644 --- a/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py +++ b/python/fedml/data/fednlp/base/preprocess/base_preprocessor.py @@ -1,11 +1,41 @@ from abc import ABC, abstractmethod +from abc import ABC, abstractmethod + class BasePreprocessor(ABC): + """Abstract base class for data preprocessors. + + This class defines the common interface for data preprocessors, which are responsible for transforming + and preparing data for further processing or analysis. + + Attributes: + **kwargs: Additional keyword arguments specific to the preprocessor implementation. + + Methods: + transform(*args): Abstract method to transform data. + + """ + @abstractmethod def __init__(self, **kwargs): + """Initialize the BasePreprocessor with optional keyword arguments. + + Args: + **kwargs: Additional keyword arguments specific to the preprocessor implementation. + """ self.__dict__.update(kwargs) @abstractmethod def transform(self, *args): + """Transform data using the preprocessor. + + This method should be implemented by subclasses to apply data transformation operations. + + Args: + *args: Variable-length arguments representing the input data to be transformed. + + Returns: + Transformed data or processed result. + """ pass diff --git a/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py b/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py index 5f50a80a95..9665cb276c 100644 --- a/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py +++ b/python/fedml/data/fednlp/base/raw_data/base_raw_data_loader.py @@ -7,27 +7,96 @@ class BaseRawDataLoader(ABC): + """Abstract base class for raw data loaders. + + This class defines the common interface for raw data loaders, which are responsible for loading + and processing raw data from various sources. + + Attributes: + data_path (str): The path to the raw data. + attributes (dict): A dictionary to store attributes related to the loaded data. + + Methods: + load_data(): Abstract method to load the raw data. + process_data_file(file_path): Abstract method to process a data file. + generate_h5_file(file_path): Abstract method to generate an HDF5 file from the loaded data. + + """ + @abstractmethod def __init__(self, data_path): + """Initialize the BaseRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ self.data_path = data_path self.attributes = dict() self.attributes["index_list"] = None @abstractmethod def load_data(self): + """Load the raw data. + + This method should be implemented by subclasses to load raw data from the specified data_path. + + Returns: + None + """ pass @abstractmethod def process_data_file(self, file_path): + """Process a data file. + + This method should be implemented by subclasses to process a specific data file. + + Args: + file_path (str): The path to the data file to be processed. + + Returns: + None + """ pass @abstractmethod def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + This method should be implemented by subclasses to generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ pass class TextClassificationRawDataLoader(BaseRawDataLoader): + """Raw data loader for text classification tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for text classification tasks. + + Attributes: + X (dict): A dictionary to store input data. + Y (dict): A dictionary to store target labels. + attributes (dict): Additional attributes related to the loaded data, including 'num_labels', + 'label_vocab', and 'task_type' which is set to "text_classification". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the TextClassificationRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(TextClassificationRawDataLoader, self).__init__(data_path) self.X = dict() self.Y = dict() @@ -36,6 +105,14 @@ def __init__(self, data_path): self.attributes["task_type"] = "text_classification" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in self.X.keys(): @@ -45,7 +122,30 @@ def generate_h5_file(self, file_path): class SpanExtractionRawDataLoader(BaseRawDataLoader): + """Raw data loader for span extraction tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for span extraction tasks. + + Attributes: + context_X (dict): A dictionary to store context input data. + question_X (dict): A dictionary to store question input data. + Y (dict): A dictionary to store target spans. + Y_answer (dict): A dictionary to store target answers. + attributes (dict): Additional attributes related to the loaded data, including 'task_type' which is + set to "span_extraction". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the SpanExtractionRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(SpanExtractionRawDataLoader, self).__init__(data_path) self.context_X = dict() self.question_X = dict() @@ -54,6 +154,14 @@ def __init__(self, data_path): self.Y_answer = dict() def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in self.context_X.keys(): @@ -65,7 +173,28 @@ def generate_h5_file(self, file_path): class SeqTaggingRawDataLoader(BaseRawDataLoader): + """Raw data loader for sequence tagging tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for sequence tagging tasks. + + Attributes: + X (dict): A dictionary to store input sequences. + Y (dict): A dictionary to store target labels. + attributes (dict): Additional attributes related to the loaded data, including 'num_labels', + 'label_vocab', and 'task_type' which is set to "seq_tagging". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the SeqTaggingRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(SeqTaggingRawDataLoader, self).__init__(data_path) self.X = dict() self.Y = dict() @@ -74,6 +203,14 @@ def __init__(self, data_path): self.attributes["task_type"] = "seq_tagging" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) utf8_type = h5py.string_dtype("utf-8", None) @@ -84,13 +221,41 @@ def generate_h5_file(self, file_path): class Seq2SeqRawDataLoader(BaseRawDataLoader): + """Raw data loader for sequence-to-sequence (seq2seq) tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for sequence-to-sequence tasks. + + Attributes: + X (dict): A dictionary to store source sequences. + Y (dict): A dictionary to store target sequences. + task_type (str): The type of the task, which is set to "seq2seq". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the Seq2SeqRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(Seq2SeqRawDataLoader, self).__init__(data_path) self.X = dict() self.Y = dict() self.task_type = "seq2seq" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in self.X.keys(): @@ -100,12 +265,39 @@ def generate_h5_file(self, file_path): class LanguageModelRawDataLoader(BaseRawDataLoader): + """Raw data loader for language modeling tasks. + + This class extends the BaseRawDataLoader and provides specific functionality for loading and processing + raw data for language modeling tasks. + + Attributes: + X (dict): A dictionary to store language model input data. + task_type (str): The type of the task, which is set to "lm". + + Methods: + generate_h5_file(file_path): Generate an HDF5 file from the loaded data. + + """ + def __init__(self, data_path): + """Initialize the LanguageModelRawDataLoader. + + Args: + data_path (str): The path to the raw data. + """ super(LanguageModelRawDataLoader, self).__init__(data_path) self.X = dict() self.task_type = "lm" def generate_h5_file(self, file_path): + """Generate an HDF5 file from the loaded data. + + Args: + file_path (str): The path to the HDF5 file to be generated. + + Returns: + None + """ f = h5py.File(file_path, "w") f["attributes"] = json.dumps(self.attributes) for key in tqdm(self.X.keys(), desc="generate data h5 file"): diff --git a/python/fedml/data/fednlp/base/raw_data/partition.py b/python/fedml/data/fednlp/base/raw_data/partition.py index e56e77e8bd..83d7f83634 100644 --- a/python/fedml/data/fednlp/base/raw_data/partition.py +++ b/python/fedml/data/fednlp/base/raw_data/partition.py @@ -5,6 +5,27 @@ def uniform_partition(train_index_list, test_index_list=None, n_clients=N_CLIENTS): + """Uniformly partition data indices into multiple clients. + + This function partitions a list of training data indices into 'n_clients' subsets, + ensuring a roughly equal distribution of data among clients. Optionally, it can also + partition a list of test data indices in a similar manner. + + Args: + train_index_list (list): List of training data indices. + test_index_list (list, optional): List of test data indices. Default is None. + n_clients (int): Number of clients to partition the data for. + + Returns: + dict: A dictionary containing the data partition information. + - 'n_clients': Number of clients. + - 'partition_data': A dictionary where each key represents a client ID (0 to n_clients-1), + and the value is another dictionary containing the partitioned data for that client. + For each client: + - 'train': List of training data indices. + - 'test': List of test data indices (if 'test_index_list' is provided). + + """ partition_dict = dict() partition_dict["n_clients"] = n_clients partition_dict["partition_data"] = dict() diff --git a/python/fedml/data/fednlp/base/utils.py b/python/fedml/data/fednlp/base/utils.py index 6048cfef05..2bd53a6c0e 100644 --- a/python/fedml/data/fednlp/base/utils.py +++ b/python/fedml/data/fednlp/base/utils.py @@ -18,6 +18,18 @@ class SpacyTokenizer: + """Tokenizer class for different languages using spaCy models. + + Attributes: + __zh_tokenizer: Chinese tokenizer instance. + __en_tokenizer: English tokenizer instance. + __cs_tokenizer: Czech tokenizer instance. + __de_tokenizer: German tokenizer instance. + __ru_tokenizer: Russian tokenizer instance. + + Methods: + get_tokenizer(lang): Get a spaCy tokenizer for the specified language. + """ def __init__(self): self.__zh_tokenizer = None self.__en_tokenizer = None @@ -27,6 +39,17 @@ def __init__(self): @staticmethod def get_tokenizer(lang): + """Get a spaCy tokenizer for the specified language. + + Args: + lang (str): The language code (e.g., "zh" for Chinese, "en" for English). + + Returns: + spacy.language.Language: A spaCy tokenizer instance. + + Raises: + Exception: If an unacceptable language code is provided. + """ if lang == "zh": # nlp = spacy.load("zh_core_web_sm") nlp = Chinese() @@ -46,37 +69,49 @@ def get_tokenizer(lang): @property def zh_tokenizer(self): + """Chinese tokenizer property.""" if self.__zh_tokenizer is None: self.__zh_tokenizer = self.get_tokenizer("zh") return self.__zh_tokenizer @property def en_tokenizer(self): + """English tokenizer property.""" if self.__en_tokenizer is None: self.__en_tokenizer = self.get_tokenizer("en") return self.__en_tokenizer @property def cs_tokenizer(self): + """Czech tokenizer property.""" if self.__cs_tokenizer is None: self.__cs_tokenizer = self.get_tokenizer("cs") return self.__cs_tokenizer @property def de_tokenizer(self): + """German tokenizer property.""" if self.__de_tokenizer is None: self.__de_tokenizer = self.get_tokenizer("de") return self.__de_tokenizer @property def ru_tokenizer(self): + """Russian tokenizer property.""" if self.__ru_tokenizer is None: self.__ru_tokenizer = self.get_tokenizer("ru") return self.__ru_tokenizer def build_vocab(x): - # x -> [num_seqs, num_tokens] + """Build a vocabulary from a list of tokenized sequences. + + Args: + x (list): List of tokenized sequences, where each sequence is a list of tokens. + + Returns: + dict: A vocabulary where tokens are keys and their corresponding indices are values. + """ vocab = dict() for single_x in x: for token in single_x: @@ -88,6 +123,14 @@ def build_vocab(x): def build_freq_vocab(x): + """Build a frequency-based vocabulary from a list of tokenized sequences. + + Args: + x (list): List of tokenized sequences, where each sequence is a list of tokens. + + Returns: + dict: A vocabulary where tokens are keys and their frequencies are values. + """ freq_vocab = dict() for single_x in x: for token in single_x: @@ -99,6 +142,16 @@ def build_freq_vocab(x): def padding_data(x, max_sequence_length): + """Pad sequences in a list to a specified maximum sequence length. + + Args: + x (list): List of sequences, where each sequence is a list of tokens. + max_sequence_length (int): The desired maximum sequence length for padding. + + Returns: + list: Padded sequences with a length of max_sequence_length. + list: Sequence lengths before padding. + """ padding_x = [] seq_lens = [] for single_x in x: @@ -115,6 +168,17 @@ def padding_data(x, max_sequence_length): def padding_char_data(x, max_sequence_length, max_word_length): + """Pad character-level sequences in a list to specified maximum lengths. + + Args: + x (list): List of sequences, where each sequence is a list of character tokens. + max_sequence_length (int): The desired maximum sequence length for padding. + max_word_length (int): The desired maximum word length for character tokens. + + Returns: + list: Padded character sequences with specified word and sequence lengths. + list: Word lengths before padding. + """ padding_x = [] word_lens = [] for sent in x: @@ -142,6 +206,15 @@ def padding_char_data(x, max_sequence_length, max_word_length): def token_to_idx(x, vocab): + """Convert tokenized sequences to indices using a vocabulary. + + Args: + x (list): List of tokenized sequences, where each sequence is a list of tokens. + vocab (dict): A vocabulary where tokens are keys and their corresponding indices are values. + + Returns: + list: Sequences with tokens replaced by their corresponding indices. + """ idx_x = [] for single_x in x: new_single_x = [] @@ -247,6 +320,15 @@ def NER_data_formatter(ner_data): def generate_h5_from_dict(file_name, data_dict): + """Generate an HDF5 file from a nested dictionary. + + Args: + file_name (str): The name of the HDF5 file to be created. + data_dict (dict): The nested dictionary containing data to be stored in the HDF5 file. + + Returns: + None + """ def dict_to_h5_recursive(h5_file, path, dic): for key, value in dic.items(): if isinstance(value, dict): @@ -270,6 +352,14 @@ def dict_to_h5_recursive(h5_file, path, dic): def decode_data_from_h5(data): + """Decode data from bytes to UTF-8 string if necessary. + + Args: + data (bytes or any): The input data, which may be in bytes. + + Returns: + str or any: The decoded data as a UTF-8 string, or the input data if it's not in bytes. + """ if isinstance(data, bytes): return data.decode("utf8") return data From f6705353d14f638c183b0ecc1aec22da410c5d9d Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Mon, 18 Sep 2023 12:58:19 +0530 Subject: [PATCH 62/70] fg --- python/fedml/data/cifar10/data_loader.py | 159 +++++++++++++- python/fedml/data/cifar10/datasets.py | 56 ++++- python/fedml/data/cifar10/efficient_loader.py | 200 +++++++++++++++++- python/fedml/data/cifar10/without_reload.py | 12 ++ python/fedml/data/cifar100/data_loader.py | 126 ++++++++++- python/fedml/data/cifar100/datasets.py | 44 +++- 6 files changed, 564 insertions(+), 33 deletions(-) diff --git a/python/fedml/data/cifar10/data_loader.py b/python/fedml/data/cifar10/data_loader.py index 459c1bbc53..cb85eb4a27 100644 --- a/python/fedml/data/cifar10/data_loader.py +++ b/python/fedml/data/cifar10/data_loader.py @@ -11,6 +11,15 @@ def read_data_distribution( filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt", ): + """ + Read data distribution from a file. + + Args: + filename (str, optional): Path to the distribution file (default: "./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt"). + + Returns: + dict: A dictionary representing the data distribution. + """ distribution = {} with open(filename, "r") as data: for x in data.readlines(): @@ -26,10 +35,18 @@ def read_data_distribution( ) return distribution - def read_net_dataidx_map( filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt", ): + """ + Read network data index map from a file. + + Args: + filename (str, optional): Path to the network data index map file (default: "./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt"). + + Returns: + dict: A dictionary representing the network data index map. + """ net_dataidx_map = {} with open(filename, "r") as data: for x in data.readlines(): @@ -43,8 +60,17 @@ def read_net_dataidx_map( net_dataidx_map[key] = [int(i.strip()) for i in tmp_array] return net_dataidx_map - def record_net_data_stats(y_train, net_dataidx_map): + """ + Record network data statistics. + + Args: + y_train (numpy.ndarray): Labels for the training data. + net_dataidx_map (dict): Network data index map. + + Returns: + dict: A dictionary containing network data statistics. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -54,12 +80,27 @@ def record_net_data_stats(y_train, net_dataidx_map): logging.debug("Data statistics: %s" % str(net_cls_counts)) return net_cls_counts - class Cutout(object): + """ + Apply cutout augmentation to an image. + + Args: + length (int): Length of the cutout square. + """ + def __init__(self, length): self.length = length def __call__(self, img): + """ + Apply cutout to the image. + + Args: + img (PIL.Image.Image): Input image. + + Returns: + PIL.Image.Image: Image with cutout applied. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -78,6 +119,12 @@ def __call__(self, img): def _data_transforms_cifar10(): + """ + Define data transformations for CIFAR-10 dataset. + + Returns: + tuple: A tuple of two transformations, one for training and one for validation. + """ CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] @@ -104,6 +151,15 @@ def _data_transforms_cifar10(): def load_cifar10_data(datadir): + """ + Load CIFAR-10 dataset. + + Args: + datadir (str): Directory where the CIFAR-10 dataset is located. + + Returns: + tuple: A tuple containing training and testing data and labels. + """ train_transform, test_transform = _data_transforms_cifar10() cifar10_train_ds = CIFAR10_truncated( @@ -120,11 +176,24 @@ def load_cifar10_data(datadir): def partition_data(dataset, datadir, partition, n_nets, alpha): + """ + Partition the CIFAR-10 dataset for federated learning. + + Args: + dataset: Not used, included for compatibility with your code. + datadir (str): Directory where the CIFAR-10 dataset is located. + partition (str): Partitioning method, can be "homo," "hetero," or "hetero-fix." + n_nets (int): Number of clients (networks). + alpha (float): Dirichlet distribution parameter for data partitioning. + + Returns: + tuple: A tuple containing data, labels, data index map, and data statistics. + """ np.random.seed(10) logging.info("*********partition data***************") X_train, y_train, X_test, y_test = load_cifar10_data(datadir) n_train = X_train.shape[0] - # n_test = X_test.shape[0] + if partition == "homo": total_num = n_train @@ -184,6 +253,19 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): # for centralized training def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for centralized training. + + Args: + dataset: Not used, included for compatibility with your code. + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to include (default: None). + + Returns: + DataLoader: Training and testing data loaders. + """ return get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs) @@ -191,12 +273,38 @@ def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test( dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ): + """ + Get data loaders for testing in CIFAR-10 dataset. + + Args: + dataset: Not used, included for compatibility with your code. + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of data indices to include in the training set. + dataidxs_test (list): List of data indices to include in the testing set. + + Returns: + DataLoader: Training and testing data loaders for CIFAR-10. + """ return get_dataloader_test_CIFAR10( datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ) def get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loaders for CIFAR-10 dataset. + + Args: + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to include (default: None). + + Returns: + DataLoader: Training and testing data loaders for CIFAR-10. + """ dl_obj = CIFAR10_truncated transform_train, transform_test = _data_transforms_cifar10() @@ -219,6 +327,19 @@ def get_dataloader_CIFAR10(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_CIFAR10( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loaders for testing CIFAR-10 dataset. + + Args: + datadir (str): Directory where the CIFAR-10 dataset is located. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list, optional): List of data indices to include in the training set. + dataidxs_test (list, optional): List of data indices to include in the testing set. + + Returns: + DataLoader: Training and testing data loaders for CIFAR-10. + """ dl_obj = CIFAR10_truncated transform_train, transform_test = _data_transforms_cifar10() @@ -257,6 +378,21 @@ def load_partition_data_distributed_cifar10( client_number, batch_size, ): + """ + Load partitioned CIFAR-10 dataset for distributed training. + + Args: + process_id (int): ID of the current process. + dataset: Not used, included for compatibility with your code. + data_dir (str): Directory where the CIFAR-10 dataset is located. + partition_method (str): Partitioning method, can be "homo," "hetero," or "hetero-fix." + partition_alpha (float): Dirichlet distribution parameter for data partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + + Returns: + tuple: A tuple containing training and testing data loaders, data statistics, and class number. + """ ( X_train, y_train, @@ -318,6 +454,21 @@ def load_partition_data_cifar10( batch_size, n_proc_in_silo=0, ): + """ + Load partitioned CIFAR-10 dataset for federated learning. + + Args: + dataset: Not used, included for compatibility with your code. + data_dir (str): Directory where the CIFAR-10 dataset is located. + partition_method (str): Partitioning method, can be "homo," "hetero," or "hetero-fix." + partition_alpha (float): Dirichlet distribution parameter for data partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + n_proc_in_silo (int, optional): Number of processes in a silo (default: 0). + + Returns: + tuple: A tuple containing training and testing data loaders, data statistics, and class number. + """ ( X_train, y_train, diff --git a/python/fedml/data/cifar10/datasets.py b/python/fedml/data/cifar10/datasets.py index 8f64d1d76e..3df54283ca 100644 --- a/python/fedml/data/cifar10/datasets.py +++ b/python/fedml/data/cifar10/datasets.py @@ -18,17 +18,45 @@ def default_loader(path): - return pil_loader(path) + """ + Default loader function for loading images. + + Args: + path (str): Path to the image file. + Returns: + PIL.Image.Image: Loaded image in RGB format. + """ + return pil_loader(path) def pil_loader(path): + """ + Custom PIL image loader function. + + Args: + path (str): Path to the image file. + + Returns: + PIL.Image.Image: Loaded image in RGB format. + """ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") - class CIFAR10_truncated(data.Dataset): + """ + Custom dataset class for truncated CIFAR-10 data. + + Args: + root (str): Root directory where CIFAR-10 dataset is located. + dataidxs (list, optional): List of data indices to include (default: None). + train (bool, optional): Whether the dataset is for training (default: True). + transform (callable, optional): Optional transform to be applied to the image (default: None). + target_transform (callable, optional): Optional transform to be applied to the target (default: None). + download (bool, optional): Whether to download the dataset if not found (default: False). + """ + def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): @@ -43,12 +71,17 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated CIFAR-10 dataset. + + Returns: + tuple: Tuple containing data and target arrays. + """ print("download = " + str(self.download)) cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) if self.train: - # print("train member of the class: {}".format(self.train)) - # data = cifar_dataobj.train_data + data = cifar_dataobj.data target = np.array(cifar_dataobj.targets) else: @@ -62,6 +95,12 @@ def __build_truncated_dataset__(self): return data, target def truncate_channel(self, index): + """ + Truncate channels (G and B) in the images specified by the given index. + + Args: + index (numpy.ndarray): Array of indices specifying which images to truncate. + """ for i in range(index.shape[0]): gs_index = index[i] self.data[gs_index, :, :, 1] = 0.0 @@ -73,7 +112,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] @@ -86,4 +125,11 @@ def __getitem__(self, index): return img, target def __len__(self): + """ + Get the number of samples in the dataset. + + Returns: + int: Number of samples in the dataset. + """ return len(self.data) + \ No newline at end of file diff --git a/python/fedml/data/cifar10/efficient_loader.py b/python/fedml/data/cifar10/efficient_loader.py index d86edbb753..8751c7293b 100644 --- a/python/fedml/data/cifar10/efficient_loader.py +++ b/python/fedml/data/cifar10/efficient_loader.py @@ -10,7 +10,16 @@ # generate the non-IID distribution for all methods -def read_data_distribution(filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt",): +def read_data_distribution(filename="./data_preprocessing/non-iid-distribution/CIFAR10/distribution.txt"): + """ + Read data distribution from a file and return it as a dictionary. + + Args: + filename (str): The path to the file containing data distribution information. + + Returns: + dict: A dictionary representing the data distribution. + """ distribution = {} with open(filename, "r") as data: for x in data.readlines(): @@ -24,8 +33,16 @@ def read_data_distribution(filename="./data_preprocessing/non-iid-distribution/C distribution[first_level_key][second_level_key] = int(tmp[1].strip().replace(",", "")) return distribution +def read_net_dataidx_map(filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt"): + """ + Read network data index mapping from a file and return it as a dictionary. + + Args: + filename (str): The path to the file containing network data index mapping information. -def read_net_dataidx_map(filename="./data_preprocessing/non-iid-distribution/CIFAR10/net_dataidx_map.txt",): + Returns: + dict: A dictionary representing the network data index mapping. + """ net_dataidx_map = {} with open(filename, "r") as data: for x in data.readlines(): @@ -41,6 +58,16 @@ def read_net_dataidx_map(filename="./data_preprocessing/non-iid-distribution/CIF def record_net_data_stats(y_train, net_dataidx_map): + """ + Record data statistics for each network based on network data index mapping. + + Args: + y_train (numpy.ndarray): The labels of the training data. + net_dataidx_map (dict): The network data index mapping. + + Returns: + dict: A dictionary containing data statistics for each network. + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): @@ -56,6 +83,15 @@ def __init__(self, length): self.length = length def __call__(self, img): + """ + Apply Cutout augmentation to the input image. + + Args: + img (PIL.Image): The input image. + + Returns: + PIL.Image: The image after applying Cutout augmentation. + """ h, w = img.size(1), img.size(2) mask = np.ones((h, w), np.float32) y = np.random.randint(h) @@ -74,6 +110,12 @@ def __call__(self, img): def _data_transforms_cifar10(): + """ + Define data transforms for CIFAR-10 dataset. + + Returns: + transforms.Compose: Training and validation data transforms. + """ CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] @@ -89,15 +131,32 @@ def _data_transforms_cifar10(): train_transform.transforms.append(Cutout(16)) - valid_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD),]) + valid_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(CIFAR_MEAN, CIFAR_STD),] + ) return train_transform, valid_transform def load_cifar10_data(datadir, process_id, synthetic_data_url, private_local_data, resize=32, augmentation=True, data_efficient_load=False): + """ + Load CIFAR-10 dataset with specified configurations. + + Args: + datadir (str): Directory where CIFAR-10 dataset is stored. + process_id (int): ID of the current process. + synthetic_data_url (str): URL for synthetic data (not used in the provided code). + private_local_data (bool): Whether to use private local data (not used in the provided code). + resize (int): Resize images to this size (not used in the provided code). + augmentation (bool): Perform data augmentation (not used in the provided code). + data_efficient_load (bool): Load data efficiently (not used in the provided code). + + Returns: + tuple: Tuple containing X_train, y_train, X_test, y_test, cifar10_train_ds, and cifar10_test_ds. + """ train_transform, test_transform = _data_transforms_cifar10() - is_download = True; + is_download = True if data_efficient_load: cifar10_train_ds = CIFAR10(datadir, train=True, download=True, transform=train_transform) @@ -113,11 +172,27 @@ def load_cifar10_data(datadir, process_id, synthetic_data_url, private_local_dat def partition_data(dataset, datadir, partition, n_nets, alpha, process_id, synthetic_data_url, private_local_data): + """ + Partition the CIFAR-10 dataset into subsets for federated learning. + + Args: + dataset (str): Name of the dataset (not used in the provided code). + datadir (str): Directory where CIFAR-10 dataset is stored. + partition (str): Partitioning method (homo, hetero, hetero-fix). + n_nets (int): Number of clients (networks). + alpha (float): Alpha value for partitioning (not used in the provided code). + process_id (int): ID of the current process. + synthetic_data_url (str): URL for synthetic data (not used in the provided code). + private_local_data (bool): Whether to use private local data (not used in the provided code). + + Returns: + tuple: Tuple containing X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts, cifar10_train_ds, and cifar10_test_ds. + """ np.random.seed(10) logging.info("*********partition data***************") X_train, y_train, X_test, y_test, cifar10_train_ds, cifar10_test_ds = load_cifar10_data(datadir, process_id, synthetic_data_url, private_local_data) n_train = X_train.shape[0] - # n_test = X_test.shape[0] + if partition == "homo": total_num = n_train @@ -174,6 +249,22 @@ def get_dataloader( full_train_dataset=None, full_test_dataset=None, ): + """ + Get data loaders for CIFAR-10 dataset. + + Args: + dataset (str): Name of the dataset. + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list): List of data indices for custom data loading (default: None). + data_efficient_load (bool): Use data-efficient loading (default: False). + full_train_dataset: Full training dataset (default: None). + full_test_dataset: Full testing dataset (default: None). + + Returns: + tuple: Tuple containing training and testing data loaders. + """ return get_dataloader_CIFAR10( datadir, train_bs, @@ -184,12 +275,24 @@ def get_dataloader( full_test_dataset=full_test_dataset, ) - # for local devices def get_dataloader_test(dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test): + """ + Get data loaders for testing CIFAR-10 dataset on local devices. + + Args: + dataset (str): Name of the dataset. + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of training data indices. + dataidxs_test (list): List of testing data indices. + + Returns: + tuple: Tuple containing training and testing data loaders. + """ return get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train, dataidxs_test) - def get_dataloader_CIFAR10( datadir, train_bs, @@ -199,6 +302,21 @@ def get_dataloader_CIFAR10( full_train_dataset=None, full_test_dataset=None, ): + """ + Get data loaders for CIFAR-10 dataset. + + Args: + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list): List of data indices for custom data loading (default: None). + data_efficient_load (bool): Use data-efficient loading (default: False). + full_train_dataset: Full training dataset (default: None). + full_test_dataset: Full testing dataset (default: None). + + Returns: + tuple: Tuple containing training and testing data loaders. + """ transform_train, transform_test = _data_transforms_cifar10() if data_efficient_load: @@ -217,8 +335,20 @@ def get_dataloader_CIFAR10( return train_dl, test_dl - def get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None): + """ + Get data loaders for testing CIFAR-10 dataset on local devices. + + Args: + datadir (str): Directory where CIFAR-10 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of training data indices. + dataidxs_test (list): List of testing data indices. + + Returns: + tuple: Tuple containing training and testing data loaders. + """ dl_obj = CIFAR10_truncated transform_train, transform_test = _data_transforms_cifar10() @@ -231,7 +361,6 @@ def get_dataloader_test_CIFAR10(datadir, train_bs, test_bs, dataidxs_train=None, return train_dl, test_dl - def load_partition_data_distributed_cifar10( process_id, dataset, @@ -242,6 +371,24 @@ def load_partition_data_distributed_cifar10( batch_size, data_efficient_load=True, ): + """ + Load partitioned CIFAR-10 data for distributed learning. + + Args: + process_id (int): ID of the current process. + dataset (str): Name of the dataset. + data_dir (str): Directory where CIFAR-10 dataset is stored. + partition_method (str): Partitioning method (homo, hetero, hetero-fix). + partition_alpha (float): Alpha value for partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + data_efficient_load (bool): Use data-efficient loading (default: True). + + Returns: + tuple: Tuple containing training data size, global training data loader, + global testing data loader, local data size, local training data loader, + local testing data loader, and class count. + """ ( X_train, y_train, @@ -318,6 +465,28 @@ def efficient_load_partition_data_cifar10( n_proc_in_silo=0, data_efficient_load=True, ): + """ + Efficiently load partitioned CIFAR-10 data for distributed learning. + + Args: + dataset (str): Name of the dataset. + data_dir (str): Directory where CIFAR-10 dataset is stored. + partition_method (str): Partitioning method (homo, hetero, hetero-fix). + partition_alpha (float): Alpha value for partitioning. + client_number (int): Number of clients (networks). + batch_size (int): Batch size for training and testing. + process_id (int): ID of the current process (default: 0). + synthetic_data_url (str): URL for synthetic data (default: ""). + private_local_data (str): Path to private local data (default: ""). + n_proc_in_silo (int): Number of processes in the silo (default: 0). + data_efficient_load (bool): Use data-efficient loading (default: True). + + Returns: + tuple: Tuple containing training data size, global testing data size, + global training data loader, global testing data loader, dictionary of + local data sample numbers, dictionary of local training data loaders, + dictionary of local testing data loaders, and class count. + """ ( X_train, y_train, @@ -327,7 +496,16 @@ def efficient_load_partition_data_cifar10( traindata_cls_counts, cifar10_train_ds, cifar10_test_ds, - ) = partition_data(dataset, data_dir, partition_method, client_number, partition_alpha, process_id, synthetic_data_url, private_local_data) + ) = partition_data( + dataset, + data_dir, + partition_method, + client_number, + partition_alpha, + process_id, + synthetic_data_url, + private_local_data, + ) class_num = len(np.unique(y_train)) logging.info("traindata_cls_counts = " + str(traindata_cls_counts)) train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)]) @@ -382,4 +560,4 @@ def efficient_load_partition_data_cifar10( train_data_local_dict, test_data_local_dict, class_num, - ) + ) \ No newline at end of file diff --git a/python/fedml/data/cifar10/without_reload.py b/python/fedml/data/cifar10/without_reload.py index c483f9f2a1..6f60f27e22 100644 --- a/python/fedml/data/cifar10/without_reload.py +++ b/python/fedml/data/cifar10/without_reload.py @@ -28,6 +28,12 @@ def __init__(self, root, dataidxs=None, train=True, transform=None, target_trans self.data, self.targets = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated CIFAR-10 dataset by loading data based on data indices. + + Returns: + tuple: A tuple containing the data and targets (class labels). + """ print("download = " + str(self.download)) cifar_dataobj = CIFAR10(self.root, self.train, self.transform, self.target_transform, self.download) @@ -47,6 +53,12 @@ def __build_truncated_dataset__(self): return data, targets def truncate_channel(self, index): + """ + Truncate the green and blue channels of specified images in the dataset. + + Args: + index (numpy.ndarray): An array of indices indicating which images to truncate. + """ for i in range(index.shape[0]): gs_index = index[i] self.data[gs_index, :, :, 1] = 0.0 diff --git a/python/fedml/data/cifar100/data_loader.py b/python/fedml/data/cifar100/data_loader.py index bee691bd13..96b9f8ab72 100644 --- a/python/fedml/data/cifar100/data_loader.py +++ b/python/fedml/data/cifar100/data_loader.py @@ -78,6 +78,12 @@ def __call__(self, img): def _data_transforms_cifar100(): + """ + Get data transforms for CIFAR-100 dataset. + + Returns: + tuple: A tuple containing train and validation data transforms. + """ CIFAR_MEAN = [0.5071, 0.4865, 0.4409] CIFAR_STD = [0.2673, 0.2564, 0.2762] @@ -103,6 +109,15 @@ def _data_transforms_cifar100(): return train_transform, valid_transform def load_cifar100_data(datadir): + """ + Load CIFAR-100 dataset. + + Args: + datadir (str): The directory where CIFAR-100 dataset is stored. + + Returns: + tuple: A tuple containing training data, training labels, testing data, and testing labels. + """ train_transform, test_transform = _data_transforms_cifar100() cifar100_train_ds = CIFAR100_truncated( @@ -115,13 +130,26 @@ def load_cifar100_data(datadir): X_train, y_train = cifar100_train_ds.data, cifar100_train_ds.target X_test, y_test = cifar100_test_ds.data, cifar100_test_ds.target - return (X_train, y_train, X_test, y_test) + return X_train, y_train, X_test, y_test def partition_data(dataset, datadir, partition, n_nets, alpha): + """ + Partition CIFAR-100 data for federated learning. + + Args: + dataset (str): The dataset name. + datadir (str): The directory where CIFAR-100 dataset is stored. + partition (str): The data partitioning method ("homo", "hetero", or "hetero-fix"). + n_nets (int): The number of clients (networks). + alpha (float): Alpha parameter for data partitioning. + + Returns: + tuple: A tuple containing training data, training labels, testing data, testing labels, network data index map, and class counts. + """ logging.info("*********partition data***************") X_train, y_train, X_test, y_test = load_cifar100_data(datadir) n_train = X_train.shape[0] - # n_test = X_test.shape[0] + if partition == "homo": total_num = n_train @@ -179,21 +207,58 @@ def partition_data(dataset, datadir, partition, n_nets, alpha): return X_train, y_train, X_test, y_test, net_dataidx_map, traindata_cls_counts -# for centralized training + def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loader for centralized training. + + Args: + dataset (str): The dataset name. + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use. Defaults to None. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ return get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs) - -# for local devices def get_dataloader_test( dataset, datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ): + """ + Get data loader for local devices. + + Args: + dataset (str): The dataset name. + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list): List of training data indices. + dataidxs_test (list): List of testing data indices. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ return get_dataloader_test_CIFAR100( datadir, train_bs, test_bs, dataidxs_train, dataidxs_test ) def get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs=None): + """ + Get data loader for CIFAR-100 dataset. + + Args: + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs (list, optional): List of data indices to use. Defaults to None. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ dl_obj = CIFAR100_truncated transform_train, transform_test = _data_transforms_cifar100() @@ -216,6 +281,19 @@ def get_dataloader_CIFAR100(datadir, train_bs, test_bs, dataidxs=None): def get_dataloader_test_CIFAR100( datadir, train_bs, test_bs, dataidxs_train=None, dataidxs_test=None ): + """ + Get data loader for testing CIFAR-100 dataset. + + Args: + datadir (str): The directory where CIFAR-100 dataset is stored. + train_bs (int): Batch size for training data. + test_bs (int): Batch size for testing data. + dataidxs_train (list, optional): List of training data indices. Defaults to None. + dataidxs_test (list, optional): List of testing data indices. Defaults to None. + + Returns: + tuple: A tuple containing training data loader and testing data loader. + """ dl_obj = CIFAR100_truncated transform_train, transform_test = _data_transforms_cifar100() @@ -254,6 +332,21 @@ def load_partition_data_distributed_cifar100( client_number, batch_size, ): + """ + Load partitioned CIFAR-100 data for distributed training. + + Args: + process_id (int): The process ID. + dataset (str): The dataset name. + data_dir (str): The directory where CIFAR-100 dataset is stored. + partition_method (str): The data partitioning method ("homo", "hetero", or "hetero-fix"). + partition_alpha (float): Alpha parameter for data partitioning. + client_number (int): The number of clients (networks). + batch_size (int): Batch size for training and testing data. + + Returns: + tuple: A tuple containing various data loaders and class count information. + """ ( X_train, y_train, @@ -268,7 +361,7 @@ def load_partition_data_distributed_cifar100( logging.info("traindata_cls_counts = " + str(traindata_cls_counts)) train_data_num = sum([len(net_dataidx_map[r]) for r in range(client_number)]) - # get global test data + if process_id == 0: train_data_global, test_data_global = get_dataloader( dataset, data_dir, batch_size, batch_size @@ -279,13 +372,13 @@ def load_partition_data_distributed_cifar100( test_data_local = None local_data_num = 0 else: - # get local dataset + dataidxs = net_dataidx_map[process_id - 1] local_data_num = len(dataidxs) logging.info( "rank = %d, local_sample_number = %d" % (process_id, local_data_num) ) - # training batch size = 64; algorithms batch size = 32 + train_data_local, test_data_local = get_dataloader( dataset, data_dir, batch_size, batch_size, dataidxs ) @@ -310,6 +403,21 @@ def load_partition_data_distributed_cifar100( def load_partition_data_cifar100( dataset, data_dir, partition_method, partition_alpha, client_number, batch_size ): + """ + Load and partition CIFAR-100 data for federated learning. + + Args: + dataset (str): The dataset name. + data_dir (str): The directory where CIFAR-100 dataset is stored. + partition_method (str): The data partitioning method ("homo", "hetero", or "hetero-fix"). + partition_alpha (float): Alpha parameter for data partitioning. + client_number (int): The number of clients (networks). + batch_size (int): Batch size for training and testing data. + + Returns: + tuple: A tuple containing various data loaders and class count information. + + """ ( X_train, y_train, @@ -363,4 +471,4 @@ def load_partition_data_cifar100( train_data_local_dict, test_data_local_dict, class_num, - ) + ) \ No newline at end of file diff --git a/python/fedml/data/cifar100/datasets.py b/python/fedml/data/cifar100/datasets.py index ee0b332bdc..c7a2cec84a 100644 --- a/python/fedml/data/cifar100/datasets.py +++ b/python/fedml/data/cifar100/datasets.py @@ -17,21 +17,46 @@ def default_loader(path): - return pil_loader(path) + """ + Default image loader function using PIL to open and convert an image to RGB format. + + Args: + path (str): The path to the image file. + Returns: + PIL.Image: The loaded image in RGB format. + """ + return pil_loader(path) def pil_loader(path): - # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + """ + Image loader function using PIL to open and convert an image to RGB format. + + Args: + path (str): The path to the image file. + + Returns: + PIL.Image: The loaded image in RGB format. + """ with open(path, "rb") as f: img = Image.open(f) return img.convert("RGB") - class CIFAR100_truncated(data.Dataset): def __init__( self, root, dataidxs=None, train=True, transform=None, target_transform=None, download=False, ): + """ + Custom dataset class for truncated CIFAR-100 dataset. + Args: + root (str): The root directory where the dataset is stored. + dataidxs (list or None): List of data indices to include in the dataset. If None, includes all data. + train (bool): Indicates whether the dataset is for training (True) or testing (False). + transform (callable, optional): A function/transform to apply to the data. + target_transform (callable, optional): A function/transform to apply to the target (class label). + download (bool, optional): Whether to download the dataset if not found locally. + """ self.root = root self.dataidxs = dataidxs self.train = train @@ -42,7 +67,12 @@ def __init__( self.data, self.target = self.__build_truncated_dataset__() def __build_truncated_dataset__(self): + """ + Build the truncated CIFAR-100 dataset by loading data based on data indices. + Returns: + tuple: A tuple containing the data and target (class labels). + """ cifar_dataobj = CIFAR100(self.root, self.train, self.transform, self.target_transform, self.download) data = cifar_dataobj.data @@ -55,6 +85,12 @@ def __build_truncated_dataset__(self): return data, target def truncate_channel(self, index): + """ + Truncate the green and blue channels of specified images in the dataset. + + Args: + index (numpy.ndarray): An array of indices indicating which images to truncate. + """ for i in range(index.shape[0]): gs_index = index[i] self.data[gs_index, :, :, 1] = 0.0 @@ -66,7 +102,7 @@ def __getitem__(self, index): index (int): Index Returns: - tuple: (image, target) where target is index of the target class. + tuple: (image, target) where target is the index of the target class. """ img, target = self.data[index], self.target[index] From ac91c951094d7ef28947f18d3c15f13d454bed0e Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Tue, 19 Sep 2023 13:13:50 +0530 Subject: [PATCH 63/70] update --- python/fedml/cross_silo/fedml_client.py | 28 ++ python/fedml/cross_silo/fedml_server.py | 28 ++ .../lightsecagg/lsa_fedml_aggregator.py | 213 +++++++++-- .../cross_silo/lightsecagg/lsa_fedml_api.py | 76 +++- .../lightsecagg/lsa_fedml_client_manager.py | 117 ++++++ .../lightsecagg/lsa_fedml_server_manager.py | 198 ++++++++-- .../cross_silo/secagg/sa_fedml_aggregator.py | 196 +++++++--- .../fedml/cross_silo/secagg/sa_fedml_api.py | 72 +++- .../secagg/sa_fedml_client_manager.py | 354 ++++++++++++++++-- .../secagg/sa_fedml_server_manager.py | 351 +++++++++++++++-- .../cross_silo/server/fedml_aggregator.py | 243 ++++++++++-- .../cross_silo/server/fedml_server_manager.py | 350 ++++++++++++++--- .../cross_silo/server/server_initializer.py | 27 +- 13 files changed, 1983 insertions(+), 270 deletions(-) diff --git a/python/fedml/cross_silo/fedml_client.py b/python/fedml/cross_silo/fedml_client.py index 5e009977d0..b1198997ca 100644 --- a/python/fedml/cross_silo/fedml_client.py +++ b/python/fedml/cross_silo/fedml_client.py @@ -3,6 +3,25 @@ class FedMLCrossSiloClient: + """ + Represents a client for a cross-silo federated learning setup. + + Args: + args (object): An object containing various configuration parameters. + device (torch.device): The device (e.g., 'cpu' or 'cuda') for computation. + dataset (tuple): A tuple containing dataset-related information. + model (torch.nn.Module): The PyTorch model used in federated learning. + model_trainer (ClientTrainer, optional): An optional client trainer. + + Raises: + Exception: If an unsupported federated optimizer is specified in args. + + Attributes: + None + + Methods: + run(): Placeholder method for client execution. + """ def __init__(self, args, device, dataset, model, model_trainer: ClientTrainer = None): if args.federated_optimizer == "FedAvg": [ @@ -61,4 +80,13 @@ def __init__(self, args, device, dataset, model, model_trainer: ClientTrainer = raise Exception("Exception") def run(self): + """ + Placeholder method for client execution. + + Args: + None + + Returns: + None + """ pass diff --git a/python/fedml/cross_silo/fedml_server.py b/python/fedml/cross_silo/fedml_server.py index 6778469b52..97d9890c66 100644 --- a/python/fedml/cross_silo/fedml_server.py +++ b/python/fedml/cross_silo/fedml_server.py @@ -2,6 +2,25 @@ class FedMLCrossSiloServer: + """ + Represents a server for a cross-silo federated learning setup. + + Args: + args (object): An object containing various configuration parameters. + device (torch.device): The device (e.g., 'cpu' or 'cuda') for computation. + dataset (tuple): A tuple containing dataset-related information. + model (torch.nn.Module): The PyTorch model used in federated learning. + server_aggregator (ServerAggregator, optional): An optional server aggregator. + + Raises: + Exception: If an unsupported federated optimizer is specified in args. + + Attributes: + None + + Methods: + run(): Placeholder method for server execution. + """ def __init__(self, args, device, dataset, model, server_aggregator: ServerAggregator = None): if args.federated_optimizer == "FedAvg": from fedml.cross_silo.server import server_initializer @@ -65,4 +84,13 @@ def __init__(self, args, device, dataset, model, server_aggregator: ServerAggreg raise Exception("Exception") def run(self): + """ + Placeholder method for server execution. + + Args: + None + + Returns: + None + """ pass diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py index 68cbd66f85..8c64fb1217 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_aggregator.py @@ -16,6 +16,49 @@ class LightSecAggAggregator(object): + """ + Initialize a LightSecAggAggregator for federated learning. + + Args: + train_global (Dataset): The global training dataset. + test_global (Dataset): The global test dataset. + all_train_data_num (int): The total number of training data points globally. + train_data_local_dict (dict): A dictionary of local training datasets for each client. + test_data_local_dict (dict): A dictionary of local test datasets for each client. + train_data_local_num_dict (dict): A dictionary of the number of local training data points for each client. + client_num (int): The number of client nodes participating in federated learning. + device (torch.device): The device on which the server runs. + args (argparse.Namespace): Command-line arguments and configurations. + model_trainer: An instance of the model trainer for federated learning. + + Attributes: + trainer: The model trainer for federated learning. + args (argparse.Namespace): Command-line arguments and configurations. + train_global (Dataset): The global training dataset. + test_global (Dataset): The global test dataset. + val_global: The validation dataset generated from the global test dataset. + all_train_data_num (int): The total number of training data points globally. + train_data_local_dict (dict): A dictionary of local training datasets for each client. + test_data_local_dict (dict): A dictionary of local test datasets for each client. + train_data_local_num_dict (dict): A dictionary of the number of local training data points for each client. + client_num (int): The number of client nodes participating in federated learning. + device (torch.device): The device on which the server runs. + model_dict (dict): A dictionary to store the local models submitted by clients. + sample_num_dict (dict): A dictionary to store the number of samples each client used for training. + aggregate_encoded_mask_dict (dict): A dictionary to store encoded aggregate masks from clients. + flag_client_model_uploaded_dict (dict): A dictionary to track whether a client has uploaded its model. + flag_client_aggregate_encoded_mask_uploaded_dict (dict): A dictionary to track whether a client has uploaded its encoded aggregate mask. + total_dimension: The total dimension of the model's parameters. + dimensions (list): A list of dimensions for each parameter of the model. + targeted_number_active_clients (int): The targeted number of active clients for aggregation. + privacy_guarantee (int): The privacy guarantee parameter. + prime_number: The prime number used in aggregation. + precision_parameter: The precision parameter used in aggregation. + + Returns: + None + """ + def __init__( self, train_global, @@ -62,14 +105,35 @@ def __init__( self.precision_parameter = args.precision_parameter def get_global_model_params(self): + """ + Get the global model parameters from the model trainer. + + Returns: + dict: The global model parameters. + """ global_model_params = self.trainer.get_model_params() - self.dimensions, self.total_dimension = model_dimension(global_model_params) + self.dimensions, self.total_dimension = model_dimension( + global_model_params) return global_model_params def set_global_model_params(self, model_parameters): + """ + Set the global model parameters in the model trainer. + + Args: + model_parameters (dict): The global model parameters to be set. + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the locally trained model results for a client. + + Args: + index (int): The index of the client. + model_params (dict): The locally trained model parameters. + sample_num (int): The number of samples used for training. + """ logging.info("add_model. index = %d" % index) # for key in model_params.keys(): # model_params[key] = model_params[key].to(self.device) @@ -78,11 +142,24 @@ def add_local_trained_result(self, index, model_params, sample_num): self.flag_client_model_uploaded_dict[index] = True def add_local_aggregate_encoded_mask(self, index, aggregate_encoded_mask): + """ + Add the locally generated aggregate encoded mask for a client. + + Args: + index (int): The index of the client. + aggregate_encoded_mask (array): The encoded aggregate mask. + """ logging.info("add_aggregate_encoded_mask index = %d" % index) self.aggregate_encoded_mask_dict[index] = aggregate_encoded_mask self.flag_client_aggregate_encoded_mask_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check whether all clients have uploaded their local models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -91,6 +168,12 @@ def check_whether_all_receive(self): return True def check_whether_all_aggregate_encoded_mask_receive(self): + """ + Check whether all clients have uploaded their aggregate encoded masks. + + Returns: + bool: True if all clients have uploaded their masks, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_aggregate_encoded_mask_uploaded_dict[idx]: return False @@ -100,38 +183,60 @@ def check_whether_all_aggregate_encoded_mask_receive(self): def aggregate_mask_reconstruction(self, active_clients): """ - Recover the aggregate-mask via decoding + Recover the aggregate-mask via decoding. + + Args: + active_clients (list): List of active client indices for aggregation. + + Returns: + array: The reconstructed aggregate mask. """ d = self.total_dimension N = self.client_num U = self.targeted_number_active_clients T = self.privacy_guarantee p = self.prime_number - logging.debug("d = {}, N = {}, U = {}, T = {}, p = {}".format(d, N, U, T, p)) + logging.debug( + "d = {}, N = {}, U = {}, T = {}, p = {}".format(d, N, U, T, p)) d = int(np.ceil(float(d) / (U - T))) * (U - T) alpha_s = np.array(range(N)) + 1 beta_s = np.array(range(U)) + (N + 1) logging.info("Server starts the reconstruction of aggregate_mask") - aggregate_encoded_mask_buffer = np.zeros((U, d // (U - T)), dtype="int64") + aggregate_encoded_mask_buffer = np.zeros( + (U, d // (U - T)), dtype="int64") # logging.info( # "active_clients = {}, aggregate_encoded_mask_dict = {}".format( # active_clients, self.aggregate_encoded_mask_dict # ) # ) for i, client_idx in enumerate(active_clients): - aggregate_encoded_mask_buffer[i, :] = self.aggregate_encoded_mask_dict[client_idx] + aggregate_encoded_mask_buffer[i, + :] = self.aggregate_encoded_mask_dict[client_idx] eval_points = alpha_s[active_clients] - aggregate_mask = LCC_decoding_with_points(aggregate_encoded_mask_buffer, eval_points, beta_s, p) - logging.info("Server finish the reconstruction of aggregate_mask via LCC decoding") + aggregate_mask = LCC_decoding_with_points( + aggregate_encoded_mask_buffer, eval_points, beta_s, p) + logging.info( + "Server finish the reconstruction of aggregate_mask via LCC decoding") aggregate_mask = np.reshape(aggregate_mask, (U * (d // (U - T)), 1)) aggregate_mask = aggregate_mask[0:d] # logging.info("aggregated mask = {}".format(aggregate_mask)) return aggregate_mask def aggregate_model_reconstruction(self, active_clients_first_round, active_clients_second_round): + """ + Perform aggregate model reconstruction using encoded masks. + + Args: + active_clients_first_round (list): List of active client indices in the first round. + active_clients_second_round (list): List of active client indices in the second round. + + Returns: + dict: The averaged global model parameters after reconstruction. + """ start_time = time.time() - aggregate_mask = self.aggregate_mask_reconstruction(active_clients_second_round) + aggregate_mask = self.aggregate_mask_reconstruction( + active_clients_second_round) p = self.prime_number q_bits = self.precision_parameter logging.info("Server starts the reconstruction of aggregate_model") @@ -146,7 +251,7 @@ def aggregate_model_reconstruction(self, active_clients_first_round, active_clie averaged_params[k] += local_model_params[k] cur_shape = np.shape(averaged_params[k]) d = self.dimensions[j] - cur_mask = np.array(aggregate_mask[pos : pos + d, :]) + cur_mask = np.array(aggregate_mask[pos: pos + d, :]) cur_mask = np.reshape(cur_mask, cur_shape) # Cancel out the aggregate-mask to recover the aggregate-model @@ -157,7 +262,8 @@ def aggregate_model_reconstruction(self, active_clients_first_round, active_clie # Convert the model from finite to real # logging.info("Server converts the aggregate_model from finite to tensor") # logging.info("aggregate model before transform = {}".format(averaged_params)) - averaged_params = transform_finite_to_tensor(averaged_params, p, q_bits) + averaged_params = transform_finite_to_tensor( + averaged_params, p, q_bits) # do the avg after transform for j, k in enumerate(averaged_params): @@ -188,15 +294,18 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou """ logging.info( - "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) + "client_num_in_total = %d, client_num_per_round = %d" % ( + client_num_in_total, client_num_per_round) ) assert client_num_in_total >= client_num_per_round if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): @@ -213,31 +322,78 @@ def client_selection(self, round_idx, client_id_list_in_total, client_num_per_ro """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a list of clients for the current training round. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients in the dataset. + client_num_per_round (int): The number of clients to sample for the current round. + + Returns: + list: List of sampled client indices for the current round. + """ if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation dataset subset. + + Args: + num_samples (int): The number of samples to include in the validation set. + + Returns: + torch.utils.data.DataLoader: DataLoader for the validation dataset subset. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) - sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) - subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) - sample_testset = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num)) + subset = torch.utils.data.Subset( + self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size) return sample_testset else: return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients and log the results. + + Args: + round_idx (int): Round index, starting from 0. + + This method tests the performance of the global model on both the training and testing datasets for all clients + and logs the results. It calculates and logs the training accuracy, training loss, test accuracy, and test loss. + + If the `round_idx` is a multiple of the specified `frequency_of_the_test` or it's the final round (`comm_round - 1`), + testing is performed; otherwise, it is skipped. + + The results are logged using the `wandb` library if the `enable_wandb` flag is set. + + Note: The method assumes that the `trainer` attribute has appropriate testing methods defined. + + Returns: + None + """ # if self.trainer.test_on_the_server( # self.train_data_local_dict, # self.test_data_local_dict, @@ -247,13 +403,15 @@ def test_on_server_for_all_clients(self, round_idx): # return if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) train_num_samples = [] train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): # train data - metrics = self.trainer.test(self.train_data_local_dict[client_idx], self.device, self.args) + metrics = self.trainer.test( + self.train_data_local_dict[client_idx], self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( metrics["test_correct"], metrics["test_total"], @@ -272,7 +430,8 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - mlops.log({"accuracy": round(train_acc, 4), "loss": round(train_loss, 4)}) + mlops.log({"accuracy": round(train_acc, 4), + "loss": round(train_loss, 4)}) # test data test_num_samples = [] @@ -280,9 +439,11 @@ def test_on_server_for_all_clients(self, round_idx): test_losses = [] if round_idx == self.args.comm_round - 1: - metrics = self.trainer.test(self.test_global, self.device, self.args) + metrics = self.trainer.test( + self.test_global, self.device, self.args) else: - metrics = self.trainer.test(self.val_global, self.device, self.args) + metrics = self.trainer.test( + self.val_global, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( metrics["test_correct"], diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py index b50f342c62..9f7b99c9fa 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_api.py @@ -8,6 +8,33 @@ def FedML_LSA_Horizontal( args, client_rank, client_num, comm, device, dataset, model, model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Initialize and run the Federated Learning with LightSecAgg (LSA) in a horizontal setup. + + Args: + args (object): Command-line arguments and configuration. + client_rank (int): Rank or identifier of the current client (0 for the server). + client_num (int): Total number of clients participating in the federated learning. + comm (object): Communication backend for distributed training. + device (object): The device on which the training will be performed (e.g., GPU or CPU). + dataset (list): A list containing dataset-related information: + - train_data_num (int): Number of samples in the global training dataset. + - test_data_num (int): Number of samples in the global test dataset. + - train_data_global (object): Global training dataset. + - test_data_global (object): Global test dataset. + - train_data_local_num_dict (dict): Dictionary mapping client indices to the number of local training samples. + - train_data_local_dict (dict): Dictionary mapping client indices to their local training dataset. + - test_data_local_dict (dict): Dictionary mapping client indices to their local test dataset. + - class_num (int): Number of classes in the dataset. + model (object): The federated learning model to be trained. + model_trainer (object, optional): The model trainer responsible for training and testing. If not provided, + it will be created based on the model and args. + preprocessed_sampling_lists (list, optional): Preprocessed client sampling lists. If provided, the server will + use these preprocessed sampling lists during initialization. + + Returns: + None + """ [ train_data_num, test_data_num, @@ -67,6 +94,29 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for Federated Learning with LightSecAgg (LSA) in a horizontal setup. + + Args: + args (object): Command-line arguments and configuration. + device (object): The device on which the training will be performed (e.g., GPU or CPU). + comm (object): Communication backend for distributed training. + client_rank (int): Rank or identifier of the server (0 for the server). + client_num (int): Total number of clients participating in the federated learning. + model (object): The federated learning model to be trained. + train_data_num (int): Number of samples in the global training dataset. + train_data_global (object): Global training dataset. + test_data_global (object): Global test dataset. + train_data_local_dict (dict): Dictionary mapping client indices to their local training dataset. + test_data_local_dict (dict): Dictionary mapping client indices to their local test dataset. + train_data_local_num_dict (dict): Dictionary mapping client indices to the number of local training samples. + model_trainer (object): The model trainer responsible for training and testing. + preprocessed_sampling_lists (list, optional): Preprocessed client sampling lists. If provided, the server will + use these preprocessed sampling lists during initialization. + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(0) @@ -88,7 +138,8 @@ def init_server( # start the distributed training backend = args.backend if preprocessed_sampling_lists is None: - server_manager = FedMLServerManager(args, aggregator, comm, client_rank, client_num, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, client_rank, client_num, backend) else: server_manager = FedMLServerManager( args, @@ -117,6 +168,26 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Federated Learning with LightSecAgg (LSA) in a horizontal setup. + + Args: + args (object): Command-line arguments and configuration. + device (object): The device on which the training will be performed (e.g., GPU or CPU). + comm (object): Communication backend for distributed training. + client_rank (int): Rank or identifier of the current client. + client_num (int): Total number of clients participating in the federated learning. + model (object): The federated learning model to be trained. + train_data_num (int): Number of samples in the global training dataset. + train_data_local_num_dict (dict): Dictionary mapping client indices to the number of local training samples. + train_data_local_dict (dict): Dictionary mapping client indices to their local training dataset. + test_data_local_dict (dict): Dictionary mapping client indices to their local test dataset. + model_trainer (object, optional): The model trainer responsible for training and testing. If not provided, + it will be created based on the model and args. + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(client_rank) @@ -131,5 +202,6 @@ def init_client( args, model_trainer, ) - client_manager = FedMLClientManager(args, trainer, comm, client_rank, client_num, backend) + client_manager = FedMLClientManager( + args, trainer, comm, client_rank, client_num, backend) client_manager.run() diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py index f46372e529..dcdb627c84 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_client_manager.py @@ -19,6 +19,17 @@ class FedMLClientManager(FedMLCommManager): def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backend="MPI"): + """ + Initialize the FedMLClientManager. + + Args: + args (argparse.Namespace): The command-line arguments. + trainer: The trainer for the client. + comm: The communication backend. + client_rank (int): The rank of the client. + client_num (int): The total number of clients. + backend (str): The communication backend (default is "MPI"). + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.trainer = trainer @@ -51,6 +62,9 @@ def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backen self.sys_stats_process = None def register_message_receive_handlers(self): + """ + Register message receive handlers for various message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -74,6 +88,12 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection-ready message. + + Args: + msg_params (dict): Parameters of the message. + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -81,9 +101,21 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the check-client-status message. + + Args: + msg_params (dict): Parameters of the message. + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params (dict): Parameters of the message. + """ global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -100,6 +132,12 @@ def handle_message_init(self, msg_params): self.__offline() def handle_message_receive_encoded_mask_from_server(self, msg_params): + """ + Handle the received encoded mask from the server. + + Args: + msg_params (dict): Parameters of the message. + """ encoded_mask = msg_params.get(MyMessage.MSG_ARG_KEY_ENCODED_MASK) client_id = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_ID) # logging.info( @@ -114,6 +152,12 @@ def handle_message_receive_encoded_mask_from_server(self, msg_params): self.__train() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Parameters of the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -130,6 +174,12 @@ def handle_message_receive_model_from_server(self, msg_params): self.__offline() def handle_message_receive_active_from_server(self, msg_params): + """ + Handle the received active clients message from the server. + + Args: + msg_params (dict): Parameters of the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) # Receive the set of active client id in first round active_clients_first_round = msg_params.get(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS) @@ -146,10 +196,20 @@ def handle_message_receive_active_from_server(self, msg_params): self.send_aggregate_encoded_mask_to_server(0, aggregate_encoded_mask) def start_training(self): + """ + Start the training process. + """ self.round_idx = 0 self.__train() def send_client_status(self, receive_id, status="ONLINE"): + """ + Send the client status to another entity. + + Args: + receive_id: The ID of the entity receiving the status. + status (str): The status to send (default is "ONLINE"). + """ logging.info("send_client_status") message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) sys_name = platform.system() @@ -163,9 +223,23 @@ def send_client_status(self, receive_id, status="ONLINE"): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status: The training status to report. + """ mlops.log_training_status(status) def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model to the server. + + Args: + receive_id: The ID of the entity receiving the model. + weights: The model parameters to send. + local_sample_num: The number of local samples used for training. + """ mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) @@ -178,21 +252,49 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def send_encoded_mask_to_server(self, receive_id, encoded_mask): + """ + Send the encoded mask to the server. + + Args: + receive_id: The ID of the entity receiving the encoded mask. + encoded_mask: The encoded mask to send. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_ENCODED_MASK_TO_SERVER, self.get_sender_id(), 0) message.add_params(MyMessage.MSG_ARG_KEY_ENCODED_MASK, encoded_mask) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_ID, receive_id) self.send_message(message) def send_aggregate_encoded_mask_to_server(self, receive_id, aggregate_encoded_mask): + """ + Send the aggregate encoded mask to the server. + + Args: + receive_id: The ID of the entity receiving the aggregate encoded mask. + aggregate_encoded_mask: The aggregate encoded mask to send. + """ message = Message(MyMessage.MSG_TYPE_C2S_SEND_MASK_TO_SERVER, self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_AGGREGATE_ENCODED_MASK, aggregate_encoded_mask) self.send_message(message) def add_encoded_mask(self, index, encoded_mask): + """ + Add an encoded mask to the internal dictionary. + + Args: + index: The index of the encoded mask. + encoded_mask: The encoded mask to add. + """ + self.encoded_mask_dict[index] = encoded_mask self.flag_encoded_mask_dict[index] = True def check_whether_all_encoded_mask_receive(self): + """ + Check if all encoded masks have been received. + + Returns: + bool: True if all encoded masks have been received, False otherwise. + """ for idx in range(self.worker_num): if not self.flag_encoded_mask_dict[idx]: return False @@ -201,6 +303,12 @@ def check_whether_all_encoded_mask_receive(self): return True def encoded_mask_sharing(self, encoded_mask_set): + """ + Share encoded masks with other clients. + + Args: + encoded_mask_set (list): A list of encoded masks. + """ for receive_id in range(1, self.size + 1): print(receive_id) print("the size is ", self.size) @@ -213,6 +321,9 @@ def encoded_mask_sharing(self, encoded_mask_set): self.flag_encoded_mask_dict[receive_id - 1] = True def __offline(self): + """ + Perform the offline phase, including mask encoding and sharing. + """ # Encoding the local generated mask logging.info("#######Client %d offline encoding round_id = %d######" % (self.get_sender_id(), self.round_idx)) @@ -237,6 +348,9 @@ def __offline(self): logging.info("finish share") def __train(self): + """ + Perform the training for the client. + """ logging.info("#######training########### round_id = %d" % self.round_idx) mlops.event("train", event_started=True, event_value=str(self.round_idx)) @@ -262,4 +376,7 @@ def __train(self): self.send_model_to_server(0, masked_weights, local_sample_num) def run(self): + """ + Start the client's execution. + """ super().run() diff --git a/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py b/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py index 89269c689e..38595164ac 100644 --- a/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py +++ b/python/fedml/cross_silo/lightsecagg/lsa_fedml_server_manager.py @@ -12,6 +12,8 @@ class FedMLServerManager(FedMLCommManager): + """FedML Server Manager class.""" + def __init__( self, args, @@ -23,6 +25,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the FedMLServerManager. + + Args: + args: Arguments for the manager. + aggregator: The aggregator for global model updates. + comm: Communication object. + client_rank: Rank of the client. + client_num: Number of clients. + backend: Communication backend. + is_preprocessed: Whether the data is preprocessed. + preprocessed_client_lists: List of preprocessed client data. + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.aggregator = aggregator @@ -55,6 +70,9 @@ def run(self): super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + """ global_model_params = self.aggregator.get_global_model_params() client_idx_in_this_round = 0 @@ -64,9 +82,13 @@ def send_init_msg(self): ) client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def register_message_receive_handlers(self): + """ + Register message receive handlers. + """ print("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_messag_connection_ready @@ -89,11 +111,19 @@ def register_message_receive_handlers(self): ) def handle_messag_connection_ready(self, msg_params): + """ + Handle the 'connection is ready' message. + + Args: + msg_params: Parameters of the message. + """ + self.client_id_list_in_this_round = self.aggregator.client_selection( self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) if not self.is_initialized: mlops.log_round_info(self.round_num, -1) @@ -107,6 +137,12 @@ def handle_messag_connection_ready(self, msg_params): client_idx_in_this_round += 1 def handle_message_client_status_update(self, msg_params): + """ + Handle client status update message. + + Args: + msg_params: Parameters of the message. + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == "ONLINE": self.client_online_mapping[str(msg_params.get_sender_id())] = True @@ -120,7 +156,8 @@ def handle_message_client_status_update(self, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: @@ -129,12 +166,25 @@ def handle_message_client_status_update(self, msg_params): self.is_initialized = True def handle_message_receive_encoded_mask_from_client(self, msg_params): + """ + Handle received encoded mask from client. + + Args: + msg_params: Parameters of the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) receive_id = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_ID) encoded_mask = msg_params.get(MyMessage.MSG_ARG_KEY_ENCODED_MASK) - self.send_message_encoded_mask_to_client(sender_id, receive_id, encoded_mask) + self.send_message_encoded_mask_to_client( + sender_id, receive_id, encoded_mask) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model from client. + + Args: + msg_params: Parameters of the message. + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event( "comm_c2s", event_started=False, event_value=str(self.round_idx), event_edge_id=sender_id, @@ -144,7 +194,8 @@ def handle_message_receive_model_from_client(self, msg_params): local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) self.active_clients_first_round.append(sender_id - 1) b_all_received = self.aggregator.check_whether_all_receive() @@ -153,13 +204,23 @@ def handle_message_receive_model_from_client(self, msg_params): if b_all_received: # Specify the active clients for the first round and inform them for receiver_id in range(1, self.size + 1): - self.send_message_to_active_client(receiver_id, self.active_clients_first_round) + self.send_message_to_active_client( + receiver_id, self.active_clients_first_round) def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): + """ + Handle received aggregate encoded mask from client. + + Args: + msg_params: Parameters of the message. + """ + # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - aggregate_encoded_mask = msg_params.get(MyMessage.MSG_ARG_KEY_AGGREGATE_ENCODED_MASK) - self.aggregator.add_local_aggregate_encoded_mask(sender_id - 1, aggregate_encoded_mask) + aggregate_encoded_mask = msg_params.get( + MyMessage.MSG_ARG_KEY_AGGREGATE_ENCODED_MASK) + self.aggregator.add_local_aggregate_encoded_mask( + sender_id - 1, aggregate_encoded_mask) logging.info( "Server handle_message_receive_aggregate_mask = %d from_client = %d" % (len(aggregate_encoded_mask), sender_id) @@ -167,12 +228,14 @@ def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): # Active clients for the second round self.active_clients_second_round.append(sender_id - 1) b_all_received = self.aggregator.check_whether_all_aggregate_encoded_mask_receive() - logging.info("Server: mask_all_received = " + str(b_all_received) + " in round_idx %d" % self.round_idx) + logging.info("Server: mask_all_received = " + + str(b_all_received) + " in round_idx %d" % self.round_idx) # TODO: add a timeout step # After receiving enough aggregate of encoded masks, server recovers the aggregate-model if b_all_received: - mlops.event("server.wait", event_started=False, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=False, + event_value=str(self.round_idx)) mlops.event( "server.agg_and_eval", event_started=True, event_value=str(self.round_idx), ) @@ -197,17 +260,20 @@ def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) client_idx_in_this_round = 0 for receiver_id in self.client_id_list_in_this_round: self.send_message_sync_model_to_client( - receiver_id, global_model_params, self.data_silo_index_list[client_idx_in_this_round], + receiver_id, global_model_params, self.data_silo_index_list[ + client_idx_in_this_round], ) client_idx_in_this_round += 1 - mlops.log_aggregated_model_info(self.round_idx + 1, self.aggregated_model_url) + mlops.log_aggregated_model_info( + self.round_idx + 1, self.aggregated_model_url) self.aggregated_model_url = None # start the next round @@ -216,18 +282,24 @@ def handle_message_receive_aggregate_encoded_mask_from_client(self, msg_params): self.active_clients_second_round = [] if self.round_idx == self.round_num: - logging.info("=================TRAINING IS FINISHED!=============") + logging.info( + "=================TRAINING IS FINISHED!=============") sleep(3) self.finish() if self.is_preprocessed: mlops.log_training_finished_status() - logging.info("=============training is finished. Cleanup...============") + logging.info( + "=============training is finished. Cleanup...============") self.cleanup() else: logging.info("waiting for another round...") - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def cleanup(self): + """ + Cleanup the server after training. + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: @@ -239,34 +311,86 @@ def cleanup(self): self.finish() def send_message_init_config(self, receive_id, global_model_params, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send an initialization configuration message to a client. + + Args: + receive_id: ID of the receiving client. + global_model_params: Global model parameters. + datasilo_index: Index of the data silo. + """ + message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) def send_message_encoded_mask_to_client(self, sender_id, receive_id, encoded_mask): - message = Message(MyMessage.MSG_TYPE_S2C_ENCODED_MASK_TO_CLIENT, self.get_sender_id(), receive_id,) + """ + Send an encoded mask to a client. + + Args: + sender_id: ID of the sender client. + receive_id: ID of the receiving client. + encoded_mask: Encoded mask to be sent. + """ + message = Message( + MyMessage.MSG_TYPE_S2C_ENCODED_MASK_TO_CLIENT, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_ID, sender_id) message.add_params(MyMessage.MSG_ARG_KEY_ENCODED_MASK, encoded_mask) self.send_message(message) def send_message_check_client_status(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a message to check the status of a client. + + Args: + receive_id: ID of the receiving client. + datasilo_index: Index of the data silo. + """ + + message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a finish message to a client. + + Args: + receive_id: ID of the receiving client. + datasilo_index: Index of the data silo. + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index): - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + """ + Send a synchronization message with the global model to a client. + + Args: + receive_id: ID of the receiving client. + global_model_params: Global model parameters to be synchronized. + client_index: Index of the client. + """ + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) @@ -275,7 +399,17 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, cli ) def send_message_to_active_client(self, receive_id, active_clients): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SEND_TO_ACTIVE_CLIENT, self.get_sender_id(), receive_id,) - message.add_params(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) + """ + Send a message to active clients. + + Args: + receive_id: ID of the receiving client. + active_clients: List of active client IDs. + """ + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SEND_TO_ACTIVE_CLIENT, self.get_sender_id(), receive_id,) + message.add_params( + MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) self.send_message(message) diff --git a/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py b/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py index c8d3c57668..46162edc1d 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_aggregator.py @@ -29,6 +29,20 @@ def __init__( args, model_trainer, ): + """ + + Args: + train_global: Global training data. + test_global: Global test data. + all_train_data_num: Total number of training samples. + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + train_data_local_num_dict: Number of local training samples for all clients. + client_num: Total number of clients. + device: Computing device (e.g., 'cuda' or 'cpu'). + args: Command-line arguments. + model_trainer: Model trainer instance. + """ self.trainer = model_trainer self.args = args @@ -54,9 +68,12 @@ def __init__( self.privacy_guarantee = int(np.floor(args.worker_num / 2)) self.prime_number = args.prime_number self.precision_parameter = args.precision_parameter - self.public_key_others = np.empty(self.num_pk_per_user * self.args.worker_num).astype("int64") - self.b_u_SS_others = np.empty((self.args.worker_num, self.args.worker_num), dtype="int64") - self.s_sk_SS_others = np.empty((self.args.worker_num, self.args.worker_num), dtype="int64") + self.public_key_others = np.empty( + self.num_pk_per_user * self.args.worker_num).astype("int64") + self.b_u_SS_others = np.empty( + (self.args.worker_num, self.args.worker_num), dtype="int64") + self.s_sk_SS_others = np.empty( + (self.args.worker_num, self.args.worker_num), dtype="int64") for idx in range(self.client_num): self.flag_client_model_uploaded_dict[idx] = False @@ -66,14 +83,36 @@ def __init__( self.dimensions = [] def get_global_model_params(self): + """ + Get the global model parameters. + + Returns: + global_model_params: Global model parameters. + """ global_model_params = self.trainer.get_model_params() - self.dimensions, self.total_dimension = model_dimension(global_model_params) + self.dimensions, self.total_dimension = model_dimension( + global_model_params) return global_model_params def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters: Global model parameters to be set. + """ self.trainer.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the locally trained model and sample count from a client. + + Args: + index: Index of the client. + model_params: Locally trained model parameters. + sample_num: Number of samples used for training. + """ + logging.info("add_model. index = %d" % index) # for key in model_params.keys(): # model_params[key] = model_params[key].to(self.device) @@ -82,6 +121,12 @@ def add_local_trained_result(self, index, model_params, sample_num): self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their locally trained models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: return False @@ -91,7 +136,15 @@ def check_whether_all_receive(self): def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): """ - Recover the aggregate-mask via decoding + Recover the aggregate-mask via decoding. + + Args: + active_clients (list): List of active client indices. + SS_rx (numpy.ndarray): Received secret shares. + public_key_list (numpy.ndarray): List of public keys. + + Returns: + numpy.ndarray: The reconstructed aggregate mask. """ d = self.total_dimension T = self.privacy_guarantee @@ -102,7 +155,8 @@ def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): for i in range(self.targeted_number_active_clients): if self.flag_client_model_uploaded_dict[i]: - SS_input = np.reshape(SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) + SS_input = np.reshape( + SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) b_u = BGW_decoding(SS_input, active_clients[: T + 1], p) np.random.seed(b_u[0][0]) mask = np.random.randint(0, p, size=d).astype(int) @@ -110,7 +164,8 @@ def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): # z = np.mod(z - temp, p) else: mask = np.zeros(d, dtype="int") - SS_input = np.reshape(SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) + SS_input = np.reshape( + SS_rx[i, active_clients[: T + 1]], (T + 1, 1)) s_sk_dec = BGW_decoding(SS_input, active_clients[: T + 1], p) for j in range(self.targeted_number_active_clients): s_pk_list_ = public_key_list[1, :] @@ -138,8 +193,21 @@ def aggregate_mask_reconstruction(self, active_clients, SS_rx, public_key_list): def aggregate_model_reconstruction( self, active_clients_first_round, active_clients_second_round, SS_rx, public_key_list ): + """ + Reconstruct the aggregate model using secret shares and aggregate masks. + + Args: + active_clients_first_round (list): List of active client indices in the first round. + active_clients_second_round (list): List of active client indices in the second round. + SS_rx (numpy.ndarray): Received secret shares. + public_key_list (numpy.ndarray): List of public keys. + + Returns: + dict: The reconstructed aggregate model parameters. + """ start_time = time.time() - aggregate_mask = self.aggregate_mask_reconstruction(active_clients_second_round, SS_rx, public_key_list) + aggregate_mask = self.aggregate_mask_reconstruction( + active_clients_second_round, SS_rx, public_key_list) p = self.prime_number q_bits = self.precision_parameter logging.info("Server starts the reconstruction of aggregate_model") @@ -164,9 +232,9 @@ def aggregate_model_reconstruction( cur_shape = np.shape(averaged_params[k]) d = self.dimensions[j] - #aggregate_mask = aggregate_mask.reshape((aggregate_mask.shape[0], 1)) + # aggregate_mask = aggregate_mask.reshape((aggregate_mask.shape[0], 1)) # logging.info('aggregate_mask shape = {}'.format(np.shape(aggregate_mask))) - cur_mask = np.array(aggregate_mask[pos : pos + d]) + cur_mask = np.array(aggregate_mask[pos: pos + d]) cur_mask = np.reshape(cur_mask, cur_shape) # Cancel out the aggregate-mask to recover the aggregate-model @@ -174,10 +242,11 @@ def aggregate_model_reconstruction( averaged_params[k] = np.mod(averaged_params[k], p) pos += d - # Convert the model from finite to real - logging.info("Server converts the aggregate_model from finite to tensor") - averaged_params = transform_finite_to_tensor(averaged_params, p, q_bits) + logging.info( + "Server converts the aggregate_model from finite to tensor") + averaged_params = transform_finite_to_tensor( + averaged_params, p, q_bits) # do the avg after transform for j, k in enumerate(averaged_params): w = 1 / len(active_clients_first_round) @@ -189,69 +258,107 @@ def aggregate_model_reconstruction( def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_round): """ + Select a subset of clients for data siloing. Args: - round_idx: round index, starting from 0 - client_num_in_total: this is equal to the users in a synthetic data, - e.g., in synthetic_1_1, this value is 30 - client_num_per_round: the number of edge devices that can train + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select. Returns: - data_silo_index_list: e.g., when client_num_in_total = 30, client_num_in_total = 3, - this value is the form of [0, 11, 20] - + list: List of selected client indices. """ logging.info( - "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) + "client_num_in_total = %d, client_num_per_round = %d" % ( + client_num_in_total, client_num_per_round) ) assert client_num_in_total >= client_num_per_round if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): """ + Select a subset of clients for training. + Args: - round_idx: round index, starting from 0 - client_id_list_in_total: this is the real edge IDs. - In MLOps, its element is real edge ID, e.g., [64, 65, 66, 67]; - in simulated mode, its element is client index starting from 1, e.g., [1, 2, 3, 4] - client_num_per_round: + round_idx (int): Round index, starting from 0. + client_id_list_in_total (list): List of real edge IDs or client indices. + client_num_per_round (int): Number of clients to select. Returns: - client_id_list_in_this_round: sampled real edge ID list, e.g., [64, 66] + list: List of selected client IDs or indices. """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Randomly sample a subset of clients for training. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to select. + + Returns: + list: List of selected client indices. + """ if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set. + + Args: + num_samples (int): Number of samples in the validation set. + + Returns: + DataLoader: DataLoader for the validation set. + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) - sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) - subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) - sample_testset = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num)) + subset = torch.utils.data.Subset( + self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size) return sample_testset else: return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Perform testing on the server for all clients. + + Args: + round_idx (int): Round index. + + Returns: + None + """ # if self.trainer.test_on_the_server( # self.train_data_local_dict, # self.test_data_local_dict, @@ -261,13 +368,15 @@ def test_on_server_for_all_clients(self, round_idx): # return if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) train_num_samples = [] train_tot_corrects = [] train_losses = [] for client_idx in range(self.args.client_num_in_total): # train data - metrics = self.trainer.test(self.train_data_local_dict[client_idx], self.device, self.args) + metrics = self.trainer.test( + self.train_data_local_dict[client_idx], self.device, self.args) train_tot_correct, train_num_sample, train_loss = ( metrics["test_correct"], metrics["test_total"], @@ -286,7 +395,8 @@ def test_on_server_for_all_clients(self, round_idx): stats = {"training_acc": train_acc, "training_loss": train_loss} logging.info(stats) - mlops.log({"accuracy": round(train_acc, 4), "loss": round(train_loss, 4)}) + mlops.log({"accuracy": round(train_acc, 4), + "loss": round(train_loss, 4)}) # test data test_num_samples = [] @@ -294,9 +404,11 @@ def test_on_server_for_all_clients(self, round_idx): test_losses = [] if round_idx == self.args.comm_round - 1: - metrics = self.trainer.test(self.test_global, self.device, self.args) + metrics = self.trainer.test( + self.test_global, self.device, self.args) else: - metrics = self.trainer.test(self.val_global, self.device, self.args) + metrics = self.trainer.test( + self.val_global, self.device, self.args) test_tot_correct, test_num_sample, test_loss = ( metrics["test_correct"], diff --git a/python/fedml/cross_silo/secagg/sa_fedml_api.py b/python/fedml/cross_silo/secagg/sa_fedml_api.py index ba0b6cbb8f..b8f3dbd8df 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_api.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_api.py @@ -8,6 +8,26 @@ def FedML_SA_Horizontal( args, client_rank, client_num, comm, device, dataset, model, model_trainer=None, preprocessed_sampling_lists=None, ): + """ + Initialize and run the Secure Aggregation-based Horizontal Federated Learning. + + This function initializes either the server or client based on the client_rank and runs + the Secure Aggregation-based Horizontal Federated Learning. + + Args: + args: Command-line arguments. + client_rank: Rank of the client. + client_num: Total number of clients. + comm: Communication backend. + device: Computing device (e.g., 'cuda' or 'cpu'). + dataset: Federated dataset containing data and metadata. + model: Federated model. + model_trainer: Model trainer instance (default: None). + preprocessed_sampling_lists: Preprocessed sampling lists (default: None). + + Returns: + None + """ [ train_data_num, test_data_num, @@ -67,6 +87,31 @@ def init_server( model_trainer, preprocessed_sampling_lists=None, ): + """ + Initialize the server for Secure Aggregation-based Horizontal Federated Learning. + + This function initializes the server for Secure Aggregation-based Horizontal Federated Learning. + + Args: + args: Command-line arguments. + device: Computing device (e.g., 'cuda' or 'cpu'). + comm: Communication backend. + client_rank: Rank of the client (server rank is 0). + client_num: Total number of clients. + model: Federated model. + train_data_num: Total number of training samples. + train_data_global: Global training data. + test_data_global: Global test data. + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + train_data_local_num_dict: Number of local training samples for all clients. + model_trainer: Model trainer instance. + preprocessed_sampling_lists: Preprocessed sampling lists (default: None). + + Returns: + None + """ + if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(0) @@ -88,7 +133,8 @@ def init_server( # start the distributed training backend = args.backend if preprocessed_sampling_lists is None: - server_manager = FedMLServerManager(args, aggregator, comm, client_rank, client_num, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, client_rank, client_num, backend) else: server_manager = FedMLServerManager( args, @@ -117,6 +163,27 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize a client for Secure Aggregation-based Horizontal Federated Learning. + + This function initializes a client for Secure Aggregation-based Horizontal Federated Learning. + + Args: + args: Command-line arguments. + device: Computing device (e.g., 'cuda' or 'cpu'). + comm: Communication backend. + client_rank: Rank of the client. + client_num: Total number of clients. + model: Federated model. + train_data_num: Total number of training samples. + train_data_local_num_dict: Number of local training samples for all clients. + train_data_local_dict: Local training data for all clients. + test_data_local_dict: Local test data for all clients. + model_trainer: Model trainer instance (default: None). + + Returns: + None + """ if model_trainer is None: model_trainer = create_model_trainer(model, args) model_trainer.set_id(client_rank) @@ -131,5 +198,6 @@ def init_client( args, model_trainer, ) - client_manager = FedMLClientManager(args, trainer, comm, client_rank, client_num, backend) + client_manager = FedMLClientManager( + args, trainer, comm, client_rank, client_num, backend) client_manager.run() diff --git a/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py b/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py index 8eff9828ea..652c8f36ee 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_client_manager.py @@ -19,6 +19,20 @@ class FedMLClientManager(FedMLCommManager): def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backend="MPI"): + """ + Initialize the client object. + + Args: + args: Command-line arguments passed to the client. + trainer: The trainer object responsible for training. + comm: Communication handler (optional). + client_rank: Rank of the client (optional). + client_num: Number of clients (optional). + backend: Communication backend (optional). + + Returns: + None + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.trainer = trainer @@ -35,9 +49,12 @@ def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backen self.privacy_guarantee = int(np.floor(args.worker_num / 2)) self.prime_number = args.prime_number self.precision_parameter = args.precision_parameter - self.public_key_others = np.empty(self.num_pk_per_user * self.worker_num).astype("int64") - self.b_u_SS_others = np.empty((self.worker_num, self.worker_num), dtype="int64") - self.s_sk_SS_others = np.empty((self.worker_num, self.worker_num), dtype="int64") + self.public_key_others = np.empty( + self.num_pk_per_user * self.worker_num).astype("int64") + self.b_u_SS_others = np.empty( + (self.worker_num, self.worker_num), dtype="int64") + self.s_sk_SS_others = np.empty( + (self.worker_num, self.worker_num), dtype="int64") self.client_real_ids = json.loads(args.client_id_list) logging.info("self.client_real_ids = {}".format(self.client_real_ids)) @@ -48,6 +65,18 @@ def __init__(self, args, trainer, comm=None, client_rank=0, client_num=0, backen self.sys_stats_process = None def register_message_receive_handlers(self): + """ + Register message receive handlers for different message types. + + This method registers handlers for various message types that the client + can receive from the server. + + Args: + self: The client instance. + + Returns: + None + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -56,7 +85,8 @@ def register_message_receive_handlers(self): MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status ) - self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) + self.register_message_receive_handler( + MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.handle_message_receive_model_from_server, @@ -75,6 +105,19 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle a connection-ready message from the server. + + This method handles the initial connection-ready message from the server, + sends a client status message, and logs system performance. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -82,10 +125,37 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle a message to check the client's status. + + This method handles a message from the server to check the client's status + and responds accordingly. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ self.send_client_status(0) def handle_message_init(self, msg_params): - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + """ + Handle an initialization message from the server. + + This method handles an initialization message from the server, updates + the client's dataset and model, and reports the training status to MLOps. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ + global_model_params = msg_params.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) logging.info("client_index = %s" % str(client_index)) @@ -93,7 +163,8 @@ def handle_message_init(self, msg_params): # Notify MLOps with training status. self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING) - self.dimensions, self.total_dimension = model_dimension(global_model_params) + self.dimensions, self.total_dimension = model_dimension( + global_model_params) self.trainer.update_dataset(int(client_index)) self.trainer.update_model(global_model_params) @@ -102,6 +173,19 @@ def handle_message_init(self, msg_params): self.__offline() def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the reception of a model from the server. + + This method updates the client's dataset and model based on the received + model parameters and handles the completion of training if it's the last round. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -116,22 +200,67 @@ def handle_message_receive_model_from_server(self, msg_params): return self.round_idx += 1 if (not self.dimensions) or (not self.total_dimension): - self.dimensions, self.total_dimension = model_dimension(model_params) + self.dimensions, self.total_dimension = model_dimension( + model_params) self.__offline() def handle_message_receive_pk_others(self, msg_params): - self.public_key_others = msg_params.get(MyMessage.MSG_ARG_KEY_PK_OTHERS) - logging.info(" self.public_key_others = {}".format( self.public_key_others)) - self.public_key_others = np.reshape(self.public_key_others, (self.num_pk_per_user, self.worker_num)) + """ + Handle the reception of public keys from other clients. + + This method handles the reception of public keys from other clients for secure aggregation. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ + + self.public_key_others = msg_params.get( + MyMessage.MSG_ARG_KEY_PK_OTHERS) + logging.info(" self.public_key_others = {}".format( + self.public_key_others)) + self.public_key_others = np.reshape( + self.public_key_others, (self.num_pk_per_user, self.worker_num)) def handle_message_receive_ss_others(self, msg_params): - self.s_sk_SS_others = msg_params.get(MyMessage.MSG_ARG_KEY_SK_SS_OTHERS).flatten() - self.b_u_SS_others = msg_params.get(MyMessage.MSG_ARG_KEY_B_SS_OTHERS).flatten() + """ + Handle the reception of encoded masks from other clients. + + This method handles the reception of encoded masks (s_sk_SS and b_u_SS) from other clients + for secure aggregation. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ + self.s_sk_SS_others = msg_params.get( + MyMessage.MSG_ARG_KEY_SK_SS_OTHERS).flatten() + self.b_u_SS_others = msg_params.get( + MyMessage.MSG_ARG_KEY_B_SS_OTHERS).flatten() self.s_pk_list = self.public_key_others[1, :] self.s_uv = np.mod(self.s_pk_list * self.my_s_sk, self.prime_number) self.__train() def handle_message_receive_active_from_server(self, msg_params): + """ + Handle the reception of active client IDs from the server. + + This method handles the reception of active client IDs from the server and decides which + encoded masks to send based on active clients. + + Args: + self: The client instance. + msg_params: A dictionary containing message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) # Receive the set of active client id in first round active_clients = msg_params.get(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS) @@ -148,8 +277,22 @@ def handle_message_receive_active_from_server(self, msg_params): self._send_others_ss_to_server(SS_info) def send_client_status(self, receive_id, status="ONLINE"): + """ + Send a client status message to the server. + + This method sends a client status message to the server to indicate the client's status. + + Args: + self: The client instance. + receive_id: The ID of the receiving entity (usually the server). + status: The status message (default is "ONLINE"). + + Returns: + None + """ logging.info("send_client_status") - message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) + message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, + self.client_real_id, receive_id) sys_name = platform.system() if sys_name == "Darwin": sys_name = "Mac" @@ -161,11 +304,40 @@ def send_client_status(self, receive_id, status="ONLINE"): self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + This method reports the training status to MLOps for tracking. + + Args: + self: The client instance. + status: The training status message. + + Returns: + None + """ mlops.log_training_status(status) def send_model_to_server(self, receive_id, weights, local_sample_num): - mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) - message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) + """ + Send the trained model to the server. + + This method sends the trained model and relevant information to the server. + + Args: + self: The client instance. + receive_id: The ID of the receiving entity (usually the server). + weights: The model parameters/weights. + local_sample_num: The number of local training samples. + + Returns: + None + """ + + mlops.event("comm_c2s", event_started=True, + event_value=str(self.round_idx)) + message = Message( + MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.get_sender_id(), receive_id,) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) @@ -176,6 +348,18 @@ def send_model_to_server(self, receive_id, weights, local_sample_num): ) def _send_public_key_to_sever(self, public_key): + """ + Send the public key to the server. + + This method sends the client's public key to the server for secure aggregation. + + Args: + self: The client instance. + public_key: The public key to send. + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_PK_TO_SERVER, self.get_sender_id(), 0 ) @@ -183,6 +367,19 @@ def _send_public_key_to_sever(self, public_key): self.send_message(message) def _send_secret_share_to_sever(self, b_u_SS, s_sk_SS): + """ + Send the secret shares to the server. + + This method sends the secret shares (b_u_SS and s_sk_SS) to the server for secure aggregation. + + Args: + self: The client instance. + b_u_SS: The encoded mask (b values). + s_sk_SS: The encoded mask (s_sk values). + + Returns: + None + """ message = Message( MyMessage.MSG_TYPE_C2S_SEND_SS_TO_SERVER, self.get_sender_id(), 0 ) @@ -191,12 +388,24 @@ def _send_secret_share_to_sever(self, b_u_SS, s_sk_SS): self.send_message(message) def _send_others_ss_to_server(self, ss_info): + """ + Send secret shares to the server. + + This method sends secret shares (ss_info) to the server for secure aggregation. + + Args: + self: The client instance. + ss_info: Secret shares to send. + + Returns: + None + """ # for j, k in enumerate(self.finite_w): - # if j == 0: - # logging.info("Sent from %d" % (self.rank - 1)) - # logging.info(self.finite_w[k][0]) - # break + # if j == 0: + # logging.info("Sent from %d" % (self.rank - 1)) + # logging.info(self.finite_w[k][0]) + # break message = Message( MyMessage.MSG_TYPE_C2S_SEND_SS_OTHERS_TO_SERVER, @@ -210,16 +419,42 @@ def _send_others_ss_to_server(self, ss_info): self.send_message(message) def get_model_dimension(self, weights): + """ + Get the dimensions of the model. + + This method calculates and returns the dimensions of the model based on its weights. + + Args: + self: The client instance. + weights: Model weights. + + Returns: + None + """ self.dimensions, self.total_dimension = model_dimension(weights) def mask(self, weights): + """ + Apply masking to the model weights. + + This method applies masking to the model weights to protect privacy during aggregation. + + Args: + self: The client instance. + weights: Model weights. + + Returns: + Masked model weights. + """ if (not self.dimensions) or (not self.total_dimension): - self.dimensions, self.total_dimension = self.get_model_dimension(weights) + self.dimensions, self.total_dimension = self.get_model_dimension( + weights) q_bits = self.precision_parameter self.infinite_w = copy.deepcopy(weights) - weights_finite = transform_tensor_to_finite(weights, self.prime_number, q_bits) + weights_finite = transform_tensor_to_finite( + weights, self.prime_number, q_bits) self.finite_w = copy.deepcopy(weights_finite) @@ -228,10 +463,12 @@ def mask(self, weights): for i in range(1, self.worker_num + 1): if self.rank == i: np.random.seed(self.b_u) - temp = np.random.randint(0, self.prime_number, size=d).astype(int) + temp = np.random.randint( + 0, self.prime_number, size=d).astype(int) logging.info("b for %d to %d" % (self.rank, i)) logging.info(temp) - self.local_mask = np.mod(self.local_mask + temp, self.prime_number) + self.local_mask = np.mod( + self.local_mask + temp, self.prime_number) # temp = np.zeros(d,dtype='int') elif self.rank > i: np.random.seed(self.s_uv[i - 1]) @@ -242,12 +479,14 @@ def mask(self, weights): logging.info("{},{}".format(self.rank - 1, i - 1)) # Debugging Block End # ################################## - temp = np.random.randint(0, self.prime_number, size=d).astype(int) + temp = np.random.randint( + 0, self.prime_number, size=d).astype(int) logging.info("s for %d to %d" % (self.rank, i)) logging.info(temp) # if self.rank == 1: # print '############ (seed, temp)=', self.s_uv[i-1], temp - self.local_mask = np.mod(self.local_mask + temp, self.prime_number) + self.local_mask = np.mod( + self.local_mask + temp, self.prime_number) else: np.random.seed(self.s_uv[i - 1]) ################################## @@ -257,23 +496,40 @@ def mask(self, weights): logging.info("{},{}".format(self.rank - 1, i - 1)) # Debugging Block End # ################################## - temp = -np.random.randint(0, self.prime_number, size=d).astype(int) + temp = - \ + np.random.randint(0, self.prime_number, size=d).astype(int) logging.info("s for %d to %d" % (self.rank, i)) logging.info(temp) # if self.rank == 1: # print '############ (seed, temp)=', self.s_uv[i-1], temp - self.local_mask = np.mod(self.local_mask + temp, self.prime_number) + self.local_mask = np.mod( + self.local_mask + temp, self.prime_number) logging.info("Client") logging.info(self.rank) - masked_weights = model_masking(weights_finite, self.dimensions, self.local_mask, self.prime_number) + masked_weights = model_masking( + weights_finite, self.dimensions, self.local_mask, self.prime_number) return masked_weights def __offline(self): + """ + Perform offline setup for secure aggregation. + + This method performs the necessary offline setup for secure aggregation, including generating + keys, secret shares, and sending them to the server. + + Args: + self: The client instance. + + Returns: + None + """ np.random.seed(self.rank) - self.sk = np.random.randint(0, self.prime_number, size=(2)).astype("int64") + self.sk = np.random.randint( + 0, self.prime_number, size=(2)).astype("int64") self.pk = my_pk_gen(self.sk, self.prime_number, 0) - self.key = np.concatenate((self.pk, self.sk)) # length=4 : c_pk, s_pk, c_sk, s_sk + # length=4 : c_pk, s_pk, c_sk, s_sk + self.key = np.concatenate((self.pk, self.sk)) self._send_public_key_to_sever(self.key[0:2]) @@ -282,8 +538,10 @@ def __offline(self): self.b_u = self.my_c_sk - self.SS_input = np.reshape(np.array([self.my_c_sk, self.my_s_sk]), (2, 1)) - self.my_SS = BGW_encoding(self.SS_input, self.worker_num, self.privacy_guarantee, self.prime_number) + self.SS_input = np.reshape( + np.array([self.my_c_sk, self.my_s_sk]), (2, 1)) + self.my_SS = BGW_encoding( + self.SS_input, self.worker_num, self.privacy_guarantee, self.prime_number) self.b_u_SS = self.my_SS[:, 0, 0].astype("int64") self.s_sk_SS = self.my_SS[:, 1, 0].astype("int64") @@ -293,14 +551,29 @@ def __offline(self): self._send_secret_share_to_sever(self.b_u_SS, self.s_sk_SS) def __train(self): - logging.info("#######training########### round_id = %d" % self.round_idx) - mlops.event("train", event_started=True, event_value=str(self.round_idx)) + """ + Perform the training for a round. + + This method initiates the training process for the current round and sends the trained model + to the server after applying masking. + + Args: + self: The client instance. + + Returns: + None + """ + logging.info("#######training########### round_id = %d" % + self.round_idx) + mlops.event("train", event_started=True, + event_value=str(self.round_idx)) weights, local_sample_num = self.trainer.train(self.round_idx) # logging.info( # "Client %d original weights = %s" % (self.get_sender_id(), weights) # ) - mlops.event("train", event_started=False, event_value=str(self.round_idx)) + mlops.event("train", event_started=False, + event_value=str(self.round_idx)) # Mask the local model masked_weights = self.mask(weights) @@ -312,4 +585,15 @@ def __train(self): self.send_model_to_server(0, masked_weights, local_sample_num) def run(self): + """ + Run the client. + + This method starts the client and its communication loop. + + Args: + self: The client instance. + + Returns: + None + """ super().run() diff --git a/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py b/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py index 0614b7e966..a24a7aa7fc 100644 --- a/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py +++ b/python/fedml/cross_silo/secagg/sa_fedml_server_manager.py @@ -23,6 +23,19 @@ def __init__( is_preprocessed=False, preprocessed_client_lists=None, ): + """ + Initialize the Federated Learning Server Manager. + + Args: + args (object): Arguments object containing configuration parameters. + aggregator (object): Federated learning aggregator. + comm (object, optional): Communication manager (default: None). + client_rank (int, optional): Rank of the client (default: 0). + client_num (int, optional): Number of clients (default: 0). + backend (str, optional): Backend for communication (default: "MQTT_S3"). + is_preprocessed (bool, optional): Whether the data is preprocessed (default: False). + preprocessed_client_lists (list, optional): List of preprocessed clients (default: None). + """ super().__init__(args, comm, client_rank, client_num, backend) self.args = args self.aggregator = aggregator @@ -46,15 +59,19 @@ def __init__( self.ss_received = 0 self.num_pk_per_user = 2 self.public_key_list = np.empty( - shape=(self.num_pk_per_user, self.targeted_number_active_clients), dtype="int64" + shape=(self.num_pk_per_user, + self.targeted_number_active_clients), dtype="int64" ) self.b_u_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) self.s_sk_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) - self.SS_rx = np.empty((self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64") + self.SS_rx = np.empty((self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64") self.aggregated_model_url = None @@ -63,9 +80,26 @@ def __init__( self.data_silo_index_list = None def run(self): + """ + Start the Federated Learning Server Manager. + + This method starts the server manager and begins the federated learning process. + """ super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + This method sends initialization messages to all clients, providing them with the + global model parameters to start training. + + Args: + None + + Returns: + None + """ global_model_params = self.aggregator.get_global_model_params() client_idx_in_this_round = 0 @@ -75,9 +109,22 @@ def send_init_msg(self): ) client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def register_message_receive_handlers(self): + """ + Register message receive handlers for server communication. + + This method registers various message receive handlers for different types of + communication messages received by the server. + + Args: + None + + Returns: + None + """ print("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_messag_connection_ready @@ -104,11 +151,25 @@ def register_message_receive_handlers(self): ) def handle_messag_connection_ready(self, msg_params): + """ + Handle a connection-ready message from clients. + + This function processes client connection requests and initializes necessary + parameters for the server's operation. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ self.client_id_list_in_this_round = self.aggregator.client_selection( self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) if not self.is_initialized: mlops.log_round_info(self.round_num, -1) @@ -122,6 +183,19 @@ def handle_messag_connection_ready(self, msg_params): client_idx_in_this_round += 1 def handle_message_client_status_update(self, msg_params): + """ + Handle a message containing client status updates. + + This function updates the server's record of client statuses and takes + appropriate actions when all clients are online. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == "ONLINE": self.client_online_mapping[str(msg_params.get_sender_id())] = True @@ -135,7 +209,8 @@ def handle_message_client_status_update(self, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: @@ -144,18 +219,45 @@ def handle_message_client_status_update(self, msg_params): self.is_initialized = True def _handle_message_receive_public_key(self, msg_params): + """ + Handle the reception of public keys from clients. + + This function receives and processes public keys from active clients, + combining them for further use. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) public_key = msg_params.get(MyMessage.MSG_ARG_KEY_PK) self.public_key_list[:, sender_id - 1] = public_key self.public_keys_received += 1 if self.public_keys_received == self.targeted_number_active_clients: - data = np.reshape(self.public_key_list, self.num_pk_per_user * self.targeted_number_active_clients) + data = np.reshape( + self.public_key_list, self.num_pk_per_user * self.targeted_number_active_clients) for i in range(self.targeted_number_active_clients): logging.info("sending data = {}".format(data)) self._send_public_key_others_to_user(i + 1, data) def _handle_message_receive_ss(self, msg_params): + """ + Handle the reception of encoded masks from clients. + + This function receives and processes encoded masks from active clients, + aggregating them for further use. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) b_u_SS = msg_params.get(MyMessage.MSG_ARG_KEY_B_SS) @@ -165,9 +267,23 @@ def _handle_message_receive_ss(self, msg_params): self.ss_received += 1 if self.ss_received == self.targeted_number_active_clients: for i in range(self.targeted_number_active_clients): - self._send_ss_others_to_user(i + 1, self.b_u_SS_list[:, i], self.s_sk_SS_list[:, i]) + self._send_ss_others_to_user( + i + 1, self.b_u_SS_list[:, i], self.s_sk_SS_list[:, i]) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the reception of a trained model from a client. + + This function receives and processes a trained model from a client, + updating the server's records and taking appropriate actions. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) mlops.event( "comm_c2s", event_started=False, event_value=str(self.round_idx), event_edge_id=sender_id, @@ -177,7 +293,8 @@ def handle_message_receive_model_from_client(self, msg_params): local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) self.active_clients_first_round.append(sender_id - 1) b_all_received = self.aggregator.check_whether_all_receive() @@ -186,9 +303,23 @@ def handle_message_receive_model_from_client(self, msg_params): if b_all_received: # Specify the active clients for the first round and inform them for receiver_id in range(1, self.size + 1): - self._send_message_to_active_client(receiver_id, self.active_clients_first_round) + self._send_message_to_active_client( + receiver_id, self.active_clients_first_round) def _handle_message_receive_ss_others_from_client(self, msg_params): + """ + Handle the reception of encoded masks from clients in the second round. + + This function receives and processes encoded masks from clients in the + second round, and performs model aggregation and evaluation. + + Args: + self: The server instance. + msg_params (dict): A dictionary containing message parameters. + + Returns: + None + """ # Receive the aggregate of encoded masks for active clients sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) ss_others = msg_params.get(MyMessage.MSG_ARG_KEY_SS_OTHERS) @@ -213,13 +344,15 @@ def _handle_message_receive_ss_others_from_client(self, msg_params): self.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) client_idx_in_this_round = 0 for receiver_id in self.client_id_list_in_this_round: self.send_message_sync_model_to_client( - receiver_id, global_model_params, self.data_silo_index_list[client_idx_in_this_round], + receiver_id, global_model_params, self.data_silo_index_list[ + client_idx_in_this_round], ) client_idx_in_this_round += 1 @@ -232,31 +365,50 @@ def _handle_message_receive_ss_others_from_client(self, msg_params): self.ss_received = 0 self.num_pk_per_user = 2 self.public_key_list = np.empty( - shape=(self.num_pk_per_user, self.targeted_number_active_clients), dtype="int64" + shape=(self.num_pk_per_user, + self.targeted_number_active_clients), dtype="int64" ) self.b_u_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) self.s_sk_SS_list = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) self.SS_rx = np.empty( - (self.targeted_number_active_clients, self.targeted_number_active_clients), dtype="int64" + (self.targeted_number_active_clients, + self.targeted_number_active_clients), dtype="int64" ) if self.round_idx == self.round_num: - logging.info("=================TRAINING IS FINISHED!=============") + logging.info( + "=================TRAINING IS FINISHED!=============") sleep(3) self.finish() if self.is_preprocessed: mlops.log_training_finished_status() - logging.info("=============training is finished. Cleanup...============") + logging.info( + "=============training is finished. Cleanup...============") self.cleanup() else: logging.info("waiting for another round...") - mlops.event("server.wait", event_started=True, event_value=str(self.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.round_idx)) def cleanup(self): + """ + Cleanup function to finish the training process. + + This function is responsible for cleaning up after the training process, + sending finish messages to clients, and finalizing the server's state. + + Args: + self: The server instance. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: @@ -268,28 +420,98 @@ def cleanup(self): self.finish() def send_message_init_config(self, receive_id, global_model_params, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send an initialization configuration message to a client. + + This function sends an initialization message containing global model + parameters and other configuration details to a specific client. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + datasilo_index (int): The index of the data silo associated with the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) def send_message_check_client_status(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a message to check the status of a client. + + This function sends a message to a client to check its status and readiness. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + datasilo_index (int): The index of the data silo associated with the client. + + Returns: + None + """ + + message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a finish message to a client. + + This function sends a finish message to a client to signal the end of the + training process. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + datasilo_index (int): The index of the data silo associated with the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index): - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + """ + Send a message to synchronize the global model with a client. + + This function sends a synchronization message to a specific client, + containing the global model parameters and client index. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving client. + global_model_params (dict): Global model parameters. + client_index (int): The index of the client. + + Returns: + None + """ + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id,) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) @@ -298,20 +520,71 @@ def send_message_sync_model_to_client(self, receive_id, global_model_params, cli ) def _send_public_key_others_to_user(self, receive_id, public_key_other): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_OTHER_PK_TO_CLIENT, self.get_sender_id(), receive_id) + """ + Send public keys to a user/client. + + This function sends public keys to a specific user/client, typically during + a secure communication setup. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving user/client. + public_key_other: The public keys to send. + + Returns: + None + """ + + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_OTHER_PK_TO_CLIENT, + self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_PK_OTHERS, public_key_other) self.send_message(message) def _send_ss_others_to_user(self, receive_id, b_ss_others, sk_ss_others): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_OTHER_SS_TO_CLIENT, self.get_sender_id(), receive_id) + """ + Send encoded masks to a user/client. + + This function sends encoded masks to a specific user/client, typically during + a secure communication setup. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving user/client. + b_ss_others: Encoded masks (b values) to send. + sk_ss_others: Encoded masks (sk values) to send. + + Returns: + None + """ + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_OTHER_SS_TO_CLIENT, + self.get_sender_id(), receive_id) message.add_params(MyMessage.MSG_ARG_KEY_B_SS_OTHERS, b_ss_others) message.add_params(MyMessage.MSG_ARG_KEY_SK_SS_OTHERS, sk_ss_others) self.send_message(message) def _send_message_to_active_client(self, receive_id, active_clients): - logging.info("Server send_message_to_active_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_ACTIVE_CLIENT_LIST, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) + """ + Send a message to active clients. + + This function sends a message to a specific user/client containing a list of + active clients, typically during initialization. + + Args: + self: The server instance. + receive_id (int): The ID of the receiving user/client. + active_clients (list): A list of active client IDs. + + Returns: + None + """ + logging.info( + "Server send_message_to_active_client. receive_id = %d" % receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_ACTIVE_CLIENT_LIST, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_ACTIVE_CLIENTS, active_clients) self.send_message(message) diff --git a/python/fedml/cross_silo/server/fedml_aggregator.py b/python/fedml/cross_silo/server/fedml_aggregator.py index 5d56b8c1cf..7a4c13a2af 100644 --- a/python/fedml/cross_silo/server/fedml_aggregator.py +++ b/python/fedml/cross_silo/server/fedml_aggregator.py @@ -11,6 +11,33 @@ class FedMLAggregator(object): + """ + Represents an aggregator for federated learning. + + Args: + train_global (object): The global training dataset. + test_global (object): The global testing dataset. + all_train_data_num (int): The total number of training data points. + train_data_local_dict (dict): A dictionary containing local training datasets. + test_data_local_dict (dict): A dictionary containing local testing datasets. + train_data_local_num_dict (dict): A dictionary containing the number of local training data points. + client_num (int): The number of clients. + device (torch.device): The device (e.g., 'cpu' or 'cuda') for computation. + args (object): An object containing various configuration parameters. + server_aggregator (ServerAggregator, optional): An optional server aggregator. + + Attributes: + None + + Methods: + get_global_model_params(): Get the global model parameters. + set_global_model_params(model_parameters): Set the global model parameters. + add_local_trained_result(index, model_params, sample_num): Add locally trained model results. + check_whether_all_receive(): Check if all clients have uploaded their models. + aggregate(): Aggregate model updates from clients. + assess_contribution(): Assess the contribution of clients. + """ + def __init__( self, train_global, @@ -49,23 +76,62 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Get the global model parameters. + + Args: + None + + Returns: + object: The global model parameters. + """ return self.aggregator.get_model_params() def set_global_model_params(self, model_parameters): + """ + Set the global model parameters. + + Args: + model_parameters (object): The global model parameters. + + Returns: + None + """ self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add locally trained model results. + + Args: + index (int): The index of the client. + model_params (object): The locally trained model parameters. + sample_num (int): The number of training samples used. + + Returns: + None + """ logging.info("add_model. index = %d" % index) - # for dictionary model_params, we let the user level code to control the device + # for dictionary model_params, we let the user level code control the device if type(model_params) is not dict: - model_params = ml_engine_adapter.model_params_to_device(self.args, model_params, self.device) + model_params = ml_engine_adapter.model_params_to_device( + self.args, model_params, self.device) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their models. + + Args: + None + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.debug("client_num = {}".format(self.client_num)) for idx in range(self.client_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -75,27 +141,44 @@ def check_whether_all_receive(self): return True def aggregate(self): + """ + Aggregate model updates from clients. + + Args: + None + + Returns: + object: The aggregated model parameters. + list: The list of models after outlier removal. + list: The list of model indexes. + """ start_time = time.time() model_list = [] for idx in range(self.client_num): - model_list.append((self.sample_num_dict[idx], self.model_dict[idx])) + model_list.append( + (self.sample_num_dict[idx], self.model_dict[idx])) # model_list is the list after outlier removal - model_list, model_list_idxes = self.aggregator.on_before_aggregation(model_list) + model_list, model_list_idxes = self.aggregator.on_before_aggregation( + model_list) Context().add(Context.KEY_CLIENT_MODEL_LIST, model_list) averaged_params = self.aggregator.aggregate(model_list) if type(averaged_params) is dict: - if len(averaged_params) == self.client_num + 1: # aggregator pass extra {-1 : global_parms_dict} as global_params - itr_count = len(averaged_params) - 1 # do not apply on_after_aggregation to client -1 + # aggregator pass extra {-1 : global_parms_dict} as global_params + if len(averaged_params) == self.client_num + 1: + # do not apply on_after_aggregation to client -1 + itr_count = len(averaged_params) - 1 else: itr_count = len(averaged_params) for client_index in range(itr_count): - averaged_params[client_index] = self.aggregator.on_after_aggregation(averaged_params[client_index]) + averaged_params[client_index] = self.aggregator.on_after_aggregation( + averaged_params[client_index]) else: - averaged_params = self.aggregator.on_after_aggregation(averaged_params) + averaged_params = self.aggregator.on_after_aggregation( + averaged_params) self.set_global_model_params(averaged_params) @@ -104,6 +187,17 @@ def aggregate(self): return averaged_params, model_list, model_list_idxes def assess_contribution(self): + """ + Assess the contribution of clients. + + If enabled, this method assesses the contribution of clients in the federated learning process. + + Args: + None + + Returns: + None + """ if hasattr(self.args, "enable_contribution") and \ self.args.enable_contribution is not None and self.args.enable_contribution: self.aggregator.assess_contribution() @@ -123,15 +217,18 @@ def data_silo_selection(self, round_idx, client_num_in_total, client_num_per_rou """ logging.info( - "client_num_in_total = %d, client_num_per_round = %d" % (client_num_in_total, client_num_per_round) + "client_num_in_total = %d, client_num_per_round = %d" % ( + client_num_in_total, client_num_per_round) ) assert client_num_in_total >= client_num_per_round if client_num_in_total == client_num_per_round: return [i for i in range(client_num_per_round)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(client_num_in_total), client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(client_num_in_total), client_num_per_round, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): @@ -148,33 +245,73 @@ def client_selection(self, round_idx, client_id_list_in_total, client_num_per_ro """ if client_num_per_round == len(client_id_list_in_total): return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Sample a subset of clients for a federated learning round. + + Args: + round_idx (int): The round index, starting from 0. + client_num_in_total (int): The total number of clients. + client_num_per_round (int): The number of clients to sample for the round. + + Returns: + list: A list of sampled client indexes. + + """ if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _generate_validation_set(self, num_samples=10000): + """ + Generate a validation set. + + Args: + num_samples (int): The number of samples to include in the validation set (default is 10,000). + + Returns: + object: The validation dataset. + + """ if self.args.dataset.startswith("stackoverflow"): test_data_num = len(self.test_global.dataset) - sample_indices = random.sample(range(test_data_num), min(num_samples, test_data_num)) - subset = torch.utils.data.Subset(self.test_global.dataset, sample_indices) - sample_testset = torch.utils.data.DataLoader(subset, batch_size=self.args.batch_size) + sample_indices = random.sample( + range(test_data_num), min(num_samples, test_data_num)) + subset = torch.utils.data.Subset( + self.test_global.dataset, sample_indices) + sample_testset = torch.utils.data.DataLoader( + subset, batch_size=self.args.batch_size) return sample_testset else: return self.test_global def test_on_server_for_all_clients(self, round_idx): + """ + Test the global model on all clients. + + Args: + round_idx (int): The round index. + + Returns: + None + """ if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -184,25 +321,39 @@ def test_on_server_for_all_clients(self, round_idx): if round_idx == self.args.comm_round - 1: # we allow to return four metrics, such as accuracy, AUC, loss, etc. - metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + metric_result_in_current_round = self.aggregator.test( + self.test_global, self.device, self.args) else: - metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) - logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) - metric_results_in_the_last_round = Context().get(Context.KEY_METRICS_ON_AGGREGATED_MODEL) + metric_result_in_current_round = self.aggregator.test( + self.val_global, self.device, self.args) + logging.info("metric_result_in_current_round = {}".format( + metric_result_in_current_round)) + metric_results_in_the_last_round = Context().get( + Context.KEY_METRICS_ON_AGGREGATED_MODEL) Context().add(Context.KEY_METRICS_ON_AGGREGATED_MODEL, metric_result_in_current_round) if metric_results_in_the_last_round is not None: Context().add(Context.KEY_METRICS_ON_LAST_ROUND, metric_results_in_the_last_round) else: Context().add(Context.KEY_METRICS_ON_LAST_ROUND, metric_result_in_current_round) key_metrics_on_last_round = Context().get(Context.KEY_METRICS_ON_LAST_ROUND) - logging.info("key_metrics_on_last_round = {}".format(key_metrics_on_last_round)) + logging.info("key_metrics_on_last_round = {}".format( + key_metrics_on_last_round)) if round_idx == self.args.comm_round - 1: mlops.log({"round_idx": round_idx}) else: mlops.log({"round_idx": round_idx}) - + def get_dummy_input_tensor(self): + """ + Get a dummy input tensor for testing purposes. + + This method retrieves a dummy input tensor from the test dataset. + + Returns: + list: A list of dummy input tensors. + + """ test_data = None if self.test_global: test_data = self.test_global @@ -210,18 +361,29 @@ def get_dummy_input_tensor(self): for k, v in self.test_data_local_dict.items(): if v: test_data = v - break - + break + with torch.no_grad(): - batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> dataloader obj + batch_idx, features_label_tensors = next( + enumerate(test_data)) # test_data -> dataloader obj dummy_list = [] for tensor in features_label_tensors: - dummy_tensor = tensor[:1] # only take the first element as dummy input + # only take the first element as dummy input + dummy_tensor = tensor[:1] dummy_list.append(dummy_tensor) features = dummy_list[:-1] # Can adapt Process Multi-Label return features def get_input_shape_type(self): + """ + Get the input shape and type information. + + This method retrieves the input shape and type information from the test dataset. + + Returns: + tuple: A tuple containing two lists - input shape and input type. + + """ test_data = None if self.test_global: test_data = self.test_global @@ -230,12 +392,14 @@ def get_input_shape_type(self): if v: test_data = v break - + with torch.no_grad(): - batch_idx, features_label_tensors = next(enumerate(test_data)) # test_data -> dataloader obj + batch_idx, features_label_tensors = next( + enumerate(test_data)) # test_data -> dataloader obj dummy_list = [] for tensor in features_label_tensors: - dummy_tensor = tensor[:1] # only take the first element as dummy input + # only take the first element as dummy input + dummy_tensor = tensor[:1] dummy_list.append(dummy_tensor) features = dummy_list[:-1] # Can adapt Multi-Label @@ -248,10 +412,19 @@ def get_input_shape_type(self): input_type.append("int") else: input_type.append("float") - + return input_shape, input_type - + def save_dummy_input_tensor(self): + """ + Save the dummy input tensor to a file. + + This method saves the input shape and type information to a file named 'dummy_input_tensor.pkl'. + + Returns: + None + + """ import pickle features = self.get_input_size_type() with open('dummy_input_tensor.pkl', 'wb') as handle: diff --git a/python/fedml/cross_silo/server/fedml_server_manager.py b/python/fedml/cross_silo/server/fedml_server_manager.py index bb6739edf0..8a0974a485 100644 --- a/python/fedml/cross_silo/server/fedml_server_manager.py +++ b/python/fedml/cross_silo/server/fedml_server_manager.py @@ -13,6 +13,28 @@ class FedMLServerManager(FedMLCommManager): + """ + Represents the server manager for federated learning. + + Args: + args: The configuration arguments. + aggregator: The aggregator for federated learning. + comm: The communication backend (default is None). + client_rank: The rank of the client (default is 0). + client_num: The number of clients (default is 0). + backend: The communication backend (default is "MQTT_S3"). + + Attributes: + ONLINE_STATUS_FLAG (str): Flag indicating online status. + RUN_FINISHED_STATUS_FLAG (str): Flag indicating run finished status. + + Methods: + is_main_process(): Check if the current process is the main process. + run(): Run the server manager. + send_init_msg(): Send initialization messages to clients. + register_message_receive_handlers(): Register message receive handlers for communication. + + """ ONLINE_STATUS_FLAG = "ONLINE" RUN_FINISHED_STATUS_FLAG = "FINISHED" @@ -35,12 +57,26 @@ def __init__( self.data_silo_index_list = None def is_main_process(self): + """ + Check if the current process is the main process. + + Returns: + bool: True if the current process is the main process, False otherwise. + """ return getattr(self.aggregator, "aggregator", None) is None or self.aggregator.aggregator.is_main_process() def run(self): super().run() def send_init_msg(self): + """ + Send initialization messages to clients. + + This method sends initialization messages to clients, including model parameters and configuration. + + Returns: + None + """ global_model_params = self.aggregator.get_global_model_params() global_model_url = None @@ -54,25 +90,37 @@ def send_init_msg(self): ) client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) try: # get input type and shape for inference dummy_input_tensor = self.aggregator.get_dummy_input_tensor() if not getattr(self.args, "skip_log_model_net", False): - model_net_url = mlops.log_training_model_net_info(self.aggregator.aggregator.model, dummy_input_tensor) + model_net_url = mlops.log_training_model_net_info( + self.aggregator.aggregator.model, dummy_input_tensor) # type and shape for later configuration input_shape, input_type = self.aggregator.get_input_shape_type() # Send output input size and type (saved as json) to s3, # and transfer when click "Create Model Card" - model_input_url = mlops.log_training_model_input_info(list(input_shape), list(input_type)) + model_input_url = mlops.log_training_model_input_info( + list(input_shape), list(input_type)) except Exception as e: - logging.info("Cannot get dummy input size or shape for model serving") + logging.info( + "Cannot get dummy input size or shape for model serving") def register_message_receive_handlers(self): + """ + Register message receive handlers for communication. + + This method registers message receive handlers for handling different types of messages. + + Returns: + None + """ logging.info("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready @@ -87,12 +135,24 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the message indicating that the client connection is ready. + + This method processes the message from a client indicating that its connection is ready for communication. + + Args: + msg_params (dict): The message parameters. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( self.args.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.args.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.args.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) mlops.log_round_info(self.round_num, -1) @@ -104,15 +164,30 @@ def handle_message_connection_ready(self, msg_params): self.send_message_check_client_status( client_id, self.data_silo_index_list[client_idx_in_this_round], ) - logging.info("Connection ready for client" + str(client_id)) + logging.info( + "Connection ready for client" + str(client_id)) except Exception as e: - logging.info("Connection not ready for client" + str(client_id)) + logging.info( + "Connection not ready for client" + str(client_id)) client_idx_in_this_round += 1 def process_online_status(self, client_status, msg_params): + """ + Process the online status message from a client. + + This method processes the online status message from a client and checks if all clients are online. + + Args: + client_status (str): The client status. + msg_params (dict): The message parameters. + + Returns: + None + """ self.client_online_mapping[str(msg_params.get_sender_id())] = True - logging.info("self.client_online_mapping = {}".format(self.client_online_mapping)) + logging.info("self.client_online_mapping = {}".format( + self.client_online_mapping)) all_client_is_online = True for client_id in self.client_id_list_in_this_round: @@ -121,17 +196,31 @@ def process_online_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: - mlops.log_aggregation_status(MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) + mlops.log_aggregation_status( + MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) # send initialization message to all clients to start training self.send_init_msg() self.is_initialized = True def process_finished_status(self, client_status, msg_params): + """ + Process the finished status message from a client. + + This method processes the finished status message from a client and checks if all clients have finished. + + Args: + client_status (str): The client status. + msg_params (dict): The message parameters. + + Returns: + None + """ self.client_finished_mapping[str(msg_params.get_sender_id())] = True all_client_is_finished = True @@ -141,7 +230,8 @@ def process_finished_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_finished = %s" % (msg_params.get_sender_id(), str(all_client_is_finished)) + "sender_id = %d, all_client_is_finished = %s" % ( + msg_params.get_sender_id(), str(all_client_is_finished)) ) if all_client_is_finished: @@ -152,6 +242,17 @@ def process_finished_status(self, client_status, msg_params): self.finish() def handle_message_client_status_update(self, msg_params): + """ + Handle the client status update message. + + This method processes the client status update message and takes appropriate actions based on the status. + + Args: + msg_params (dict): The message parameters. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) logging.info(f"received client status {client_status}") if client_status == FedMLServerManager.ONLINE_STATUS_FLAG: @@ -160,38 +261,59 @@ def handle_message_client_status_update(self, msg_params): self.process_finished_status(client_status, msg_params) def handle_message_receive_model_from_client(self, msg_params): + """ + Handle the message receiving the model from a client. + + This method handles the message that receives the model parameters from a client and performs aggregation. + + Args: + msg_params (dict): The message parameters. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) + mlops.event("comm_c2s", event_started=False, event_value=str( + self.args.round_idx), event_edge_id=sender_id) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) b_all_received = self.aggregator.check_whether_all_receive() logging.info("b_all_received = " + str(b_all_received)) if b_all_received: - mlops.event("server.wait", event_started=False, event_value=str(self.args.round_idx)) - mlops.event("server.agg_and_eval", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=False, + event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=True, + event_value=str(self.args.round_idx)) tick = time.time() global_model_params, model_list, model_list_idxes = self.aggregator.aggregate() - logging.info("self.client_id_list_in_this_round = {}".format(self.client_id_list_in_this_round)) + logging.info("self.client_id_list_in_this_round = {}".format( + self.client_id_list_in_this_round)) new_client_id_list_in_this_round = [] for client_idx in model_list_idxes: - new_client_id_list_in_this_round.append(self.client_id_list_in_this_round[client_idx]) - logging.info("new_client_id_list_in_this_round = {}".format(new_client_id_list_in_this_round)) - Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, new_client_id_list_in_this_round) + new_client_id_list_in_this_round.append( + self.client_id_list_in_this_round[client_idx]) + logging.info("new_client_id_list_in_this_round = {}".format( + new_client_id_list_in_this_round)) + Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, + new_client_id_list_in_this_round) if self.is_main_process(): - MLOpsProfilerEvent.log_to_wandb({"AggregationTime": time.time() - tick, "round": self.args.round_idx}) + MLOpsProfilerEvent.log_to_wandb( + {"AggregationTime": time.time() - tick, "round": self.args.round_idx}) self.aggregator.test_on_server_for_all_clients(self.args.round_idx) self.aggregator.assess_contribution() - mlops.event("server.agg_and_eval", event_started=False, event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=False, + event_value=str(self.args.round_idx)) # send round info to the MQTT backend mlops.log_round_info(self.round_num, self.args.round_idx) @@ -200,12 +322,15 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx, self.client_real_ids, self.args.client_num_per_round ) self.data_silo_index_list = self.aggregator.data_silo_selection( - self.args.round_idx, self.args.client_num_in_total, len(self.client_id_list_in_this_round), + self.args.round_idx, self.args.client_num_in_total, len( + self.client_id_list_in_this_round), ) - Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, self.client_id_list_in_this_round) + Context().add(Context.KEY_CLIENT_ID_LIST_IN_THIS_ROUND, + self.client_id_list_in_this_round) if self.args.round_idx == 0 and self.is_main_process(): - MLOpsProfilerEvent.log_to_wandb({"BenchmarkStart": time.time()}) + MLOpsProfilerEvent.log_to_wandb( + {"BenchmarkStart": time.time()}) client_idx_in_this_round = 0 global_model_url = None @@ -232,13 +357,24 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx += 1 if self.is_main_process(): - mlops.log_aggregated_model_info(self.args.round_idx, model_url=global_model_url) + mlops.log_aggregated_model_info( + self.args.round_idx, model_url=global_model_url) - logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) + logging.info( + "\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) if self.args.round_idx < self.round_num: - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) def cleanup(self): + """ + Send cleanup messages to clients. + + This method sends cleanup messages to all clients to signal the end of communication. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -248,73 +384,175 @@ def cleanup(self): def send_message_init_config(self, receive_id, global_model_params, datasilo_index, global_model_url=None, global_model_key=None): + """ + Send an initialization message with configuration to a client. + + This method sends an initialization message to a client containing configuration information and model parameters. + + Args: + receive_id (int): The receiver's ID. + global_model_params (dict): Global model parameters. + datasilo_index (int): The data silo index of the client. + global_model_url (str): The URL of the global model (optional). + global_model_key (str): The key of the global model (optional). + + Returns: + str: The URL of the global model. + str: The key of the global model. + """ if self.is_main_process(): tick = time.time() - message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id) + message = Message(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, + self.get_sender_id(), receive_id) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) - global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) - global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) - MLOpsProfilerEvent.log_to_wandb({"Communiaction/Send_Total": time.time() - tick}) + global_model_url = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) + global_model_key = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) + MLOpsProfilerEvent.log_to_wandb( + {"Communiaction/Send_Total": time.time() - tick}) return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a message to check the status of a client. + + This method sends a message to a client to check its status. + + Args: + receive_id (int): The receiver's ID. + datasilo_index (int): The data silo index of the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send a finish message to a client. + + This method sends a finish message to a client to signal the end of communication. + + Args: + receive_id (int): The receiver's ID. + datasilo_index (int): The data silo index of the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) logging.info( "finish from send id {} to receive id {}.".format(message.get_sender_id(), message.get_receiver_id())) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_sync_model_to_client(self, receive_id, global_model_params, client_index, global_model_url=None, global_model_key=None): + """ + Send a message to synchronize the global model to a client. + + This method sends a message to a client to synchronize the global model parameters. + + Args: + receive_id (int): The receiver's ID. + global_model_params (dict): Global model parameters. + client_index (int): The client index. + global_model_url (str): The URL of the global model (optional). + global_model_key (str): The key of the global model (optional). + + Returns: + str: The URL of the global model. + str: The key of the global model. + """ + if self.is_main_process(): tick = time.time() - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) - MLOpsProfilerEvent.log_to_wandb({"Communiaction/Send_Total": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Communiaction/Send_Total": time.time() - tick}) - global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) - global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) + global_model_url = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) + global_model_key = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) return global_model_url, global_model_key def send_message_diff_sync_model_to_client(self, receive_id, client_model_params, client_index): + """ + Send a message to synchronize a different global model to a client. + + This method sends a message to a client to synchronize a different global model parameters. + Unlike `send_message_sync_model_to_client`, this method does not synchronize the global model for all clients, + but rather sends a specific client's model. + + Args: + receive_id (int): The receiver's ID. + client_model_params (dict): The client's model parameters. + client_index (int): The client index. + + Returns: + str: The URL of the global model. + str: The key of the global model. + """ global_model_url = None global_model_key = None if self.is_main_process(): tick = time.time() - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) - message = Message(MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, client_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) + message = Message( + MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, client_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "PythonClient") self.send_message(message) - MLOpsProfilerEvent.log_to_wandb({"Communiaction/Send_Total": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Communiaction/Send_Total": time.time() - tick}) - global_model_url = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) - global_model_key = message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) + global_model_url = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL) + global_model_key = message.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY) return global_model_url, global_model_key diff --git a/python/fedml/cross_silo/server/server_initializer.py b/python/fedml/cross_silo/server/server_initializer.py index 5877d96fea..f402be918a 100644 --- a/python/fedml/cross_silo/server/server_initializer.py +++ b/python/fedml/cross_silo/server/server_initializer.py @@ -18,6 +18,30 @@ def init_server( train_data_local_num_dict, server_aggregator, ): + """ + Initialize the server for federated learning. + + This function sets up the server for federated learning, including creating an aggregator, + starting distributed training, and running the server manager. + + Args: + args (argparse.Namespace): Command-line arguments and configurations. + device (torch.device): The device on which the server runs. + comm (Communicator): The communication backend. + rank (int): The rank of the server in the distributed environment. + worker_num (int): The number of worker nodes participating in federated learning. + model (torch.nn.Module): The model used for federated learning. + train_data_num (int): The number of training data points globally. + train_data_global (Dataset): The global training dataset. + test_data_global (Dataset): The global test dataset. + train_data_local_dict (dict): A dictionary of local training datasets for each client. + test_data_local_dict (dict): A dictionary of local test datasets for each client. + train_data_local_num_dict (dict): A dictionary of the number of local training data points for each client. + server_aggregator (ServerAggregator, optional): The server aggregator. If not provided, it will be created. + + Returns: + None + """ if server_aggregator is None: server_aggregator = create_server_aggregator(model, args) server_aggregator.set_id(0) @@ -38,5 +62,6 @@ def init_server( # start the distributed training backend = args.backend - server_manager = FedMLServerManager(args, aggregator, comm, rank, worker_num, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, rank, worker_num, backend) server_manager.run() From 96e311842f5c66ac191f9239bd52a5e2a655b2f8 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 20 Sep 2023 19:13:00 +0530 Subject: [PATCH 64/70] add --- .../core/security/defense/RFA_defense.py | 60 +++++- .../core/security/defense/bulyan_defense.py | 71 +++++++ .../core/security/defense/cclip_defense.py | 73 ++++++- .../defense/coordinate_wise_median_defense.py | 28 ++- .../coordinate_wise_trimmed_mean_defense.py | 34 +++- .../core/security/defense/crfl_defense.py | 63 +++++- .../security/defense/cross_round_defense.py | 62 ++++++ .../core/security/defense/defense_base.py | 34 +++- .../security/defense/foolsgold_defense.py | 42 ++++ .../defense/geometric_median_defense.py | 32 ++- .../core/security/defense/krum_defense.py | 25 +++ .../defense/norm_diff_clipping_defense.py | 50 ++++- .../security/defense/outlier_detection.py | 34 +++- .../residual_based_reweighting_defense.py | 114 ++++++++++- .../defense/robust_learning_rate_defense.py | 70 ++++++- .../core/security/defense/slsgd_defense.py | 63 ++++++ .../core/security/defense/soteria_defense.py | 47 +++++ .../security/defense/three_sigma_defense.py | 86 +++++++- .../defense/three_sigma_geomedian_defense.py | 107 +++++++++- .../defense/three_sigma_krum_defense.py | 133 +++++++++++++ .../core/security/defense/wbc_defense.py | 52 ++++- .../core/security/defense/weak_dp_defense.py | 44 ++++ python/fedml/core/security/fedml_attacker.py | 129 +++++++++++- python/fedml/core/security/fedml_defender.py | 144 ++++++++++++++ python/fedml/cross_device/mnn_server.py | 39 +++- .../cross_silo/client/client_initializer.py | 73 ++++++- .../cross_silo/client/client_launcher.py | 70 ++++++- .../client/fedml_client_master_manager.py | 188 ++++++++++++++++-- .../client/fedml_client_slave_manager.py | 56 ++++++ .../fedml/cross_silo/client/fedml_trainer.py | 122 +++++++++++- .../client/fedml_trainer_dist_adapter.py | 89 ++++++++- .../client/process_group_manager.py | 35 +++- python/fedml/cross_silo/client/utils.py | 51 ++++- 33 files changed, 2189 insertions(+), 131 deletions(-) diff --git a/python/fedml/core/security/defense/RFA_defense.py b/python/fedml/core/security/defense/RFA_defense.py index ceedcf6b65..1bba3a6809 100644 --- a/python/fedml/core/security/defense/RFA_defense.py +++ b/python/fedml/core/security/defense/RFA_defense.py @@ -12,15 +12,65 @@ class RFADefense(BaseDefenseMethod): - def __init__(self, config): - pass + """ + Robust Aggregation for Federated Learning (RFA) Defense. - def defend_on_aggregation( - self, + This defense method computes a geometric median in aggregation. + + Args: + config: Configuration parameters (currently unused). + + Attributes: + None + + Methods: + defend_on_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, - ): + ) -> OrderedDict: + Defend against potential adversarial behavior during aggregation. + + References: + - "RFA: Robust Aggregation for Federated Learning." + https://arxiv.org/pdf/1912.13445.pdf + """ + + def __init__(self, config): + """ + Initialize the RFADefense. + + Args: + config: Configuration parameters (currently unused). + """ + pass + + def defend_on_aggregation( + self, + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + """ + Defend against potential adversarial behavior during aggregation. + + This method computes a geometric median aggregation of client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + OrderedDict: + Aggregated parameters after applying the defense. + + Notes: + This defense method computes a geometric median aggregation of client gradients. + """ (num0, avg_params) = raw_client_grad_list[0] weights = {num for (num, params) in raw_client_grad_list} weights = {weight / sum(weights, 0.0) for weight in weights} diff --git a/python/fedml/core/security/defense/bulyan_defense.py b/python/fedml/core/security/defense/bulyan_defense.py index cea55960ae..253488934d 100644 --- a/python/fedml/core/security/defense/bulyan_defense.py +++ b/python/fedml/core/security/defense/bulyan_defense.py @@ -21,6 +21,21 @@ class BulyanDefense(BaseDefenseMethod): + """ + Bulyan Defense for Federated Learning. + + Bulyan Defense is a defense method for federated learning that aims to mitigate the impact of Byzantine clients + by selecting a subset of clients' gradients for aggregation. + + Args: + config: Configuration parameters for the defense. + - byzantine_client_num (int): The number of Byzantine (malicious) clients. + - client_num_per_round (int): The total number of clients participating in each aggregation round. + + Attributes: + byzantine_client_num (int): The number of Byzantine (malicious) clients. + client_num_per_round (int): The total number of clients participating in each aggregation round. + """ def __init__(self, config): self.byzantine_client_num = config.byzantine_client_num self.client_num_per_round = config.client_num_per_round @@ -37,6 +52,18 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> OrderedDict: + """ + Run the Bulyan Defense to aggregate gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + base_aggregation_func (Callable, optional): The base aggregation function to use. Default is None. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + OrderedDict: The aggregated gradients after applying the Bulyan Defense. + """ # note: raw_client_grad_list is a list, each item is (sample_num, gradients). num_clients = len(raw_client_grad_list) (num0, localw0) = raw_client_grad_list[0] @@ -70,6 +97,18 @@ def run( return aggregated_params def _bulyan(self, users_params, users_count, corrupted_count): + """ + Perform the Bulyan aggregation. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + users_count (int): The total number of users. + corrupted_count (int): The number of corrupted (Byzantine) users. + + Returns: + Tuple[List[int], List[numpy.ndarray], numpy.ndarray]: A tuple containing the selected indices, + selected set of gradients, and the aggregated gradients. + """ assert users_count >= 4 * corrupted_count + 3 set_size = users_count - 2 * corrupted_count selection_set = [] @@ -98,6 +137,16 @@ def _bulyan(self, users_params, users_count, corrupted_count): @staticmethod def trimmed_mean(users_params, corrupted_count): + """ + Compute the trimmed mean of users' gradients. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + corrupted_count (int): The number of corrupted (Byzantine) users. + + Returns: + numpy.ndarray: The trimmed mean of gradients. + """ users_params = np.array(users_params) number_to_consider = int(users_params.shape[0] - corrupted_count) - 1 @@ -120,6 +169,19 @@ def _krum( distances=None, return_index=False, ): + """ + Perform the Krum selection. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + users_count (int): The total number of users. + corrupted_count (int): The number of corrupted (Byzantine) users. + distances (dict, optional): Precomputed distances between users. Default is None. + return_index (bool, optional): Whether to return the selected index. Default is False. + + Returns: + numpy.ndarray or int: The selected gradients or index. + """ non_malicious_count = users_count - corrupted_count minimal_error = 1e20 @@ -141,6 +203,15 @@ def _krum( @staticmethod def _krum_create_distances(users_params): + """ + Create pairwise distances between users' gradients. + + Args: + users_params (numpy.ndarray): Gradients of users' parameters. + + Returns: + dict: A dictionary containing pairwise distances between users' gradients. + """ distances = defaultdict(dict) for i in range(len(users_params)): for j in range(i): diff --git a/python/fedml/core/security/defense/cclip_defense.py b/python/fedml/core/security/defense/cclip_defense.py index eba983cb48..63e232ba8e 100755 --- a/python/fedml/core/security/defense/cclip_defense.py +++ b/python/fedml/core/security/defense/cclip_defense.py @@ -13,10 +13,29 @@ class CClipDefense(BaseDefenseMethod): + """ + CClip Defense for Federated Learning. + + CClip (Coordinate-wise Clipping) Defense is a defense method for federated learning that clips gradients at each + coordinate to mitigate the impact of Byzantine clients. + + Args: + config: Configuration parameters for the defense. + - tau (float, optional): The clipping radius. Default is 10. + - bucket_size (int, optional): The number of elements in each bucket when partitioning gradients. + Default is None. + + Attributes: + tau (float): The clipping radius. + bucket_size (int): The number of elements in each bucket when partitioning gradients. + initial_guess (OrderedDict): The initial guess for the global model. + """ + def __init__(self, config): self.config = config if hasattr(config, "tau") and type(config.tau) in [int, float] and config.tau > 0: - self.tau = config.tau # clipping raduis; tau = 10 / (1-beta), beta is the coefficient of momentum + # clipping raduis; tau = 10 / (1-beta), beta is the coefficient of momentum + self.tau = config.tau else: self.tau = 10 # default: no momentum, beta = 0 # element # in each bucket; a grad_list is partitioned into floor(len(grad_list)/bucket_size) buckets @@ -28,10 +47,23 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply CClip Defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + List[Tuple[float, OrderedDict]]: The modified gradients after applying CClip Defense. + """ + client_grad_buckets = Bucket.bucketization( raw_client_grad_list, self.bucket_size ) - self.initial_guess = self._compute_an_initial_guess(client_grad_buckets) + self.initial_guess = self._compute_an_initial_guess( + client_grad_buckets) bucket_num = len(client_grad_buckets) vec_local_w = [ ( @@ -47,25 +79,58 @@ def defend_before_aggregation( tuple = OrderedDict() sample_num, bucket_params = client_grad_buckets[i] for k in bucket_params.keys(): - tuple[k] = (bucket_params[k] - self.initial_guess[k]) * cclip_score[i] + tuple[k] = (bucket_params[k] - + self.initial_guess[k]) * cclip_score[i] new_grad_list.append((sample_num, tuple)) return new_grad_list def defend_after_aggregation(self, global_model): + """ + Apply CClip Defense after aggregation. + + Args: + global_model (OrderedDict): The global model after aggregation. + + Returns: + OrderedDict: The modified global model after applying CClip Defense. + """ + for k in global_model.keys(): global_model[k] = self.initial_guess[k] + global_model[k] return global_model @staticmethod def _compute_an_initial_guess(client_grad_list): + """ + Compute an initial guess for the global model. + + Args: + client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + + Returns: + OrderedDict: The initial guess for the global model. + """ # randomly select a gradient as the initial guess return client_grad_list[np.random.randint(0, len(client_grad_list))][1] def _compute_cclip_score(self, local_w, refs): + """ + Compute the CClip score for each local gradient. + + Args: + local_w (List[Tuple[float, numpy.ndarray]]): A list of tuples containing the number of samples and + vectorized local gradients. + refs (numpy.ndarray): Vectorized reference gradient. + + Returns: + List[float]: A list of CClip scores for each local gradient. + """ cclip_score = [] num_client = len(local_w) for i in range(0, num_client): - dist = utils.compute_euclidean_distance(local_w[i][1], refs).item() + 1e-8 + dist = utils.compute_euclidean_distance( + local_w[i][1], refs).item() + 1e-8 score = min(1, self.tau / dist) cclip_score.append(score) return cclip_score diff --git a/python/fedml/core/security/defense/coordinate_wise_median_defense.py b/python/fedml/core/security/defense/coordinate_wise_median_defense.py index 30357a1d67..6412b55cdc 100644 --- a/python/fedml/core/security/defense/coordinate_wise_median_defense.py +++ b/python/fedml/core/security/defense/coordinate_wise_median_defense.py @@ -12,6 +12,19 @@ class CoordinateWiseMedianDefense(BaseDefenseMethod): + """ + Coordinate-wise Median Defense for Federated Learning. + + Coordinate-wise Median Defense is a defense method for federated learning that computes the median of the gradients + for each coordinate to mitigate the impact of Byzantine clients. + + Args: + config: Configuration parameters for the defense. (Currently, no specific parameters are required.) + + Attributes: + None + """ + def __init__(self, config): pass @@ -21,6 +34,18 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Apply Coordinate-wise Median Defense on aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + base_aggregation_func (Callable, optional): The base aggregation function. Default is None. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + OrderedDict: The aggregated global model after applying Coordinate-wise Median Defense. + """ vectorized_params = [] for i in range(0, len(raw_client_grad_list)): @@ -35,11 +60,10 @@ def defend_on_aggregation( index = 0 (num0, averaged_params) = raw_client_grad_list[0] for k, params in averaged_params.items(): - median_params = vec_median_params[index : index + params.numel()].view( + median_params = vec_median_params[index: index + params.numel()].view( params.size() ) index += params.numel() averaged_params[k] = median_params return averaged_params - diff --git a/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py b/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py index 1a717946d7..6bbc4f97bf 100644 --- a/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py +++ b/python/fedml/core/security/defense/coordinate_wise_trimmed_mean_defense.py @@ -12,15 +12,45 @@ class CoordinateWiseTrimmedMeanDefense(BaseDefenseMethod): + """ + Coordinate-wise Trimmed Mean Defense for Federated Learning. + + Coordinate-wise Trimmed Mean Defense is a defense method for federated learning that computes the trimmed mean of + gradients for each coordinate to mitigate the impact of Byzantine clients. + + Args: + config: Configuration parameters for the defense, including 'beta' which represents the fraction of trimmed + values; total trimmed values: client_num * beta * 2. + + Attributes: + beta (float): The fraction of trimmed values, which determines the number of gradients to be trimmed on each side. + """ + def __init__(self, config): - self.beta = config.beta # fraction of trimmed values; total trimmed values: client_num * beta * 2 + """ + Initialize the CoordinateWiseTrimmedMeanDefense with the specified configuration. + Args: + config: Configuration parameters for the defense. + """ + self.beta = config.beta def defend_before_aggregation( self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply Coordinate-wise Trimmed Mean Defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples containing the number of samples + and gradients for each client. + extra_auxiliary_info (Any, optional): Additional auxiliary information. Default is None. + + Returns: + OrderedDict: The aggregated global model after applying Coordinate-wise Trimmed Mean Defense. + """ if self.beta > 1 / 2 or self.beta < 0: - raise ValueError("the bound of beta is [0, 1/2)") + raise ValueError("The bound of 'beta' is [0, 1/2)") return trimmed_mean(raw_client_grad_list, int(self.beta * len(raw_client_grad_list))) diff --git a/python/fedml/core/security/defense/crfl_defense.py b/python/fedml/core/security/defense/crfl_defense.py index 13712da1f3..bff5d53691 100644 --- a/python/fedml/core/security/defense/crfl_defense.py +++ b/python/fedml/core/security/defense/crfl_defense.py @@ -9,8 +9,38 @@ """ +from .base_defense_method import BaseDefenseMethod +from .utils import compute_model_norm +from .gaussian import compute_noise_using_sigma +from collections import OrderedDict + + class CRFLDefense(BaseDefenseMethod): + """ + CRFL (Clip and Randomly Flip) Defense for Federated Learning. + + CRFL Defense is a defense method for federated learning that clips the global model's weights if they exceed a + dynamic threshold and adds Gaussian noise to the clipped weights to improve privacy. + + Args: + config: Configuration parameters for the defense, including 'clip_threshold' (optional), 'sigma', 'comm_round', + and 'dataset'. + + Attributes: + epoch (int): The current training epoch. + user_defined_clip_threshold (float, optional): A user-defined clipping threshold for model weights. + sigma (float): The standard deviation of Gaussian noise added to clipped weights. + total_ite_num (int): The total number of communication rounds. + dataset_param_function (function): A function to compute the dynamic clipping threshold based on the dataset. + """ + def __init__(self, config): + """ + Initialize the CRFLDefense with the specified configuration. + + Args: + config: Configuration parameters for the defense. + """ self.config = config self.epoch = 1 if hasattr(config, "clip_threshold"): @@ -20,7 +50,7 @@ def __init__(self, config): if hasattr(config, "sigma") and isinstance(config.sigma, float): self.sigma = config.sigma else: - self.sigma = 0.01 # in the code of CRFL, the author set sigma to 0.01 + self.sigma = 0.01 # Default sigma value as used in CRFL code self.total_ite_num = config.comm_round if config.dataset == "mnist": self.dataset_param_function = self._crfl_compute_param_for_mnist @@ -31,15 +61,18 @@ def __init__(self, config): elif self.user_defined_clip_threshold is not None: self.dataset_param_function = self._crfl_self_defined_dataset_param else: - raise Exception(f"dataset not supported: {config.dataset} and clip_threshold not defined ") + raise Exception( + f"Dataset not supported: {config.dataset} and clip_threshold not defined.") def defend_after_aggregation(self, global_model): """ - clip the global model; dynamic threshold is adjusted according to the dataset; - in the experiment, the authors set the dynamic threshold as follows: - dataset == MNIST: dynamic_thres = epoch * 0.1 + 2 - dataseet == LOAN: dynamic_thres = epoch * 0.025 + 2 - datset == EMNIST: dynamic_thres = epoch * 0.25 + 4 + Apply CRFL Defense after model aggregation. + + Args: + global_model (OrderedDict): The global model to be defended. + + Returns: + OrderedDict: The defended global model after clipping and adding Gaussian noise. """ clip_threshold = self.dataset_param_function() if self.user_defined_clip_threshold is not None and self.user_defined_clip_threshold < clip_threshold: @@ -51,7 +84,8 @@ def defend_after_aggregation(self, global_model): self.epoch += 1 new_global_model = OrderedDict() for k in global_model.keys(): - new_global_model[k] = global_model[k] + Gaussian.compute_noise_using_sigma(self.sigma, global_model[k].shape) + new_global_model[k] = global_model[k] + \ + compute_noise_using_sigma(self.sigma, global_model[k].shape) return new_global_model def _crfl_self_defined_dataset_param(self): @@ -68,8 +102,17 @@ def _crfl_compute_param_for_emnist(self): @staticmethod def clip_weight_norm(model, clip_threshold): - total_norm = utils.compute_model_norm(model) - print(f"total_norm = {total_norm}") + """ + Clip the weight norm of the model. + + Args: + model (OrderedDict): The model whose weights are to be clipped. + clip_threshold (float): The threshold value for clipping. + + Returns: + OrderedDict: The model with clipped weights. + """ + total_norm = compute_model_norm(model) if total_norm > clip_threshold: clip_coef = clip_threshold / (total_norm + 1e-6) new_model = OrderedDict() diff --git a/python/fedml/core/security/defense/cross_round_defense.py b/python/fedml/core/security/defense/cross_round_defense.py index 0d8eb34bd5..ae8174563a 100644 --- a/python/fedml/core/security/defense/cross_round_defense.py +++ b/python/fedml/core/security/defense/cross_round_defense.py @@ -13,6 +13,24 @@ # too much difference: malicious, need further defense # todo: pretraining round? class CrossRoundDefense(BaseDefenseMethod): + """ + CrossRoundDefense for Federated Learning. + + CrossRoundDefense is a defense method for federated learning that detects potentially poisoned workers + based on cosine similarity between client and global model features across training rounds. + + Args: + config: Configuration parameters for the defense, including 'upperbound' and 'lowerbound'. + + Attributes: + potentially_poisoned_worker_list (list): List of potentially poisoned worker indices. + lazy_worker_list (list): List of lazy worker indices. + upperbound (float): Threshold for detecting potential attacks. + lowerbound (float): Threshold for defining "very limited difference." + client_cache (list): Cache of client features for comparison across training rounds. + training_round (int): The current training round. + is_attack_existing (bool): Flag indicating whether an attack exists in the current round. + """ def __init__(self, config): self.potentially_poisoned_worker_list = [] self.lazy_worker_list = None @@ -28,6 +46,16 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply CrossRoundDefense before model aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + extra_auxiliary_info: Global model or auxiliary information. + + Returns: + list: List of potentially poisoned client gradients. + """ self.is_attack_existing = False client_features = self._get_importance_feature(raw_client_grad_list) if self.training_round == 1: @@ -71,9 +99,25 @@ def defend_before_aggregation( return raw_client_grad_list def get_potential_poisoned_clients(self): + """ + Get the list of potentially poisoned client indices. + + Returns: + list: List of potentially poisoned client indices. + """ return self.potentially_poisoned_worker_list def compute_client_cosine_scores(self, client_features, global_model_feature): + """ + Compute cosine similarity scores between client features and global model features. + + Args: + client_features (list): List of client feature vectors. + global_model_feature (list): Feature vector of the global model. + + Returns: + tuple: Two lists of cosine similarity scores for each client (client-wise and global-wise). + """ client_wise_scores = [] global_wise_scores = [] num_client = len(client_features) @@ -85,6 +129,15 @@ def compute_client_cosine_scores(self, client_features, global_model_feature): return client_wise_scores, global_wise_scores def _get_importance_feature(self, raw_client_grad_list): + """ + Extract importance features from client gradients. + + Args: + raw_client_grad_list (list): List of client gradients. + + Returns: + list: List of extracted importance feature vectors. + """ ret_feature_vector_list = [] for idx in range(len(raw_client_grad_list)): raw_grad = raw_client_grad_list[idx] @@ -96,6 +149,15 @@ def _get_importance_feature(self, raw_client_grad_list): @classmethod def _get_importance_feature_of_a_model(self, grad): + """ + Extract importance feature from a client gradient. + + Args: + grad (OrderedDict): Client gradient. + + Returns: + numpy.ndarray: Importance feature vector. + """ # Get last key-value tuple (weight_name, importance_feature) = list(grad.items())[-2] # print(importance_feature) diff --git a/python/fedml/core/security/defense/defense_base.py b/python/fedml/core/security/defense/defense_base.py index 4abc3bbecf..77a1adbaa9 100644 --- a/python/fedml/core/security/defense/defense_base.py +++ b/python/fedml/core/security/defense/defense_base.py @@ -4,8 +4,20 @@ class BaseDefenseMethod(ABC): + """ + Base class for defense methods in Federated Learning. + + Attributes: + config: Configuration parameters for the defense method. + """ @abstractmethod def __init__(self, config): + """ + Initialize the defense method with the specified configuration. + + Args: + config: Configuration parameters for the defense method. + """ pass def defend_before_aggregation( @@ -14,12 +26,14 @@ def defend_before_aggregation( extra_auxiliary_info: Any = None, ) -> List[Tuple[float, OrderedDict]]: """ - args: - client_grad_list: client_grad_list is a list, each item is (sample_num, gradients) - extra_auxiliary_info: for methods which need extra info (e.g., data, previous model/gradient), - please use this variable. - return: - Note: the data type of the return variable should be the same as the input + Apply defense before model aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + extra_auxiliary_info: Additional information required for defense. + + Returns: + list: List of defended client gradients. """ pass @@ -41,4 +55,10 @@ def defend_on_aggregation( pass def get_malicious_client_idxs(self): - return [] \ No newline at end of file + """ + Get the indices of potentially malicious clients. + + Returns: + list: List of indices of potentially malicious clients. + """ + return [] diff --git a/python/fedml/core/security/defense/foolsgold_defense.py b/python/fedml/core/security/defense/foolsgold_defense.py index 5db59eecad..4814637d96 100644 --- a/python/fedml/core/security/defense/foolsgold_defense.py +++ b/python/fedml/core/security/defense/foolsgold_defense.py @@ -12,7 +12,21 @@ class FoolsGoldDefense(BaseDefenseMethod): + """ + Defense method using FoolsGold for federated learning. + + Attributes: + config: Configuration parameters for the defense method. + memory: Memory for storing client importance features. + """ + def __init__(self, config): + """ + Initialize the FoolsGoldDefense. + + Args: + config: Configuration parameters for the defense method. + """ super().__init__(config) self.config = config self.memory = None @@ -22,6 +36,16 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply FoolsGold defense before model aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + extra_auxiliary_info: Additional information required for defense. + + Returns: + list: List of defended client gradients. + """ client_num = len(raw_client_grad_list) importance_feature_list = self._get_importance_feature(raw_client_grad_list) # print(len(importance_feature_list)) @@ -47,6 +71,15 @@ def defend_before_aggregation( # Takes in grad, compute similarity, get weightings @classmethod def fools_gold_score(cls, feature_vec_list): + """ + Compute FoolsGold scores for client importance features. + + Args: + feature_vec_list (list): List of client importance features. + + Returns: + list: List of FoolsGold scores. + """ import sklearn.metrics.pairwise as smp n_clients = len(feature_vec_list) cs = smp.cosine_similarity(feature_vec_list) - np.eye(n_clients) @@ -75,6 +108,15 @@ def fools_gold_score(cls, feature_vec_list): return alpha def _get_importance_feature(self, raw_client_grad_list): + """ + Get the importance feature from client gradients. + + Args: + raw_client_grad_list (list): List of client gradients. + + Returns: + list: List of importance features. + """ # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] for idx in range(len(raw_client_grad_list)): diff --git a/python/fedml/core/security/defense/geometric_median_defense.py b/python/fedml/core/security/defense/geometric_median_defense.py index adf60edaa2..edd2dc733d 100644 --- a/python/fedml/core/security/defense/geometric_median_defense.py +++ b/python/fedml/core/security/defense/geometric_median_defense.py @@ -20,7 +20,23 @@ class GeometricMedianDefense(BaseDefenseMethod): + """ + Defense method using Geometric Median for federated learning. + + Attributes: + byzantine_client_num: Number of Byzantine clients in the system. + client_num_per_round: Number of clients participating in each round. + batch_num: Number of batches used for geometric median computation. + batch_size: Size of each batch for gradient aggregation. + """ + def __init__(self, config): + """ + Initialize the GeometricMedianDefense. + + Args: + config: Configuration parameters for the defense method. + """ self.byzantine_client_num = config.byzantine_client_num self.client_num_per_round = config.client_num_per_round # 2(1 + ε )q ≤ batch_num ≤ client_num_per_round @@ -37,7 +53,19 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): - batch_grad_list = Bucket.bucketization(raw_client_grad_list, self.batch_size) + """ + Apply Geometric Median defense on gradient aggregation. + + Args: + raw_client_grad_list (list): List of client gradients for the current round. + base_aggregation_func (Callable): Base aggregation function to use (optional). + extra_auxiliary_info: Additional information required for defense (optional). + + Returns: + OrderedDict: Aggregated global model parameters. + """ + batch_grad_list = Bucket.bucketization( + raw_client_grad_list, self.batch_size) (num0, avg_params) = batch_grad_list[0] alphas = {alpha for (alpha, params) in batch_grad_list} alphas = {alpha / sum(alphas, 0.0) for alpha in alphas} @@ -45,5 +73,3 @@ def defend_on_aggregation( batch_grads = [params[k] for (alpha, params) in batch_grad_list] avg_params[k] = compute_geometric_median(alphas, batch_grads) return avg_params - - diff --git a/python/fedml/core/security/defense/krum_defense.py b/python/fedml/core/security/defense/krum_defense.py index 5201cc8c09..19d8af73cf 100755 --- a/python/fedml/core/security/defense/krum_defense.py +++ b/python/fedml/core/security/defense/krum_defense.py @@ -16,6 +16,12 @@ class KrumDefense(BaseDefenseMethod): def __init__(self, config): + """ + Initialize the KrumDefense method. + + Args: + config (object): Configuration object containing defense parameters. + """ self.config = config self.byzantine_client_num = config.byzantine_client_num @@ -29,6 +35,16 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation using the KrumDefense method. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional information (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ num_client = len(raw_client_grad_list) # in the Krum paper, it says 2 * byzantine_client_num + 2 < client # if not 2 * self.byzantine_client_num + 2 <= num_client - self.krum_param_m: @@ -48,6 +64,15 @@ def defend_before_aggregation( return [raw_client_grad_list[i] for i in score_index] def _compute_krum_score(self, vec_grad_list): + """ + Compute Krum scores for the given list of gradient vectors. + + Args: + vec_grad_list (List[torch.Tensor]): List of gradient vectors. + + Returns: + List[float]: List of Krum scores. + """ krum_scores = [] num_client = len(vec_grad_list) for i in range(0, num_client): diff --git a/python/fedml/core/security/defense/norm_diff_clipping_defense.py b/python/fedml/core/security/defense/norm_diff_clipping_defense.py index a01306e7f6..bbb064478b 100644 --- a/python/fedml/core/security/defense/norm_diff_clipping_defense.py +++ b/python/fedml/core/security/defense/norm_diff_clipping_defense.py @@ -13,20 +13,38 @@ class NormDiffClippingDefense(BaseDefenseMethod): def __init__(self, config): + """ + Initialize the NormDiffClippingDefense method. + + Args: + config (object): Configuration object containing defense parameters. + """ self.config = config - self.norm_bound = config.norm_bound # for norm diff clipping; in the paper, they set it to 0.1, 0.17, and 0.33. + # for norm diff clipping; in the paper, they set it to 0.1, 0.17, and 0.33. + self.norm_bound = config.norm_bound def defend_before_aggregation( self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation using norm difference clipping. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Global model for clipping (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ global_model = extra_auxiliary_info vec_global_w = utils.vectorize_weight(global_model) new_grad_list = [] for (sample_num, local_w) in raw_client_grad_list: vec_local_w = utils.vectorize_weight(local_w) - clipped_weight_diff = self._get_clipped_norm_diff(vec_local_w, vec_global_w) + clipped_weight_diff = self._get_clipped_norm_diff( + vec_local_w, vec_global_w) clipped_w = self._get_clipped_weights( local_w, global_model, clipped_weight_diff ) @@ -34,20 +52,44 @@ def defend_before_aggregation( return new_grad_list def _get_clipped_norm_diff(self, vec_local_w, vec_global_w): + """ + Compute the clipped norm difference between local and global weights. + + Args: + vec_local_w (torch.Tensor): Vectorized local weights. + vec_global_w (torch.Tensor): Vectorized global weights. + + Returns: + torch.Tensor: Clipped weight difference. + """ vec_diff = vec_local_w - vec_global_w weight_diff_norm = torch.norm(vec_diff).item() - clipped_weight_diff = vec_diff / max(1, weight_diff_norm / self.norm_bound) + clipped_weight_diff = vec_diff / \ + max(1, weight_diff_norm / self.norm_bound) return clipped_weight_diff @staticmethod def _get_clipped_weights(local_w, global_w, weight_diff): + """ + Compute clipped weights based on global and local weights. + + Args: + local_w (OrderedDict): Local model weights. + global_w (OrderedDict): Global model weights. + weight_diff (torch.Tensor): Clipped weight difference. + + Returns: + OrderedDict: Clipped local model weights. + """ + # rule: global_w + clipped(local_w - global_w) recons_local_w = OrderedDict() index_bias = 0 for item_index, (k, v) in enumerate(local_w.items()): if utils.is_weight_param(k): recons_local_w[k] = ( - weight_diff[index_bias: index_bias + v.numel()].view(v.size()) + weight_diff[index_bias: index_bias + + v.numel()].view(v.size()) + global_w[k] ) index_bias += v.numel() diff --git a/python/fedml/core/security/defense/outlier_detection.py b/python/fedml/core/security/defense/outlier_detection.py index e24f6c594f..793d4ffd11 100644 --- a/python/fedml/core/security/defense/outlier_detection.py +++ b/python/fedml/core/security/defense/outlier_detection.py @@ -6,7 +6,14 @@ class OutlierDetection(BaseDefenseMethod): + def __init__(self, config): + """ + Initialize the OutlierDetection method. + + Args: + config (object): Configuration object containing defense parameters. + """ self.cross_round_check = CrossRoundDefense(config) self.three_sigma_check = ThreeSigmaKrumDefense(config) @@ -15,11 +22,30 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): - raw_client_grad_list = self.cross_round_check.defend_before_aggregation(raw_client_grad_list, extra_auxiliary_info) + """ + Perform outlier detection defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional information (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ + raw_client_grad_list = self.cross_round_check.defend_before_aggregation( + raw_client_grad_list, extra_auxiliary_info) if self.cross_round_check.is_attack_existing: - self.three_sigma_check.set_potential_malicious_clients(self.cross_round_check.get_potential_poisoned_clients()) - raw_client_grad_list = self.three_sigma_check.defend_before_aggregation(raw_client_grad_list, extra_auxiliary_info) + self.three_sigma_check.set_potential_malicious_clients( + self.cross_round_check.get_potential_poisoned_clients()) + raw_client_grad_list = self.three_sigma_check.defend_before_aggregation( + raw_client_grad_list, extra_auxiliary_info) return raw_client_grad_list def get_malicious_client_idxs(self): - return self.three_sigma_check.get_malicious_client_idxs() \ No newline at end of file + """ + Get the indices of potential malicious clients. + + Returns: + List[int]: List of indices of potential malicious clients. + """ + return self.three_sigma_check.get_malicious_client_idxs() diff --git a/python/fedml/core/security/defense/residual_based_reweighting_defense.py b/python/fedml/core/security/defense/residual_based_reweighting_defense.py index 32c1c07b14..f71fabb977 100644 --- a/python/fedml/core/security/defense/residual_based_reweighting_defense.py +++ b/python/fedml/core/security/defense/residual_based_reweighting_defense.py @@ -16,6 +16,12 @@ class ResidualBasedReweightingDefense(BaseDefenseMethod): def __init__(self, config): + """ + Initialize the ResidualBasedReweightingDefense method. + + Args: + config (object): Configuration object containing defense parameters. + """ if hasattr(config, "lambda_param"): self.lambda_param = config.lambda_param else: @@ -31,16 +37,36 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation using residual-based reweighting. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional information (optional). + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ return self.IRLS_other_split_restricted(raw_client_grad_list) def IRLS_other_split_restricted(self, raw_client_grad_list): + """ + Perform the Iteratively Reweighted Least Squares (IRLS) defense with restricted mode. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + + Returns: + List[Tuple[float, OrderedDict]]: List of defended client gradients. + """ reweight_algorithm = median_reweight_algorithm_restricted if self.mode == "median": reweight_algorithm = median_reweight_algorithm_restricted elif self.mode == "theilsen": reweight_algorithm = theilsen_reweight_algorithm_restricted elif self.mode == "gaussian": - reweight_algorithm = gaussian_reweight_algorithm_restricted # in gaussian reweight algorithm, lambda is sigma + # in gaussian reweight algorithm, lambda is sigma + reweight_algorithm = gaussian_reweight_algorithm_restricted SHARD_SIZE = 2000 w = [grad for (_, grad) in raw_client_grad_list] @@ -70,13 +96,15 @@ def IRLS_other_split_restricted(self, raw_client_grad_list): else: num_shards = int(math.ceil(total_num / SHARD_SIZE)) for i in range(num_shards): - y = transposed_y_list[i * SHARD_SIZE : (i + 1) * SHARD_SIZE, ...] + y = transposed_y_list[i * + SHARD_SIZE: (i + 1) * SHARD_SIZE, ...] reweight, restricted_y = reweight_algorithm( y, self.lambda_param, self.thresh ) print(reweight.sum(dim=0)) reweight_sum += reweight.sum(dim=0) - y_result[i * SHARD_SIZE : (i + 1) * SHARD_SIZE, ...] = restricted_y + y_result[i * SHARD_SIZE: (i + 1) + * SHARD_SIZE, ...] = restricted_y # put restricted y back to w y_result = torch.t(y_result) @@ -89,13 +117,25 @@ def IRLS_other_split_restricted(self, raw_client_grad_list): def median_reweight_algorithm_restricted(y, LAMBDA, thresh): + """ + Perform reweighting using the Median Reweight Algorithm with restricted mode. + + Args: + y (torch.Tensor): Input data. + LAMBDA (float): Lambda parameter. + thresh (float): Threshold value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing reweight values and restricted data. + """ num_models = y.shape[1] total_num = y.shape[0] X_pure = y.sort()[1].sort()[1].type(torch.float) # calculate H matrix X_pure = X_pure.unsqueeze(2) - X = torch.cat((torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) + X = torch.cat( + (torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) X_X = torch.matmul(X.transpose(1, 2), X) X_X = torch.matmul(X, torch.inverse(X_X)) H = torch.matmul(X_X, X.transpose(1, 2)) @@ -121,6 +161,16 @@ def median_reweight_algorithm_restricted(y, LAMBDA, thresh): def median(input): + """ + Calculate the median of the input data. + + Args: + input (torch.Tensor): Input data. + + Returns: + torch.Tensor: Median value. + """ + shape = input.shape input = input.sort()[0] if shape[-1] % 2 != 0: @@ -133,6 +183,17 @@ def median(input): def theilsen_reweight_algorithm_restricted(y, LAMBDA, thresh): + """ + Perform reweighting using the Theil-Sen Reweight Algorithm with restricted mode. + + Args: + y (torch.Tensor): Input data. + LAMBDA (float): Lambda parameter. + thresh (float): Threshold value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing reweight values and restricted data. + """ num_models = y.shape[1] total_num = y.shape[0] slopes, intercepts = theilsen(y) @@ -140,7 +201,8 @@ def theilsen_reweight_algorithm_restricted(y, LAMBDA, thresh): # calculate H matrix X_pure = X_pure.unsqueeze(2) - X = torch.cat((torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) + X = torch.cat( + (torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) X_X = torch.matmul(X.transpose(1, 2), X) X_X = torch.matmul(X, torch.inverse(X_X)) H = torch.matmul(X_X, X.transpose(1, 2)) @@ -173,12 +235,24 @@ def theilsen_reweight_algorithm_restricted(y, LAMBDA, thresh): def gaussian_reweight_algorithm_restricted(y, sig, thresh): + """ + Perform reweighting using the Gaussian Reweight Algorithm with restricted mode. + + Args: + y (torch.Tensor): Input data. + sig (float): Sigma parameter for the Gaussian distribution. + thresh (float): Threshold value. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing reweight values and restricted data. + """ num_models = y.shape[1] total_num = y.shape[0] slopes, intercepts = repeated_median(y) X_pure = y.sort()[1].sort()[1].type(torch.float) X_pure = X_pure.unsqueeze(2) - X = torch.cat((torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) + X = torch.cat( + (torch.ones(total_num, num_models, 1).to(y.device), X_pure), dim=-1) beta = torch.cat( ( @@ -205,10 +279,29 @@ def gaussian_reweight_algorithm_restricted(y, sig, thresh): def gaussian_zero_mean(x, sig=1): + """ + Compute the Gaussian reweighting with zero mean. + + Args: + x (torch.Tensor): Input data. + sig (float, optional): Sigma parameter for the Gaussian distribution. Default is 1. + + Returns: + torch.Tensor: Reweighted data. + """ return torch.exp(-x * x / (2 * sig * sig)) def repeated_median(y): + """ + Compute the repeated median and intercepts for the Theil-Sen Reweight Algorithm. + + Args: + y (torch.Tensor): Input data. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing slopes and intercepts. + """ num_models = y.shape[1] total_num = y.shape[0] y = y.sort()[0] @@ -238,6 +331,15 @@ def repeated_median(y): def theilsen(y): + """ + Compute the Theil-Sen estimator for slopes and intercepts. + + Args: + y (torch.Tensor): Input data. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple containing slopes and intercepts. + """ num_models = y.shape[1] total_num = y.shape[0] y = y.sort()[0] diff --git a/python/fedml/core/security/defense/robust_learning_rate_defense.py b/python/fedml/core/security/defense/robust_learning_rate_defense.py index ccbb1c24df..bacacfcfa3 100644 --- a/python/fedml/core/security/defense/robust_learning_rate_defense.py +++ b/python/fedml/core/security/defense/robust_learning_rate_defense.py @@ -24,22 +24,66 @@ class RobustLearningRateDefense(BaseDefenseMethod): + """ + Robust Learning Rate Defense. + + This defense method adjusts the learning rates of clients based on the robust threshold. + + Args: + config: Configuration parameters. + + Attributes: + robust_threshold (int): The robust threshold used for learning rate adjustment. + server_learning_rate (int): The server's learning rate. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + Adjust the learning rates of clients based on the robust threshold. + + """ + def __init__(self, config): + """ + Initialize the RobustLearningRateDefense. + + Args: + config: Configuration parameters. + """ self.robust_threshold = config.robust_threshold # e.g., robust threshold = 4 self.server_learning_rate = 1 def run( - self, - raw_client_grad_list: List[Tuple[float, OrderedDict]], - base_aggregation_func: Callable = None, - extra_auxiliary_info: Any = None, - ): + self, + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + """ + Adjust the learning rates of clients based on the robust threshold. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + OrderedDict: + Aggregated parameters after adjusting learning rates based on the robust threshold. + """ if self.robust_threshold == 0: return base_aggregation_func(raw_client_grad_list) # avg_params total_sample_num = get_total_sample_num(raw_client_grad_list) (num0, avg_params) = raw_client_grad_list[0] for k in avg_params.keys(): - client_update_sign = [] # self._compute_robust_learning_rates(model_list) + # self._compute_robust_learning_rates(model_list) + client_update_sign = [] for i in range(0, len(raw_client_grad_list)): local_sample_number, local_model_params = raw_client_grad_list[i] client_update_sign.append(torch.sign(local_model_params[k])) @@ -53,7 +97,19 @@ def run( return avg_params def _compute_robust_learning_rates(self, client_update_sign): + """ + Compute robust learning rates based on the client update signs. + + Args: + client_update_sign (list of torch.Tensor): + List of tensors containing the sign of client updates. + + Returns: + torch.Tensor: + Adjusted learning rates for clients. + """ client_lr = torch.abs(sum(client_update_sign)) - client_lr[client_lr < self.robust_threshold] = -self.server_learning_rate + client_lr[client_lr < self.robust_threshold] = - \ + self.server_learning_rate client_lr[client_lr >= self.robust_threshold] = self.server_learning_rate return client_lr diff --git a/python/fedml/core/security/defense/slsgd_defense.py b/python/fedml/core/security/defense/slsgd_defense.py index c39b60da7c..ac8d848da9 100644 --- a/python/fedml/core/security/defense/slsgd_defense.py +++ b/python/fedml/core/security/defense/slsgd_defense.py @@ -27,7 +27,42 @@ class SLSGDDefense(BaseDefenseMethod): + """ + Stochastic Leader Selection for SGD Defense. + + This defense method performs leader selection and aggregation for federated learning. + + Args: + config: Configuration parameters. + + Attributes: + b (int): Parameter of trimmed mean. + alpha (float): Weighting factor for aggregation. + option_type (int): Type of option. + config: Configuration parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform preprocessing and leader selection on client gradients before aggregation. + + defend_on_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> OrderedDict: + Perform aggregation with leader selection based on the given configuration. + + """ def __init__(self, config): + """ + Initialize the SLSGDDefense. + + Args: + config: Configuration parameters. + """ self.b = config.trim_param_b # parameter of trimmed mean if config.alpha > 1 or config.alpha < 0: raise ValueError("the bound of alpha is [0, 1]") @@ -40,6 +75,19 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform preprocessing and leader selection on client gradients before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: + Processed and selected client gradients. + """ if self.b > math.ceil(len(raw_client_grad_list) / 2) - 1 or self.b < 0: raise ValueError( "the bound of b is [0, {}])".format( @@ -60,6 +108,21 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Perform aggregation with leader selection based on the given configuration. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + OrderedDict: + Aggregated parameters after leader selection and aggregation. + """ global_model = extra_auxiliary_info avg_params = base_aggregation_func(args=self.config, raw_grad_list=raw_client_grad_list) for k in avg_params.keys(): diff --git a/python/fedml/core/security/defense/soteria_defense.py b/python/fedml/core/security/defense/soteria_defense.py index a85203eade..e9f372737e 100644 --- a/python/fedml/core/security/defense/soteria_defense.py +++ b/python/fedml/core/security/defense/soteria_defense.py @@ -26,6 +26,29 @@ class SoteriaDefense(BaseDefenseMethod): + """ + Soteria Defense for Federated Learning. + + This defense method performs a Soteria-based defense for federated learning. + + Args: + num_class (int): Number of classes in the dataset. + model: The federated learning model. + defense_data: Defense data for the model. + defense_label (int): Defense label. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> Dict: + Perform Soteria-based defense on the federated learning model. + + label_to_onehot(target, num_classes=100) -> torch.Tensor: + Convert labels to one-hot encoding. + + """ def __init__( self, num_class, @@ -33,6 +56,15 @@ def __init__( defense_data, defense_label=84, ): + """ + Initialize the SoteriaDefense. + + Args: + num_class (int): Number of classes in the dataset. + model: The federated learning model. + defense_data: Defense data for the model. + defense_label (int): Defense label. + """ self.num_class = num_class # number of classess of the dataset self.model = model self.defense_data = defense_data @@ -46,6 +78,21 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> Dict: + """ + Perform Soteria-based defense on the federated learning model. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + Dict: + Aggregation result after Soteria-based defense. + """ # load local model self.model.load_state_dict(raw_client_grad_list, strict=True) original_dy_dx = extra_auxiliary_info # refs for local gradient diff --git a/python/fedml/core/security/defense/three_sigma_defense.py b/python/fedml/core/security/defense/three_sigma_defense.py index ebdb56ce5b..efe7719103 100644 --- a/python/fedml/core/security/defense/three_sigma_defense.py +++ b/python/fedml/core/security/defense/three_sigma_defense.py @@ -6,7 +6,7 @@ from ..common import utils from scipy import spatial -### Original paper: https://arxiv.org/pdf/2107.05252.pdf +# Original paper: https://arxiv.org/pdf/2107.05252.pdf # training: In each iteration, each client k splits its local dataset into batches of size B, # and runs for E local epochs batched-gradient descent through the local dataset # to obtain local model, and sends it to the server. @@ -41,7 +41,39 @@ class ThreeSigmaDefense(BaseDefenseMethod): + """ + Three-Sigma Defense for Federated Learning. + + This defense method performs a Three-Sigma-based defense for federated learning. + + Args: + config: Configuration object for defense parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform defense before aggregation. + + compute_gaussian_distribution() -> Tuple[float, float]: + Compute the Gaussian distribution parameters. + + compute_client_scores(raw_client_grad_list) -> List[float]: + Compute client scores. + + fools_gold_score(feature_vec_list) -> List[float]: + Compute Fool's Gold scores. + + """ + def __init__(self, config): + """ + Initialize the ThreeSigmaDefense. + + Args: + config: Configuration object for defense parameters. + """ self.memory = None self.iteration_num = 1 self.score_list = [] @@ -74,6 +106,18 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: Batched gradient list after defense. + """ # grad_list = [grad for (_, grad) in raw_client_grad_list] client_scores = self.compute_client_scores(raw_client_grad_list) if self.iteration_num < self.pretraining_round_number: @@ -96,8 +140,6 @@ def defend_before_aggregation( raw_client_grad_list.pop(i) print(f"pop -- i = {i}") - - batch_grad_list = Bucket.bucketization( raw_client_grad_list, self.bucketing_batch_size ) @@ -120,6 +162,12 @@ def defend_before_aggregation( # return avg_params def compute_gaussian_distribution(self): + """ + Compute the Gaussian distribution parameters. + + Returns: + Tuple[float, float]: Mean (mu) and standard deviation (sigma). + """ n = len(self.score_list) mu = sum(list(self.score_list)) / n temp = 0 @@ -131,8 +179,18 @@ def compute_gaussian_distribution(self): return mu, sigma def compute_client_scores(self, raw_client_grad_list): + """ + Compute client scores. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of client scores. + """ if self.score_function == "foolsgold": - importance_feature_list = self._get_importance_feature(raw_client_grad_list) + importance_feature_list = self._get_importance_feature( + raw_client_grad_list) if self.memory is None: self.memory = importance_feature_list else: # memory: potential bugs: grads in different iterations may be from different clients @@ -141,6 +199,15 @@ def compute_client_scores(self, raw_client_grad_list): return self.fools_gold_score(self.memory) def _get_importance_feature(self, raw_client_grad_list): + """ + Get importance features for Fool's Gold score computation. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of importance features. + """ # print(f"raw_client_grad_list = {raw_client_grad_list}") # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] @@ -162,6 +229,15 @@ def _get_importance_feature(self, raw_client_grad_list): @staticmethod def fools_gold_score(feature_vec_list): + """ + Compute Fool's Gold scores. + + Args: + feature_vec_list: List of importance features. + + Returns: + List[float]: List of Fool's Gold scores. + """ n_clients = len(feature_vec_list) cs = np.zeros((n_clients, n_clients)) for i in range(n_clients): @@ -183,7 +259,7 @@ def fools_gold_score(feature_vec_list): alpha[alpha <= 0.0] = 1e-15 # Rescale so that max value is alpha - # print(np.max(alpha)) + # print(np.max(alpha)) alpha = alpha / np.max(alpha) alpha[(alpha == 1.0)] = 0.999999 diff --git a/python/fedml/core/security/defense/three_sigma_geomedian_defense.py b/python/fedml/core/security/defense/three_sigma_geomedian_defense.py index 73d4ac9a05..9c78d19646 100644 --- a/python/fedml/core/security/defense/three_sigma_geomedian_defense.py +++ b/python/fedml/core/security/defense/three_sigma_geomedian_defense.py @@ -9,7 +9,42 @@ class ThreeSigmaGeoMedianDefense(BaseDefenseMethod): + """ + Three-Sigma Defense with Geometric Median for Federated Learning. + + This defense method performs a Three-Sigma-based defense with geometric median for federated learning. + + Args: + config: Configuration object for defense parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform defense before aggregation. + + compute_gaussian_distribution() -> Tuple[float, float]: + Compute the Gaussian distribution parameters. + + compute_client_scores(raw_client_grad_list) -> List[float]: + Compute client scores. + + fools_gold_score(feature_vec_list) -> List[float]: + Compute Fool's Gold scores. + + l2_scores(importance_feature_list) -> List[float]: + Compute L2 scores. + + """ + def __init__(self, config): + """ + Initialize the ThreeSigmaGeoMedianDefense. + + Args: + config: Configuration object for defense parameters. + """ self.memory = None self.iteration_num = 1 self.score_list = [] @@ -39,6 +74,18 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: Gradient list after defense. + """ # grad_list = [grad for (_, grad) in raw_client_grad_list] client_scores = self.compute_client_scores(raw_client_grad_list) print(f"client scores = {client_scores}") @@ -64,6 +111,13 @@ def defend_before_aggregation( return raw_client_grad_list def compute_gaussian_distribution(self): + """ + Compute the Gaussian distribution parameters. + + Returns: + Tuple[float, float]: Mean (mu) and standard deviation (sigma). + """ + n = len(self.score_list) mu = sum(list(self.score_list)) / n temp = 0 @@ -75,7 +129,17 @@ def compute_gaussian_distribution(self): return mu, sigma def compute_client_scores(self, raw_client_grad_list): - importance_feature_list = self._get_importance_feature(raw_client_grad_list) + """ + Compute client scores. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of client scores. + """ + importance_feature_list = self._get_importance_feature( + raw_client_grad_list) if self.score_function == "foolsgold": if self.memory is None: self.memory = importance_feature_list @@ -88,19 +152,39 @@ def compute_client_scores(self, raw_client_grad_list): # (num0, avg_params) = raw_client_grad_list[0] # alphas = {alpha for (alpha, params) in raw_client_grad_list} # alphas = {alpha / sum(alphas, 0.0) for alpha in alphas} - alphas = [1/len(raw_client_grad_list)] * len(raw_client_grad_list) - self.geo_median = compute_geometric_median(alphas, importance_feature_list) + alphas = [1/len(raw_client_grad_list)] * \ + len(raw_client_grad_list) + self.geo_median = compute_geometric_median( + alphas, importance_feature_list) return self.l2_scores(importance_feature_list) def l2_scores(self, importance_feature_list): + """ + Compute L2 scores. + + Args: + importance_feature_list: List of importance features. + + Returns: + List[float]: List of L2 scores. + """ scores = [] for feature in importance_feature_list: - score = compute_euclidean_distance(torch.Tensor(feature), self.geo_median) + score = compute_euclidean_distance( + torch.Tensor(feature), self.geo_median) scores.append(score) return scores - def _get_importance_feature(self, raw_client_grad_list): + """ + Get importance features for score computation. + + Args: + raw_client_grad_list: List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of importance features. + """ # print(f"raw_client_grad_list = {raw_client_grad_list}") # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] @@ -122,6 +206,15 @@ def _get_importance_feature(self, raw_client_grad_list): @staticmethod def fools_gold_score(feature_vec_list): + """ + Compute Fool's Gold scores. + + Args: + feature_vec_list: List of importance features. + + Returns: + List[float]: List of Fool's Gold scores. + """ n_clients = len(feature_vec_list) cs = np.zeros((n_clients, n_clients)) for i in range(n_clients): @@ -143,7 +236,7 @@ def fools_gold_score(feature_vec_list): alpha[alpha <= 0.0] = 1e-15 # Rescale so that max value is alpha - # print(np.max(alpha)) + # print(np.max(alpha)) alpha = alpha / np.max(alpha) alpha[(alpha == 1.0)] = 0.999999 @@ -154,4 +247,4 @@ def fools_gold_score(feature_vec_list): print("alpha = {}".format(alpha)) - return alpha \ No newline at end of file + return alpha diff --git a/python/fedml/core/security/defense/three_sigma_krum_defense.py b/python/fedml/core/security/defense/three_sigma_krum_defense.py index 565aa0f962..8d476ad45e 100644 --- a/python/fedml/core/security/defense/three_sigma_krum_defense.py +++ b/python/fedml/core/security/defense/three_sigma_krum_defense.py @@ -14,7 +14,52 @@ class ThreeSigmaKrumDefense(BaseDefenseMethod): + """ + Three-Sigma Defense with Krum-based Malicious Client Detection for Federated Learning. + + This defense method performs a Three-Sigma-based defense with Krum-based malicious client detection for federated learning. + + Args: + config: Configuration object for defense parameters. + + Methods: + defend_before_aggregation( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + extra_auxiliary_info: Any = None, + ) -> List[Tuple[float, OrderedDict]]: + Perform defense before aggregation. + + kick_out_poisoned_local_models( + client_scores: List[float], + raw_client_grad_list: List[Tuple[float, OrderedDict]] + ) -> Tuple[List[Tuple[float, OrderedDict]], List[float]]: + Remove poisoned local models based on client scores. + + get_malicious_client_idxs() -> List[int]: + Get indices of detected malicious clients. + + set_potential_malicious_clients(potential_malicious_client_idxs: List[int]): + Set potential malicious client indices. + + compute_avg_with_krum(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Compute an average feature with Krum-based malicious client detection. + + compute_l2_scores(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Compute L2 scores for client models. + + compute_client_cosine_scores(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Compute cosine similarity scores between client models. + + _get_importance_feature(raw_client_grad_list: List[Tuple[float, OrderedDict]]) -> List[float]: + Get importance features from raw client gradients. + """ def __init__(self, config): + """ + Initialize the ThreeSigmaKrumDefense. + + Args: + config: Configuration object for defense parameters. + """ self.average = None self.upper_bound = 0 self.malicious_client_idxs = [] @@ -31,6 +76,18 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform defense before aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + List[Tuple[float, OrderedDict]]: Gradient list after defense. + """ if self.average is None: self.average = self.compute_avg_with_krum(raw_client_grad_list) client_scores = self.compute_l2_scores(raw_client_grad_list) @@ -46,6 +103,19 @@ def defend_before_aggregation( return new_client_models def compute_an_average_feature(self, importance_feature_list): + """ + Remove poisoned local models based on client scores. + + Args: + client_scores (List[float]): List of client scores. + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + Tuple[List[Tuple[float, OrderedDict]], List[float]]: + Tuple containing gradient list after removing poisoned models + and updated client scores. + """ alphas = [1 / len(importance_feature_list)] * len(importance_feature_list) return compute_middle_point(alphas, importance_feature_list) @@ -99,6 +169,13 @@ def compute_an_average_feature(self, importance_feature_list): # return raw_client_grad_list def kick_out_poisoned_local_models(self, client_scores, raw_client_grad_list): + """ + Get indices of detected malicious clients. + + Returns: + List[int]: List of indices of malicious clients. + """ + print(f"upper bound = {self.upper_bound}") # traverse the score list in a reversed order self.malicious_client_idxs = [] @@ -112,12 +189,38 @@ def kick_out_poisoned_local_models(self, client_scores, raw_client_grad_list): return raw_client_grad_list, client_scores def get_malicious_client_idxs(self): + """ + Set potential malicious client indices. + + Args: + potential_malicious_client_idxs: List of potential malicious client indices. + """ return self.malicious_client_idxs def set_potential_malicious_clients(self, potential_malicious_client_idxs): + """ + Compute an average feature with Krum-based malicious client detection. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List representing an average feature. + """ self.potential_malicious_client_idxs = None # potential_malicious_client_idxs todo def compute_avg_with_krum(self, raw_client_grad_list): + """ + Compute L2 scores for client models. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of L2 scores. + """ importance_feature_list = self._get_importance_feature(raw_client_grad_list) krum_scores = compute_krum_score( importance_feature_list, @@ -133,6 +236,16 @@ def compute_avg_with_krum(self, raw_client_grad_list): return self.compute_an_average_feature(honest_importance_feature_list) def compute_l2_scores(self, raw_client_grad_list): + """ + Compute L2 scores for client models. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of L2 scores. + """ importance_feature_list = self._get_importance_feature(raw_client_grad_list) scores = [] for feature in importance_feature_list: @@ -141,6 +254,16 @@ def compute_l2_scores(self, raw_client_grad_list): return scores def compute_client_cosine_scores(self, raw_client_grad_list): + """ + Compute cosine similarity scores between client models. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of cosine similarity scores. + """ importance_feature_list = self._get_importance_feature(raw_client_grad_list) cosine_scores = [] num_client = len(importance_feature_list) @@ -158,6 +281,16 @@ def compute_client_cosine_scores(self, raw_client_grad_list): return cosine_scores def _get_importance_feature(self, raw_client_grad_list): + """ + Get importance features from raw client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + + Returns: + List[float]: List of importance feature vectors. + """ # print(f"raw_client_grad_list = {raw_client_grad_list}") # Foolsgold uses the last layer's gradient/weights as the importance feature. ret_feature_vector_list = [] diff --git a/python/fedml/core/security/defense/wbc_defense.py b/python/fedml/core/security/defense/wbc_defense.py index e65dc4d597..e804a1f8ba 100644 --- a/python/fedml/core/security/defense/wbc_defense.py +++ b/python/fedml/core/security/defense/wbc_defense.py @@ -23,7 +23,36 @@ class WbcDefense(BaseDefenseMethod): + """ + Weight-Based Client Defense for Federated Learning. + + This defense method performs weight-based client defense for federated learning. + + Args: + args: Argument object containing client and batch indices. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> Dict: + Run the weight-based client defense. + + Attributes: + args: Argument object containing client and batch indices. + client_idx: Index of the client. + batch_idx: Index of the batch. + old_gradient: Dictionary to store old gradients for weight perturbation. + """ + def __init__(self, args): + """ + Initialize the WbcDefense. + + Args: + args: Argument object containing client and batch indices. + """ self.args = args self.client_idx = args.client_idx self.batch_idx = args.batch_idx @@ -35,6 +64,20 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> Dict: + """ + Run the weight-based client defense. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + Dict: Dictionary containing aggregated model parameters. + """ num_client = len(raw_client_grad_list) vec_local_w = [ ( @@ -53,7 +96,8 @@ def run( for (k, v) in model_param.items(): if "weight" in k: grad_tensor = ( - raw_client_grad_list[self.client_idx][1][k].cpu().numpy() + raw_client_grad_list[self.client_idx][1][k].cpu( + ).numpy() ) # for testing, simply pre-defin old gradient self.old_gradient[k] = grad_tensor * 0.2 @@ -67,7 +111,8 @@ def run( ) learning_rate = 0.1 new_model_param[k] = torch.from_numpy( - model_param[k].cpu().numpy() + pertubation * learning_rate + model_param[k].cpu().numpy() + + pertubation * learning_rate ) else: new_model_param[k] = model_param[k] @@ -82,7 +127,8 @@ def run( if i != self.client_idx or self.batch_idx == 0: param_list.append(models_param[i]) else: - param_list.append((models_param[self.client_idx][0], new_model_param)) + param_list.append( + (models_param[self.client_idx][0], new_model_param)) logging.info(f"New. param: {param_list[i]}") return base_aggregation_func(self.args, param_list) # avg_params diff --git a/python/fedml/core/security/defense/weak_dp_defense.py b/python/fedml/core/security/defense/weak_dp_defense.py index 06e465c3f4..25c3372c5e 100644 --- a/python/fedml/core/security/defense/weak_dp_defense.py +++ b/python/fedml/core/security/defense/weak_dp_defense.py @@ -9,6 +9,27 @@ class WeakDPDefense(BaseDefenseMethod): + """ + Weak Differential Privacy (DP) Defense for Federated Learning. + + This defense method adds weak differential privacy noise to client gradients to enhance privacy. + + Args: + config: Configuration object containing defense parameters. + + Methods: + run( + raw_client_grad_list: List[Tuple[float, OrderedDict]], + base_aggregation_func: Callable = None, + extra_auxiliary_info: Any = None, + ) -> Dict: + Run the weak DP defense. + + Attributes: + config: Configuration object containing defense parameters. + stddev: Standard deviation for adding noise to gradients. + """ + def __init__(self, config): self.config = config self.stddev = config.stddev # for weak DP defenses @@ -19,6 +40,20 @@ def run( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ) -> Dict: + """ + Run the weak DP defense by adding noise to client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): + List of tuples containing client gradients as OrderedDict. + base_aggregation_func (Callable, optional): + Base aggregation function (currently unused). + extra_auxiliary_info (Any, optional): + Extra auxiliary information (currently unused). + + Returns: + Dict: Dictionary containing aggregated model parameters with added noise. + """ new_grad_list = [] for (sample_num, local_w) in raw_client_grad_list: new_w = self._add_noise(local_w) @@ -26,6 +61,15 @@ def run( return base_aggregation_func(self.config, new_grad_list) # avg_params def _add_noise(self, param): + """ + Add Gaussian noise to the parameters. + + Args: + param (OrderedDict): Client parameters. + + Returns: + OrderedDict: Parameters with added noise. + """ dp_param = dict() for k in param.keys(): dp_param[k] = param[k] + torch.randn(param[k].size()) * self.stddev diff --git a/python/fedml/core/security/fedml_attacker.py b/python/fedml/core/security/fedml_attacker.py index 6264ccfb1c..34739231bc 100644 --- a/python/fedml/core/security/fedml_attacker.py +++ b/python/fedml/core/security/fedml_attacker.py @@ -11,6 +11,18 @@ class FedMLAttacker: + """ + Represents an attacker in a federated learning system. + + The `FedMLAttacker` class is responsible for managing different types of attacks, including model poisoning, data poisoning, + and data reconstruction attacks, within a federated learning setting. + + Attributes: + _attacker_instance (FedMLAttacker): A singleton instance of the `FedMLAttacker` class. + is_enabled (bool): Whether the attacker is enabled. + attack_type (str): The type of attack being used. + attacker (Any): The specific attacker object. + """ _attacker_instance = None @staticmethod @@ -21,11 +33,31 @@ def get_instance(): return FedMLAttacker._attacker_instance def __init__(self): + """ + Initialize a FedMLAttacker instance. + + This constructor sets up the attacker instance and initializes its properties. + + Attributes: + is_enabled (bool): Whether the attacker is enabled. + attack_type (str): The type of attack being used. + attacker (Any): The specific attacker object. + + """ self.is_enabled = False self.attack_type = None self.attacker = None def init(self, args): + """ + Initialize the attacker with provided arguments. + + This method initializes the attacker based on the provided arguments. + + Args: + args: The arguments used to configure the attacker. + + """ if hasattr(args, "enable_attack") and args.enable_attack: logging.info("------init attack..." + args.attack_type.strip()) self.is_enabled = True @@ -56,13 +88,35 @@ def init(self, args): self.is_enabled = False def is_attack_enabled(self): + """ + Check if the attacker is enabled. + + Returns: + bool: True if the attacker is enabled, False otherwise. + + """ return self.is_enabled def get_attack_types(self): + """ + Get the type of attack. + + Returns: + str: The type of attack being used. + + """ return self.attack_type # --------------- for model poisoning attacks --------------- # def is_model_attack(self): + """ + Check if the attack is a model poisoning attack. + + Returns: + bool: True if it's a model poisoning attack, False otherwise. + + """ + if self.is_attack_enabled() and self.attack_type in [ ATTACK_METHOD_BYZANTINE_ATTACK, BACKDOOR_ATTACK_MODEL_REPLACEMENT ]: @@ -70,6 +124,23 @@ def is_model_attack(self): return False def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Attack the model with poisoned gradients. + + This method is used for model poisoning attacks. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the attack. + + Returns: + Any: The poisoned client gradients. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") return self.attacker.attack_model(raw_client_grad_list, extra_auxiliary_info) @@ -77,16 +148,48 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], ex # --------------- for data poisoning attacks --------------- # def is_data_poisoning_attack(self): + """ + Check if the attack is a data poisoning attack. + + Returns: + bool: True if it's a data poisoning attack, False otherwise. + + """ if self.is_attack_enabled() and self.attack_type in [ATTACK_LABEL_FLIPPING]: return True return False def is_to_poison_data(self): + """ + Check if data should be poisoned. + + Returns: + bool: True if data should be poisoned, False otherwise. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") return self.attacker.is_to_poison_data() def poison_data(self, dataset): + """ + Poison the dataset. + + This method is used for data poisoning attacks. + + Args: + dataset: The dataset to be poisoned. + + Returns: + Any: The poisoned dataset. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") return self.attacker.poison_data(dataset) @@ -94,12 +197,34 @@ def poison_data(self, dataset): # --------------- for data reconstructing attacks --------------- # def is_data_reconstruction_attack(self): + """ + Check if the attack is a data reconstruction attack. + + Returns: + bool: True if it's a data reconstruction attack, False otherwise. + + """ if self.is_attack_enabled() and self.attack_type in [ATTACK_METHOD_DLG]: return True return False def reconstruct_data(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Reconstruct the data from gradients. + + This method is used for data reconstruction attacks. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the attack. + + Raises: + Exception: If the attacker is not initialized. + + """ if self.attacker is None: raise Exception("attacker is not initialized!") - self.attacker.reconstruct_data(raw_client_grad_list, extra_auxiliary_info=extra_auxiliary_info) - # --------------- for data reconstructing attacks --------------- # \ No newline at end of file + self.attacker.reconstruct_data( + raw_client_grad_list, extra_auxiliary_info=extra_auxiliary_info) + # --------------- for data reconstructing attacks --------------- # diff --git a/python/fedml/core/security/fedml_defender.py b/python/fedml/core/security/fedml_defender.py index d88f8dfcd9..56b688317c 100644 --- a/python/fedml/core/security/fedml_defender.py +++ b/python/fedml/core/security/fedml_defender.py @@ -38,21 +38,64 @@ class FedMLDefender: + """ + A class for managing defense mechanisms in federated learning. + + This class handles the configuration and execution of defense mechanisms to enhance the robustness + of federated learning against adversarial attacks. + + Methods: + get_instance: Get an instance of the FedMLDefender class. + init: Initialize the defense mechanism based on configuration. + is_defense_enabled: Check if defense mechanisms are enabled. + defend: Defend against adversarial attacks on client gradients. + is_defense_on_aggregation: Check if defense occurs during aggregation. + is_defense_before_aggregation: Check if defense occurs before aggregation. + is_defense_after_aggregation: Check if defense occurs after aggregation. + defend_before_aggregation: Apply defense before gradient aggregation. + defend_on_aggregation: Apply defense during gradient aggregation. + defend_after_aggregation: Apply defense after gradient aggregation. + get_malicious_client_idxs: Get the indices of malicious clients. + get_benign_client_idxs: Get the indices of benign clients. + + Attributes: + None + """ + _defender_instance = None @staticmethod def get_instance(): + """ + Get an instance of the FedMLDefender class. + + Returns: + FedMLDefender: An instance of the FedMLDefender class. + """ + if FedMLDefender._defender_instance is None: FedMLDefender._defender_instance = FedMLDefender() return FedMLDefender._defender_instance def __init__(self): + """ + Initialize a FedMLDefender instance. + """ self.is_enabled = False self.defense_type = None self.defender = None def init(self, args): + """ + Initialize the defense mechanism based on configuration. + + Args: + args: The command-line arguments. + + Raises: + Exception: If the defense mechanism type is not defined. + """ if hasattr(args, "enable_defense") and args.enable_defense: self.args = args logging.info("------init defense..." + args.defense_type) @@ -114,6 +157,12 @@ def init(self, args): self.is_enabled = False def is_defense_enabled(self): + """ + Check if defense mechanisms are enabled. + + Returns: + bool: True if defense is enabled, False otherwise. + """ return self.is_enabled def defend( @@ -122,6 +171,21 @@ def defend( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Defend against adversarial attacks on client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + base_aggregation_func (Callable, optional): The base aggregation function for gradient aggregation. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the defense mechanism. + + Returns: + Any: The defended client gradients or the result of the aggregation function. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") return self.defender.run( @@ -129,9 +193,22 @@ def defend( ) def is_defense_on_aggregation(self): + """ + Check if defense occurs during gradient aggregation. + + Returns: + bool: True if defense occurs during aggregation, False otherwise. + """ return self.is_defense_enabled() and self.defense_type in [DEFENSE_SLSGD, DEFENSE_RFA, DEFENSE_WISE_MEDIAN, DEFENSE_GEO_MEDIAN] def is_defense_before_aggregation(self): + """ + Check if defense occurs before gradient aggregation. + + Returns: + bool: True if defense occurs before aggregation, False otherwise. + """ + return self.is_defense_enabled() and self.defense_type in [ DEFENSE_SLSGD, DEFENSE_FOOLSGOLD, @@ -147,6 +224,13 @@ def is_defense_before_aggregation(self): ] def is_defense_after_aggregation(self): + """ + Check if defense occurs after gradient aggregation. + + Returns: + bool: True if defense occurs after aggregation, False otherwise. + """ + return self.is_defense_enabled() and self.defense_type in [DEFENSE_CRFL, DEFENSE_CCLIP] def defend_before_aggregation( @@ -154,6 +238,20 @@ def defend_before_aggregation( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Apply defense before gradient aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the defense mechanism. + + Returns: + List[Tuple[float, OrderedDict]]: The defended client gradients. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") if self.is_defense_before_aggregation(): @@ -168,6 +266,21 @@ def defend_on_aggregation( base_aggregation_func: Callable = None, extra_auxiliary_info: Any = None, ): + """ + Apply defense during gradient aggregation. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): A list of tuples, each containing a weight and a + dictionary of client gradients. + base_aggregation_func (Callable, optional): The base aggregation function for gradient aggregation. + extra_auxiliary_info (Any, optional): Additional auxiliary information for the defense mechanism. + + Returns: + Any: The defended client gradients or the result of the aggregation function. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") if self.is_defense_on_aggregation(): @@ -177,6 +290,18 @@ def defend_on_aggregation( return base_aggregation_func(args=self.args, raw_grad_list=raw_client_grad_list) def defend_after_aggregation(self, global_model): + """ + Apply defense after gradient aggregation. + + Args: + global_model: The global model after gradient aggregation. + + Returns: + Any: The defended global model or its equivalent. + + Raises: + Exception: If the defender is not initialized. + """ if self.defender is None: raise Exception("defender is not initialized!") if self.is_defense_after_aggregation(): @@ -184,7 +309,26 @@ def defend_after_aggregation(self, global_model): return global_model def get_malicious_client_idxs(self): + """ + Get the indices of malicious clients. + + Returns: + List[int]: A list of indices corresponding to malicious clients. + """ + return self.defender.get_malicious_client_idxs() def get_benign_client_idxs(self, client_idxs): + """ + Get the indices of benign clients from a list of client indices. + + Args: + client_idxs (List[int]): A list of client indices. + + Returns: + List[int]: A list of indices corresponding to benign clients. + + Notes: + This method assumes that malicious clients have been identified using defense mechanisms. + """ return [i for i in client_idxs if i not in self.defender.get_malicious_client_idxs()] diff --git a/python/fedml/cross_device/mnn_server.py b/python/fedml/cross_device/mnn_server.py index 3502929944..599c72978a 100644 --- a/python/fedml/cross_device/mnn_server.py +++ b/python/fedml/cross_device/mnn_server.py @@ -4,15 +4,50 @@ class ServerMNN: + """ + A class representing the server in federated learning using MNN (Mobile Neural Networks). + + This class is responsible for coordinating and aggregating model updates from client devices. + + Args: + args: The command-line arguments. + device: The device for computations. + test_dataloader: The DataLoader for testing data. + model: The federated learning model. + server_aggregator: The server aggregator (optional). + + Attributes: + None + + Methods: + run: Run the server for federated learning. + """ + def __init__(self, args, device, test_dataloader, model, server_aggregator=None): + """ + Initialize a ServerMNN instance. + + Args: + args: The command-line arguments. + device: The device for computations. + test_dataloader: The DataLoader for testing data. + model: The federated learning model. + server_aggregator: The server aggregator (optional). + """ if args.federated_optimizer == "FedAvg": - logging.info("test_data_global.iter_number = {}".format(test_dataloader.iter_number)) + logging.info("test_data_global.iter_number = {}".format( + test_dataloader.iter_number)) fedavg_cross_device( args, 0, args.worker_num, None, device, test_dataloader, model, server_aggregator=server_aggregator ) else: - raise Exception("Exception") + raise Exception("Unsupported federated optimizer") def run(self): + """ + Run the server for federated learning. + + This method coordinates and aggregates model updates from client devices. + """ pass diff --git a/python/fedml/cross_silo/client/client_initializer.py b/python/fedml/cross_silo/client/client_initializer.py index 54fd865710..0b6503b0d7 100644 --- a/python/fedml/cross_silo/client/client_initializer.py +++ b/python/fedml/cross_silo/client/client_initializer.py @@ -20,6 +20,25 @@ def init_client( test_data_local_dict, model_trainer=None, ): + """ + Initialize and run a federated learning client. + + Args: + args: The command-line arguments. + device: The device to perform computations on. + comm: The communication backend. + client_rank: The rank of the client. + client_num: The total number of clients. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + + Returns: + None + """ backend = args.backend trainer_dist_adapter = get_trainer_dist_adapter( @@ -36,8 +55,8 @@ def init_client( if ( args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL or ( - args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL and - getattr(args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) + args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL and + getattr(args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) ) ): if args.proc_rank_in_silo == 0: @@ -46,13 +65,16 @@ def init_client( ) else: - client_manager = get_client_manager_salve(args, trainer_dist_adapter) + client_manager = get_client_manager_salve( + args, trainer_dist_adapter) elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL: - client_manager = get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend) + client_manager = get_client_manager_master( + args, trainer_dist_adapter, comm, client_rank, client_num, backend) else: - raise RuntimeError("we do not support {}. Please check whether this is typo.".format(args.scenario)) + raise RuntimeError( + "we do not support {}. Please check whether this is typo.".format(args.scenario)) client_manager.run() @@ -68,6 +90,23 @@ def get_trainer_dist_adapter( test_data_local_dict, model_trainer, ): + """ + Get a trainer distributed adapter. + + Args: + args: The command-line arguments. + device: The device to perform computations on. + client_rank: The rank of the client. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + + Returns: + TrainerDistAdapter: The trainer distributed adapter. + """ return TrainerDistAdapter( args, device, @@ -82,10 +121,34 @@ def get_trainer_dist_adapter( def get_client_manager_master(args, trainer_dist_adapter, comm, client_rank, client_num, backend): + """ + Get the federated learning client manager for the master. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + comm: The communication backend. + client_rank: The rank of the client. + client_num: The total number of clients. + backend: The communication backend. + + Returns: + ClientMasterManager: The federated learning client manager for the master. + """ return ClientMasterManager(args, trainer_dist_adapter, comm, client_rank, client_num, backend) def get_client_manager_salve(args, trainer_dist_adapter): + """ + Get the federated learning client manager for a slave. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + + Returns: + ClientSlaveManager: The federated learning client manager for a slave. + """ from .fedml_client_slave_manager import ClientSlaveManager return ClientSlaveManager(args, trainer_dist_adapter) diff --git a/python/fedml/cross_silo/client/client_launcher.py b/python/fedml/cross_silo/client/client_launcher.py index 1a4831b11e..76ff8ee703 100644 --- a/python/fedml/cross_silo/client/client_launcher.py +++ b/python/fedml/cross_silo/client/client_launcher.py @@ -27,25 +27,73 @@ class CrossSiloLauncher: @staticmethod def launch_dist_trainers(torch_client_filename, inputs): + """ + Launch distributed trainers for cross-silo federated learning. + + Args: + torch_client_filename (str): The filename of the torch client script to run. + inputs (list): A list of input arguments to pass to the torch client script. + + Returns: + None + """ # this is only used by the client (DDP or single process), so there is no need to specify the backend. args = load_arguments(FEDML_TRAINING_PLATFORM_CROSS_SILO) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - CrossSiloLauncher._run_cross_silo_hierarchical(args, torch_client_filename, inputs) + CrossSiloLauncher._run_cross_silo_hierarchical( + args, torch_client_filename, inputs) elif args.scenario == FEDML_CROSS_SILO_SCENARIO_HORIZONTAL: - CrossSiloLauncher._run_cross_silo_horizontal(args, torch_client_filename, inputs) + CrossSiloLauncher._run_cross_silo_horizontal( + args, torch_client_filename, inputs) else: - raise Exception("we do not support {}, check whether this is typo in args.scenario".format(args.scenario)) + raise Exception( + "we do not support {}, check whether this is typo in args.scenario".format(args.scenario)) @staticmethod def _run_cross_silo_horizontal(args, torch_client_filename, inputs): - python_path = subprocess.run(["which", "python"], capture_output=True, text=True).stdout.strip() + """ + Run cross-silo federated learning in horizontal scenario. + + Args: + args: The command-line arguments. + torch_client_filename (str): The filename of the torch client script to run. + inputs (list): A list of input arguments to pass to the torch client script. + + Returns: + None + """ + + python_path = subprocess.run( + ["which", "python"], capture_output=True, text=True).stdout.strip() process_arguments = [python_path, torch_client_filename] + inputs subprocess.run(process_arguments) @staticmethod def _run_cross_silo_hierarchical(args, torch_client_filename, inputs): + """ + Run cross-silo federated learning in hierarchical scenario. + + Args: + args: The command-line arguments. + torch_client_filename (str): The filename of the torch client script to run. + inputs (list): A list of input arguments to pass to the torch client script. + + Returns: + None + """ + def get_torchrun_arguments(node_rank): - torchrun_path = subprocess.run(["which", "torchrun"], capture_output=True, text=True).stdout.strip() + """ + Get the torchrun command arguments for launching on each node. + + Args: + node_rank (int): The rank of the current node. + + Returns: + list: List of command arguments for torchrun. + """ + torchrun_path = subprocess.run( + ["which", "torchrun"], capture_output=True, text=True).stdout.strip() return [ torchrun_path, @@ -58,8 +106,10 @@ def get_torchrun_arguments(node_rank): torch_client_filename, ] + inputs - network_interface = None if not hasattr(args, "network_interface") else args.network_interface - print(f"Using network interface {network_interface} for process group and TRPC communication") + network_interface = None if not hasattr( + args, "network_interface") else args.network_interface + print( + f"Using network interface {network_interface} for process group and TRPC communication") env_variables = { "OMP_NUM_THREADS": "4", } @@ -78,7 +128,8 @@ def get_torchrun_arguments(node_rank): device_type = get_device_type(args) if torch.cuda.is_available() and device_type == "gpu": gpu_count = torch.cuda.device_count() - print(f"Using number of GPUs ({gpu_count}) as number of processeses.") + print( + f"Using number of GPUs ({gpu_count}) as number of processeses.") args.n_proc_per_node = gpu_count else: print(f"Using number 1 as number of processeses.") @@ -95,7 +146,8 @@ def get_torchrun_arguments(node_rank): else: print(f"Automatic Client Launcher") - which_pdsh = subprocess.run(["which", "pdsh"], capture_output=True, text=True).stdout.strip() + which_pdsh = subprocess.run( + ["which", "pdsh"], capture_output=True, text=True).stdout.strip() if not which_pdsh: raise Exception( diff --git a/python/fedml/cross_silo/client/fedml_client_master_manager.py b/python/fedml/cross_silo/client/fedml_client_master_manager.py index 1b4ea7f81e..95ae7a2475 100644 --- a/python/fedml/cross_silo/client/fedml_client_master_manager.py +++ b/python/fedml/cross_silo/client/fedml_client_master_manager.py @@ -24,6 +24,17 @@ class ClientMasterManager(FedMLCommManager): RUN_FINISHED_STATUS_FLAG = "FINISHED" def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the ClientMasterManager. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + comm: The communication backend. + rank: The rank of the client. + size: The total number of clients. + backend: The communication backend (default is "MPI"). + """ super().__init__(args, comm, rank, size, backend) self.trainer_dist_adapter = trainer_dist_adapter self.args = args @@ -50,21 +61,42 @@ def __init__(self, args, trainer_dist_adapter, comm=None, rank=0, size=0, backen @property def use_customized_hierarchical(self) -> bool: + """ + Check if customized hierarchical cross-silo is enabled. + + Returns: + bool: True if customized hierarchical is enabled, False otherwise. + """ return getattr(self.args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) @property def has_customized_sync_process_group(self) -> bool: + """ + Check if a customized sync process group method is available in the trainer. + + Returns: + bool: True if a customized sync process group method is available, False otherwise. + """ return check_method_override( cls_obj=self.trainer_dist_adapter.trainer.trainer, method_name="sync_process_group" ) def is_main_process(self): + """ + Check if the current process is the main process. + + Returns: + bool: True if the current process is the main process, False otherwise. + """ return getattr(self.trainer_dist_adapter, "trainer", None) is None or \ getattr(self.trainer_dist_adapter.trainer, "trainer", None) is None or \ self.trainer_dist_adapter.trainer.trainer.is_main_process() def register_message_receive_handlers(self): + """ + Register message receive handlers for various message types. + """ self.register_message_receive_handler( MyMessage.MSG_TYPE_CONNECTION_IS_READY, self.handle_message_connection_ready ) @@ -73,7 +105,8 @@ def register_message_receive_handlers(self): MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.handle_message_check_status ) - self.register_message_receive_handler(MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) + self.register_message_receive_handler( + MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.handle_message_init) self.register_message_receive_handler( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.handle_message_receive_model_from_server, ) @@ -83,6 +116,12 @@ def register_message_receive_handlers(self): ) def handle_message_connection_ready(self, msg_params): + """ + Handle the connection-ready message. + + Args: + msg_params (dict): Parameters of the message. + """ if not self.has_sent_online_msg: self.has_sent_online_msg = True self.send_client_status(0) @@ -90,15 +129,28 @@ def handle_message_connection_ready(self, msg_params): mlops.log_sys_perf(self.args) def handle_message_check_status(self, msg_params): + """ + Handle the check-client-status message. + + Args: + msg_params (dict): Parameters of the message. + """ self.send_client_status(0) def handle_message_init(self, msg_params): + """ + Handle the initialization message. + + Args: + msg_params (dict): Parameters of the message. + """ if self.is_inited: return self.is_inited = True - global_model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) + global_model_params = msg_params.get( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS) data_silo_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) logging.info("data_silo_index = %s" % str(data_silo_index)) @@ -107,10 +159,12 @@ def handle_message_init(self, msg_params): self.report_training_status(MyMessage.MSG_MLOPS_CLIENT_STATUS_TRAINING) if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - global_model_params = convert_model_params_to_ddp(global_model_params) + global_model_params = convert_model_params_to_ddp( + global_model_params) self.sync_process_group(0, global_model_params, data_silo_index) elif self.use_customized_hierarchical: - self.customized_sync_process_group(0, global_model_params, data_silo_index) + self.customized_sync_process_group( + 0, global_model_params, data_silo_index) self.trainer_dist_adapter.update_dataset(int(data_silo_index)) self.trainer_dist_adapter.update_model(global_model_params) @@ -121,6 +175,12 @@ def handle_message_init(self, msg_params): self.round_idx += 1 def handle_message_receive_model_from_server(self, msg_params): + """ + Handle the received model from the server. + + Args: + msg_params (dict): Parameters of the message. + """ logging.info("handle_message_receive_model_from_server.") model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) client_index = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_INDEX) @@ -129,10 +189,12 @@ def handle_message_receive_model_from_server(self, msg_params): model_params = convert_model_params_to_ddp(model_params) self.sync_process_group(self.round_idx, model_params, client_index) elif self.use_customized_hierarchical: - self.customized_sync_process_group(self.round_idx, model_params, client_index) + self.customized_sync_process_group( + self.round_idx, model_params, client_index) self.trainer_dist_adapter.update_dataset(int(client_index)) - logging.info("current round index {}, total rounds {}".format(self.round_idx, self.num_rounds)) + logging.info("current round index {}, total rounds {}".format( + self.round_idx, self.num_rounds)) self.trainer_dist_adapter.update_model(model_params) if self.round_idx < self.num_rounds: self.__test(is_before_aggregation=False) # After aggregation @@ -141,40 +203,75 @@ def handle_message_receive_model_from_server(self, msg_params): self.round_idx += 1 else: mlops.stop_sys_perf() - self.send_client_status(0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) + self.send_client_status( + 0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) if self.is_main_process(): mlops.log_training_finished_status() self.finish() def handle_message_finish(self, msg_params): + """ + Handle the finish message. + + Args: + msg_params (dict): Parameters of the message. + """ logging.info(" ====================cleanup ====================") self.cleanup() def cleanup(self): - self.send_client_status(0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) + self.send_client_status( + 0, ClientMasterManager.RUN_FINISHED_STATUS_FLAG) if self.is_main_process(): mlops.log_training_finished_status() self.finish() def send_model_to_server(self, receive_id, weights, local_sample_num): + """ + Send the model to the server. + + Args: + receive_id: The ID of the entity receiving the model. + weights: The model weights to send. + local_sample_num: The number of local training samples. + + Returns: + None + """ if self.is_main_process(): tick = time.time() - mlops.event("comm_c2s", event_started=True, event_value=str(self.round_idx)) - message = Message(MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id) + mlops.event("comm_c2s", event_started=True, + event_value=str(self.round_idx)) + message = Message( + MyMessage.MSG_TYPE_C2S_SEND_MODEL_TO_SERVER, self.client_real_id, receive_id) message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, weights) - message.add_params(MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) + message.add_params( + MyMessage.MSG_ARG_KEY_NUM_SAMPLES, local_sample_num) self.send_message(message) - MLOpsProfilerEvent.log_to_wandb({"Communication/Send_Total": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Communication/Send_Total": time.time() - tick}) mlops.log_client_model_info( self.round_idx + 1, self.num_rounds, model_url=message.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL), ) def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): + """ + Send the client status to another entity. + + Args: + receive_id: The ID of the entity receiving the status. + status (str): The status to send (default is "ONLINE"). + + Returns: + None + """ if self.is_main_process(): logging.info("send_client_status") - logging.info("self.client_real_id = {}".format(self.client_real_id)) - message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, self.client_real_id, receive_id) + logging.info("self.client_real_id = {}".format( + self.client_real_id)) + message = Message(MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, + self.client_real_id, receive_id) sys_name = platform.system() if sys_name == "Darwin": sys_name = "Mac" @@ -185,11 +282,21 @@ def send_client_status(self, receive_id, status=ONLINE_STATUS_FLAG): message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, sys_name) if getattr(self.args, "using_mlops", False) and status == ClientMasterManager.RUN_FINISHED_STATUS_FLAG: - mlops.log_server_payload(self.args.run_id, self.client_real_id, json.dumps(message.get_params())) + mlops.log_server_payload( + self.args.run_id, self.client_real_id, json.dumps(message.get_params())) else: self.send_message(message) def report_training_status(self, status): + """ + Report the training status to MLOps. + + Args: + status: The training status to report. + + Returns: + None + """ mlops.log_training_status(status) def sync_process_group( @@ -199,12 +306,25 @@ def sync_process_group( client_index: Optional[int] = None, src: int = 0 ) -> None: + """ + Synchronize the process group for hierarchical cross-silo scenarios. + + Args: + round_idx (int): The round index. + model_params: The model parameters. + client_index (int): The client index. + src (int): The source index. + + Returns: + None + """ logging.info("sending round number to pg") round_number = [round_idx, model_params, client_index] dist.broadcast_object_list( round_number, src=src, group=self.trainer_dist_adapter.process_group_manager.get_process_group(), ) - logging.info("round number %d broadcast to process group" % round_number[0]) + logging.info("round number %d broadcast to process group" % + round_number[0]) def customized_sync_process_group( self, @@ -213,6 +333,18 @@ def customized_sync_process_group( client_index: Optional[int] = None, src: int = 0 ) -> None: + """ + Synchronize the process group using a customized method for hierarchical cross-silo scenarios. + + Args: + round_idx (int): The round index. + model_params: The model parameters. + client_index (int): The client index. + src (int): The source index. + + Returns: + None + """ trainer = self.trainer_dist_adapter.trainer.trainer trainer_class_name = trainer.__class__.__name__ @@ -225,13 +357,23 @@ def customized_sync_process_group( trainer.sync_process_group(round_idx, model_params, client_index, src) def __train(self): - logging.info("#######training########### round_id = %d" % self.round_idx) + """ + Perform the training process. - mlops.event("train", event_started=True, event_value=str(self.round_idx)) + Returns: + None + """ + logging.info("#######training########### round_id = %d" % + self.round_idx) - weights, local_sample_num = self.trainer_dist_adapter.train(self.round_idx) + mlops.event("train", event_started=True, + event_value=str(self.round_idx)) - mlops.event("train", event_started=False, event_value=str(self.round_idx)) + weights, local_sample_num = self.trainer_dist_adapter.train( + self.round_idx) + + mlops.event("train", event_started=False, + event_value=str(self.round_idx)) # the current model is still DDP-wrapped under cross-silo-hi setting if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: @@ -253,4 +395,10 @@ def __test(self, is_before_aggregation=False): self.trainer_dist_adapter.test(self.round_idx) def run(self): + """ + Run the client manager. + + Returns: + None + """ super().run() diff --git a/python/fedml/cross_silo/client/fedml_client_slave_manager.py b/python/fedml/cross_silo/client/fedml_client_slave_manager.py index a5320fed95..401f18672c 100644 --- a/python/fedml/cross_silo/client/fedml_client_slave_manager.py +++ b/python/fedml/cross_silo/client/fedml_client_slave_manager.py @@ -8,6 +8,13 @@ class ClientSlaveManager: def __init__(self, args, trainer_dist_adapter): + """ + Initialize a federated learning client manager for a slave. + + Args: + args: The command-line arguments. + trainer_dist_adapter: The trainer distributed adapter. + """ self.trainer_dist_adapter = trainer_dist_adapter self.args = args self.round_idx = 0 @@ -31,10 +38,22 @@ def __init__(self, args, trainer_dist_adapter): @property def use_customized_hierarchical(self) -> bool: + """ + Determine whether customized hierarchical cross-silo is enabled. + + Returns: + bool: True if customized hierarchical cross-silo is enabled, False otherwise. + """ return getattr(self.args, FEDML_CROSS_SILO_CUSTOMIZED_HIERARCHICAL_KEY, False) @property def has_customized_await_sync_process_group(self) -> bool: + """ + Check if the trainer has a customized "await_sync_process_group" method. + + Returns: + bool: True if the method is overridden, False otherwise. + """ return check_method_override( cls_obj=self.trainer_dist_adapter.trainer.trainer, method_name="await_sync_process_group" @@ -42,12 +61,21 @@ def has_customized_await_sync_process_group(self) -> bool: @property def has_customized_cleanup_process_group(self) -> bool: + """ + Check if the trainer has a customized "cleanup_process_group" method. + + Returns: + bool: True if the method is overridden, False otherwise. + """ return check_method_override( cls_obj=self.trainer_dist_adapter.trainer.trainer, method_name="cleanup_process_group" ) def train(self): + """ + Perform a training round for the federated learning client. + """ if self.use_customized_hierarchical: [round_idx, model_params, client_index] = self.customized_await_sync_process_group() else: @@ -67,6 +95,9 @@ def train(self): self.trainer_dist_adapter.train(self.round_idx) def finish(self): + """ + Finish the federated learning client's training process. + """ if self.use_customized_hierarchical: self.customized_cleanup_process_group() else: @@ -78,6 +109,16 @@ def finish(self): self.finished = True def await_sync_process_group(self, src: int = 0) -> list: + """ + Await synchronization of the process group. + + Args: + src (int): The source rank for synchronization. + + Returns: + list: A list containing round number, model parameters, and client index. + """ + logging.info("process %d waiting for round number" % dist.get_rank()) objects = [None, None, None] dist.broadcast_object_list( @@ -87,6 +128,15 @@ def await_sync_process_group(self, src: int = 0) -> list: return objects def customized_await_sync_process_group(self, src: int = 0) -> list: + """ + Perform a customized await synchronization of the process group. + + Args: + src (int): The source rank for synchronization. + + Returns: + list: A list containing round number, model parameters, and client index. + """ trainer = self.trainer_dist_adapter.trainer.trainer trainer_class_name = trainer.__class__.__name__ @@ -99,10 +149,16 @@ def customized_await_sync_process_group(self, src: int = 0) -> list: return trainer.await_sync_process_group(src) def customized_cleanup_process_group(self) -> None: + """ + Perform a customized cleanup of the process group. + """ trainer = self.trainer_dist_adapter.trainer.trainer if self.has_customized_cleanup_process_group: trainer.cleanup_process_group() def run(self): + """ + Run the federated learning client manager. + """ while not self.finished: self.train() diff --git a/python/fedml/cross_silo/client/fedml_trainer.py b/python/fedml/cross_silo/client/fedml_trainer.py index 46d92479f6..8244d26766 100755 --- a/python/fedml/cross_silo/client/fedml_trainer.py +++ b/python/fedml/cross_silo/client/fedml_trainer.py @@ -6,6 +6,41 @@ class FedMLTrainer(object): + """ + A class representing a Federated Machine Learning Trainer. + + This class manages the training process for federated learning on a client. + + Args: + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + train_data_num: The total number of training data samples. + device: The device for computations. + args: The command-line arguments. + model_trainer: The model trainer. + + Attributes: + trainer: The model trainer. + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + all_train_data_num: The total number of training data samples. + train_local: The local training data for the client. + local_sample_number: The number of local training data samples. + test_local: The local testing data for the client. + device: The device for computations. + args: The command-line arguments. + + Methods: + update_model: Update the federated learning model with new weights. + update_dataset: Update the local dataset for training. + train: Train the federated learning model for a specified round. + test: Test the federated learning model. + """ + def __init__( self, client_index, @@ -17,12 +52,26 @@ def __init__( args, model_trainer, ): + """ + Initialize a Federated Machine Learning Trainer. + + Args: + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + train_data_num: The total number of training data samples. + device: The device for computations. + args: The command-line arguments. + model_trainer: The model trainer. + """ self.trainer = model_trainer self.client_index = client_index if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - self.train_data_local_dict = split_data_for_dist_trainers(train_data_local_dict, args.n_proc_in_silo) + self.train_data_local_dict = split_data_for_dist_trainers( + train_data_local_dict, args.n_proc_in_silo) else: self.train_data_local_dict = train_data_local_dict @@ -38,9 +87,21 @@ def __init__( self.args.device = device def update_model(self, weights): + """ + Update the federated learning model with new weights. + + Args: + weights: The new model weights. + """ self.trainer.set_model_params(weights) def update_dataset(self, client_index): + """ + Update the local dataset for training. + + Args: + client_index: The index of the client. + """ self.client_index = client_index if self.train_data_local_dict is not None: @@ -61,22 +122,65 @@ def update_dataset(self, client_index): else: self.test_local = None - self.trainer.update_dataset(self.train_local, self.test_local, self.local_sample_number) + self.trainer.update_dataset( + self.train_local, self.test_local, self.local_sample_number) def train(self, round_idx=None): + """ + Train the federated learning model for a specified round. + + Args: + round_idx: The index of the training round (optional). + + Returns: + tuple: A tuple containing weights and the number of local training data samples. + """ self.args.round_idx = round_idx tick = time.time() - self.trainer.on_before_local_training(self.train_local, self.device, self.args) + self.trainer.on_before_local_training( + self.train_local, self.device, self.args) self.trainer.train(self.train_local, self.device, self.args) - self.trainer.on_after_local_training(self.train_local, self.device, self.args) + self.trainer.on_after_local_training( + self.train_local, self.device, self.args) - MLOpsProfilerEvent.log_to_wandb({"Train/Time": time.time() - tick, "round": round_idx}) + MLOpsProfilerEvent.log_to_wandb( + {"Train/Time": time.time() - tick, "round": round_idx}) weights = self.trainer.get_model_params() # transform Tensor to list return weights, self.local_sample_number - def test(self, round_idx=None): - self.args.round_idx = round_idx - if hasattr(self.trainer, "test"): - self.trainer.test(self.test_local, self.device, self.args) \ No newline at end of file + def test(self): + """ + Test the federated learning model. + + Returns: + tuple: A tuple containing training accuracy, training loss, the number of training samples, + testing accuracy, testing loss, and the number of testing samples. + """ + # train data + train_metrics = self.trainer.test( + self.train_local, self.device, self.args) + train_tot_correct, train_num_sample, train_loss = ( + train_metrics["test_correct"], + train_metrics["test_total"], + train_metrics["test_loss"], + ) + + # test data + test_metrics = self.trainer.test( + self.test_local, self.device, self.args) + test_tot_correct, test_num_sample, test_loss = ( + test_metrics["test_correct"], + test_metrics["test_total"], + test_metrics["test_loss"], + ) + + return ( + train_tot_correct, + train_loss, + train_num_sample, + test_tot_correct, + test_loss, + test_num_sample, + ) diff --git a/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py b/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py index 004d4cb04e..3f3a4cf90f 100644 --- a/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py +++ b/python/fedml/cross_silo/client/fedml_trainer_dist_adapter.py @@ -7,6 +7,38 @@ class TrainerDistAdapter: + """ + A class representing a Trainer Distribution Adapter for federated learning. + + This adapter facilitates training a federated learning model with distributed computing support. + + Args: + args: The command-line arguments. + device: The device for computations. + client_rank: The rank of the client. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + + Attributes: + process_group_manager: The process group manager for distributed training. + client_index: The index of the client. + client_rank: The rank of the client. + device: The device for computations. + trainer: The federated learning trainer. + args: The command-line arguments. + + Methods: + get_trainer: Get the federated learning trainer. + train: Train the federated learning model for a round. + update_model: Update the federated learning model with new parameters. + update_dataset: Update the dataset for training. + cleanup_pg: Clean up the process group for distributed training. + """ + def __init__( self, args, @@ -19,11 +51,26 @@ def __init__( test_data_local_dict, model_trainer, ): + """ + Initialize a Trainer Distribution Adapter. + + Args: + args: The command-line arguments. + device: The device for computations. + client_rank: The rank of the client. + model: The federated learning model. + train_data_num: The total number of training data samples. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + model_trainer: The model trainer (optional). + """ ml_engine_adapter.model_to_device(args, model, device) if args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: - self.process_group_manager, model = ml_engine_adapter.model_ddp(args, model, device) + self.process_group_manager, model = ml_engine_adapter.model_ddp( + args, model, device) if model_trainer is None: model_trainer = create_model_trainer(model, args) @@ -62,6 +109,22 @@ def get_trainer( args, model_trainer, ): + """ + Get the federated learning trainer. + + Args: + client_index: The index of the client. + train_data_local_dict: A dictionary mapping client IDs to their local training data. + train_data_local_num_dict: A dictionary mapping client IDs to the number of local training data samples. + test_data_local_dict: A dictionary mapping client IDs to their local testing data. + train_data_num: The total number of training data samples. + device: The device for computations. + args: The command-line arguments. + model_trainer: The model trainer. + + Returns: + FedMLTrainer: The federated learning trainer. + """ return FedMLTrainer( client_index, train_data_local_dict, @@ -74,6 +137,15 @@ def get_trainer( ) def train(self, round_idx): + """ + Train the federated learning model for a round. + + Args: + round_idx: The index of the training round. + + Returns: + tuple: A tuple containing weights and local sample number. + """ weights, local_sample_num = self.trainer.train(round_idx) return weights, local_sample_num @@ -81,13 +153,28 @@ def test(self, round_idx): self.trainer.test(round_idx) def update_model(self, model_params): + """ + Update the federated learning model with new parameters. + + Args: + model_params: The new model parameters. + """ self.trainer.update_model(model_params) def update_dataset(self, client_index=None): + """ + Update the dataset for training. + + Args: + client_index: The index of the client (optional). + """ _client_index = client_index or self.client_index self.trainer.update_dataset(int(_client_index)) def cleanup_pg(self): + """ + Clean up the process group for distributed training. + """ if self.args.scenario == FEDML_CROSS_SILO_SCENARIO_HIERARCHICAL: logging.info( "Cleaningup process group for client %s in silo %s" diff --git a/python/fedml/cross_silo/client/process_group_manager.py b/python/fedml/cross_silo/client/process_group_manager.py index 92519c6cc4..571ad3c2ab 100644 --- a/python/fedml/cross_silo/client/process_group_manager.py +++ b/python/fedml/cross_silo/client/process_group_manager.py @@ -6,6 +6,23 @@ class ProcessGroupManager: + """ + A class for managing the process group for distributed training. + + This class initializes and manages the process group for distributed training using PyTorch's distributed library. + + Args: + rank (int): The rank of the current process. + world_size (int): The total number of processes in the group. + master_address (str): The address of the master node for coordination. + master_port (int): The port number for coordination with the master node. + only_gpu (bool): Whether to use NCCL backend for GPU-based communication. + + Methods: + cleanup: Clean up the process group and release resources. + get_process_group: Get the initialized process group. + """ + def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Start process group") logging.info( @@ -17,10 +34,13 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): os.environ["WORLD_SIZE"] = str(world_size) os.environ["RANK"] = str(rank) - env_dict = {key: os.environ[key] for key in ("MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE",)} - logging.info(f"[{os.getpid()}] Initializing process group with: {env_dict}") + env_dict = {key: os.environ[key] for key in ( + "MASTER_ADDR", "MASTER_PORT", "RANK", "WORLD_SIZE",)} + logging.info( + f"[{os.getpid()}] Initializing process group with: {env_dict}") - backend = dist.Backend.NCCL if (only_gpu and torch.cuda.is_available()) else dist.Backend.GLOO + backend = dist.Backend.NCCL if ( + only_gpu and torch.cuda.is_available()) else dist.Backend.GLOO logging.info(f"Process group backend: {backend}") # initialize the process group @@ -31,7 +51,16 @@ def __init__(self, rank, world_size, master_address, master_port, only_gpu): logging.info("Initiated") def cleanup(self): + """ + Clean up the process group and release associated resources. + """ dist.destroy_process_group() def get_process_group(self): + """ + Get the initialized process group. + + Returns: + dist.ProcessGroup: The initialized process group. + """ return self.messaging_pg diff --git a/python/fedml/cross_silo/client/utils.py b/python/fedml/cross_silo/client/utils.py index 960aa5e3ac..308cc5b38e 100644 --- a/python/fedml/cross_silo/client/utils.py +++ b/python/fedml/cross_silo/client/utils.py @@ -3,25 +3,54 @@ # ref: https://discuss.pytorch.org/t/failed-to-load-model-trained-by-ddp-for-inference/84841/2?u=amir_zsh def convert_model_params_from_ddp(ddp_model_params): - model_params = OrderedDict() - for k, v in ddp_model_params.items(): - name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel - model_params[name] = v - return model_params + """ + Convert model parameters from DataParallel/DistributedDataParallel format to a regular model format. + Args: + ddp_model_params (dict): Model parameters in DataParallel/DistributedDataParallel format. -def convert_model_params_to_ddp(ddp_model_params): + Returns: + OrderedDict: Model parameters in the regular format. + """ model_params = OrderedDict() for k, v in ddp_model_params.items(): - name = f"module.{k}" # add 'module.' of DataParallel/DistributedDataParallel + name = k[7:] # Remove 'module.' of DataParallel/DistributedDataParallel model_params[name] = v return model_params +def convert_model_params_to_ddp(model_params): + """ + Convert model parameters from a regular format to DataParallel/DistributedDataParallel format. + + Args: + model_params (dict): Model parameters in the regular format. + + Returns: + OrderedDict: Model parameters in DataParallel/DistributedDataParallel format. + """ + ddp_model_params = OrderedDict() + for k, v in model_params.items(): + # Add 'module.' for DataParallel/DistributedDataParallel + name = f"module.{k}" + ddp_model_params[name] = v + return ddp_model_params + + def check_method_override(cls_obj, method_name: str) -> bool: - # check if method has been overriden by class + """ + Check if a method has been overridden by a class. + + Args: + cls_obj (object): The class object. + method_name (str): The name of the method to check for override. + + Returns: + bool: True if the method has been overridden, False otherwise. + """ + # Check if method has been overridden by class return ( - method_name in cls_obj.__class__.__dict__ and - hasattr(cls_obj, method_name) and - callable(getattr(cls_obj, method_name)) + method_name in cls_obj.__class__.__dict__ and + hasattr(cls_obj, method_name) and + callable(getattr(cls_obj, method_name)) ) From 9cd3e32377890df9239ffb70232a4ced75cdc326 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Thu, 21 Sep 2023 11:45:48 +0530 Subject: [PATCH 65/70] add --- python/fedml/core/data/noniid_partition.py | 42 +- python/fedml/core/mpc/lightsecagg.py | 173 ++++++- python/fedml/core/mpc/secagg.py | 421 ++++++++++++++++- .../fedml/core/schedule/runtime_estimate.py | 56 +++ .../core/schedule/seq_train_scheduler.py | 121 ++++- .../core/security/attack/backdoor_attack.py | 86 +++- .../core/security/attack/byzantine_attack.py | 37 ++ .../fedml/core/security/attack/dlg_attack.py | 55 ++- .../attack/edge_case_backdoor_attack.py | 20 + .../security/attack/invert_gradient_attack.py | 430 +++++++++++++++--- .../security/attack/label_flipping_attack.py | 37 +- .../fedml/core/security/attack/lazy_worker.py | 73 ++- .../model_replacement_backdoor_attack.py | 26 ++ .../revealing_labels_from_gradients_attack.py | 63 +++ .../common/attack_defense_data_loader.py | 36 +- python/fedml/core/security/common/bucket.py | 13 + python/fedml/core/security/common/net.py | 22 + python/fedml/core/security/common/utils.py | 192 +++++++- .../server_mnn/fedml_aggregator.py | 175 ++++++- .../server_mnn/fedml_server_manager.py | 242 ++++++++-- .../cross_device/server_mnn/server_mnn_api.py | 53 ++- 21 files changed, 2177 insertions(+), 196 deletions(-) diff --git a/python/fedml/core/data/noniid_partition.py b/python/fedml/core/data/noniid_partition.py index 368710ddd9..065102c063 100644 --- a/python/fedml/core/data/noniid_partition.py +++ b/python/fedml/core/data/noniid_partition.py @@ -55,7 +55,8 @@ def non_iid_partition_with_dirichlet_distribution( ) else: idx_k = np.asarray( - [np.any(label_list[i] == cat) for i in range(len(label_list))] + [np.any(label_list[i] == cat) + for i in range(len(label_list))] ) # Get the indices of images that have category = c @@ -87,6 +88,26 @@ def non_iid_partition_with_dirichlet_distribution( def partition_class_samples_with_dirichlet_distribution( N, alpha, client_num, idx_batch, idx_k ): + """ + Partition class samples using the Dirichlet distribution. + + Parameters: + N (int): Total number of samples to partition. + alpha (float): Parameter for the Dirichlet distribution. + client_num (int): Number of clients. + idx_batch (list of arrays): List of arrays containing sample indices for each client. + idx_k (array): Array of sample indices to be partitioned. + + Returns: + tuple: A tuple containing the updated idx_batch and the minimum batch size. + + This function partitions class samples using the Dirichlet distribution to create unbalanced proportions + for each client. It shuffles the sample indices, calculates the proportions, and generates batch lists + for each client. The minimum batch size is also computed. + + Example: + idx_batch, min_size = partition_class_samples_with_dirichlet_distribution(N, alpha, client_num, idx_batch, idx_k) + """ np.random.shuffle(idx_k) # using dirichlet distribution to determine the unbalanced proportion for each client (client_num in total) # e.g., when client_num = 4, proportions = [0.29543505 0.38414498 0.31998781 0.00043216], sum(proportions) = 1 @@ -94,7 +115,8 @@ def partition_class_samples_with_dirichlet_distribution( # get the index in idx_k according to the dirichlet distribution proportions = np.array( - [p * (len(idx_j) < N / client_num) for p, idx_j in zip(proportions, idx_batch)] + [p * (len(idx_j) < N / client_num) + for p, idx_j in zip(proportions, idx_batch)] ) proportions = proportions / proportions.sum() proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] @@ -110,6 +132,22 @@ def partition_class_samples_with_dirichlet_distribution( def record_data_stats(y_train, net_dataidx_map, task="classification"): + """ + Record data statistics for each client. + + Parameters: + y_train (array): Labels for the entire dataset. + net_dataidx_map (dict): Mapping of client indices to their respective data indices. + task (str): Task type, either "classification" or "segmentation". + + Returns: + dict: A dictionary containing class counts for each client. + + This function records data statistics for each client, specifically the count of each class in their data. + + Example: + net_cls_counts = record_data_stats(y_train, net_dataidx_map, task="classification") + """ net_cls_counts = {} for net_i, dataidx in net_dataidx_map.items(): diff --git a/python/fedml/core/mpc/lightsecagg.py b/python/fedml/core/mpc/lightsecagg.py index bc77bdf15d..34fa72ac19 100644 --- a/python/fedml/core/mpc/lightsecagg.py +++ b/python/fedml/core/mpc/lightsecagg.py @@ -6,6 +6,16 @@ def modular_inv(a, p): + """ + Compute the modular multiplicative inverse of 'a' modulo 'p'. + + Parameters: + a (int): The integer for which to find the modular inverse. + p (int): The prime number modulo which to compute the inverse. + + Returns: + int: The modular multiplicative inverse of 'a' modulo 'p'. + """ x, y, m = 1, 0, p while a > 1: q = a // m @@ -23,14 +33,35 @@ def modular_inv(a, p): def divmod(_num, _den, _p): - # compute num / den modulo prime p + """ + Compute the result of _num / _den modulo prime _p. + + Parameters: + _num (int): The numerator. + _den (int): The denominator. + _p (int): The prime number modulo which to compute the result. + + Returns: + int: The result of (_num / _den) modulo _p. + """ + # Compute the modulus of inputs _num = np.mod(_num, _p) _den = np.mod(_den, _p) _inv = modular_inv(_den, _p) return np.mod(np.int64(_num) * np.int64(_inv), _p) -def PI(vals, p): # upper-case PI -- product of inputs +def PI(vals, p): + """ + Compute the product of values in 'vals' modulo prime 'p'. + + Parameters: + vals (list of int): List of integers to be multiplied. + p (int): The prime number modulo which to compute the product. + + Returns: + int: The product of values in 'vals' modulo 'p'. + """ accum = 1 for v in vals: tmp = np.mod(v, p) @@ -39,6 +70,18 @@ def PI(vals, p): # upper-case PI -- product of inputs def LCC_encoding_with_points(X, alpha_s, beta_s, p): + """ + Perform Lagrange-Cauchy Coding encoding of data 'X' using specified points. + + Parameters: + X (numpy.ndarray): The input data matrix. + alpha_s (list of int): List of alpha values for encoding. + beta_s (list of int): List of beta values for encoding. + p (int): The prime number modulo which to perform encoding. + + Returns: + numpy.ndarray: The encoded data matrix. + """ m, d = np.shape(X) U = gen_Lagrange_coeffs(beta_s, alpha_s, p).astype("int64") X_LCC = np.zeros((len(beta_s), d), dtype="int64") @@ -48,6 +91,18 @@ def LCC_encoding_with_points(X, alpha_s, beta_s, p): def LCC_decoding_with_points(f_eval, eval_points, target_points, p): + """ + Perform Lagrange-Cauchy Coding decoding of data 'f_eval' using specified evaluation and target points. + + Parameters: + f_eval (numpy.ndarray): The data to decode. + eval_points (list of int): List of evaluation points. + target_points (list of int): List of target points. + p (int): The prime number modulo which to perform decoding. + + Returns: + numpy.ndarray: The decoded data. + """ alpha_s_eval = eval_points beta_s = target_points U_dec = gen_Lagrange_coeffs(beta_s, alpha_s_eval, p) @@ -57,6 +112,18 @@ def LCC_decoding_with_points(f_eval, eval_points, target_points, p): def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): + """ + Generate Lagrange coefficients for encoding and decoding. + + Parameters: + alpha_s (list of int): List of alpha values. + beta_s (list of int): List of beta values. + p (int): The prime number modulo which to compute the coefficients. + is_K1 (int, optional): A flag indicating whether it's for K=1 (1 for K=1, 0 otherwise). + + Returns: + numpy.ndarray: The Lagrange coefficients. + """ if is_K1 == 1: num_alpha = 1 else: @@ -81,12 +148,24 @@ def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): def model_masking(weights_finite, dimensions, local_mask, prime_number): + """ + Apply masking to model weights. + + Parameters: + weights_finite (dict): A dictionary of model weights. + dimensions (list of int): List of dimensions corresponding to weights. + local_mask (numpy.ndarray): The masking values. + prime_number (int): The prime number modulo which to perform masking. + + Returns: + dict: The masked model weights. + """ pos = 0 for i, k in enumerate(weights_finite): tmp = weights_finite[k] cur_shape = tmp.shape d = dimensions[i] - cur_mask = local_mask[pos : pos + d, :] + cur_mask = local_mask[pos: pos + d, :] cur_mask = np.reshape(cur_mask, cur_shape) weights_finite[k] += cur_mask weights_finite[k] = np.mod(weights_finite[k], prime_number) @@ -102,6 +181,20 @@ def mask_encoding( prime_number, local_mask, ): + """ + Encode a masking scheme for privacy-preserving federated learning. + + Parameters: + total_dimension (int): Total dimension. + num_clients (int): Number of clients. + targeted_number_active_clients (int): Targeted number of active clients. + privacy_guarantee (int): Privacy guarantee parameter. + prime_number (int): The prime number modulo which to perform encoding. + local_mask (numpy.ndarray): The local mask. + + Returns: + numpy.ndarray: The encoded mask set. + """ d = total_dimension N = num_clients U = targeted_number_active_clients @@ -124,6 +217,17 @@ def mask_encoding( def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): + """ + Compute the aggregate encoded mask from a dictionary of encoded masks for active clients. + + Parameters: + encoded_mask_dict (dict): A dictionary containing encoded masks for clients. + p (int): The prime number modulo which to compute the aggregate mask. + active_clients (list): List of active client IDs. + + Returns: + list: The aggregate encoded mask as a list. + """ aggregate_encoded_mask = np.zeros((np.shape(encoded_mask_dict[0]))) for client_id in active_clients: aggregate_encoded_mask += encoded_mask_dict[client_id] @@ -133,8 +237,14 @@ def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): def aggregate_models_in_finite(weights_finite, prime_number): """ - weights_finite : array of state_dict() - prime_number : size of the finite field + Aggregate model weights in a finite field. + + Parameters: + weights_finite (list): List of model weights (state_dict) from different clients. + prime_number (int): The size of the finite field. + + Returns: + dict: The aggregated model weights in the finite field. """ w_sum = copy.deepcopy(weights_finite[0]) @@ -148,6 +258,17 @@ def aggregate_models_in_finite(weights_finite, prime_number): def my_q(X, q_bit, p): + """ + Quantize input values using fixed-point representation. + + Parameters: + X (numpy.ndarray): Input values to be quantized. + q_bit (int): Number of quantization bits. + p (int): The prime number modulo which to quantize. + + Returns: + numpy.ndarray: Quantized values. + """ X_int = np.round(X * (2**q_bit)) is_negative = (abs(np.sign(X_int)) - np.sign(X_int)) / 2 out = X_int + p * is_negative @@ -155,6 +276,17 @@ def my_q(X, q_bit, p): def my_q_inv(X_q, q_bit, p): + """ + Inverse quantize values back to their original range. + + Parameters: + X_q (numpy.ndarray): Quantized values to be de-quantized. + q_bit (int): Number of quantization bits. + p (int): The prime number modulo which to perform inverse quantization. + + Returns: + numpy.ndarray: De-quantized values. + """ flag = X_q - (p - 1) / 2 is_negative = (abs(np.sign(flag)) + np.sign(flag)) / 2 X_q = X_q - p * is_negative @@ -162,6 +294,17 @@ def my_q_inv(X_q, q_bit, p): def transform_finite_to_tensor(model_params, p, q_bits): + """ + Transform model parameters from finite field representation to tensor representation. + + Parameters: + model_params (dict): Model parameters represented in a finite field. + p (int): The prime number used for finite field representation. + q_bits (int): Number of quantization bits. + + Returns: + dict: Transformed model parameters in tensor representation. + """ for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_real = my_q_inv(tmp, q_bits, p) @@ -185,6 +328,17 @@ def transform_finite_to_tensor(model_params, p, q_bits): def transform_tensor_to_finite(model_params, p, q_bits): + """ + Transform model parameters from tensor representation to finite field representation. + + Parameters: + model_params (dict): Model parameters represented as tensors. + p (int): The prime number used for finite field representation. + q_bits (int): Number of quantization bits. + + Returns: + dict: Transformed model parameters in finite field representation. + """ for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_finite = my_q(tmp, q_bits, p) @@ -193,6 +347,15 @@ def transform_tensor_to_finite(model_params, p, q_bits): def model_dimension(weights): + """ + Compute the dimensions and total dimension of model weights. + + Parameters: + weights (dict): Model weights (state_dict). + + Returns: + tuple: A tuple containing dimensions (list) and total dimension (int). + """ logging.info("Get model dimension") dimensions = [] for k in weights.keys(): diff --git a/python/fedml/core/mpc/secagg.py b/python/fedml/core/mpc/secagg.py index 45874faba8..1660cbb27e 100644 --- a/python/fedml/core/mpc/secagg.py +++ b/python/fedml/core/mpc/secagg.py @@ -6,6 +6,16 @@ def modular_inv(a, p): + """ + Compute the modular inverse of 'a' modulo 'p' using the extended Euclidean algorithm. + + Parameters: + a (int): The number for which to find the modular inverse. + p (int): The modulus. + + Returns: + int: The modular inverse of 'a' modulo 'p'. + """ x, y, m = 1, 0, p while a > 1: q = a // m @@ -23,6 +33,18 @@ def modular_inv(a, p): def divmod(_num, _den, _p): + """ + Compute 'num' divided by 'den' modulo prime 'p'. + + Parameters: + _num (int): The numerator. + _den (int): The denominator. + _p (int): The prime modulus. + + Returns: + int: The result of 'num' / 'den' modulo 'p'. + """ + # compute num / den modulo prime p _num = np.mod(_num, _p) _den = np.mod(_den, _p) @@ -31,6 +53,16 @@ def divmod(_num, _den, _p): def PI(vals, p): # upper-case PI -- product of inputs + """ + Compute the product of a list of values modulo 'p'. + + Parameters: + vals (list): List of values. + p (int): The modulus. + + Returns: + int: The product of the values modulo 'p'. + """ accum = np.int64(1) for v in vals: tmp = np.mod(v, p) @@ -39,6 +71,18 @@ def PI(vals, p): # upper-case PI -- product of inputs def LCC_encoding_with_points(X, alpha_s, beta_s, p): + """ + Linear Code with Complementary coefficients (LCC) encoding of a matrix 'X' with given alpha and beta points. + + Parameters: + X (numpy.ndarray): Input matrix to be encoded. + alpha_s (list): List of alpha points. + beta_s (list): List of beta points. + p (int): The modulus. + + Returns: + numpy.ndarray: Encoded matrix using LCC encoding. + """ m, d = np.shape(X) U = gen_Lagrange_coeffs(beta_s, alpha_s, p).astype("int64") X_LCC = np.zeros((len(beta_s), d), dtype="int64") @@ -48,6 +92,18 @@ def LCC_encoding_with_points(X, alpha_s, beta_s, p): def LCC_decoding_with_points(f_eval, eval_points, target_points, p): + """ + Linear Code with Complementary coefficients (LCC) decoding with given evaluation and target points. + + Parameters: + f_eval (numpy.ndarray): Evaluation points. + eval_points (list): List of evaluation points. + target_points (list): List of target points. + p (int): The modulus. + + Returns: + int: Decoded result using LCC decoding. + """ alpha_s_eval = eval_points beta_s = target_points U_dec = gen_Lagrange_coeffs(beta_s, alpha_s_eval, p) @@ -57,6 +113,18 @@ def LCC_decoding_with_points(f_eval, eval_points, target_points, p): def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): + """ + Generate Lagrange coefficients for given alpha and beta points. + + Parameters: + alpha_s (list): List of alpha points. + beta_s (list): List of beta points. + p (int): The modulus. + is_K1 (int): Indicator for K1 coefficient generation. + + Returns: + numpy.ndarray: Lagrange coefficients matrix. + """ if is_K1 == 1: num_alpha = 1 else: @@ -81,6 +149,23 @@ def gen_Lagrange_coeffs(alpha_s, beta_s, p, is_K1=0): def model_masking(weights_finite, dimensions, local_mask, prime_number): + """ + Apply masking to model weights. + + Parameters: + weights_finite (dict): Dictionary of model weights. + dimensions (list): List of dimensions for each weight. + local_mask (numpy.ndarray): Local mask to be applied. + prime_number (int): The prime number for modulo operation. + + Returns: + dict: Updated model weights after masking. + + This function applies a local mask to model weights by element-wise addition and modulo operation. + + Example: + updated_weights = model_masking(weights_finite, dimensions, local_mask, prime_number) + """ pos = 0 reshaped_local_mask = local_mask.reshape((local_mask.shape[0], 1)) for i, k in enumerate(weights_finite): @@ -95,7 +180,7 @@ def model_masking(weights_finite, dimensions, local_mask, prime_number): tmp = weights_finite[k] cur_shape = tmp.shape d = dimensions[i] - cur_mask = reshaped_local_mask[pos : pos + d, :] + cur_mask = reshaped_local_mask[pos: pos + d, :] cur_mask = np.reshape(cur_mask, cur_shape) weights_finite[k] += cur_mask weights_finite[k] = np.mod(weights_finite[k], prime_number) @@ -118,6 +203,26 @@ def model_masking(weights_finite, dimensions, local_mask, prime_number): def mask_encoding( total_dimension, num_clients, targeted_number_active_clients, privacy_guarantee, prime_number, local_mask ): + """ + Encode a local mask for privacy. + + Parameters: + total_dimension (int): Total dimension. + num_clients (int): Total number of clients. + targeted_number_active_clients (int): Targeted number of active clients. + privacy_guarantee (int): Privacy guarantee parameter. + prime_number (int): The prime number for modulo operation. + local_mask (numpy.ndarray): Local mask. + + Returns: + numpy.ndarray: Encoded mask. + + This function encodes a local mask for privacy using parameters like total dimension, number of clients, etc. + + Example: + encoded_mask = mask_encoding(total_dimension, num_clients, targeted_number_active_clients, privacy_guarantee, prime_number, local_mask) + """ + d = total_dimension N = num_clients U = targeted_number_active_clients @@ -132,12 +237,30 @@ def mask_encoding( LCC_in = np.concatenate([local_mask, n_i], axis=0) LCC_in = np.reshape(LCC_in, (U, d // (U - T))) - encoded_mask_set = LCC_encoding_with_points(LCC_in, alpha_s, beta_s, p).astype("int64") + encoded_mask_set = LCC_encoding_with_points( + LCC_in, alpha_s, beta_s, p).astype("int64") return encoded_mask_set def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): + """ + Compute the aggregate encoded mask. + + Parameters: + encoded_mask_dict (dict): Dictionary of encoded masks for each client. + p (int): The prime number for modulo operation. + active_clients (list): List of active client IDs. + + Returns: + numpy.ndarray: Aggregate encoded mask. + + This function computes the aggregate encoded mask from individual client masks. + + Example: + aggregate_mask = compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients) + """ + aggregate_encoded_mask = np.zeros((np.shape(encoded_mask_dict[0]))) for client_id in active_clients: aggregate_encoded_mask += encoded_mask_dict[client_id] @@ -147,8 +270,19 @@ def compute_aggregate_encoded_mask(encoded_mask_dict, p, active_clients): def aggregate_models_in_finite(weights_finite, prime_number): """ - weights_finite : array of state_dict() - prime_number : size of the finite field + Aggregate model weights in a finite field. + + Parameters: + weights_finite (list of dict): List of model weight dictionaries. + prime_number (int): The prime number for modulo operation. + + Returns: + dict: Aggregated model weights. + + This function aggregates model weights in a finite field using modulo operation. + + Example: + aggregated_weights = aggregate_models_in_finite(weights_finite, prime_number) """ w_sum = copy.deepcopy(weights_finite[0]) @@ -162,6 +296,23 @@ def aggregate_models_in_finite(weights_finite, prime_number): def BGW_encoding(X, N, T, p): + """ + Encode data using BGW encoding. + + Parameters: + X (numpy.ndarray): Data to be encoded. + N (int): Number of evaluation points. + T (int): Degree of polynomial. + p (int): Prime number. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using BGW encoding scheme. + + Example: + encoded_data = BGW_encoding(X, N, T, p) + """ m = len(X) d = len(X[0]) @@ -173,11 +324,27 @@ def BGW_encoding(X, N, T, p): for i in range(N): for t in range(T + 1): - X_BGW[i, :, :] = np.mod(X_BGW[i, :, :] + R[t, :, :] * (alpha_s[i] ** t), p) + X_BGW[i, :, :] = np.mod( + X_BGW[i, :, :] + R[t, :, :] * (alpha_s[i] ** t), p) return X_BGW def gen_BGW_lambda_s(alpha_s, p): + """ + Generate lambda values for BGW encoding. + + Parameters: + alpha_s (numpy.ndarray): Array of alpha values. + p (int): Prime number. + + Returns: + numpy.ndarray: Generated lambda values. + + This function generates lambda values for BGW encoding. + + Example: + lambda_values = gen_BGW_lambda_s(alpha_s, p) + """ lambda_s = np.zeros((1, len(alpha_s)), dtype="int64") for i in range(len(alpha_s)): @@ -190,6 +357,23 @@ def gen_BGW_lambda_s(alpha_s, p): def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluation points + """ + Decode data using BGW decoding. + + Parameters: + f_eval (numpy.ndarray): Evaluated data. + worker_idx (list): List of worker indices. + p (int): Prime number. + + Returns: + numpy.ndarray: Decoded data. + + This function decodes data using BGW decoding scheme. + + Example: + decoded_data = BGW_decoding(f_eval, worker_idx, p) + """ + # f_eval : [RT X d ] # worker_idx : [ 1 X RT] # output : [ 1 X d ] @@ -211,12 +395,30 @@ def BGW_decoding(f_eval, worker_idx, p): # decode the output from T+1 evaluatio def LCC_encoding(X, N, K, T, p): + """ + Encode data using LCC encoding. + + Parameters: + X (numpy.ndarray): Data to be encoded. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + p (int): Prime number. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using LCC encoding scheme. + + Example: + encoded_data = LCC_encoding(X, N, K, T, p) + """ m = len(X) d = len(X[0]) # print(m,d,m//K) X_sub = np.zeros((K + T, m // K, d), dtype="int64") for i in range(K): - X_sub[i] = X[i * m // K : (i + 1) * m // K :] + X_sub[i] = X[i * m // K: (i + 1) * m // K:] for i in range(K, K + T): X_sub[i] = np.random.randint(p, size=(m // K, d)) @@ -232,17 +434,37 @@ def LCC_encoding(X, N, K, T, p): X_LCC = np.zeros((N, m // K, d), dtype="int64") for i in range(N): for j in range(K + T): - X_LCC[i, :, :] = np.mod(X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) + X_LCC[i, :, :] = np.mod( + X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) return X_LCC def LCC_encoding_w_Random(X, R_, N, K, T, p): + """ + Encode data using LCC encoding with random values. + + Parameters: + X (numpy.ndarray): Data to be encoded. + R_ (numpy.ndarray): Random values for encoding. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + p (int): Prime number. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using LCC encoding scheme with random values. + + Example: + encoded_data = LCC_encoding_w_Random(X, R_, N, K, T, p) + """ m = len(X) d = len(X[0]) # print(m,d,m//K) X_sub = np.zeros((K + T, m // K, d), dtype="int64") for i in range(K): - X_sub[i] = X[i * m // K : (i + 1) * m // K :] + X_sub[i] = X[i * m // K: (i + 1) * m // K:] for i in range(K, K + T): X_sub[i] = R_[i - K, :, :].astype("int64") @@ -262,17 +484,39 @@ def LCC_encoding_w_Random(X, R_, N, K, T, p): X_LCC = np.zeros((N, m // K, d), dtype="int64") for i in range(N): for j in range(K + T): - X_LCC[i, :, :] = np.mod(X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) + X_LCC[i, :, :] = np.mod( + X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) return X_LCC def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): + """ + Encode data using LCC encoding with random values for a subset of workers. + + Parameters: + X (numpy.ndarray): Data to be encoded. + R_ (numpy.ndarray): Random values for encoding. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + p (int): Prime number. + worker_idx (list): List of worker indices. + + Returns: + numpy.ndarray: Encoded data. + + This function encodes data using LCC encoding scheme with random values for a subset of workers. + + Example: + encoded_data = LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx) + """ + m = len(X) d = len(X[0]) # print(m,d,m//K) X_sub = np.zeros((K + T, m // K, d), dtype="int64") for i in range(K): - X_sub[i] = X[i * m // K : (i + 1) * m // K :] + X_sub[i] = X[i * m // K: (i + 1) * m // K:] for i in range(K, K + T): X_sub[i] = R_[i - K, :, :].astype("int64") @@ -290,11 +534,33 @@ def LCC_encoding_w_Random_partial(X, R_, N, K, T, p, worker_idx): X_LCC = np.zeros((N_out, m // K, d), dtype="int64") for i in range(N_out): for j in range(K + T): - X_LCC[i, :, :] = np.mod(X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) + X_LCC[i, :, :] = np.mod( + X_LCC[i, :, :] + np.mod(U[i][j] * X_sub[j, :, :], p), p) return X_LCC def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): + """ + Decode the encoded data using LCC decoding. + + Parameters: + f_eval (numpy.ndarray): Encoded data to be decoded. + f_deg (int): Degree of the encoded data. + N (int): Number of evaluation points. + K (int): Number of known points. + T (int): Number of random points. + worker_idx (list): List of worker indices. + p (int): Prime number. + + Returns: + numpy.ndarray: Decoded data. + + This function decodes the encoded data using LCC decoding scheme. + + Example: + decoded_data = LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p) + """ + RT_LCC = f_deg * (K + T - 1) + 1 n_beta = K # +T @@ -314,6 +580,23 @@ def LCC_decoding(f_eval, f_deg, N, K, T, worker_idx, p): def Gen_Additive_SS(d, n_out, p): + """ + Generate additive secret sharing. + + Parameters: + d (int): Dimension of the secret. + n_out (int): Number of output shares. + p (int): Prime number. + + Returns: + numpy.ndarray: Additive secret sharing matrix. + + This function generates additive secret sharing matrix. + + Example: + secret_sharing_matrix = Gen_Additive_SS(d, n_out, p) + """ + # x_model should be one dimension temp = np.random.randint(0, p, size=(n_out - 1, d)) @@ -327,6 +610,22 @@ def Gen_Additive_SS(d, n_out, p): def my_pk_gen(my_sk, p, g): + """ + Generate public key. + + Parameters: + my_sk (int): Private key. + p (int): Prime number. + g (int): Generator. + + Returns: + int: Public key. + + This function generates a public key from a private key. + + Example: + public_key = my_pk_gen(my_sk, p, g) + """ # print 'my_pk_gen option: g=',g if g == 0: return my_sk @@ -335,6 +634,23 @@ def my_pk_gen(my_sk, p, g): def my_key_agreement(my_sk, u_pk, p, g): + """ + Perform key agreement. + + Parameters: + my_sk (int): Private key. + u_pk (int): Other party's public key. + p (int): Prime number. + g (int): Generator. + + Returns: + int: Shared secret key. + + This function performs key agreement between two parties. + + Example: + shared_secret_key = my_key_agreement(my_sk, u_pk, p, g) + """ if g == 0: return np.mod(my_sk * u_pk, p) else: @@ -342,6 +658,22 @@ def my_key_agreement(my_sk, u_pk, p, g): def my_q(X, q_bit, p): + """ + Quantize data to a finite field. + + Parameters: + X (numpy.ndarray): Data to be quantized. + q_bit (int): Number of bits for quantization. + p (int): Prime number. + + Returns: + numpy.ndarray: Quantized data. + + This function quantizes data to a specific number of bits within a finite field. + + Example: + quantized_data = my_q(X, q_bit, p) + """ X_int = np.round(X * (2 ** q_bit)) is_negative = (abs(np.sign(X_int)) - np.sign(X_int)) / 2 out = X_int + p * is_negative @@ -349,6 +681,23 @@ def my_q(X, q_bit, p): def transform_tensor_to_finite(model_params, p, q_bits): + """ + Transform model tensor parameters to finite field. + + Parameters: + model_params (dict): Dictionary of model parameters. + p (int): Prime number for the finite field. + q_bits (int): Number of bits for quantization. + + Returns: + dict: Transformed model parameters in the finite field. + + This function takes a dictionary of model parameters (typically tensors) and transforms them to the specified finite field. + + Example: + finite_model_params = transform_tensor_to_finite(model_params, p, q_bits) + """ + for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_finite = my_q(tmp, q_bits, p) @@ -357,6 +706,22 @@ def transform_tensor_to_finite(model_params, p, q_bits): def my_q_inv(X_q, q_bit, p): + """ + Inverse quantize data from a finite field. + + Parameters: + X_q (numpy.ndarray): Data in the finite field to be inverse quantized. + q_bit (int): Number of bits for quantization. + p (int): Prime number. + + Returns: + numpy.ndarray: Inverse quantized data in the real field. + + This function performs inverse quantization of data from a finite field to the real field. + + Example: + real_data = my_q_inv(X_q, q_bit, p) + """ flag = X_q - (p - 1) / 2 is_negative = (abs(np.sign(flag)) + np.sign(flag)) / 2 X_q = X_q - p * is_negative @@ -364,6 +729,22 @@ def my_q_inv(X_q, q_bit, p): def transform_finite_to_tensor(model_params, p, q_bits): + """ + Transform model parameters from a finite field to tensor. + + Parameters: + model_params (dict): Dictionary of model parameters in the finite field. + p (int): Prime number for the finite field. + q_bits (int): Number of bits for quantization. + + Returns: + dict: Transformed model parameters as tensors in the real field. + + This function takes a dictionary of model parameters in the finite field and transforms them to tensors in the real field. + + Example: + tensor_model_params = transform_finite_to_tensor(model_params, p, q_bits) + """ for k in model_params.keys(): tmp = np.array(model_params[k]) tmp_real = my_q_inv(tmp, q_bits, p) @@ -377,12 +758,28 @@ def transform_finite_to_tensor(model_params, p, q_bits): 0 - Wed, 13 Oct 2021 07:50:59 utils.py[line:33] DEBUG tmp_real = 256812209.4375 """ # logging.debug("tmp_real = {}".format(tmp_real)) - tmp_real = torch.Tensor([tmp_real]) if isinstance(tmp_real, np.floating) else torch.Tensor(tmp_real) + tmp_real = torch.Tensor([tmp_real]) if isinstance( + tmp_real, np.floating) else torch.Tensor(tmp_real) model_params[k] = tmp_real return model_params def model_dimension(weights): + """ + Get the dimension of a model. + + Parameters: + weights (dict): Dictionary of model weights. + + Returns: + list: List of dimensions of model parameters. + int: Total dimension of the model. + + This function calculates the dimensions of model parameters and the total dimension of the model. + + Example: + dimensions, total_dimension = model_dimension(weights) + """ logging.info("Get model dimension") dimensions = [] for k in weights.keys(): diff --git a/python/fedml/core/schedule/runtime_estimate.py b/python/fedml/core/schedule/runtime_estimate.py index f4984407d1..4500478e0d 100644 --- a/python/fedml/core/schedule/runtime_estimate.py +++ b/python/fedml/core/schedule/runtime_estimate.py @@ -2,6 +2,24 @@ def linear_fit(x, y): + """ + Fit a linear model to the given data. + + Parameters: + x (array-like): The independent variable data. + y (array-like): The dependent variable data. + + Returns: + z1 (array-like): Coefficients of the linear fit. + p1 (numpy.poly1d): The polynomial representing the linear fit. + yvals (array-like): Predicted values based on the linear fit. + fit_error (float): Mean absolute percentage error of the fit. + + Example: + x = [1, 2, 3, 4, 5] + y = [2, 4, 5, 4, 5] + z1, p1, yvals, fit_error = linear_fit(x, y) + """ z1 = np.polyfit(x, y, 1) p1 = np.poly1d(z1) print(p1) @@ -21,6 +39,44 @@ def t_sample_fit( 0: {0: [], 1: [], 2: []...}, 1: {0: [], 1: [], 2: []...}, } + + Fit linear models to runtime data for each worker and client combination. + + Parameters: + num_workers (int): The number of workers. + num_clients (int): The number of clients. + runtime_history (dict): A dictionary containing runtime history data. + Format: { + worker_id: { + client_id: [list of runtimes] + } + } + train_data_local_num_dict (dict): A dictionary containing the number of local training data samples for each client. + Format: { + client_id: num_samples + } + uniform_client (bool): Whether all clients have the same number of GPUs. + uniform_gpu (bool): Whether all clients have the same number of GPUs. + + Returns: + fit_params (dict): Fitted parameters (slope and intercept) of the linear models for each worker and client. + Format: { + worker_id: { + client_id: (slope, intercept) + } + } + fit_funcs (dict): Fitted linear functions for each worker and client. + Format: { + worker_id: { + client_id: p1 (linear function) + } + } + fit_errors (dict): Fit errors (mean absolute percentage error) for each worker and client. + Format: { + worker_id: { + client_id: fit_error + } + } """ fit_params = {} fit_funcs = {} diff --git a/python/fedml/core/schedule/seq_train_scheduler.py b/python/fedml/core/schedule/seq_train_scheduler.py index cd155df271..2e2b7082a7 100644 --- a/python/fedml/core/schedule/seq_train_scheduler.py +++ b/python/fedml/core/schedule/seq_train_scheduler.py @@ -7,6 +7,41 @@ class SeqTrainScheduler: + """ + Initialize the Sequential Training Scheduler. + + Parameters: + workloads (list): List of client workloads. + constraints (list): List of constraints corresponding to each resource. + memory (list): List of memory constraints for each resource. + cost_funcs (list of lists or list of functions): Cost functions for assigning workloads. + uniform_client (bool): Whether the client workloads are uniform. + uniform_gpu (bool): Whether the GPU resources are uniform. + prune_equal_sub_solution (bool): Whether to prune equal sub-solutions. + + Attributes: + workloads (list): List of client workloads. + constraints (list): List of constraints corresponding to each resource. + memory (list): List of memory constraints for each resource. + cost_funcs (list of lists or list of functions): Cost functions for assigning workloads. + uniform_client (bool): Whether the client workloads are uniform. + uniform_gpu (bool): Whether the GPU resources are uniform. + len_x (int): Number of workloads (clients). + len_y (int): Number of constraints (resources). + iter_times (int): Iteration counter. + + Example: + scheduler = SeqTrainScheduler( + workloads=[100, 200, 150], + constraints=[10, 20], + memory=[300, 400], + cost_funcs=[[cost_func1, cost_func2], [cost_func3, cost_func4]], + uniform_client=True, + uniform_gpu=False, + prune_equal_sub_solution=True, + ) + """ + def __init__( self, workloads, @@ -33,6 +68,23 @@ def __init__( self.iter_times = 0 def obtain_client_cost(self, resource_id, client_id): + """ + Calculate the cost of assigning a workload to a resource. + + Parameters: + resource_id (int): Index of the resource. + client_id (int): Index of the client. + + Returns: + float: The calculated cost. + + This method calculates the cost of assigning a workload to a resource based on the specified cost functions + and resource and client characteristics. It handles different scenarios based on the values of + `uniform_client` and `uniform_gpu`. + + Example: + cost = scheduler.obtain_client_cost(0, 1) + """ if self.uniform_client and self.uniform_gpu: # cost = self.cost_funcs[0][0](self.client_data_nums[client_id]) cost = self.cost_funcs[0][0](self.workloads[client_id]) @@ -44,12 +96,29 @@ def obtain_client_cost(self, resource_id, client_id): cost = self.cost_funcs[resource_id][0](self.workloads[client_id]) else: # cost = self.cost_funcs[resource_id][client_id](self.client_data_nums[client_id]) - cost = self.cost_funcs[resource_id][client_id](self.workloads[client_id]) + cost = self.cost_funcs[resource_id][client_id]( + self.workloads[client_id]) if cost < 0.0: cost = 0.0 return cost def assign_a_workload_serial(self, x_maps, cost_maps): + """ + Assign workloads to resources sequentially. + + Parameters: + x_maps (list): List of workload assignment maps. + cost_maps (list): List of cost maps corresponding to workload assignments. + + Returns: + tuple: A tuple containing updated x_maps and cost_maps. + + This method assigns workloads to resources sequentially while minimizing the cost. It explores various workload + assignments and prunes suboptimal solutions based on the `prune_equal_sub_solution` attribute. + + Example: + x_maps, cost_maps = scheduler.assign_a_workload_serial(x_maps, cost_maps) + """ # Find the case with the minimum cost. self.iter_times += 1 costs = [] @@ -108,6 +177,24 @@ def assign_a_workload_serial(self, x_maps, cost_maps): return self.assign_a_workload_serial(x_maps, cost_maps) def assign_a_workload(self, x_maps, cost_maps, resource_maps): + """ + Assign workloads to resources considering both parallel and serial execution. + + Parameters: + x_maps (list): List of workload assignment maps. + cost_maps (list): List of cost maps corresponding to workload assignments. + resource_maps (list): List of resource maps. + + Returns: + tuple: A tuple containing updated x_maps, cost_maps, and resource_maps. + + This method assigns workloads to resources while considering both parallel and serial execution possibilities. + It explores various workload assignments and prunes suboptimal solutions based on the `prune_equal_sub_solution` + attribute. + + Example: + x_maps, cost_maps, resource_maps = scheduler.assign_a_workload(x_maps, cost_maps, resource_maps) + """ # Find the case with the minimum cost. costs = [] for i in range(len(cost_maps)): @@ -139,7 +226,8 @@ def assign_a_workload(self, x_maps, cost_maps, resource_maps): new_maps.append(np.copy(x_map)) new_maps[-1][target_index] = i new_costs.append(np.copy(cost_map)) - new_costs[-1][i] = max((self.y[i] * self.x[target_index]), new_costs[-1][i]) + new_costs[-1][i] = max((self.y[i] * + self.x[target_index]), new_costs[-1][i]) new_resources.append(np.copy(resource_map)) new_resources[-1][i] += self.x[target_index] @@ -163,6 +251,22 @@ def assign_a_workload(self, x_maps, cost_maps, resource_maps): return self.assign_a_workload(x_maps, cost_maps, resource_maps) def DP_schedule(self, mode): + """ + Perform Dynamic Programming (DP) based scheduling. + + Parameters: + mode (int): Scheduling mode, 0 for serial, 1 for parallel. + + Returns: + tuple: A tuple containing the schedules and output_schedules. + + This method performs dynamic programming-based scheduling to assign workloads to resources while minimizing + the cost. It explores various workload assignments and prunes suboptimal solutions based on the scheduling mode. + The schedules are returned in the format of a list of dictionaries. + + Example: + schedules, output_schedules = scheduler.DP_schedule(1) + """ x_maps = [] x_maps.append(np.negative(np.ones((self.len_x)))) cost_maps = [] @@ -172,9 +276,11 @@ def DP_schedule(self, mode): if mode == 1: resource_maps = [] resource_maps.append(np.zeros((self.len_y))) - x_maps, cost_maps, resource_maps = self.assign_a_workload(x_maps, cost_maps, resource_maps) + x_maps, cost_maps, resource_maps = self.assign_a_workload( + x_maps, cost_maps, resource_maps) else: - x_maps, cost_maps = self.assign_a_workload_serial(x_maps, cost_maps) + x_maps, cost_maps = self.assign_a_workload_serial( + x_maps, cost_maps) # print(f"x_maps: {x_maps} len(x_maps): {len(x_maps)}") # print(f"cost_maps: {cost_maps} len(cost_maps): {len(cost_maps)}") @@ -195,9 +301,11 @@ def DP_schedule(self, mode): # logging.info(f"schedules: {schedules} len(schedules): {len(schedules)}") logging.info(f"self.iter_times: {self.iter_times}") logging.info( - "The optimal maximum cost: %f, assignment: %s\n" % (costs[target_index], str(x_maps[target_index])) + "The optimal maximum cost: %f, assignment: %s\n" % ( + costs[target_index], str(x_maps[target_index])) ) - logging.info(f"target_index: {target_index} cost_map: {cost_maps[target_index]}") + logging.info( + f"target_index: {target_index} cost_map: {cost_maps[target_index]}") # print(f"schedules: {schedules} len(schedules): {len(schedules)}") # print(f"self.iter_times: {self.iter_times}") @@ -239,4 +347,3 @@ def DP_schedule(self, mode): schedule[num_bunches] = jobs output_schedules.append(schedule) return schedules, output_schedules - diff --git a/python/fedml/core/security/attack/backdoor_attack.py b/python/fedml/core/security/attack/backdoor_attack.py index f0f882bdb2..ae7734097b 100644 --- a/python/fedml/core/security/attack/backdoor_attack.py +++ b/python/fedml/core/security/attack/backdoor_attack.py @@ -32,6 +32,16 @@ class BackdoorAttack(BaseAttackMethod): def __init__( self, backdoor_client_num, client_num, num_std=None, dataset=None, backdoor_type="pattern", ): + """ + Initialize the BackdoorAttack. + + Args: + backdoor_client_num (int): Number of malicious clients for the backdoor attack. + client_num (int): Total number of clients. + num_std (float): Number of standard deviations for clipping gradients (default=None). + dataset (Tuple[Tensor, Tensor] or None): Dataset for generating backdoor (default=None). + backdoor_type (str): Type of backdoor ("pattern" or "random"). + """ self.backdoor_client_num = backdoor_client_num self.client_num = client_num self.num_std = num_std @@ -52,9 +62,20 @@ def __init__( pass def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], - extra_auxiliary_info: Any = None): + extra_auxiliary_info: Any = None): + """ + Attack the model using a backdoor attack strategy. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Extra auxiliary information. + + Returns: + np.ndarray: New gradients for malicious clients. + """ # the local_w comes from local training (regular) - backdoor_idxs = self._get_malicious_client_idx(len(raw_client_grad_list)) + backdoor_idxs = self._get_malicious_client_idx( + len(raw_client_grad_list)) (num0, averaged_params) = raw_client_grad_list[0] # fake grad @@ -64,54 +85,105 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], for i in backdoor_idxs: (_, param) = raw_client_grad_list[i] # grad = np.concatenate([param.grad.data.cpu().numpy().flatten() for param in model.parameters()]) // for real net - grad = np.concatenate([param[p_name].numpy().flatten() * 0.5 for p_name in param]) + grad = np.concatenate( + [param[p_name].numpy().flatten() * 0.5 for p_name in param]) grads.append(grad) grads_mean = np.mean(grads, axis=0) grads_stdev = np.var(grads, axis=0) ** 0.5 learning_rate = 0.1 - original_params_flat = np.concatenate([averaged_params[p_name].numpy().flatten() for p_name in averaged_params]) + original_params_flat = np.concatenate( + [averaged_params[p_name].numpy().flatten() for p_name in averaged_params]) initial_params_flat = ( original_params_flat - learning_rate * grads_mean ) # the corrected param after the user optimized, because we still want the model to improve - mal_net_params = self.train_malicious_network(initial_params_flat, original_params_flat) + mal_net_params = self.train_malicious_network( + initial_params_flat, original_params_flat) # Getting from the final required mal_net_params to the gradients that needs to be applied on the parameters of the previous round. new_params = mal_net_params + learning_rate * grads_mean new_grads = (initial_params_flat - new_params) / learning_rate # authors in the paper claims to limit the range of parameters but the code limits the gradient. new_user_grads = np.clip( - new_grads, grads_mean - self.num_std * grads_stdev, grads_mean + self.num_std * grads_stdev, + new_grads, grads_mean - self.num_std * + grads_stdev, grads_mean + self.num_std * grads_stdev, ) # the returned gradient controls the local update for malicious clients return new_user_grads @staticmethod def add_pattern(img): + """ + Add a pattern to an image (currently disabled). + + Args: + img (Tensor): Input image. + + Returns: + Tensor: Image with added pattern (disabled). + """ # disable img[:, :5, :5] = 2.8 return img def train_malicious_network(self, initial_params_flat, param): + """ + Train a malicious network (currently skipped). + + Args: + initial_params_flat (np.ndarray): Initial flattened model parameters. + param (np.ndarray): Original model parameters. + + Returns: + np.ndarray: Flattened malicious model parameters. + """ # skip training process # return flatten_params(param) return param def _get_malicious_client_idx(self, client_num): + """ + Get indices of malicious clients. + + Args: + client_num (int): Total number of clients. + + Returns: + List[int]: List of indices of malicious clients. + """ return random.sample(range(client_num), self.backdoor_client_num) def flatten_params(params): + """ + Flatten model parameters. + + Args: + params (Iterable[Tensor]): Model parameters. + + Returns: + np.ndarray: Flattened parameters as a NumPy array. + """ # for real net return np.concatenate([i.data.cpu().numpy().flatten() for i in params]) def row_into_parameters(row, parameters): + """ + Map a flattened row of parameters to the original model parameters. + + Args: + row (np.ndarray): Flattened row of parameters. + parameters (Iterable[Tensor]): Model parameters to map to. + + Returns: + None + """ # for real net offset = 0 for param in parameters: new_size = functools.reduce(lambda x, y: x * y, param.shape) - current_data = row[offset : offset + new_size] + current_data = row[offset: offset + new_size] param.data[:] = torch.from_numpy(current_data.reshape(param.shape)) offset += new_size diff --git a/python/fedml/core/security/attack/byzantine_attack.py b/python/fedml/core/security/attack/byzantine_attack.py index c4f6b63257..7dfeead988 100644 --- a/python/fedml/core/security/attack/byzantine_attack.py +++ b/python/fedml/core/security/attack/byzantine_attack.py @@ -13,13 +13,30 @@ class ByzantineAttack(BaseAttackMethod): + def __init__(self, args): + """ + Initialize the ByzantineAttack. + + Args: + args (Namespace): Command-line arguments containing attack configuration. + """ self.byzantine_client_num = args.byzantine_client_num self.attack_mode = args.attack_mode # random: randomly generate a weight; zero: set the weight to 0 self.device = fedml.device.get_device(args) def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Attack the model using Byzantine clients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Extra auxiliary information (global model). + + Returns: + List[Tuple[float, OrderedDict]]: List of modified client gradients. + """ if len(raw_client_grad_list) < self.byzantine_client_num: self.byzantine_client_num = len(raw_client_grad_list) byzantine_idxs = sample_some_clients(len(raw_client_grad_list), self.byzantine_client_num) @@ -35,6 +52,16 @@ def attack_model(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], return byzantine_local_w def _attack_zero_mode(self, model_list, byzantine_idxs): + """ + Perform zero-value Byzantine attack on the model gradients. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client gradients. + byzantine_idxs (List[int]): Indices of Byzantine clients. + + Returns: + List[Tuple[float, OrderedDict]]: List of modified client gradients. + """ new_model_list = [] for i in range(0, len(model_list)): if i not in byzantine_idxs: @@ -48,6 +75,16 @@ def _attack_zero_mode(self, model_list, byzantine_idxs): return new_model_list def _attack_random_mode(self, model_list, byzantine_idxs): + """ + Perform random Byzantine attack on the model gradients. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client gradients. + byzantine_idxs (List[int]): Indices of Byzantine clients. + + Returns: + List[Tuple[float, OrderedDict]]: List of modified client gradients. + """ new_model_list = [] for i in range(0, len(model_list)): diff --git a/python/fedml/core/security/attack/dlg_attack.py b/python/fedml/core/security/attack/dlg_attack.py index f9424625a3..176cd4cb7b 100644 --- a/python/fedml/core/security/attack/dlg_attack.py +++ b/python/fedml/core/security/attack/dlg_attack.py @@ -25,10 +25,15 @@ class DLGAttack(BaseAttackMethod): def __init__(self, args): + """ + Initialize the DLGAttack. + + Args: + args (Namespace): Command-line arguments containing attack configuration. + """ self.model = None self.model_type = args.model - if args.dataset in ["cifar10", "cifar100"]: self.original_data_size = torch.Size([1, 3, 32, 32]) if args.dataset == "cifar10": @@ -39,7 +44,8 @@ def __init__(self, args): self.original_data_size = torch.Size([1, 28, 28]) self.original_label_size = torch.Size([1, 10]) else: - raise Exception(f"do not support this dataset for DLG attack: {args.dataset}") + raise Exception( + f"do not support this dataset for DLG attack: {args.dataset}") self.criterion = cross_entropy_for_onehot # cifar 100: # original data size = torch.Size([1, 3, 32, 32]) @@ -60,6 +66,13 @@ def __init__(self, args): # attack the last iteration, as it contains more information def get_model(self): + """ + Get the model based on the specified model type. + + Returns: + torch.nn.Module: The model instance. + """ + if self.model_type == "LeNet": return LeNet() elif self.model_type == "resnet56": @@ -70,13 +83,36 @@ def get_model(self): raise Exception(f"do not support this model: {self.model_type}") def reconstruct_data(self, raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None): + """ + Reconstruct the data using the provided client gradients. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Extra auxiliary information (global model of last round). + + Note: + This method performs data reconstruction based on specified conditions. + + """ if self.iteration_num in self.attack_iteration_idxs: for (_, local_model) in raw_client_grad_list: print(f"-----------attack---------------") - self.reconstruct_data_using_a_model(a_model=local_model, extra_auxiliary_info=extra_auxiliary_info) + self.reconstruct_data_using_a_model( + a_model=local_model, extra_auxiliary_info=extra_auxiliary_info) self.iteration_num += 1 def reconstruct_data_using_a_model(self, a_model: OrderedDict, extra_auxiliary_info: Any = None): + """ + Reconstruct data using a specific model and auxiliary information. + + Args: + a_model (OrderedDict): Client model parameters. + extra_auxiliary_info (Any): Extra auxiliary information (global model of last round). + + Returns: + torch.Tensor: Reconstructed data. + torch.Tensor: Reconstructed labels. + """ self.model = self.get_model() global_model_of_last_round = extra_auxiliary_info gradient = [] @@ -85,10 +121,13 @@ def reconstruct_data_using_a_model(self, a_model: OrderedDict, extra_auxiliary_i for k, _ in global_model_of_last_round.items(): if "weight" in k or "bias" in k: if self.protected_layers is not None and layer_counter in self.protected_layers: - gradient.append(torch.from_numpy(np.zeros(global_model_of_last_round[k].size())).float()) + gradient.append(torch.from_numpy( + np.zeros(global_model_of_last_round[k].size())).float()) # if the layer is protected, set to 0 else: - gradient.append(a_model[k] - global_model_of_last_round[k].to(self.device)) # !!!!!!!!!!!!!!!!!!todo: to double check + # !!!!!!!!!!!!!!!!!!todo: to double check + gradient.append( + a_model[k] - global_model_of_last_round[k].to(self.device)) layer_counter += 1 gradient = tuple(gradient) dummy_data = torch.randn(self.original_data_size) @@ -101,8 +140,10 @@ def reconstruct_data_using_a_model(self, a_model: OrderedDict, extra_auxiliary_i def closure(): optimizer.zero_grad() dummy_pred = self.model(dummy_data) - dummy_loss = self.criterion(dummy_pred, F.softmax(dummy_label, dim=-1)) - dummy_grad = torch.autograd.grad(dummy_loss, self.model.parameters(), create_graph=True) + dummy_loss = self.criterion( + dummy_pred, F.softmax(dummy_label, dim=-1)) + dummy_grad = torch.autograd.grad( + dummy_loss, self.model.parameters(), create_graph=True) dummy_grad = tuple( g.to(self.device) for g in dummy_grad) # extract tensor from tuple and move to device diff --git a/python/fedml/core/security/attack/edge_case_backdoor_attack.py b/python/fedml/core/security/attack/edge_case_backdoor_attack.py index 2e2a0d64ea..4b016fcda2 100644 --- a/python/fedml/core/security/attack/edge_case_backdoor_attack.py +++ b/python/fedml/core/security/attack/edge_case_backdoor_attack.py @@ -25,6 +25,16 @@ def __init__( backdoor_dataset, batch_size, ): + """ + Initialize the EdgeCaseBackdoorAttack. + + Args: + client_num (int): Total number of clients in the system. + poisoned_client_num (int): Number of clients to poison with backdoor samples. + backdoor_sample_percentage (float): Percentage of backdoor samples to insert. + backdoor_dataset (Dataset): Backdoor dataset containing poisoned samples. + batch_size (int): Batch size for data loaders. + """ self.client_num = client_num self.attack_epoch = 0 self.poisoned_client_num = poisoned_client_num @@ -34,6 +44,16 @@ def __init__( self.batch_size = batch_size def poison_data(self, dataset): + """ + Poison the training data of selected clients with backdoor samples. + + Args: + dataset (list): List containing various data related to clients and the dataset. + + Returns: + list: List of data loaders for each client, including backdoored clients. + """ + [ train_data_num, test_data_num, diff --git a/python/fedml/core/security/attack/invert_gradient_attack.py b/python/fedml/core/security/attack/invert_gradient_attack.py index 0f65fd2871..a11e0c80c7 100644 --- a/python/fedml/core/security/attack/invert_gradient_attack.py +++ b/python/fedml/core/security/attack/invert_gradient_attack.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import logging import math import time @@ -37,6 +38,16 @@ class InvertAttack(BaseAttackMethod): def __init__( self, attack_client_idx=0, trained_model=False, model=None, num_images=1, use_updates=False, ): + """ + Initialize the Invert Attack. + + Args: + attack_client_idx (int): Index of the target client for the attack. + trained_model (bool): Whether the model is already trained. + model: The model used for the attack. + num_images (int): Number of images to use for the attack. + use_updates (bool): Whether to use model updates for the attack. + """ defs = ConservativeStrategy() loss_fn = Classification() self.use_updates = use_updates @@ -48,15 +59,27 @@ def __init__( self.num_images = num_images # = batch_size in local training def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): + """ + Reconstruct the data after the attack. + + Args: + a_gradient (dict): Gradient information. + extra_auxiliary_info: Additional auxiliary information. + + Returns: + tuple: A tuple containing the reconstructed data and statistics. + """ self.ground_truth = extra_auxiliary_info[0][0] self.labels = extra_auxiliary_info[0][1] if not self.use_updates: rec_machine = GradientReconstructor( - self.model, (self.dm, self.ds), config=extra_auxiliary_info[1], num_images=self.num_images, + self.model, (self.dm, + self.ds), config=extra_auxiliary_info[1], num_images=self.num_images, ) self.input_gradient = a_gradient - output, stats = rec_machine.reconstruct(self.input_gradient, self.labels, self.img_shape) + output, stats = rec_machine.reconstruct( + self.input_gradient, self.labels, self.img_shape) else: rec_machine = FedAvgReconstructor( self.model, @@ -67,10 +90,12 @@ def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): use_updates=self.use_updates, ) self.input_parameters = a_gradient - output, stats = rec_machine.reconstruct(self.input_parameters, self.labels, self.img_shape) + output, stats = rec_machine.reconstruct( + self.input_parameters, self.labels, self.img_shape) test_mse = (output.detach() - self.ground_truth).pow(2).mean() - feat_mse = (self.model(output.detach()) - self.model(self.ground_truth)).pow(2).mean() + feat_mse = (self.model(output.detach()) - + self.model(self.ground_truth)).pow(2).mean() test_psnr = psnr(output, self.ground_truth, factor=1 / self.ds) logging.info( f"Rec. loss: {stats['opt']:2.4f} | MSE: {test_mse:2.4f} | PSNR: {test_psnr:4.2f} | FMSE: {feat_mse:2.4e} |" @@ -83,8 +108,6 @@ def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): """Optimization setups.""" -from dataclasses import dataclass - @dataclass # class ConservativeStrategy(Strategy): @@ -108,8 +131,10 @@ def __init__(self, lr=None, epochs=None, dryrun=False): class Loss: """Abstract class, containing necessary methods. - Abstract class to collect information about the 'higher-level' loss function, used to train an energy-based model - containing the evaluation of the loss function, its gradients w.r.t. to first and second argument and evaluations + + Abstract class to collect information about the 'higher-level' loss function, + used to train an energy-based model containing the evaluation of the loss + function, its gradients w.r.t. to first and second argument and evaluations of the actual metric that is targeted. """ @@ -181,13 +206,34 @@ def metric(self, x=None, y=None): def _label_to_onehot(target, num_classes=100): + """Convert class labels to one-hot encoded tensors. + + Args: + target (torch.Tensor): Class labels. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: One-hot encoded tensor with shape (target.size(0), num_classes). + """ target = torch.unsqueeze(target, 1) - onehot_target = torch.zeros(target.size(0), num_classes, device=target.device) + onehot_target = torch.zeros(target.size( + 0), num_classes, device=target.device) onehot_target.scatter_(1, target, 1) return onehot_target def _validate_config(config): + """Validate and fill in missing configuration values with defaults. + + Args: + config (dict): Configuration dictionary. + + Returns: + dict: Validated configuration dictionary with missing keys filled in with defaults. + + Raises: + ValueError: If deprecated keys are found in the configuration. + """ for key in DEFAULT_CONFIG.keys(): if config.get(key) is None: config[key] = DEFAULT_CONFIG[key] @@ -198,13 +244,32 @@ def _validate_config(config): class GradientReconstructor: - """Instantiate a reconstruction algorithm.""" + """ + Instantiate a reconstruction algorithm for gradients. + + Args: + model: The PyTorch model used for the reconstruction. + mean_std: Tuple of mean and standard deviation used for normalization. + config: Configuration dictionary for algorithm setup. + num_images: Number of images to use for reconstruction. + + Attributes: + config (dict): Algorithm configuration parameters. + model: The PyTorch model used for reconstruction. + setup (dict): Device and data type setup for the model. + mean_std (tuple): Mean and standard deviation used for normalization. + num_images (int): Number of images to use for reconstruction. + inception (InceptionScore): Inception score calculator (optional). + loss_fn (torch.nn.Module): Loss function used for reconstruction. + iDLG (bool): Flag indicating whether to use the iDLG trick. + """ def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images=1): """Initialize with algorithm setup.""" self.config = _validate_config(config) self.model = model - self.setup = dict(device=next(model.parameters()).device, dtype=next(model.parameters()).dtype) + self.setup = dict(device=next(model.parameters()).device, + dtype=next(model.parameters()).dtype) self.mean_std = mean_std self.num_images = num_images @@ -218,7 +283,20 @@ def __init__(self, model, mean_std=(0.0, 1.0), config=DEFAULT_CONFIG, num_images def reconstruct( self, input_data, labels, img_shape=(3, 32, 32), dryrun=False, eval=True, tol=None, ): - """Reconstruct image from gradient.""" + """ + Reconstruct images from gradients. + + Args: + input_data (torch.Tensor): Input gradient data. + labels (torch.Tensor): Labels associated with the input data. + img_shape (tuple): Image shape (channels, height, width). + dryrun (bool): Whether to perform a dry run. + eval (bool): Whether to set the model to evaluation mode. + tol (float): Tolerance threshold for reconstruction. + + Returns: + tuple: A tuple containing the reconstructed data and statistics. + """ start_time = time.time() if eval: self.model.eval() @@ -230,7 +308,8 @@ def reconstruct( if labels is None: if self.num_images == 1 and self.iDLG: # iDLG trick: - last_weight_min = torch.argmin(torch.sum(input_data[-2], dim=-1), dim=-1) + last_weight_min = torch.argmin( + torch.sum(input_data[-2], dim=-1), dim=-1) labels = last_weight_min.detach().reshape((1,)).requires_grad_(False) self.reconstruct_label = False else: @@ -249,7 +328,8 @@ def loss_fn(pred, labels): try: for trial in range(self.config["restarts"]): - x_trial, labels = self._run_trial(x[trial], input_data, labels, dryrun=dryrun) + x_trial, labels = self._run_trial( + x[trial], input_data, labels, dryrun=dryrun) # Finalize scores[trial] = self._score_trial(x_trial, input_data, labels) x[trial] = x_trial @@ -263,7 +343,8 @@ def loss_fn(pred, labels): # Choose optimal result: print("Choosing optimal result ...") - scores = scores[torch.isfinite(scores)] # guard against NaN/-Inf scores? + # guard against NaN/-Inf scores? + scores = scores[torch.isfinite(scores)] optimal_index = torch.argmin(scores) print(f"Optimal result score: {scores[optimal_index]:2.4f}") stats["opt"] = scores[optimal_index].item() @@ -273,26 +354,50 @@ def loss_fn(pred, labels): return x_optimal.detach(), stats def _init_images(self, img_shape): + """ + Initialize images for reconstruction. + + Args: + img_shape (tuple): Image shape (channels, height, width). + + Returns: + torch.Tensor: Initialized image data. + """ if self.config["init"] == "randn": return torch.randn((self.config["restarts"], self.num_images, *img_shape)) else: raise ValueError() def _run_trial(self, x_trial, input_data, labels, dryrun=False): + """ + Run a reconstruction trial. + + Args: + x_trial (torch.Tensor): Image data for the trial. + input_data (torch.Tensor): Input gradient data. + labels (torch.Tensor): Labels associated with the input data. + dryrun (bool): Whether to perform a dry run. + + Returns: + tuple: A tuple containing the reconstructed image data and labels. + """ x_trial.requires_grad = True if self.reconstruct_label: output_test = self.model(x_trial) - labels = torch.randn(output_test.shape[1]).to(**self.setup).requires_grad_(True) + labels = torch.randn(output_test.shape[1]).to( + **self.setup).requires_grad_(True) if self.config["optim"] == "adam": - optimizer = torch.optim.Adam([x_trial, labels], lr=self.config["lr"]) + optimizer = torch.optim.Adam( + [x_trial, labels], lr=self.config["lr"]) else: raise ValueError() else: if self.config["optim"] == "adam": optimizer = torch.optim.Adam([x_trial], lr=self.config["lr"]) elif self.config["optim"] == "sgd": # actually gd - optimizer = torch.optim.SGD([x_trial], lr=0.01, momentum=0.9, nesterov=True) + optimizer = torch.optim.SGD( + [x_trial], lr=0.01, momentum=0.9, nesterov=True) elif self.config["optim"] == "LBFGS": optimizer = torch.optim.LBFGS([x_trial]) else: @@ -303,12 +408,14 @@ def _run_trial(self, x_trial, input_data, labels, dryrun=False): if self.config["lr_decay"]: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, - milestones=[max_iterations // 2.667, max_iterations // 1.6, max_iterations // 1.142,], + milestones=[max_iterations // 2.667, + max_iterations // 1.6, max_iterations // 1.142,], gamma=0.1, ) # 3/8 5/8 7/8 try: for iteration in range(max_iterations): - closure = self._gradient_closure(optimizer, x_trial, input_data, labels) + closure = self._gradient_closure( + optimizer, x_trial, input_data, labels) rec_loss = optimizer.step(closure) if self.config["lr_decay"]: scheduler.step() @@ -316,16 +423,19 @@ def _run_trial(self, x_trial, input_data, labels, dryrun=False): with torch.no_grad(): # Project into image space if self.config["boxed"]: - x_trial.data = torch.max(torch.min(x_trial, (1 - dm) / ds), -dm / ds) + x_trial.data = torch.max( + torch.min(x_trial, (1 - dm) / ds), -dm / ds) if (iteration + 1 == max_iterations) or iteration % 500 == 0: - print(f"It: {iteration}. Rec. loss: {rec_loss.item():2.4f}.") + print( + f"It: {iteration}. Rec. loss: {rec_loss.item():2.4f}.") if (iteration + 1) % 500 == 0: if self.config["filter"] == "none": pass elif self.config["filter"] == "median": - x_trial.data = MedianPool2d(kernel_size=3, stride=1, padding=1, same=False)(x_trial) + x_trial.data = MedianPool2d( + kernel_size=3, stride=1, padding=1, same=False)(x_trial) else: raise ValueError() @@ -337,11 +447,24 @@ def _run_trial(self, x_trial, input_data, labels, dryrun=False): return x_trial.detach(), labels def _gradient_closure(self, optimizer, x_trial, input_gradient, label): + """ + Create a closure for gradient computation. + + Args: + optimizer: The optimizer used for reconstruction. + x_trial (torch.Tensor): Image data for the trial. + input_gradient (torch.Tensor): Input gradient data. + label (torch.Tensor): Labels associated with the input data. + + Returns: + function: A closure for gradient computation. + """ def closure(): optimizer.zero_grad() self.model.zero_grad() loss = self.loss_fn(self.model(x_trial), label) - gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=True) + gradient = torch.autograd.grad( + loss, self.model.parameters(), create_graph=True) rec_loss = reconstruction_costs( [gradient], input_gradient, @@ -351,7 +474,8 @@ def closure(): ) if self.config["total_variation"] > 0: - rec_loss += self.config["total_variation"] * total_variation(x_trial) + rec_loss += self.config["total_variation"] * \ + total_variation(x_trial) rec_loss.backward() if self.config["signed"]: x_trial.grad.sign_() @@ -360,11 +484,23 @@ def closure(): return closure def _score_trial(self, x_trial, input_gradient, label): + """ + Score a reconstruction trial. + + Args: + x_trial (torch.Tensor): Reconstructed image data. + input_gradient (torch.Tensor): Input gradient data. + label (torch.Tensor): Labels associated with the input data. + + Returns: + float: The score for the reconstruction trial. + """ if self.config["scoring_choice"] == "loss": self.model.zero_grad() x_trial.grad = None loss = self.loss_fn(self.model(x_trial), label) - gradient = torch.autograd.grad(loss, self.model.parameters(), create_graph=False) + gradient = torch.autograd.grad( + loss, self.model.parameters(), create_graph=False) return reconstruction_costs( [gradient], input_gradient, @@ -377,7 +513,33 @@ def _score_trial(self, x_trial, input_gradient, label): class FedAvgReconstructor(GradientReconstructor): - """Reconstruct an image from weights after n gradient descent steps.""" + """ + Reconstruct an image from model weights after performing gradient descent steps. + + Args: + model: The PyTorch model used for the reconstruction. + mean_std: Tuple of mean and standard deviation used for normalization. + local_steps: Number of local gradient descent steps. + local_lr: Learning rate for local gradient descent. + config: Configuration dictionary for algorithm setup. + num_images: Number of images to use for reconstruction. + use_updates: Flag indicating whether to use weight updates. + batch_size: Batch size for local gradient descent. + + Attributes: + config (dict): Algorithm configuration parameters. + model: The PyTorch model used for reconstruction. + setup (dict): Device and data type setup for the model. + mean_std (tuple): Mean and standard deviation used for normalization. + num_images (int): Number of images to use for reconstruction. + inception (InceptionScore): Inception score calculator (optional). + loss_fn (torch.nn.Module): Loss function used for reconstruction. + iDLG (bool): Flag indicating whether to use the iDLG trick. + local_steps (int): Number of local gradient descent steps. + local_lr (float): Learning rate for local gradient descent. + use_updates (bool): Flag indicating whether to use weight updates. + batch_size (int): Batch size for local gradient descent. + """ def __init__( self, @@ -390,7 +552,19 @@ def __init__( use_updates=True, batch_size=0, ): - """Initialize with model, (mean, std) and config.""" + """ + Initialize the FedAvgReconstructor with the given parameters. + + Args: + model: The PyTorch model used for the reconstruction. + mean_std: Tuple of mean and standard deviation used for normalization. + local_steps: Number of local gradient descent steps. + local_lr: Learning rate for local gradient descent. + config: Configuration dictionary for algorithm setup. + num_images: Number of images to use for reconstruction. + use_updates: Flag indicating whether to use weight updates. + batch_size: Batch size for local gradient descent. + """ super().__init__(model, mean_std, config, num_images) self.local_steps = local_steps self.local_lr = local_lr @@ -398,6 +572,18 @@ def __init__( self.batch_size = batch_size def _gradient_closure(self, optimizer, x_trial, input_parameters, labels): + """ + Closure function for computing gradients during optimization. + + Args: + optimizer (torch.optim.Optimizer): The optimizer used for gradient descent. + x_trial (torch.Tensor): The input image to be optimized. + input_parameters (torch.Tensor): The ground truth model weights. + labels (torch.Tensor): The labels used for reconstruction. + + Returns: + Callable: A closure function for computing gradients and loss. + """ def closure(): optimizer.zero_grad() self.model.zero_grad() @@ -419,7 +605,8 @@ def closure(): weights=self.config["weights"], ) if self.config["total_variation"] > 0: - rec_loss += self.config["total_variation"] * total_variation(x_trial) + rec_loss += self.config["total_variation"] * \ + total_variation(x_trial) rec_loss.backward() if self.config["signed"]: x_trial.grad.sign_() @@ -428,6 +615,17 @@ def closure(): return closure def _score_trial(self, x_trial, input_parameters, labels): + """ + Compute the score of a trial reconstruction. + + Args: + x_trial (torch.Tensor): The reconstructed image. + input_parameters (torch.Tensor): The ground truth model weights. + labels (torch.Tensor): The labels used for reconstruction. + + Returns: + float: The score of the trial reconstruction. + """ if self.config["scoring_choice"] == "loss": self.model.zero_grad() parameters = loss_steps( @@ -451,7 +649,22 @@ def _score_trial(self, x_trial, input_parameters, labels): def loss_steps( model, inputs, labels, loss_fn=torch.nn.CrossEntropyLoss(), lr=1e-4, local_steps=4, use_updates=True, batch_size=0, ): - """Take a few gradient descent steps to fit the model to the given input.""" + """ + Perform gradient descent steps to fit the model to the given input data. + + Args: + model (nn.Module): The neural network model to be optimized. + inputs (torch.Tensor): The input data for optimization. + labels (torch.Tensor): The labels for the input data. + loss_fn (torch.nn.Module, optional): The loss function used for optimization. Default is CrossEntropyLoss. + lr (float, optional): The learning rate for gradient descent. Default is 1e-4. + local_steps (int, optional): The number of gradient descent steps to perform. Default is 4. + use_updates (bool, optional): Whether to use parameter updates during optimization. Default is True. + batch_size (int, optional): Batch size for mini-batch gradient descent. Default is 0 (full batch). + + Returns: + List[torch.Tensor]: A list of model parameter tensors after optimization. + """ patched_model = MetaMonkey(model) if use_updates: patched_model_origin = deepcopy(patched_model) @@ -461,8 +674,9 @@ def loss_steps( labels_ = labels else: idx = i % (inputs.shape[0] // batch_size) - outputs = patched_model(inputs[idx * batch_size : (idx + 1) * batch_size], patched_model.parameters,) - labels_ = labels[idx * batch_size : (idx + 1) * batch_size] + outputs = patched_model( + inputs[idx * batch_size: (idx + 1) * batch_size], patched_model.parameters,) + labels_ = labels[idx * batch_size: (idx + 1) * batch_size] loss = loss_fn(outputs, labels_).sum() grad = torch.autograd.grad( loss, patched_model.parameters.values(), retain_graph=True, create_graph=True, only_inputs=True, @@ -483,21 +697,38 @@ def loss_steps( def reconstruction_costs(gradients, input_gradient, cost_fn="l2", indices="def", weights="equal"): - """Input gradient is given data.""" + """ + Calculate reconstruction costs between gradients and input gradient. + + Args: + gradients (List[torch.Tensor]): List of gradients to be compared with the input gradient. + input_gradient (torch.Tensor): The input gradient (data). + cost_fn (str, optional): The reconstruction cost function to use ("l2" or "sim"). Default is "l2". + indices (Union[str, List[int]], optional): The indices of gradients to consider or method to choose them. + Default is "def" (all gradients). + weights (Union[str, List[float]], optional): The weights for each gradient during reconstruction cost calculation. + Default is "equal" (equal weights). + + Returns: + float: The total reconstruction cost averaged over the provided gradients. + """ if isinstance(indices, list): pass elif indices == "def": indices = torch.arange(len(input_gradient)) elif indices == "top10": - _, indices = torch.topk(torch.stack([p.norm() for p in input_gradient], dim=0), 10) + _, indices = torch.topk(torch.stack( + [p.norm() for p in input_gradient], dim=0), 10) else: raise ValueError() ex = input_gradient[0] if weights == "linear": - weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient) + weights = torch.arange(len( + input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) / len(input_gradient) elif weights == "exp": - weights = torch.arange(len(input_gradient), 0, -1, dtype=ex.dtype, device=ex.device) + weights = torch.arange(len(input_gradient), 0, -1, + dtype=ex.dtype, device=ex.device) weights = weights.softmax(dim=0) weights = weights / weights[0] else: @@ -509,7 +740,8 @@ def reconstruction_costs(gradients, input_gradient, cost_fn="l2", indices="def", costs = 0 for i in indices: if cost_fn == "sim": - costs -= (trial_gradient[i] * input_gradient[i]).sum() * weights[i] + costs -= (trial_gradient[i] * + input_gradient[i]).sum() * weights[i] pnorm[0] += trial_gradient[i].pow(2).sum() * weights[i] pnorm[1] += input_gradient[i].pow(2).sum() * weights[i] if cost_fn == "sim": @@ -529,13 +761,27 @@ class MetaMonkey(torch.nn.Module): """ def __init__(self, net): - """Init with network.""" + """ + Initialize MetaMonkey with a neural network. + + Args: + net (torch.nn.Module): The neural network to be patched. + """ super().__init__() self.net = net self.parameters = OrderedDict(net.named_parameters()) def forward(self, inputs, parameters=None): - """Live Patch ... :> ...""" + """ + Forward pass through the network with optional live patching of modules. + + Args: + inputs (torch.Tensor): The input data. + parameters (OrderedDict, optional): Dictionary of parameters to be used for live patching. + + Returns: + torch.Tensor: The output tensor. + """ # If no parameter dictionary is given, everything is normal if parameters is None: return self.net(inputs) @@ -573,7 +819,8 @@ def forward(self, inputs, parameters=None): if module.num_batches_tracked is not None: module.num_batches_tracked += 1 if module.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float(module.num_batches_tracked) + exponential_average_factor = 1.0 / \ + float(module.num_batches_tracked) else: # use exponential moving average exponential_average_factor = module.momentum @@ -595,7 +842,8 @@ def forward(self, inputs, parameters=None): lin_weights = next(param_gen) lin_bias = next(param_gen) method_pile.append(module.forward) - module.forward = partial(F.linear, weight=lin_weights, bias=lin_bias) + module.forward = partial( + F.linear, weight=lin_weights, bias=lin_bias) elif next(module.parameters(), None) is None: # Pass over modules that do not contain parameters @@ -605,7 +853,8 @@ def forward(self, inputs, parameters=None): pass else: # Warn for other containers - warnings.warn(f"Patching for module {module.__class__} is not implemented.") + warnings.warn( + f"Patching for module {module.__class__} is not implemented.") output = self.net(inputs) @@ -622,13 +871,15 @@ def forward(self, inputs, parameters=None): class MedianPool2d(nn.Module): - """Median pool (usable as median filter when stride=1) module. - Args: - kernel_size: size of pooling kernel, int or 2-tuple - stride: pool stride, int or 2-tuple - padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad - same: override padding and enforce same padding, boolean """ + Initialize the MedianPool2d module. + + Args: + kernel_size: Size of the pooling kernel, can be an integer or a 2-tuple. + stride: Pooling stride, can be an integer or a 2-tuple. + padding: Pooling padding, can be an integer or a 4-tuple (left, right, top, bottom). + same: If True, override padding and enforce "same" padding. If False, use the specified padding. + """ def __init__(self, kernel_size=3, stride=1, padding=0, same=True): """Initialize with kernel_size, stride, padding.""" @@ -639,6 +890,15 @@ def __init__(self, kernel_size=3, stride=1, padding=0, same=True): self.same = same def _padding(self, x): + """ + Calculate the padding required based on the 'same' attribute and input size. + + Args: + x: Input tensor. + + Returns: + Tuple (pl, pr, pt, pb): Padding values for left, right, top, and bottom. + """ if self.same: ih, iw = x.size()[2:] if ih % self.stride[0] == 0: @@ -659,44 +919,88 @@ def _padding(self, x): return padding def forward(self, x): + """ + Perform median pooling on the input tensor. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: Output tensor after median pooling. + """ # using existing pytorch functions and tensor ops so that we get autograd, # would likely be more efficient to implement from scratch at C/Cuda level x = F.pad(x, self._padding(x), mode="reflect") - x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) + x = x.unfold(2, self.k[0], self.stride[0]).unfold( + 3, self.k[1], self.stride[1]) x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] return x class InceptionScore(torch.nn.Module): - """Class that manages and returns the inception score of images.""" + """Class that manages and returns the inception score of images. + + Args: + batch_size (int): Batch size for calculating the Inception Score. + setup (dict): A dictionary containing device and dtype setup for the model. + + Attributes: + preprocessing (torch.nn.Module): Preprocessing module to resize images to (299, 299). + model (torch.nn.Module): Inception V3 model used for scoring. + batch_size (int): Batch size for scoring. + + Note: + The input image batch should have dimensions BCHW and should be normalized. + B should be divisible by self.batch_size. + """ def __init__(self, batch_size=32, setup=dict(device=torch.device("cpu"), dtype=torch.float)): """Initialize with setup and target inception batch size.""" super().__init__() - self.preprocessing = torch.nn.Upsample(size=(299, 299), mode="bilinear", align_corners=False) - self.model = torchvision.models.inception_v3(pretrained=True).to(**setup) + self.preprocessing = torch.nn.Upsample( + size=(299, 299), mode="bilinear", align_corners=False) + self.model = torchvision.models.inception_v3( + pretrained=True).to(**setup) self.model.eval() self.batch_size = batch_size def forward(self, image_batch): - """Image batch should have dimensions BCHW and should be normalized. - B should be divisible by self.batch_size. + """Calculate the Inception Score for an image batch. + + Args: + image_batch (torch.Tensor): Input image batch with dimensions BCHW. + + Returns: + torch.Tensor: Inception Score. """ B, C, H, W = image_batch.shape batches = B // self.batch_size scores = [] for batch in range(batches): - input = self.preprocessing(image_batch[batch * self.batch_size : (batch + 1) * self.batch_size]) + input = self.preprocessing( + image_batch[batch * self.batch_size: (batch + 1) * self.batch_size]) scores.append(self.model(input)) # pylint: disable=E1102 prob_yx = torch.nn.functional.softmax(torch.cat(scores, 0), dim=1) - entropy = torch.where(prob_yx > 0, -prob_yx * prob_yx.log(), torch.zeros_like(prob_yx)) + entropy = torch.where(prob_yx > 0, -prob_yx * + prob_yx.log(), torch.zeros_like(prob_yx)) return entropy.sum() def psnr(img_batch, ref_batch, batched=False, factor=1.0): - """Standard PSNR.""" + """Calculate the Peak Signal-to-Noise Ratio (PSNR) between two image batches. + + Args: + img_batch (torch.Tensor): Input image batch. + ref_batch (torch.Tensor): Reference image batch. + batched (bool): If True, compute PSNR for the entire batch. If False, compute individual PSNRs. + factor (float): Scaling factor for PSNR computation. + + Returns: + float or torch.Tensor: PSNR value(s). + """ def get_psnr(img_in, img_ref): + mse = ((img_in - img_ref) ** 2).mean() if mse > 0 and torch.isfinite(mse): return 10 * torch.log10(factor ** 2 / mse) @@ -711,14 +1015,22 @@ def get_psnr(img_in, img_ref): [B, C, m, n] = img_batch.shape psnrs = [] for sample in range(B): - psnrs.append(get_psnr(img_batch.detach()[sample, :, :, :], ref_batch[sample, :, :, :])) + psnrs.append(get_psnr(img_batch.detach()[ + sample, :, :, :], ref_batch[sample, :, :, :])) psnr = torch.stack(psnrs, dim=0).mean() return psnr.item() def total_variation(x): - """Anisotropic TV.""" + """"Calculate the Anisotropic Total Variation (TV) of an image. + + Args: + x (torch.Tensor): Input image. + + Returns: + torch.Tensor: Total Variation value. + """ dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])) return dx + dy diff --git a/python/fedml/core/security/attack/label_flipping_attack.py b/python/fedml/core/security/attack/label_flipping_attack.py index 690f0f0748..d604da29ca 100644 --- a/python/fedml/core/security/attack/label_flipping_attack.py +++ b/python/fedml/core/security/attack/label_flipping_attack.py @@ -19,6 +19,12 @@ class LabelFlippingAttack(BaseAttackMethod): def __init__(self, args): + """ + Initialize the Label Flipping Attack. + + Args: + args: An object containing attack configuration parameters. + """ self.original_class_list = args.original_class_list self.target_class_list = args.target_class_list self.batch_size = args.batch_size @@ -43,9 +49,21 @@ def __init__(self, args): self.counter = 0 def get_ite_num(self): + """ + Get the current iteration number. + + Returns: + int: The current iteration number. + """ return math.floor(self.counter / self.client_num_per_round) # ite num starts from 0 def is_to_poison_data(self): + """ + Check if data poisoning should be performed for the current iteration. + + Returns: + bool: True if data poisoning should be performed, False otherwise. + """ self.counter += 1 if self.get_ite_num() < self.poison_start_round_id or self.get_ite_num() > self.poison_end_round_id: return False @@ -55,11 +73,26 @@ def is_to_poison_data(self): return rand < self.ratio_of_poisoned_client def print_dataset(self, dataset): + """ + Print information about the given dataset. + + Args: + dataset: The dataset to print information about. + """ print("---------------print dataset------------") for batch_idx, (data, target) in enumerate(dataset): print(f"{batch_idx} ----- {target}") def poison_data(self, local_dataset): + """ + Poison the local dataset by flipping labels. + + Args: + local_dataset: The local dataset to poison. + + Returns: + DataLoader: The poisoned data loader. + """ get_client_data_stat(local_dataset) # print("=======================1 end ") # self.print_dataset(local_dataset) @@ -83,7 +116,7 @@ def poison_data(self, local_dataset): total_counter += item[1] # print(f"total counter = {total_counter}") - ####################### below are correct ###############################3 + # below are correct ###############################3 tmp_y = replace_original_class_with_target_class( data_labels=tmp_local_dataset_y, @@ -94,4 +127,4 @@ def poison_data(self, local_dataset): poisoned_data = DataLoader(dataset, batch_size=self.batch_size) get_client_data_stat(poisoned_data) - return poisoned_data \ No newline at end of file + return poisoned_data diff --git a/python/fedml/core/security/attack/lazy_worker.py b/python/fedml/core/security/attack/lazy_worker.py index 6aa32a4fce..9a80b5caa4 100644 --- a/python/fedml/core/security/attack/lazy_worker.py +++ b/python/fedml/core/security/attack/lazy_worker.py @@ -12,6 +12,12 @@ class LazyWorkerAttack(BaseAttackMethod): def __init__(self, config): + """ + Initialize the Lazy Worker Attack. + + Args: + config: An object containing attack configuration parameters. + """ self.lazy_worker_num = config.lazy_worker_num self.attack_mode = ( config.attack_mode @@ -39,6 +45,16 @@ def attack_model( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Perform the Lazy Worker Attack on the global model. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional auxiliary information. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client gradients with the attack. + """ if self.round == 1: self.client_cache = [grad for (_, grad) in raw_client_grad_list] return raw_client_grad_list @@ -74,10 +90,23 @@ def attack_model( else: # client previous_model_params = self.client_cache[i] previous_model_params = mask_func(previous_model_params) - new_model_list.append((local_sample_number, previous_model_params)) + new_model_list.append( + (local_sample_number, previous_model_params)) return new_model_list def _add_a_mask_on_clients(self, model_list, lazy_worker_idxs, mask_func: Callable): + """ + Perform the Lazy Worker Attack on the client models. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client models. + lazy_worker_idxs (List[int]): List of lazy worker indices. + mask_func (Callable): Masking function. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client models with the attack. + + """ new_model_list = [] for i in range(0, len(model_list)): if i not in lazy_worker_idxs: @@ -88,10 +117,23 @@ def _add_a_mask_on_clients(self, model_list, lazy_worker_idxs, mask_func: Callab local_sample_number, _ = model_list[i] previous_model_params = self.client_cache[i] previous_model_params = mask_func(previous_model_params) - new_model_list.append((local_sample_number, previous_model_params)) + new_model_list.append( + (local_sample_number, previous_model_params)) return new_model_list def _add_a_mask_on_global(self, model_list, lazy_worker_idxs, mask_func: Callable): + """ + Perform the Lazy Worker Attack on the global model. + + Args: + model_list (List[Tuple[float, OrderedDict]]): List of client models. + lazy_worker_idxs (List[int]): List of lazy worker indices. + mask_func (Callable): Masking function. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client models with the attack. + + """ new_model_list = [] for i in range(0, len(model_list)): if i not in lazy_worker_idxs: @@ -100,10 +142,17 @@ def _add_a_mask_on_global(self, model_list, lazy_worker_idxs, mask_func: Callabl local_sample_number, _ = model_list[i] previous_model_params = self.client_cache[i] previous_model_params = mask_func(previous_model_params) - new_model_list.append((local_sample_number, previous_model_params)) + new_model_list.append( + (local_sample_number, previous_model_params)) return new_model_list def random_mask(self, previous_model_params): + """ + Add a random mask in [-1, 1]. + + Args: + previous_model_params (OrderedDict): Previous model parameters. + """ # add a random mask in [-1, 1] for k in previous_model_params.keys(): if is_weight_param(k): @@ -120,6 +169,12 @@ def random_mask(self, previous_model_params): return previous_model_params def gaussian_mask(self, previous_model_params): + """ + Add a gaussian mask. + + Args: + previous_model_params (OrderedDict): Previous model parameters. + """ # add a gaussian mask for k in previous_model_params.keys(): if is_weight_param(k): @@ -131,6 +186,12 @@ def gaussian_mask(self, previous_model_params): return previous_model_params def uniform_mask(self, previous_model): + """ + Randomly generate a uniform mask. + + Args: + previous_model (OrderedDict): Previous model parameters. + """ # randomly generate a uniform mask unif_param = random.uniform(-1, 1) print(f"unif_mode_param = {unif_param}") @@ -147,5 +208,11 @@ def uniform_mask(self, previous_model): return previous_model def no_mask(self, previous_model_params): + """ + Directly return the model in the last round. + + Args: + previous_model_params (OrderedDict): Previous model parameters. + """ # directly return the model in the last round return previous_model_params diff --git a/python/fedml/core/security/attack/model_replacement_backdoor_attack.py b/python/fedml/core/security/attack/model_replacement_backdoor_attack.py index be4495d188..5c84bb3401 100644 --- a/python/fedml/core/security/attack/model_replacement_backdoor_attack.py +++ b/python/fedml/core/security/attack/model_replacement_backdoor_attack.py @@ -24,6 +24,12 @@ class ModelReplacementBackdoorAttack(BaseAttackMethod): def __init__(self, args): + """ + Initialize the Model Replacement Backdoor Attack. + + Args: + args: An object containing attack parameters. + """ if hasattr(args, "malicious_client_id") and isinstance(args.malicious_client_id, int): # assume only 1 malicious client self.malicious_client_id = args.malicious_client_id @@ -46,6 +52,16 @@ def attack_model( raw_client_grad_list: List[Tuple[float, OrderedDict]], extra_auxiliary_info: Any = None, ): + """ + Attack the global model by replacing the model of a selected malicious client. + + Args: + raw_client_grad_list (List[Tuple[float, OrderedDict]]): List of client gradients. + extra_auxiliary_info (Any): Additional auxiliary information. + + Returns: + List[Tuple[float, OrderedDict]]: Updated list of client gradients with the model replacement attack. + """ participant_num = len(raw_client_grad_list) if self.attack_training_rounds is not None and self.training_round not in self.attack_training_rounds: return raw_client_grad_list @@ -71,6 +87,16 @@ def attack_model( return raw_client_grad_list def compute_gamma(self, global_model, original_client_model): + """ + Compute the scaling factor gamma for model replacement. + + Args: + global_model (OrderedDict): Global model parameters. + original_client_model (OrderedDict): Model parameters of the malicious client. + + Returns: + float: Scaling factor gamma. + """ # total_client_num / η, η: global learning rate; # when η = total_client_num/participant_num, the model is fully replaced by the average of the local models malicious_client_model_vec = vectorize_weight(original_client_model) diff --git a/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py b/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py index 4594fc99e1..95caa3ec44 100644 --- a/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py +++ b/python/fedml/core/security/attack/revealing_labels_from_gradients_attack.py @@ -22,10 +22,27 @@ class RevealingLabelsFromGradientsAttack(BaseAttackMethod): def __init__(self, batch_size, model_type): + """ + Initialize the Revealing Labels from Gradients Attack. + + Args: + batch_size (int): Batch size for the attack. + model_type (str): The type of the target model (e.g., "ResNet50"). + """ self.batch_size = batch_size self.model_type = model_type def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): + """ + Reconstruct data labels using gradients information. + + Args: + a_gradient (dict): A dictionary containing gradients information. + extra_auxiliary_info (Any): Additional auxiliary information (e.g., ground truth labels). + + Returns: + None + """ vec_local_weight = utils.vectorize_weight(a_gradient) print(vec_local_weight) @@ -37,12 +54,33 @@ def reconstruct_data(self, a_gradient: dict, extra_auxiliary_info: Any = None): return def _attack_on_gradients(self, gt_labels, v): + """ + Attack on gradients to infer labels. + + Args: + gt_labels (set): Ground truth labels. + v: Gradients information. + + Returns: + None + """ grads = np.sign(v) _, pred_labels = self._infer_labels(grads, gt_k=self.batch_size, epsilon=1e-10) print("In gt, not in pr:", [i for i in gt_labels if i not in pred_labels]) print("In pr, not in gt:", [i for i in pred_labels if i not in gt_labels]) def _infer_labels(self, grads, gt_k=None, epsilon=1e-8): + """ + Infer labels from gradients. + + Args: + grads: Gradients information. + gt_k: Number of ground truth labels to consider. + epsilon: A small value to avoid numerical instability. + + Returns: + Tuple[int, list]: Tuple containing the number of predicted labels and the list of inferred labels. + """ m, n = np.shape(grads) B, s, C = np.linalg.svd(grads, full_matrices=False) pred_k = np.linalg.matrix_rank(grads) @@ -91,6 +129,20 @@ def _infer_labels(self, grads, gt_k=None, epsilon=1e-8): @staticmethod def _solve_perceptron(X, y, fit_intercept=True, max_iter=1000, tol=1e-3, eta0=1.0): + """ + Solve the perceptron problem. + + Args: + X: Input data. + y: Target labels. + fit_intercept: Whether to fit an intercept. + max_iter: Maximum number of iterations. + tol: Tolerance for stopping criterion. + eta0: Learning rate. + + Returns: + bool: True if the perceptron problem is successfully solved, False otherwise. + """ from sklearn.linear_model import Perceptron clf = Perceptron( @@ -105,6 +157,17 @@ def _solve_perceptron(X, y, fit_intercept=True, max_iter=1000, tol=1e-3, eta0=1. @staticmethod def solve_lp(grads, b, c): + """ + Solve a linear programming problem. + + Args: + grads: Gradients information. + b: Target vector. + c: Coefficients matrix. + + Returns: + bool: True if the linear programming problem is successfully solved, False otherwise. + """ # from cvxopt import matrix, solvers np.solvers.options["show_progress"] = False diff --git a/python/fedml/core/security/common/attack_defense_data_loader.py b/python/fedml/core/security/common/attack_defense_data_loader.py index c01328b748..31e188617c 100644 --- a/python/fedml/core/security/common/attack_defense_data_loader.py +++ b/python/fedml/core/security/common/attack_defense_data_loader.py @@ -10,6 +10,19 @@ class AttackDefenseDataLoader: def load_cifar10_data( cls, client_num, batch_size, data_dir="../../../../../data/cifar10", partition_method="homo", partition_alpha=None ): + """ + Load CIFAR-10 dataset and partition it among clients. + + Args: + client_num (int): The number of clients to partition the dataset for. + batch_size (int): The batch size for DataLoader objects. + data_dir (str): The directory where the CIFAR-10 dataset is located. + partition_method (str): The method for partitioning the dataset among clients. + partition_alpha (float): The alpha parameter for partitioning (used when partition_method is "hetero"). + + Returns: + dict: A dictionary containing DataLoader objects for each client. + """ return load_partition_data_cifar10( "cifar10", data_dir=data_dir, @@ -24,13 +37,14 @@ def get_data_loader_from_data(cls, batch_size, X, Y, **kwargs): """ Get a data loader created from a given set of data. - :param batch_size: batch size of data loader - :type batch_size: int - :param X: data features - :type X: numpy.Array() - :param Y: data labels - :type Y: numpy.Array() - :return: torch.utils.data.DataLoader + Args: + batch_size (int): Batch size of the DataLoader. + X (numpy.ndarray): Data features. + Y (numpy.ndarray): Data labels. + **kwargs: Additional arguments for DataLoader. + + Returns: + torch.utils.data.DataLoader: DataLoader object for the provided data. """ X_torch = torch.from_numpy(X).float() @@ -48,9 +62,13 @@ def get_data_loader_from_data(cls, batch_size, X, Y, **kwargs): @classmethod def load_data_loader_from_file(cls, filename): """ - Loads DataLoader object from a file if available. + Load a DataLoader object from a file. + + Args: + filename (str): The name of the file containing the DataLoader object. - :param filename: string + Returns: + torch.utils.data.DataLoader: Loaded DataLoader object. """ print("Loading data loader from file: {}".format(filename)) diff --git a/python/fedml/core/security/common/bucket.py b/python/fedml/core/security/common/bucket.py index ac07019aeb..0f400887a3 100644 --- a/python/fedml/core/security/common/bucket.py +++ b/python/fedml/core/security/common/bucket.py @@ -5,6 +5,19 @@ class Bucket: @classmethod def bucketization(cls, client_grad_list, batch_size): + """ + Perform bucketization of client gradients. + + Args: + client_grad_list (list): A list of tuples containing client gradients, where each tuple consists of + the number of samples and a dictionary of gradient values. + batch_size (int): The desired batch size for bucketization. + + Returns: + list: A list of batched client gradients, where each batch is represented as a tuple containing + the total number of samples and a dictionary of batched gradient values. + + """ (num0, averaged_params) = client_grad_list[0] batch_grad_list = [] for batch_idx in range(0, math.ceil(len(client_grad_list) / batch_size)): diff --git a/python/fedml/core/security/common/net.py b/python/fedml/core/security/common/net.py index 4023ede220..5c8f7e554b 100644 --- a/python/fedml/core/security/common/net.py +++ b/python/fedml/core/security/common/net.py @@ -1,6 +1,18 @@ import torch.nn as nn class LeNet(nn.Module): + """ + LeNet-5 is a convolutional neural network architecture that was designed for handwritten and machine-printed character + recognition tasks. This implementation includes four convolutional layers and one fully connected layer. + + Args: + None + + Attributes: + body (nn.Sequential): The convolutional layers of the LeNet model. + fc (nn.Sequential): The fully connected layer of the LeNet model. + + """ def __init__(self): super(LeNet, self).__init__() act = nn.Sigmoid @@ -17,6 +29,16 @@ def __init__(self): self.fc = nn.Sequential(nn.Linear(768, 10)) def forward(self, x): + """ + Forward pass of the LeNet model. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, num_classes). + + """ out = self.body(x) out = out.view(out.size(0), -1) out = self.fc(out) diff --git a/python/fedml/core/security/common/utils.py b/python/fedml/core/security/common/utils.py index 9e12a17fdc..f753be2c35 100644 --- a/python/fedml/core/security/common/utils.py +++ b/python/fedml/core/security/common/utils.py @@ -6,6 +6,15 @@ def vectorize_weight(state_dict): + """ + Vectorizes the weight tensors in the given state_dict. + + Args: + state_dict (OrderedDict): The state_dict containing model weights. + + Returns: + torch.Tensor: A concatenated tensor of flattened weights. + """ weight_list = [] for (k, v) in state_dict.items(): if is_weight_param(k): @@ -14,27 +23,62 @@ def vectorize_weight(state_dict): def is_weight_param(k): + """ + Checks if a parameter key is a weight parameter. + + Args: + k (str): The parameter key. + + Returns: + bool: True if the key corresponds to a weight parameter, False otherwise. + """ return ( - "running_mean" not in k - and "running_var" not in k - and "num_batches_tracked" not in k + "running_mean" not in k + and "running_var" not in k + and "num_batches_tracked" not in k ) def compute_euclidean_distance(v1, v2, device='cpu'): + """ + Computes the Euclidean distance between two tensors. + + Args: + v1 (torch.Tensor): The first tensor. + v2 (torch.Tensor): The second tensor. + device (str): The device for computation (default is 'cpu'). + + Returns: + torch.Tensor: The Euclidean distance between the two tensors. + """ v1 = v1.to(device) v2 = v2.to(device) return (v1 - v2).norm() def compute_model_norm(model): + """ + Computes the norm of a model's weights. + + Args: + model: The model. + + Returns: + torch.Tensor: The norm of the model's weights. + """ return vectorize_weight(model).norm() def compute_middle_point(alphas, model_list): """ - alphas: weights of model_dict - model_dict: a model submitted by a user + Computes the weighted sum of model weights. + + Args: + alphas (list): List of weights. + model_list (list): List of model weights. + + Returns: + numpy.ndarray: The weighted sum of model weights. """ sum_batch = torch.zeros(model_list[0].shape) for a, a_batch_w in zip(alphas, model_list): @@ -88,6 +132,15 @@ def compute_geometric_median(weights, client_grads): def get_total_sample_num(model_list): + """ + Calculates the total number of samples across multiple clients. + + Args: + model_list (list): List of tuples containing local sample numbers and model parameters. + + Returns: + int: Total number of samples. + """ sample_num = 0 for i in range(len(model_list)): local_sample_num, local_model_params = model_list[i] @@ -96,6 +149,17 @@ def get_total_sample_num(model_list): def get_malicious_client_id_list(random_seed, client_num, malicious_client_num): + """ + Generates a list of malicious client IDs. + + Args: + random_seed (int): Random seed for reproducibility. + client_num (int): Total number of clients. + malicious_client_num (int): Number of malicious clients to generate. + + Returns: + list: List of malicious client IDs. + """ if client_num == malicious_client_num: client_indexes = [client_index for client_index in range(client_num)] else: @@ -103,7 +167,8 @@ def get_malicious_client_id_list(random_seed, client_num, malicious_client_num): np.random.seed( random_seed ) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num), num_clients, replace=False) + client_indexes = np.random.choice( + range(client_num), num_clients, replace=False) print("malicious client_indexes = %s" % str(client_indexes)) return client_indexes @@ -112,9 +177,15 @@ def replace_original_class_with_target_class( data_labels, original_class_list=None, target_class_list=None ): """ - :param targets: Target class IDs - :type targets: list - :return: new class IDs + Replaces original class labels in data_labels with corresponding target class labels. + + Args: + data_labels (list): List of class labels. + original_class_list (list): List of original class labels to be replaced. + target_class_list (list): List of target class labels to replace with. + + Returns: + list: Updated list of class labels. """ if ( len(original_class_list) == 0 @@ -141,12 +212,11 @@ def replace_original_class_with_target_class( def log_client_data_statistics(poisoned_client_ids, train_data_local_dict): """ - Logs all client data statistics. + Logs data distribution statistics for each client in the dataset. - :param poisoned_client_ids: list of malicious clients - :type poisoned_client_ids: list - :param train_data_local_dict: distributed dataset - :type train_data_local_dict: list(tuple) + Args: + poisoned_client_ids (list): List of malicious client IDs. + train_data_local_dict (list): Distributed dataset. """ for client_idx in range(len(train_data_local_dict)): if client_idx in poisoned_client_ids: @@ -163,6 +233,13 @@ def log_client_data_statistics(poisoned_client_ids, train_data_local_dict): def get_client_data_stat(local_dataset): + """ + Prints data distribution statistics for a local dataset. + + Args: + local_dataset (Iterable): Local dataset. + + """ print("-==========================") targets_set = {} for batch_idx, (data, targets) in enumerate(local_dataset): @@ -200,17 +277,51 @@ def get_client_data_stat(local_dataset): def cross_entropy_for_onehot(pred, target): + """ + Computes the cross-entropy loss between predicted and target one-hot encoded vectors. + + Args: + pred (torch.Tensor): Predicted logit values. + target (torch.Tensor): Target one-hot encoded vectors. + + Returns: + torch.Tensor: Cross-entropy loss. + + """ return torch.mean(torch.sum(-target * F.log_softmax(pred, dim=-1), 1)) def label_to_onehot(target, num_classes=100): + """ + Converts class labels to one-hot encoded vectors. + + Args: + target (torch.Tensor): Class labels. + num_classes (int, optional): Number of classes. Defaults to 100. + + Returns: + torch.Tensor: One-hot encoded vectors. + + """ target = torch.unsqueeze(target, 1) - onehot_target = torch.zeros(target.size(0), num_classes, device=target.device) + onehot_target = torch.zeros(target.size( + 0), num_classes, device=target.device) onehot_target.scatter_(1, target, 1) return onehot_target def trimmed_mean(model_list, trimmed_num): + """ + Trims the list of models by removing a specified number of models from both ends. + + Args: + model_list (list): List of model tuples containing local sample numbers and gradients. + trimmed_num (int): Number of models to trim from each end. + + Returns: + list: Trimmed list of models. + + """ temp_model_list = [] for i in range(0, len(model_list)): local_sample_num, client_grad = model_list[i] @@ -221,18 +332,42 @@ def trimmed_mean(model_list, trimmed_num): compute_a_score(local_sample_num), ) ) - temp_model_list.sort(key=lambda grad: grad[2]) # sort by coordinate-wise scores - temp_model_list = temp_model_list[trimmed_num: len(model_list) - trimmed_num] + # sort by coordinate-wise scores + temp_model_list.sort(key=lambda grad: grad[2]) + temp_model_list = temp_model_list[trimmed_num: len( + model_list) - trimmed_num] model_list = [(t[0], t[1]) for t in temp_model_list] return model_list def compute_a_score(local_sample_number): + """ + Compute a score for a client based on its local sample number. + + Args: + local_sample_number (int): Number of local samples for a client. + + Returns: + int: A score for the client. + + """ # todo: change to coordinate-wise score return local_sample_number def compute_krum_score(vec_grad_list, client_num_after_trim, p=2): + """ + Compute Krum scores for clients based on their gradients. + + Args: + vec_grad_list (list): List of gradient vectors for each client. + client_num_after_trim (int): Number of clients to consider. + p (int, optional): Power parameter for distance calculation. Defaults to 2. + + Returns: + list: List of Krum scores for each client. + + """ krum_scores = [] num_client = len(vec_grad_list) for i in range(0, num_client): @@ -252,6 +387,16 @@ def compute_krum_score(vec_grad_list, client_num_after_trim, p=2): def compute_gaussian_distribution(score_list): + """ + Compute the mean (mu) and standard deviation (sigma) of a list of scores. + + Args: + score_list (list): List of scores. + + Returns: + Tuple[float, float]: Mean (mu) and standard deviation (sigma). + + """ n = len(score_list) mu = sum(list(score_list)) / n temp = 0 @@ -263,4 +408,15 @@ def compute_gaussian_distribution(score_list): def sample_some_clients(client_num, sampled_client_num): - return random.sample(range(client_num), sampled_client_num) \ No newline at end of file + """ + Sample a specified number of clients from the total number of clients. + + Args: + client_num (int): Total number of clients. + sampled_client_num (int): Number of clients to sample. + + Returns: + list: List of sampled client indices. + + """ + return random.sample(range(client_num), sampled_client_num) diff --git a/python/fedml/cross_device/server_mnn/fedml_aggregator.py b/python/fedml/cross_device/server_mnn/fedml_aggregator.py index cf6c2c23c1..aafa5539b0 100644 --- a/python/fedml/cross_device/server_mnn/fedml_aggregator.py +++ b/python/fedml/cross_device/server_mnn/fedml_aggregator.py @@ -18,6 +18,19 @@ class FedMLAggregator(object): def __init__( self, test_dataloader, worker_num, device, args, aggregator, ): + """ + Initialize the FedMLAggregator. + + Args: + test_dataloader: DataLoader for the test dataset. + worker_num: Number of worker nodes (clients). + device: The device (e.g., CPU or GPU) to use for computations. + args: Arguments for configuration. + aggregator: The aggregator used for federated learning aggregation. + + Returns: + None + """ self.aggregator = aggregator self.args = args @@ -32,23 +45,67 @@ def __init__( self.flag_client_model_uploaded_dict[idx] = False def get_global_model_params(self): + """ + Test the global model on the server using the MNN (Mobile Neural Network) file format. + + Args: + mnn_file_path: The path to the MNN file containing the global model. + round_idx: The current round index. + report_metrics: A boolean indicating whether to report metrics (default is True). + + Returns: + None + """ return self.aggregator.get_model_params() # TODO: refactor MNN-related file processing def get_global_model_params_file(self): + """ + Get the file path of the global model parameters. + + Returns: + str: File path of the global model parameters. + """ return self.aggregator.get_model_params_file() def set_global_model_params(self, model_parameters): - logging.info("FedDebug. model_parameters = {}".format(model_parameters)) + """ + Set the global model parameters. + + Args: + model_parameters: Parameters of the global model. + + Returns: + None + """ + logging.info( + "FedDebug. model_parameters = {}".format(model_parameters)) self.aggregator.set_model_params(model_parameters) def add_local_trained_result(self, index, model_params, sample_num): + """ + Add the results of local model training for aggregation. + + Args: + index (int): Index of the local client. + model_params: Parameters of the locally trained model. + sample_num (int): Number of samples used for training. + + Returns: + None + """ logging.info("add_model. index = %d" % index) self.model_dict[index] = model_params self.sample_num_dict[index] = sample_num self.flag_client_model_uploaded_dict[index] = True def check_whether_all_receive(self): + """ + Check if all clients have uploaded their local models. + + Returns: + bool: True if all clients have uploaded their models, False otherwise. + """ logging.info("worker_num = {}".format(self.worker_num)) for idx in range(self.worker_num): if not self.flag_client_model_uploaded_dict[idx]: @@ -58,27 +115,48 @@ def check_whether_all_receive(self): return True def _test_individual_model_perf_before_agg(self, model_file_path, round_idx): - self.test_on_server_for_all_clients_mnn(model_file_path, round_idx, report_metrics=False) + """ + Test the performance of an individual model before aggregation. + + Args: + model_file_path (str): File path of the individual model. + round_idx (int): Index of the current federated learning round. + + Returns: + None + """ + self.test_on_server_for_all_clients_mnn( + model_file_path, round_idx, report_metrics=False) def aggregate(self): + """ + Aggregate local model updates to obtain the global model. + + Returns: + averaged_params: Averaged global model parameters. + """ logging.info("FedMLDebug. Individual model performance:") for idx in range(self.worker_num): - logging.info("self.model_dict[idx] = {}".format(self.model_dict[idx])) + logging.info("self.model_dict[idx] = {}".format( + self.model_dict[idx])) mnn_file_path = self.model_dict[idx] - self._test_individual_model_perf_before_agg(mnn_file_path, self.args.round_idx) + self._test_individual_model_perf_before_agg( + mnn_file_path, self.args.round_idx) start_time = time.time() model_list = [] training_num = 0 for idx in range(self.worker_num): - logging.info("self.model_dict[idx] = {}".format(self.model_dict[idx])) + logging.info("self.model_dict[idx] = {}".format( + self.model_dict[idx])) mnn_file_path = self.model_dict[idx] tensor_params_dict = read_mnn_as_tensor_dict(mnn_file_path) model_list.append((self.sample_num_dict[idx], tensor_params_dict)) training_num += self.sample_num_dict[idx] logging.info("training_num = {}".format(training_num)) - logging.info("len of self.model_dict[idx] = " + str(len(self.model_dict))) + logging.info( + "len of self.model_dict[idx] = " + str(len(self.model_dict))) # logging.info("################aggregate: %d" % len(model_list)) averaged_params = self.aggregator.aggregate(model_list) @@ -102,14 +180,17 @@ def data_silo_selection(self, round_idx, data_silo_num_in_total, client_num_in_t """ logging.info( - "data_silo_num_in_total = %d, client_num_in_total = %d" % (data_silo_num_in_total, client_num_in_total) + "data_silo_num_in_total = %d, client_num_in_total = %d" % ( + data_silo_num_in_total, client_num_in_total) ) assert data_silo_num_in_total >= client_num_in_total if client_num_in_total == data_silo_num_in_total: return [i for i in range(data_silo_num_in_total)] else: - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - data_silo_index_list = np.random.choice(range(data_silo_num_in_total), client_num_in_total, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + data_silo_index_list = np.random.choice( + range(data_silo_num_in_total), client_num_in_total, replace=False) return data_silo_index_list def client_selection(self, round_idx, client_id_list_in_total, client_num_per_round): @@ -126,21 +207,49 @@ def client_selection(self, round_idx, client_id_list_in_total, client_num_per_ro """ if client_num_per_round == len(client_id_list_in_total) or len(client_id_list_in_total) == 1: # for debugging return client_id_list_in_total - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_id_list_in_this_round = np.random.choice(client_id_list_in_total, client_num_per_round, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_id_list_in_this_round = np.random.choice( + client_id_list_in_total, client_num_per_round, replace=False) return client_id_list_in_this_round def client_sampling(self, round_idx, client_num_in_total, client_num_per_round): + """ + Description of the client_sampling method. + + Args: + round_idx (int): Round index, starting from 0. + client_num_in_total (int): Total number of clients. + client_num_per_round (int): Number of clients to sample per round. + + Returns: + client_indexes: List of selected client indexes. + """ + if client_num_in_total == client_num_per_round: - client_indexes = [client_index for client_index in range(client_num_in_total)] + client_indexes = [ + client_index for client_index in range(client_num_in_total)] else: num_clients = min(client_num_per_round, client_num_in_total) - np.random.seed(round_idx) # make sure for each comparison, we are selecting the same clients each round - client_indexes = np.random.choice(range(client_num_in_total), num_clients, replace=False) + # make sure for each comparison, we are selecting the same clients each round + np.random.seed(round_idx) + client_indexes = np.random.choice( + range(client_num_in_total), num_clients, replace=False) logging.info("client_indexes = %s" % str(client_indexes)) return client_indexes def _test(self, test_data, device, args): + """ + Description of the _test method. + + Args: + test_data: Test data. + device: Device on which to perform testing. + args: Additional arguments. + + Returns: + metrics: Dictionary containing test metrics. + """ model = self.model model.to(device) @@ -166,6 +275,17 @@ def _test(self, test_data, device, args): return metrics def test(self, test_data, device, args): + """ + Description of the test method. + + Args: + test_data: Test data. + device: Device on which to perform testing. + args: Additional arguments. + + Returns: + Tuple containing test accuracy, test loss, and additional metrics. + """ # test data test_num_samples = [] test_tot_corrects = [] @@ -199,7 +319,8 @@ def test(self, test_data, device, args): def test_on_server_for_all_clients(self, round_idx, global_model_file=None): if round_idx % self.args.frequency_of_the_test == 0 or round_idx == self.args.comm_round - 1: - logging.info("################test_on_server_for_all_clients : {}".format(round_idx)) + logging.info( + "################test_on_server_for_all_clients : {}".format(round_idx)) self.aggregator.test_all( self.train_data_local_dict, self.test_data_local_dict, @@ -209,10 +330,13 @@ def test_on_server_for_all_clients(self, round_idx, global_model_file=None): if round_idx == self.args.comm_round - 1: # we allow to return four metrics, such as accuracy, AUC, loss, etc. - metric_result_in_current_round = self.aggregator.test(self.test_global, self.device, self.args) + metric_result_in_current_round = self.aggregator.test( + self.test_global, self.device, self.args) else: - metric_result_in_current_round = self.aggregator.test(self.val_global, self.device, self.args) - logging.info("metric_result_in_current_round = {}".format(metric_result_in_current_round)) + metric_result_in_current_round = self.aggregator.test( + self.val_global, self.device, self.args) + logging.info("metric_result_in_current_round = {}".format( + metric_result_in_current_round)) if round_idx == self.args.comm_round - 1: mlops.log({"round_idx": round_idx}) @@ -237,8 +361,10 @@ def test_on_server_for_all_clients_mnn(self, mnn_file_path, round_idx, report_me example = self.test_global.next() input_data = example[0] output_target = example[1] - data = input_data[0] # which input, model may have more than one inputs - label = output_target[0] # also, model may have more than one outputs + # which input, model may have more than one inputs + data = input_data[0] + # also, model may have more than one outputs + label = output_target[0] result = module.forward(data) predict = F.argmax(result, 1) @@ -250,13 +376,15 @@ def test_on_server_for_all_clients_mnn(self, mnn_file_path, round_idx, report_me target = F.one_hot(F.cast(label, F.int), 10, 1, 0) loss = nn.loss.cross_entropy(result, target) - logging.info(f"correct = {correct}, self.test_global.size = {self.test_global.size}") + logging.info( + f"correct = {correct}, self.test_global.size = {self.test_global.size}") test_accuracy = correct / self.test_global.size test_loss = loss.read() if report_metrics: logging.info("test acc = {}".format(test_accuracy)) - logging.info("test loss = {}, round loss {}".format(test_loss, round(float(np.round(test_loss, 4)), 4))) + logging.info("test loss = {}, round loss {}".format( + test_loss, round(float(np.round(test_loss, 4)), 4))) mlops.log( { @@ -268,5 +396,6 @@ def test_on_server_for_all_clients_mnn(self, mnn_file_path, round_idx, report_me if self.args.enable_wandb: wandb.log( - {"round idx": round_idx, "test acc": test_accuracy, "test loss": test_loss, } + {"round idx": round_idx, "test acc": test_accuracy, + "test loss": test_loss, } ) diff --git a/python/fedml/cross_device/server_mnn/fedml_server_manager.py b/python/fedml/cross_device/server_mnn/fedml_server_manager.py index 12b49ae68c..99ddf07579 100644 --- a/python/fedml/cross_device/server_mnn/fedml_server_manager.py +++ b/python/fedml/cross_device/server_mnn/fedml_server_manager.py @@ -12,6 +12,21 @@ class FedMLServerManager(FedMLCommManager): + """ + Federated Learning Server Manager. + + This class manages the server-side operations of federated learning. + + Args: + args: Arguments for the federated learning process. + aggregator: Server aggregator for aggregating model updates. + comm: Communication backend for distributed training (default: None). + rank (int): The rank of the current worker (default: 0). + size (int): The total number of workers (default: 0). + backend (str): The communication backend (default: "MPI"). + is_preprocessed (bool): Flag indicating if data is preprocessed (default: False). + preprocessed_client_lists: List of preprocessed client data (default: None). + """ ONLINE_STATUS_FLAG = "ONLINE" RUN_FINISHED_STATUS_FLAG = "FINISHED" @@ -37,10 +52,12 @@ def __init__( self.global_model_file_path = self.args.global_model_file_path self.model_file_cache_folder = self.args.model_file_cache_folder logging.info( - "self.global_model_file_path = {}".format(self.global_model_file_path) + "self.global_model_file_path = {}".format( + self.global_model_file_path) ) logging.info( - "self.model_file_cache_folder = {}".format(self.model_file_cache_folder) + "self.model_file_cache_folder = {}".format( + self.model_file_cache_folder) ) self.client_online_mapping = {} @@ -56,6 +73,14 @@ def run(self): super().run() def start_train(self): + """ + Start the federated training process. + + This method initiates federated training by sending start training messages to all clients. + + Returns: + None + """ start_train_json = { "edges": [ { @@ -148,7 +173,8 @@ def start_train(self): "timestamp": "1651635148138", } for client_id in self.client_real_ids: - logging.info("com_manager_status - client_id = {}".format(client_id)) + logging.info( + "com_manager_status - client_id = {}".format(client_id)) self.send_message_json( "flserver_agent/" + str(client_id) + "/start_train", json.dumps(start_train_json), @@ -162,6 +188,13 @@ def send_init_msg(self): MNN (file) -> numpy -> pytorch -> aggregation -> numpy -> MNN (the same file) S2C - send the model to clients send MNN file + + Initialize and send model to clients. + + This method sends the initial model to clients to start the federated learning process. + + Returns: + """ global_model_url = None global_model_key = None @@ -174,16 +207,26 @@ def send_init_msg(self): self.data_silo_index_list[client_idx_in_this_round], global_model_url, global_model_key ) - logging.info(f"global_model_url = {global_model_url}, global_model_key = {global_model_key}") + logging.info( + f"global_model_url = {global_model_url}, global_model_key = {global_model_key}") client_idx_in_this_round += 1 - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) # Todo: for serving the cross-device model, # how to transform it to pytorch and upload the model network to ModelOps # mlops.log_training_model_net_info(self.aggregator.aggregator.model) def register_message_receive_handlers(self): + """ + Register message receive handlers. + + This method registers message handlers for processing incoming messages from clients. + + Returns: + None + """ print("register_message_receive_handlers------") self.register_message_receive_handler( MyMessage.MSG_TYPE_C2S_CLIENT_STATUS, @@ -199,9 +242,20 @@ def register_message_receive_handlers(self): ) def process_online_status(self, client_status, msg_params): + """ + Process online status message from clients. + + Args: + client_status (str): The status message from clients. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_online_mapping[str(msg_params.get_sender_id())] = True - logging.info("self.client_online_mapping = {}".format(self.client_online_mapping)) + logging.info("self.client_online_mapping = {}".format( + self.client_online_mapping)) all_client_is_online = True for client_id in self.client_id_list_in_this_round: @@ -210,17 +264,29 @@ def process_online_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_online = %s" % (msg_params.get_sender_id(), str(all_client_is_online)) + "sender_id = %d, all_client_is_online = %s" % ( + msg_params.get_sender_id(), str(all_client_is_online)) ) if all_client_is_online: - mlops.log_aggregation_status(MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) + mlops.log_aggregation_status( + MyMessage.MSG_MLOPS_SERVER_STATUS_RUNNING) # send initialization message to all clients to start training self.send_init_msg() self.is_initialized = True def process_finished_status(self, client_status, msg_params): + """ + Process finished status message from clients. + + Args: + client_status (str): The status message from clients. + msg_params: Parameters of the received message. + + Returns: + None + """ self.client_finished_mapping[str(msg_params.get_sender_id())] = True all_client_is_finished = True @@ -230,7 +296,8 @@ def process_finished_status(self, client_status, msg_params): break logging.info( - "sender_id = %d, all_client_is_finished = %s" % (msg_params.get_sender_id(), str(all_client_is_finished)) + "sender_id = %d, all_client_is_finished = %s" % ( + msg_params.get_sender_id(), str(all_client_is_finished)) ) if all_client_is_finished: @@ -239,6 +306,15 @@ def process_finished_status(self, client_status, msg_params): self.finish() def handle_message_client_status_update(self, msg_params): + """ + Handle client status update message. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ client_status = msg_params.get(MyMessage.MSG_ARG_KEY_CLIENT_STATUS) if client_status == FedMLServerManager.ONLINE_STATUS_FLAG: self.process_online_status(client_status, msg_params) @@ -246,6 +322,15 @@ def handle_message_client_status_update(self, msg_params): self.process_finished_status(client_status, msg_params) def handle_message_connection_ready(self, msg_params): + """ + Handle connection ready message. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ if not self.is_initialized: self.client_id_list_in_this_round = self.aggregator.client_selection( self.args.round_idx, self.client_real_ids, self.args.client_num_per_round @@ -270,22 +355,34 @@ def handle_message_connection_ready(self, msg_params): self.send_message_check_client_status( client_id, self.data_silo_index_list[client_idx_in_this_round], ) - logging.info("Connection ready for client: " + str(client_id)) + logging.info( + "Connection ready for client: " + str(client_id)) except Exception as e: logging.info("Connection not ready for client: {}".format( str(client_id), traceback.format_exc())) client_idx_in_this_round += 1 def handle_message_receive_model_from_client(self, msg_params): + """ + Handle received model from client. + + Args: + msg_params: Parameters of the received message. + + Returns: + None + """ sender_id = msg_params.get(MyMessage.MSG_ARG_KEY_SENDER) - mlops.event("comm_c2s", event_started=False, event_value=str(self.args.round_idx), event_edge_id=sender_id) + mlops.event("comm_c2s", event_started=False, event_value=str( + self.args.round_idx), event_edge_id=sender_id) model_params = msg_params.get(MyMessage.MSG_ARG_KEY_MODEL_PARAMS) local_sample_number = msg_params.get(MyMessage.MSG_ARG_KEY_NUM_SAMPLES) self.aggregator.add_local_trained_result( - self.client_real_ids.index(sender_id), model_params, local_sample_number + self.client_real_ids.index( + sender_id), model_params, local_sample_number ) b_all_received = self.aggregator.check_whether_all_receive() logging.info("b_all_received = %s " % str(b_all_received)) @@ -298,23 +395,26 @@ def handle_message_receive_model_from_client(self, msg_params): ) logging.info("=================================================") - mlops.event("server.wait", event_started=False, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=False, + event_value=str(self.args.round_idx)) - mlops.event("server.agg_and_eval", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=True, + event_value=str(self.args.round_idx)) global_model_params = self.aggregator.aggregate() - + # self.aggregator.test_on_server_for_all_clients( # self.args.round_idx, self.global_model_file_path # ) - - write_tensor_dict_to_mnn(self.global_model_file_path, global_model_params) + + write_tensor_dict_to_mnn( + self.global_model_file_path, global_model_params) self.aggregator.test_on_server_for_all_clients_mnn( self.global_model_file_path, self.args.round_idx ) - - mlops.event("server.agg_and_eval", event_started=False, event_value=str(self.args.round_idx)) + mlops.event("server.agg_and_eval", event_started=False, + event_value=str(self.args.round_idx)) # send round info to the MQTT backend mlops.log_round_info(self.round_num, self.args.round_idx) @@ -333,8 +433,8 @@ def handle_message_receive_model_from_client(self, msg_params): global_model_key = None logging.info("round idx {}, client_num_in_total {}, data_silo_index_list length {}," "client_id_list_in_this_round length {}.".format( - self.args.round_idx, self.args.client_num_in_total, - len(self.data_silo_index_list), len(self.client_id_list_in_this_round))) + self.args.round_idx, self.args.client_num_in_total, + len(self.data_silo_index_list), len(self.client_id_list_in_this_round))) for receiver_id in self.client_id_list_in_this_round: global_model_url, global_model_key = self.send_message_sync_model_to_client( receiver_id, @@ -350,11 +450,21 @@ def handle_message_receive_model_from_client(self, msg_params): self.args.round_idx, model_url=global_model_url, ) - logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) + logging.info( + "\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) if self.args.round_idx < self.round_num: - mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) + mlops.event("server.wait", event_started=True, + event_value=str(self.args.round_idx)) def cleanup(self): + """ + Clean up and send finish message to clients. + + This method sends a finish message to all clients to indicate the completion of the federated learning round. + + Returns: + None + """ client_idx_in_this_round = 0 for client_id in self.client_id_list_in_this_round: self.send_message_finish( @@ -363,25 +473,55 @@ def cleanup(self): client_idx_in_this_round += 1 def send_message_finish(self, receive_id, datasilo_index): - message = Message(MyMessage.MSG_TYPE_S2C_FINISH, self.get_sender_id(), receive_id) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + """ + Send finish message to a client. + + Args: + receive_id: The ID of the client to receive the finish message. + datasilo_index: The data silo index associated with the client. + + Returns: + None + """ + message = Message(MyMessage.MSG_TYPE_S2C_FINISH, + self.get_sender_id(), receive_id) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) logging.info( "finish from send id {} to receive id {}.".format(message.get_sender_id(), message.get_receiver_id())) - logging.info(" ====================send cleanup message to {}====================".format(str(datasilo_index))) + logging.info(" ====================send cleanup message to {}====================".format( + str(datasilo_index))) def send_message_init_config(self, receive_id, global_model_params, client_index, global_model_url, global_model_key): + """ + Send initialization configuration message to a client. + + Args: + receive_id: The ID of the client to receive the message. + global_model_params: The global model parameters to be sent. + client_index: The client's index. + global_model_url: URL for global model parameters (if available). + global_model_key: Key for global model parameters (if available). + + Returns: + Tuple: A tuple containing the global model URL and key after sending the message. + """ message = Message( MyMessage.MSG_TYPE_S2C_INIT_CONFIG, self.get_sender_id(), receive_id ) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) logging.info("global_model_params = {}".format(global_model_params)) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(client_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "AndroidClient") self.send_message(message) @@ -390,28 +530,58 @@ def send_message_init_config(self, receive_id, global_model_params, client_index return global_model_url, global_model_key def send_message_check_client_status(self, receive_id, datasilo_index): + """ + Send message to check client status. + + Args: + receive_id: The ID of the client to receive the message. + datasilo_index: The data silo index associated with the client. + + Returns: + None + """ + message = Message( MyMessage.MSG_TYPE_S2C_CHECK_CLIENT_STATUS, self.get_sender_id(), receive_id ) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(datasilo_index)) self.send_message(message) def send_message_sync_model_to_client( self, receive_id, global_model_params, data_silo_index, global_model_url=None, global_model_key=None ): - logging.info("send_message_sync_model_to_client. receive_id = %d" % receive_id) + """ + Send model synchronization message to a client. + + Args: + receive_id: The ID of the client to receive the model synchronization message. + global_model_params: The global model parameters to be synchronized. + data_silo_index: The data silo index associated with the client. + global_model_url: URL for global model parameters (if available). + global_model_key: Key for global model parameters (if available). + + Returns: + Tuple: A tuple containing the global model URL and key after sending the message. + """ + logging.info( + "send_message_sync_model_to_client. receive_id = %d" % receive_id) message = Message( MyMessage.MSG_TYPE_S2C_SYNC_MODEL_TO_CLIENT, self.get_sender_id(), receive_id, ) - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS, global_model_params) if global_model_url is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_URL, global_model_url) if global_model_key is not None: - message.add_params(MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) - message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(data_silo_index)) + message.add_params( + MyMessage.MSG_ARG_KEY_MODEL_PARAMS_KEY, global_model_key) + message.add_params( + MyMessage.MSG_ARG_KEY_CLIENT_INDEX, str(data_silo_index)) message.add_params(MyMessage.MSG_ARG_KEY_CLIENT_OS, "AndroidClient") self.send_message(message) diff --git a/python/fedml/cross_device/server_mnn/server_mnn_api.py b/python/fedml/cross_device/server_mnn/server_mnn_api.py index e0ca0786ee..ce33707b4d 100644 --- a/python/fedml/cross_device/server_mnn/server_mnn_api.py +++ b/python/fedml/cross_device/server_mnn/server_mnn_api.py @@ -6,27 +6,68 @@ def fedavg_cross_device(args, process_id, worker_number, comm, device, test_dataloader, model, server_aggregator=None): - logging.info("test_data_global.iter_number = {}".format(test_dataloader.iter_number)) + """ + Federated Averaging across Multiple Devices (Cross-Device Aggregation). + + This function performs federated averaging across multiple devices using cross-device aggregation. + + Args: + args: Arguments for the federated learning process. + process_id (int): The process ID of the current worker. + worker_number (int): The total number of workers. + comm: Communication backend for distributed training. + device: The device (e.g., CPU or GPU) to perform computations. + test_dataloader: DataLoader for the test dataset. + model: The federated learning model. + server_aggregator: Server aggregator for aggregating model updates (default: None). + + Returns: + None + """ + logging.info("test_data_global.iter_number = {}".format( + test_dataloader.iter_number)) if process_id == 0: - init_server(args, device, comm, process_id, worker_number, model, test_dataloader, server_aggregator) + init_server(args, device, comm, process_id, worker_number, + model, test_dataloader, server_aggregator) def init_server(args, device, comm, rank, size, model, test_dataloader, aggregator): + """ + Initialize the Federated Learning Server. + + This function initializes the federated learning server for aggregation. + + Args: + args: Arguments for the federated learning process. + device: The device (e.g., CPU or GPU) to perform computations. + comm: Communication backend for distributed training. + rank (int): The rank of the current worker. + size (int): The total number of workers. + model: The federated learning model. + test_dataloader: DataLoader for the test dataset. + aggregator: Server aggregator for aggregating model updates. + + Returns: + None + """ if aggregator is None: aggregator = create_server_aggregator(model, args) aggregator.set_id(-1) td_id = id(test_dataloader) logging.info("test_dataloader = {}".format(td_id)) - logging.info("test_data_global.iter_number = {}".format(test_dataloader.iter_number)) + logging.info("test_data_global.iter_number = {}".format( + test_dataloader.iter_number)) worker_num = size - aggregator = FedMLAggregator(test_dataloader, worker_num, device, args, aggregator) + aggregator = FedMLAggregator( + test_dataloader, worker_num, device, args, aggregator) - # start the distributed training + # Start the distributed training backend = args.backend - server_manager = FedMLServerManager(args, aggregator, comm, rank, size, backend) + server_manager = FedMLServerManager( + args, aggregator, comm, rank, size, backend) if not args.using_mlops: server_manager.start_train() server_manager.run() From e4b35ed381f050b0cb61a42fadd443cd583f146e Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Fri, 22 Sep 2023 18:10:24 +0530 Subject: [PATCH 66/70] push --- python/fedml/core/dp/common/utils.py | 62 ++++ .../core/dp/fedml_differential_privacy.py | 43 ++- python/fedml/core/dp/frames/NbAFL.py | 75 ++++- .../fedml/core/dp/frames/base_dp_solution.py | 92 +++++- python/fedml/core/dp/frames/cdp.py | 41 ++- python/fedml/core/dp/frames/dp_clip.py | 102 +++++- python/fedml/core/dp/frames/ldp.py | 29 +- .../fedml/core/dp/mechanisms/dp_mechanism.py | 73 ++++- python/fedml/core/dp/mechanisms/gaussian.py | 54 ++- python/fedml/core/dp/mechanisms/laplace.py | 38 ++- .../dp/test/test_fed_privacy_mechanism.py | 18 + python/fedml/core/mlops/mlops_configs.py | 100 +++++- python/fedml/core/mlops/mlops_device_perfs.py | 78 ++++- python/fedml/core/mlops/mlops_job_perfs.py | 60 +++- python/fedml/core/mlops/mlops_metrics.py | 307 ++++++++++++++++-- .../fedml/core/mlops/mlops_profiler_event.py | 55 ++++ python/fedml/core/mlops/mlops_runtime_log.py | 60 +++- .../core/mlops/mlops_runtime_log_daemon.py | 208 ++++++++++-- python/fedml/core/mlops/mlops_status.py | 85 +++++ python/fedml/core/mlops/mlops_utils.py | 39 ++- python/fedml/core/mlops/stats_impl.py | 88 ++++- python/fedml/core/mlops/system_stats.py | 25 ++ 22 files changed, 1594 insertions(+), 138 deletions(-) diff --git a/python/fedml/core/dp/common/utils.py b/python/fedml/core/dp/common/utils.py index 91558cb6ea..46b37f6048 100644 --- a/python/fedml/core/dp/common/utils.py +++ b/python/fedml/core/dp/common/utils.py @@ -7,6 +7,20 @@ def check_bounds(lower, upper): + """ + Check if the provided lower and upper bounds are valid. + + Args: + lower (Real): The lower bound. + upper (Real): The upper bound. + + Returns: + Tuple[Real, Real]: A tuple containing the validated lower and upper bounds. + + Raises: + TypeError: If lower or upper is not a numeric type. + ValueError: If the lower bound is greater than the upper bound. + """ if not isinstance(lower, Real) or not isinstance(upper, Real): raise TypeError("Bounds must be numeric") if lower > upper: @@ -15,18 +29,54 @@ def check_bounds(lower, upper): def check_numeric_value(value): + """ + Check if the provided value is a numeric type. + + Args: + value (Real): The value to be checked. + + Returns: + bool: True if the value is numeric, False otherwise. + + Raises: + TypeError: If the value is not a numeric type. + """ if not isinstance(value, Real): raise TypeError("Value to be randomised must be a number") return True def check_integer_value(value): + """ + Check if the provided value is an integer. + + Args: + value (Integral): The value to be checked. + + Returns: + bool: True if the value is an integer, False otherwise. + + Raises: + TypeError: If the value is not an integer. + """ if not isinstance(value, Integral): raise TypeError("Value to be randomised must be an integer") return True def check_epsilon_delta(epsilon, delta, allow_zero=False): + """ + Check if the provided epsilon and delta values are valid for differential privacy. + + Args: + epsilon (Real): Epsilon value. + delta (Real): Delta value. + allow_zero (bool, optional): Whether to allow epsilon and delta to be zero. Default is False. + + Raises: + TypeError: If epsilon or delta is not a numeric type. + ValueError: If epsilon is negative, delta is outside [0, 1] range, or both epsilon and delta are zero. + """ if not isinstance(epsilon, Real) or not isinstance(delta, Real): raise TypeError("Epsilon and delta must be numeric") if epsilon < 0: @@ -38,6 +88,18 @@ def check_epsilon_delta(epsilon, delta, allow_zero=False): def check_params(epsilon, delta, sensitivity): + """ + Check the validity of epsilon, delta, and sensitivity parameters for differential privacy. + + Args: + epsilon (Real): Epsilon value. + delta (Real): Delta value. + sensitivity (Real): Sensitivity value. + + Raises: + TypeError: If epsilon, delta, or sensitivity is not a numeric type. + ValueError: If epsilon is negative, delta is outside [0, 1] range, or sensitivity is negative. + """ check_epsilon_delta(epsilon, delta, allow_zero=False) if not isinstance(sensitivity, Real): raise TypeError("Sensitivity must be numeric") diff --git a/python/fedml/core/dp/fedml_differential_privacy.py b/python/fedml/core/dp/fedml_differential_privacy.py index de76817a48..5b81b5470a 100644 --- a/python/fedml/core/dp/fedml_differential_privacy.py +++ b/python/fedml/core/dp/fedml_differential_privacy.py @@ -11,6 +11,33 @@ class FedMLDifferentialPrivacy: + """ + A class for managing Differential Privacy in Federated Learning. + + Attributes: + enable_rdp_accountant (bool): Flag indicating if RDP accountant is enabled. + max_grad_norm (float): Maximum gradient norm for clipping. + dp_solution_type (str): Type of differential privacy solution (e.g., 'gaussian', 'laplace'). + dp_solution: An instance of the differential privacy solution. + dp_accountant: An instance of the differential privacy accountant. + is_enabled (bool): Flag indicating if differential privacy is enabled. + privacy_engine: The privacy engine used for differential privacy. + current_round (int): Current federated learning round. + accountant: An accountant for tracking privacy budget consumption. + delta (float): Delta value for differential privacy. + + Methods: + init(args): Initialize the differential privacy settings based on command-line arguments. + is_dp_enabled(): Check if differential privacy is enabled. + is_local_dp_enabled(): Check if local differential privacy is enabled. + is_global_dp_enabled(): Check if global differential privacy is enabled. + is_clipping(): Check if gradient clipping is enabled. + to_compute_params_in_aggregation_enabled(): Check if computing parameters in aggregation is enabled. + global_clip(raw_client_model_or_grad_list): Apply global gradient clipping. + add_local_noise(local_grad): Add local noise to gradients. + add_global_noise(global_model): Add global noise to the global model. + set_params_for_dp(raw_client_model_or_grad_list): Set parameters for differential privacy. + """ _dp_instance = None @staticmethod @@ -20,6 +47,12 @@ def get_instance(): return FedMLDifferentialPrivacy._dp_instance def __init__(self): + """ + Initialize differential privacy settings based on command-line arguments. + + Args: + args (argparse.Namespace): Parsed command-line arguments. + """ self.enable_rdp_accountant = False self.max_grad_norm = None self.dp_solution_type = None @@ -33,7 +66,8 @@ def __init__(self): def init(self, args): if hasattr(args, "enable_dp") and args.enable_dp: - logging.info(".......init dp......." + args.dp_solution_type + "-" + args.dp_solution_type) + logging.info(".......init dp......." + + args.dp_solution_type + "-" + args.dp_solution_type) self.is_enabled = True self.dp_solution_type = args.dp_solution_type.strip() if hasattr(args, "max_grad_norm"): @@ -67,6 +101,12 @@ def init(self, args): self.is_enabled = False def is_dp_enabled(self): + """ + Check if differential privacy is enabled. + + Returns: + bool: True if differential privacy is enabled, False otherwise. + """ return self.is_enabled def is_local_dp_enabled(self): @@ -101,4 +141,3 @@ def set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, Ord if self.dp_solution is None: raise Exception("dp solution is not initialized!") self.dp_solution.set_params_for_dp(raw_client_model_or_grad_list) - diff --git a/python/fedml/core/dp/frames/NbAFL.py b/python/fedml/core/dp/frames/NbAFL.py index 8c7d3c3f33..83c01eb816 100644 --- a/python/fedml/core/dp/frames/NbAFL.py +++ b/python/fedml/core/dp/frames/NbAFL.py @@ -12,7 +12,33 @@ class NbAFL_DP(BaseDPFrame): + """ + Non-Blocking Asynchronous Federated Learning with Differential Privacy Mechanism. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + big_C_clipping (float): A clipping threshold for bounding model weights. + total_round_num (int): The total number of communication rounds. + small_c_constant (float): A constant used in the mechanism. + client_num_per_round (int): The number of clients participating in each round. + client_num_in_total (int): The total number of clients. + epsilon (float): The privacy parameter epsilon. + m (int): The minimum size of local datasets. + + Methods: + __init__(self, args): Initialize the NbAFL_DP mechanism. + add_local_noise(self, local_grad: OrderedDict): Add local noise to the gradients. + add_global_noise(self, global_model: OrderedDict): Add global noise to the global model. + set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): Set parameters for DP. + """ + def __init__(self, args): + """ + Initialize the NbAFL_DP mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) self.set_ldp( DPMechanism( @@ -22,39 +48,64 @@ def __init__(self, args): ) ) """ - In the experiments, the authors choosed C by taking the median of the norms of the unclipped parameters. - This is not practical in reality. The server can not obtain unclipped plaintext parameters. It can only - get noised clipped parameters. So here we set C as a parameter that indicated by users. + In the experiments, the authors chose C by taking the median of the norms of the unclipped parameters. + This is not practical in reality. The server cannot obtain unclipped plaintext parameters. It can only + get noised clipped parameters. So here we set C as a parameter indicated by users. """ self.big_C_clipping = args.C # C: a clipping threshold for bounding w_i self.total_round_num = args.comm_round # T in the paper self.small_c_constant = np.sqrt( - 2 * math.log(1.25 / args.delta)) # the author indicated c>= sqrt(2ln(1.25/delta) - self.client_num_per_round = args.client_num_per_round # L in the paper - self.client_num_in_total = args.client_num_in_total # N in the paper + 2 * math.log(1.25 / args.delta)) # the author indicated c >= sqrt(2ln(1.25/delta) + self.client_num_per_round = args.client_num_per_round # L in the paper + self.client_num_in_total = args.client_num_in_total # N in the paper self.epsilon = args.epsilon # 0 < epsilon < 1 - """ The author said ''m is the minimum size of the local datasets''. + """ The author said ''m is the minimum size of the local datasets''. In their paper, clients did not sample local data for training; In our setting, we set m to the minimum sample num of each round.""" self.m = 0 # the minimum size of the local datasets def add_local_noise(self, local_grad: OrderedDict): + """ + Add local noise to the gradients. + + Args: + local_grad (OrderedDict): Local gradients. + + Returns: + OrderedDict: Local gradients with added noise. + """ for k in local_grad.keys(): # Clip weight local_grad[k] = local_grad[k] / torch.max(torch.ones(size=local_grad[k].shape), torch.abs(local_grad[k]) / self.big_C_clipping) return super().add_local_noise(local_grad=local_grad) def add_global_noise(self, global_model: OrderedDict): + """ + Add global noise to the global model. + + Args: + global_model (OrderedDict): Global model parameters. + + Returns: + OrderedDict: Global model parameters with added noise. + """ if self.total_round_num > np.sqrt(self.client_num_in_total) * self.client_num_per_round: - scale_d = 2 * self.small_c_constant * self.big_C_clipping * np.sqrt(np.power(self.total_round_num, 2) - - np.power(self.client_num_per_round, - 2) * self.client_num_in_total) / ( - self.m * self.client_num_in_total * self.epsilon) + scale_d = 2 * self.small_c_constant * self.big_C_clipping * np.sqrt( + np.power(self.total_round_num, 2) - + np.power(self.client_num_per_round, 2) * self.client_num_in_total) / ( + self.m * self.client_num_in_total * self.epsilon) for k in global_model.keys(): - global_model[k] = Gaussian.compute_noise_using_sigma(scale_d, global_model[k].shape) + global_model[k] = Gaussian.compute_noise_using_sigma( + scale_d, global_model[k].shape) return global_model def set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): + """ + Set parameters for Differential Privacy. + + Args: + raw_client_model_or_grad_list (List[Tuple[float, OrderedDict]]): List of tuples containing sample numbers and gradients/models. + """ smallest_sample_num, _ = raw_client_model_or_grad_list[0] for (sample_num, _) in raw_client_model_or_grad_list: if smallest_sample_num > sample_num: diff --git a/python/fedml/core/dp/frames/base_dp_solution.py b/python/fedml/core/dp/frames/base_dp_solution.py index e0647eb903..de3f91d097 100644 --- a/python/fedml/core/dp/frames/base_dp_solution.py +++ b/python/fedml/core/dp/frames/base_dp_solution.py @@ -6,7 +6,34 @@ class BaseDPFrame(ABC): + """ + Abstract base class for Differential Privacy mechanisms. + + Attributes: + cdp: A DPMechanism instance for global differential privacy. + ldp: A DPMechanism instance for local differential privacy. + args: A namespace containing the configuration arguments for the mechanism. + is_rdp_accountant_enabled: A boolean indicating whether RDP accountant is enabled. + max_grad_norm: Maximum gradient norm for gradient clipping. + + Methods: + __init__(self, args=None): Initialize the BaseDPFrame instance. + set_cdp(self, dp_mechanism: DPMechanism): Set the global differential privacy mechanism. + set_ldp(self, dp_mechanism: DPMechanism): Set the local differential privacy mechanism. + add_local_noise(self, local_grad: OrderedDict): Add local noise to local gradients. + add_global_noise(self, global_model: OrderedDict): Add global noise to global model parameters. + set_params_for_dp(self, raw_client_model_or_grad_list): Set parameters for differential privacy mechanism. + get_rdp_accountant_val(self): Get the differential privacy parameter for RDP accountant. + global_clip(self, raw_client_model_or_grad_list): Apply gradient clipping to global gradients. + """ + def __init__(self, args=None): + """ + Initialize the BaseDPFrame instance. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ self.cdp = None self.ldp = None self.args = args @@ -17,21 +44,66 @@ def __init__(self, args=None): self.max_grad_norm = None def set_cdp(self, dp_mechanism: DPMechanism): + """ + Set the global differential privacy mechanism. + + Args: + dp_mechanism (DPMechanism): A DPMechanism instance for global differential privacy. + """ self.cdp = dp_mechanism def set_ldp(self, dp_mechanism: DPMechanism): + """ + Set the local differential privacy mechanism. + + Args: + dp_mechanism (DPMechanism): A DPMechanism instance for local differential privacy. + """ self.ldp = dp_mechanism + @abstractmethod def add_local_noise(self, local_grad: OrderedDict): - return self.ldp.add_noise(grad=local_grad) + """ + Add local noise to local gradients. + + Args: + local_grad (OrderedDict): Local gradients. + Returns: + OrderedDict: Local gradients with added noise. + """ + pass + + @abstractmethod def add_global_noise(self, global_model: OrderedDict): - return self.cdp.add_noise(grad=global_model) + """ + Add global noise to global model parameters. + + Args: + global_model (OrderedDict): Global model parameters. - def set_params_for_dp(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): + Returns: + OrderedDict: Global model parameters with added noise. + """ + pass + + @abstractmethod + def set_params_for_dp(self, raw_client_model_or_grad_list): + """ + Set parameters for differential privacy mechanism. + + Args: + raw_client_model_or_grad_list: List of raw client models or gradients. + """ pass def get_rdp_accountant_val(self): + """ + Get the differential privacy parameter for RDP accountant. + + Returns: + float: Differential privacy parameter. + """ if self.cdp is not None: dp_param = self.cdp.get_rdp_scale() elif self.ldp is not None: @@ -40,7 +112,16 @@ def get_rdp_accountant_val(self): raise Exception("can not create rdp accountant") return dp_param - def global_clip(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDict]]): + def global_clip(self, raw_client_model_or_grad_list): + """ + Apply gradient clipping to global gradients. + + Args: + raw_client_model_or_grad_list: List of raw client models or gradients. + + Returns: + List: List of clipped client models or gradients. + """ if self.max_grad_norm is None: return raw_client_model_or_grad_list new_grad_list = [] @@ -54,6 +135,3 @@ def global_clip(self, raw_client_model_or_grad_list: List[Tuple[float, OrderedDi local_grad[k].mul_(clip_coef_clamped) new_grad_list.append((num, local_grad)) return new_grad_list - - - diff --git a/python/fedml/core/dp/frames/cdp.py b/python/fedml/core/dp/frames/cdp.py index ccc65b5822..58bfa7ca18 100644 --- a/python/fedml/core/dp/frames/cdp.py +++ b/python/fedml/core/dp/frames/cdp.py @@ -6,16 +6,49 @@ class GlobalDP(BaseDPFrame): + """ + Differential Privacy mechanism with global noise. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + enable_rdp_accountant: A boolean indicating whether RDP accountant is enabled. + is_rdp_accountant_enabled: A boolean indicating whether RDP accountant is enabled. + sample_rate: Sample rate for RDP accountant. + accountant: RDP accountant for privacy analysis. + + Methods: + __init__(self, args): Initialize the GlobalDP mechanism. + add_global_noise(self, global_model: OrderedDict): Add global noise to the global model parameters. + """ + def __init__(self, args): + """ + Initialize the GlobalDP mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) - self.set_cdp(DPMechanism(args.mechanism_type, args.epsilon, args.delta, args.sensitivity)) + self.set_cdp(DPMechanism(args.mechanism_type, + args.epsilon, args.delta, args.sensitivity)) self.enable_rdp_accountant = False if hasattr(args, "enable_rdp_accountant") and args.enable_rdp_accountant: self.is_rdp_accountant_enabled = True self.sample_rate = args.client_num_per_round / args.client_num_in_total - self.accountant = RDP_Accountant(alpha=args.rdp_alpha, dp_mechanism=args.mechanism_type, args=args) + self.accountant = RDP_Accountant( + alpha=args.rdp_alpha, dp_mechanism=args.mechanism_type, args=args) def add_global_noise(self, global_model: OrderedDict): + """ + Add global noise to the global model parameters. + + Args: + global_model (OrderedDict): Global model parameters. + + Returns: + OrderedDict: Global model parameters with added global noise. + """ if self.is_rdp_accountant_enabled: - self.accountant.step(noise_multiplier=self.cdp.get_rdp_scale(), sample_rate=self.sample_rate) # todo: ask??? - return super().add_global_noise(global_model=global_model) \ No newline at end of file + self.accountant.step( + noise_multiplier=self.cdp.get_rdp_scale(), sample_rate=self.sample_rate) + return super().add_global_noise(global_model=global_model) diff --git a/python/fedml/core/dp/frames/dp_clip.py b/python/fedml/core/dp/frames/dp_clip.py index b6e7d02b65..eb44037e53 100644 --- a/python/fedml/core/dp/frames/dp_clip.py +++ b/python/fedml/core/dp/frames/dp_clip.py @@ -13,49 +13,119 @@ """ class DP_Clip(BaseDPFrame): + """ + Differential Privacy mechanism with gradient clipping. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + + Methods: + __init__(self, args): Initialize the DP_Clip mechanism. + clip_local_update(self, local_grad, norm_type: float = 2.0): Clip local gradients. + add_local_noise(self, local_grad: OrderedDict, extra_auxiliary_info: Any = None): Add local noise to gradients. + add_global_noise(self, global_model: OrderedDict): Add global noise to the global model parameters. + get_global_params(self): Get global parameters. + compute_noise(self, size, qw): Compute noise. + add_noise(self, w_global, qw): Add noise to global parameters. + """ + def __init__(self, args): + """ + Initialize the DP_Clip mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) self.clipping_norm = args.clipping_norm self.train_data_num_in_total = args.train_data_num_in_total self._scale = args.clipping_norm * args.noise_multiplier def clip_local_update(self, local_grad, norm_type: float = 2.0): - total_norm = torch.norm(torch.stack([torch.norm(local_grad[k], norm_type) for k in local_grad.keys()]), norm_type) + """ + Clip local gradients. + + Args: + local_grad (OrderedDict): Local gradients. + norm_type (float): Type of norm to compute (default is 2.0). + + Returns: + OrderedDict: Clipped local gradients. + """ + total_norm = torch.norm(torch.stack( + [torch.norm(local_grad[k], norm_type) for k in local_grad.keys()]), norm_type) clip_coef = self.clipping_norm / (total_norm + 1e-6) clip_coef_clamped = torch.clamp(clip_coef, max=1.0) for k in local_grad.keys(): local_grad[k].mul_(clip_coef_clamped) return local_grad - def add_local_noise(self, local_grad: OrderedDict, extra_auxiliary_info: Any = None,): + def add_local_noise(self, local_grad: OrderedDict, extra_auxiliary_info: Any = None): + """ + Add local noise to gradients. + + Args: + local_grad (OrderedDict): Local gradients. + extra_auxiliary_info (Any): Extra auxiliary information (not used). + + Returns: + OrderedDict: Local gradients with added noise. + """ global_model_params = extra_auxiliary_info for k in global_model_params.keys(): local_grad[k] = local_grad[k] - global_model_params[k] return self.clip_local_update(local_grad, self.clipping_norm) def add_global_noise(self, global_model: OrderedDict): - qw = self.train_data_num_in_total * (self.args.client_num_per_round / self.args.client_num_in_total) - for k in global_model.keys(): - global_model[k] = global_model[k] / qw - w_global = self.add_noise( - global_model, qw - ) - for k in w_global.keys(): - w_global[k] = w_global[k] + global_model[k] + """ + Add global noise to the global model parameters (not implemented). + + Args: + global_model (OrderedDict): Global model parameters. + + Raises: + NotImplementedError: This method is not implemented. + """ + raise NotImplementedError( + "add_global_noise method is not implemented.") def get_global_params(self): - pass + """ + Get global parameters (not implemented). + + Raises: + NotImplementedError: This method is not implemented. + """ + raise NotImplementedError( + "get_global_params method is not implemented.") def compute_noise(self, size, qw): + """ + Compute noise for differential privacy. + + Args: + size: Size of the noise. + qw: Noise scaling factor. + + Returns: + torch.Tensor: Noise tensor. + """ self._scale = self._scale / qw return torch.normal(mean=0, std=self._scale, size=size) def add_noise(self, w_global, qw): + """ + Add noise to global parameters for differential privacy. + + Args: + w_global (OrderedDict): Global model parameters. + qw: Noise scaling factor. + + Returns: + OrderedDict: Global model parameters with added noise. + """ new_params = OrderedDict() for k in w_global.keys(): - new_params[k] = self.compute_noise(w_global[k].shape, qw) + w_global[k] + new_params[k] = self.compute_noise( + w_global[k].shape, qw) + w_global[k] return new_params - - - - diff --git a/python/fedml/core/dp/frames/ldp.py b/python/fedml/core/dp/frames/ldp.py index 94f4443431..c26cb8e131 100644 --- a/python/fedml/core/dp/frames/ldp.py +++ b/python/fedml/core/dp/frames/ldp.py @@ -5,9 +5,36 @@ class LocalDP(BaseDPFrame): + """ + Local Differential Privacy mechanism. + + Attributes: + args: A namespace containing the configuration arguments for the mechanism. + + Methods: + __init__(self, args): Initialize the LocalDP mechanism. + add_local_noise(self, local_grad: OrderedDict): Add local noise to the gradients. + """ + def __init__(self, args): + """ + Initialize the LocalDP mechanism. + + Args: + args: A namespace containing the configuration arguments for the mechanism. + """ super().__init__(args) - self.set_ldp(DPMechanism(args.mechanism_type, args.epsilon, args.delta, args.sensitivity)) + self.set_ldp(DPMechanism(args.mechanism_type, + args.epsilon, args.delta, args.sensitivity)) def add_local_noise(self, local_grad: OrderedDict): + """ + Add local noise to the gradients. + + Args: + local_grad (OrderedDict): Local gradients. + + Returns: + OrderedDict: Local gradients with added noise. + """ return super().add_local_noise(local_grad=local_grad) diff --git a/python/fedml/core/dp/mechanisms/dp_mechanism.py b/python/fedml/core/dp/mechanisms/dp_mechanism.py index ba64fe3eb6..b3cbfcc366 100644 --- a/python/fedml/core/dp/mechanisms/dp_mechanism.py +++ b/python/fedml/core/dp/mechanisms/dp_mechanism.py @@ -1,3 +1,5 @@ +from .gaussian import Gaussian +from .laplace import Laplace from fedml.core.dp.mechanisms import Gaussian, Laplace import torch from typing import Union, Iterable @@ -9,7 +11,33 @@ class DPMechanism: + """ + A class representing a Differential Privacy Mechanism. + + Attributes: + mechanism_type (str): The type of differential privacy mechanism ('laplace' or 'gaussian'). + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta. + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + + Methods: + __init__(self, mechanism_type, epsilon, delta, sensitivity=1): Initialize the DP mechanism. + add_noise(self, grad): Add noise to a gradient. + _compute_new_grad(self, grad): Compute a new gradient by adding noise. + add_a_noise_to_local_data(self, local_data): Add noise to local data. + get_rdp_scale(self): Get the RDP (Rényi Differential Privacy) scale of the mechanism. + """ + def __init__(self, mechanism_type, epsilon, delta, sensitivity=1): + """ + Initialize the Differential Privacy Mechanism. + + Args: + mechanism_type (str): The type of differential privacy mechanism ('laplace' or 'gaussian'). + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta. + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + """ mechanism_type = mechanism_type.lower() if mechanism_type == "laplace": self.dp = Laplace( @@ -21,28 +49,57 @@ def __init__(self, mechanism_type, epsilon, delta, sensitivity=1): raise NotImplementedError("DP mechanism not implemented!") def add_noise(self, grad): + """ + Add noise to a gradient. + + Args: + grad (OrderedDict): The gradient to which noise will be added. + + Returns: + OrderedDict: A new gradient with added noise. + """ new_grad = OrderedDict() for k in grad.keys(): new_grad[k] = self._compute_new_grad(grad[k]) return new_grad def _compute_new_grad(self, grad): + """ + Compute a new gradient by adding noise. + + Args: + grad (torch.Tensor): The gradient tensor. + + Returns: + torch.Tensor: A new gradient tensor with added noise. + """ noise = self.dp.compute_noise(grad.shape) return noise + grad def add_a_noise_to_local_data(self, local_data): + """ + Add noise to local data. + + Args: + local_data (list of tuples): Local data where each tuple represents a data point. + + Returns: + list of tuples: Local data with added noise. + """ new_data = [] for i in range(len(local_data)): - list = [] + data_tuple = [] for x in local_data[i]: - y = self._compute_new_grad(x) - list.append(y) - new_data.append(tuple(list)) + noisy_data = self._compute_new_grad(x) + data_tuple.append(noisy_data) + new_data.append(tuple(data_tuple)) return new_data def get_rdp_scale(self): - return self.dp.get_rdp_scale() - - - + """ + Get the RDP (Rényi Differential Privacy) scale of the mechanism. + Returns: + float: The RDP scale of the mechanism. + """ + return self.dp.get_rdp_scale() diff --git a/python/fedml/core/dp/mechanisms/gaussian.py b/python/fedml/core/dp/mechanisms/gaussian.py index 93b0ad56d5..3074ec6070 100644 --- a/python/fedml/core/dp/mechanisms/gaussian.py +++ b/python/fedml/core/dp/mechanisms/gaussian.py @@ -5,7 +5,30 @@ class Gaussian(BaseDPMechanism): + """ + The Gaussian mechanism in differential privacy. + + Attributes: + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta (default is 0.0). + sensitivity (float): The sensitivity of the mechanism (default is 1). + + Methods: + __init__(self, epsilon, delta=0.0, sensitivity=1): Initialize the Gaussian mechanism. + compute_noise(self, size): Generate Gaussian noise. + compute_noise_using_sigma(cls, sigma, size): Generate Gaussian noise with a given standard deviation. + get_rdp_scale(self): Get the RDP (Rényi Differential Privacy) scale of the mechanism. + """ + def __init__(self, epsilon, delta=0.0, sensitivity=1): + """ + Initialize the Gaussian mechanism. + + Args: + epsilon (float): The privacy parameter epsilon. + delta (float, optional): The privacy parameter delta (default is 0.0). + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + """ check_params(epsilon, delta, sensitivity) if epsilon == 0 or delta == 0: raise ValueError("Neither Epsilon nor Delta can be zero") @@ -15,19 +38,44 @@ def __init__(self, epsilon, delta=0.0, sensitivity=1): ) self.scale = ( - np.sqrt(2 * np.log(1.25 / float(delta))) - * float(sensitivity) - / float(epsilon) + np.sqrt(2 * np.log(1.25 / float(delta))) + * float(sensitivity) + / float(epsilon) ) @classmethod def compute_noise_using_sigma(cls, sigma, size): + """ + Generate Gaussian noise with a given standard deviation. + + Args: + sigma (float): The standard deviation of the Gaussian noise. + size (int or tuple): The size of the noise vector. + + Returns: + torch.Tensor: A tensor containing Gaussian noise. + """ if not isinstance(sigma, float): raise ValueError("sigma should be a float") return torch.normal(mean=0, std=sigma, size=size) def compute_noise(self, size): + """ + Generate Gaussian noise. + + Args: + size (int or tuple): The size of the noise vector. + + Returns: + torch.Tensor: A tensor containing Gaussian noise. + """ return torch.normal(mean=0, std=self.scale, size=size) def get_rdp_scale(self): + """ + Get the RDP (Rényi Differential Privacy) scale of the mechanism. + + Returns: + float: The RDP scale of the mechanism. + """ return self.scale diff --git a/python/fedml/core/dp/mechanisms/laplace.py b/python/fedml/core/dp/mechanisms/laplace.py index cc4fabd95f..4b304eab0e 100644 --- a/python/fedml/core/dp/mechanisms/laplace.py +++ b/python/fedml/core/dp/mechanisms/laplace.py @@ -7,15 +7,49 @@ class Laplace(BaseDPMechanism): """ The classical Laplace mechanism in differential privacy. + + Attributes: + epsilon (float): The privacy parameter epsilon. + delta (float): The privacy parameter delta (default is 0.0). + sensitivity (float): The sensitivity of the mechanism (default is 1). + + Methods: + __init__(self, epsilon, delta=0.0, sensitivity=1): Initialize the Laplace mechanism. + compute_noise(self, size): Generate Laplace noise. + get_rdp_scale(self): Get the RDP (Rényi Differential Privacy) scale of the mechanism. """ def __init__(self, epsilon, delta=0.0, sensitivity=1): + """ + Initialize the Laplace mechanism. + + Args: + epsilon (float): The privacy parameter epsilon. + delta (float, optional): The privacy parameter delta (default is 0.0). + sensitivity (float, optional): The sensitivity of the mechanism (default is 1). + """ check_params(epsilon, delta, sensitivity) - self.scale = float(sensitivity) / (float(epsilon) - np.log(1 - float(delta))) + self.scale = float(sensitivity) / \ + (float(epsilon) - np.log(1 - float(delta))) self.sensitivity = sensitivity def compute_noise(self, size): + """ + Generate Laplace noise. + + Args: + size (int or tuple): The size of the noise vector. + + Returns: + torch.Tensor: A tensor containing Laplace noise. + """ return torch.tensor(np.random.laplace(loc=0.0, scale=self.scale, size=size)) def get_rdp_scale(self): - return self.scale/self.sensitivity + """ + Get the RDP (Rényi Differential Privacy) scale of the mechanism. + + Returns: + float: The RDP scale of the mechanism. + """ + return self.scale / self.sensitivity diff --git a/python/fedml/core/dp/test/test_fed_privacy_mechanism.py b/python/fedml/core/dp/test/test_fed_privacy_mechanism.py index 99d252dffb..19aa39697e 100644 --- a/python/fedml/core/dp/test/test_fed_privacy_mechanism.py +++ b/python/fedml/core/dp/test/test_fed_privacy_mechanism.py @@ -13,6 +13,12 @@ def add_gaussian_args(): + """ + Define and parse command-line arguments for Gaussian differential privacy mechanism. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ parser = argparse.ArgumentParser(description="FedML") parser.add_argument( "--yaml_config_file", @@ -34,6 +40,12 @@ def add_gaussian_args(): def add_laplace_args(): + """ + Define and parse command-line arguments for Laplace differential privacy mechanism. + + Returns: + argparse.Namespace: Parsed command-line arguments. + """ parser = argparse.ArgumentParser(description="FedML") parser.add_argument( "--yaml_config_file", @@ -53,6 +65,9 @@ def add_laplace_args(): def test_FedMLDifferentialPrivacy_gaussian(): + """ + Test the FedMLDifferentialPrivacy class with the Gaussian mechanism. + """ print("----------- test_FedMLDifferentialPrivacy - gaussian mechanism -----------") FedMLDifferentialPrivacy.get_instance().init(add_gaussian_args()) print(f"grad = {a_local_w}") @@ -60,6 +75,9 @@ def test_FedMLDifferentialPrivacy_gaussian(): def test_FedMLDifferentialPrivacy_laplace(): + """ + Test the FedMLDifferentialPrivacy class with the Laplace mechanism. + """ print("----------- test_FedMLDifferentialPrivacy - laplace mechanism -----------") FedMLDifferentialPrivacy.get_instance().init(add_laplace_args()) print(f"grad = {a_local_w}") diff --git a/python/fedml/core/mlops/mlops_configs.py b/python/fedml/core/mlops/mlops_configs.py index b0ea899ab1..a43851689e 100644 --- a/python/fedml/core/mlops/mlops_configs.py +++ b/python/fedml/core/mlops/mlops_configs.py @@ -39,6 +39,13 @@ def get_instance(args): return MLOpsConfigs._config_instance def get_request_params(self): + """ + Get the request parameters for fetching configurations. + + Returns: + str: The URL for configuration retrieval. + str: The path to the certificate file, if applicable. + """ url = "https://open.fedml.ai/fedmlOpsServer/configs/fetch" config_version = "release" if ( @@ -55,7 +62,8 @@ def get_request_params(self): url = "https://open-dev.fedml.ai/fedmlOpsServer/configs/fetch" elif self.args.config_version == "local": if hasattr(self.args, "local_server") and self.args.local_server is not None: - url = "http://{}:9000/fedmlOpsServer/configs/fetch".format(self.args.local_server) + url = "http://{}:9000/fedmlOpsServer/configs/fetch".format( + self.args.local_server) else: url = "http://localhost:9000/fedmlOpsServer/configs/fetch" @@ -78,7 +86,8 @@ def get_request_params_with_version(self, version): url = "https://open-dev.fedml.ai/fedmlOpsServer/configs/fetch" elif version == "local": if hasattr(self.args, "local_server") and self.args.local_server is not None: - url = "http://{}:9000/fedmlOpsServer/configs/fetch".format(self.args.local_server) + url = "http://{}:9000/fedmlOpsServer/configs/fetch".format( + self.args.local_server) else: url = "http://localhost:9000/fedmlOpsServer/configs/fetch" @@ -93,6 +102,12 @@ def get_request_params_with_version(self, version): @staticmethod def get_root_ca_path(): + """ + Get the file path to the root CA certificate. + + Returns: + str: The file path to the root CA certificate. + """ cur_source_dir = os.path.dirname(__file__) cert_path = os.path.join( cur_source_dir, "ssl", "open-root-ca.crt" @@ -101,6 +116,14 @@ def get_root_ca_path(): @staticmethod def install_root_ca_file(): + """ + Install the root CA certificate file. + + This method appends the root CA certificate to the CA file used by the requests library. + + Raises: + FileNotFoundError: If the root CA certificate file is not found. + """ ca_file = certifi.where() open_root_ca_path = MLOpsConfigs.get_root_ca_path() with open(open_root_ca_path, 'rb') as infile: @@ -109,6 +132,16 @@ def install_root_ca_file(): outfile.write(open_root_ca_file) def fetch_configs(self): + """ + Fetch device configurations. + + Returns: + dict: MQTT configuration. + dict: S3 configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url, cert_path = self.get_request_params() json_params = {"config_name": ["mqtt_config", "s3_config", "ml_ops_config"], "device_send_time": int(time.time() * 1000)} @@ -126,7 +159,8 @@ def fetch_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -140,6 +174,16 @@ def fetch_configs(self): return mqtt_config, s3_config def fetch_web3_configs(self): + """ + Fetch MQTT, Web3, and ML Ops configurations. + + Returns: + dict: MQTT configuration. + dict: Web3 configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url, cert_path = self.get_request_params() json_params = {"config_name": ["mqtt_config", "web3_config", "ml_ops_config"], "device_send_time": int(time.time() * 1000)} @@ -157,7 +201,8 @@ def fetch_web3_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -171,6 +216,17 @@ def fetch_web3_configs(self): return mqtt_config, web3_config def fetch_thetastore_configs(self): + """ + Fetch MQTT, ThetaStore, and ML Ops configurations. + + Returns: + dict: MQTT configuration. + dict: ThetaStore configuration. + + Raises: + Exception: If fetching device configurations fails. + """ + url, cert_path = self.get_request_params() json_params = {"config_name": ["mqtt_config", "thetastore_config", "ml_ops_config"], "device_send_time": int(time.time() * 1000)} @@ -188,7 +244,8 @@ def fetch_thetastore_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -202,6 +259,18 @@ def fetch_thetastore_configs(self): return mqtt_config, thetastore_config def fetch_all_configs(self): + """ + Fetch all configurations including MQTT, S3, ML Ops, and Docker configurations. + + Returns: + dict: MQTT configuration. + dict: S3 configuration. + dict: ML Ops configuration. + dict: Docker configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url, cert_path = self.get_request_params() json_params = { "config_name": ["mqtt_config", "s3_config", "ml_ops_config", "docker_config"], @@ -221,7 +290,8 @@ def fetch_all_configs(self): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") @@ -238,6 +308,21 @@ def fetch_all_configs(self): @staticmethod def fetch_all_configs_with_version(version): + """ + Fetch all configurations with a specific version. + + Args: + version (str): The version to fetch configurations for. + + Returns: + dict: MQTT configuration. + dict: S3 configuration. + dict: ML Ops configuration. + dict: Docker configuration. + + Raises: + Exception: If fetching device configurations fails. + """ url = "https://open{}.fedml.ai/fedmlOpsServer/configs/fetch".format( "" if version == "release" else "-"+version) cert_path = None @@ -265,7 +350,8 @@ def fetch_all_configs_with_version(version): ) else: response = requests.post( - url, json=json_params, headers={"content-type": "application/json", "Connection": "close"} + url, json=json_params, headers={ + "content-type": "application/json", "Connection": "close"} ) status_code = response.json().get("code") diff --git a/python/fedml/core/mlops/mlops_device_perfs.py b/python/fedml/core/mlops/mlops_device_perfs.py index dc747f338a..00ebc4360f 100644 --- a/python/fedml/core/mlops/mlops_device_perfs.py +++ b/python/fedml/core/mlops/mlops_device_perfs.py @@ -18,6 +18,12 @@ class MLOpsDevicePerfStats(object): + """ + Class for reporting device performance statistics to MLOps. + + This class handles the reporting of device performance statistics to MLOps using MQTT. + """ + def __init__(self): self.device_realtime_stats_process = None self.device_realtime_stats_event = None @@ -28,23 +34,61 @@ def __init__(self): self.is_client = True def report_device_realtime_stats(self, sys_args): + """ + Report device real-time statistics to MLOps. + + Args: + sys_args: The system arguments passed to the device. + + This method sets up and starts a process to report real-time device statistics to MLOps. + + Returns: + None + """ + self.setup_realtime_stats_process(sys_args) def stop_device_realtime_stats(self): + """ + Stop reporting device real-time statistics. + + This method sets the event to stop reporting device real-time statistics. + + Returns: + None + """ if self.device_realtime_stats_event is not None: self.device_realtime_stats_event.set() def should_stop_device_realtime_stats(self): + """ + Check if reporting of device real-time statistics should stop. + + Returns: + bool: True if reporting should stop, False otherwise. + """ if self.device_realtime_stats_event is not None and self.device_realtime_stats_event.is_set(): return True return False def setup_realtime_stats_process(self, sys_args): + """ + Set up the process for reporting real-time device statistics. + + Args: + sys_args: The system arguments passed to the device. + + This method sets up the process for reporting real-time device statistics to MLOps. + + Returns: + None + """ perf_stats = MLOpsDevicePerfStats() perf_stats.args = sys_args perf_stats.edge_id = getattr(sys_args, "edge_id", None) - perf_stats.edge_id = getattr(sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id + perf_stats.edge_id = getattr( + sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.edge_id = 0 if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.device_id = getattr(sys_args, "device_id", 0) perf_stats.run_id = getattr(sys_args, "run_id", 0) @@ -60,8 +104,17 @@ def setup_realtime_stats_process(self, sys_args): self.device_realtime_stats_process.start() def report_device_realtime_stats_entry(self, sys_event): - print(f"Report device realtime stats, process id {os.getpid()}") + """ + Entry point for reporting real-time device statistics. + + Args: + sys_event: The system event used to control reporting. + This method is the entry point for reporting real-time device statistics to MLOps. + + Returns: + None + """ self.device_realtime_stats_event = sys_event mqtt_mgr = MqttManager( self.args.mqtt_config_path["BROKER_HOST"], @@ -69,7 +122,8 @@ def report_device_realtime_stats_entry(self, sys_event): self.args.mqtt_config_path["MQTT_USER"], self.args.mqtt_config_path["MQTT_PWD"], 180, - "FedML_Metrics_DevicePerf_{}_{}_{}".format(str(self.args.device_id), str(self.edge_id), str(uuid.uuid4())) + "FedML_Metrics_DevicePerf_{}_{}_{}".format( + str(self.args.device_id), str(self.edge_id), str(uuid.uuid4())) ) mqtt_mgr.connect() mqtt_mgr.loop_start() @@ -80,9 +134,11 @@ def report_device_realtime_stats_entry(self, sys_event): # Notify MLOps with system information. while not self.should_stop_device_realtime_stats(): try: - MLOpsDevicePerfStats.report_gpu_device_info(self.edge_id, mqtt_mgr=mqtt_mgr) + MLOpsDevicePerfStats.report_gpu_device_info( + self.edge_id, mqtt_mgr=mqtt_mgr) except Exception as e: - logging.debug("exception when reporting device pref: {}.".format(traceback.format_exc())) + logging.debug("exception when reporting device pref: {}.".format( + traceback.format_exc())) pass time.sleep(10) @@ -97,6 +153,18 @@ def report_device_realtime_stats_entry(self, sys_event): @staticmethod def report_gpu_device_info(edge_id, mqtt_mgr=None): + """ + Report GPU device information to MLOps. + + Args: + edge_id: The ID of the edge device. + mqtt_mgr: The MQTT manager for communication. + + This method reports GPU device information to MLOps using MQTT. + + Returns: + None + """ total_mem, free_mem, total_disk_size, free_disk_size, cup_utilization, cpu_cores, gpu_cores_total, \ gpu_cores_available, sent_bytes, recv_bytes, gpu_available_ids = sys_utils.get_sys_realtime_stats() diff --git a/python/fedml/core/mlops/mlops_job_perfs.py b/python/fedml/core/mlops/mlops_job_perfs.py index e7405bc762..6eb3a9c659 100644 --- a/python/fedml/core/mlops/mlops_job_perfs.py +++ b/python/fedml/core/mlops/mlops_job_perfs.py @@ -15,6 +15,9 @@ class MLOpsJobPerfStats(object): def __init__(self): + """ + Initialize MLOpsJobPerfStats object. + """ self.job_stats_process = None self.job_stats_event = None self.args = None @@ -25,11 +28,28 @@ def __init__(self): self.job_stats_obj_map = dict() def add_job(self, job_id, process_id): + """ + Add a job to be tracked for performance statistics. + + Args: + job_id (str): The ID of the job. + process_id (int): The process ID of the job. + """ self.job_process_id_map[job_id] = process_id @staticmethod def report_system_metric(run_id, edge_id, metric_json=None, mqtt_mgr=None, sys_stats_obj=None): + """ + Report system performance metrics to MLOps. + + Args: + run_id (int): The ID of the run. + edge_id (int): The ID of the edge device. + metric_json (dict, optional): The system performance metrics in JSON format. + mqtt_mgr (MqttManager, optional): The MQTT manager for communication. + sys_stats_obj (SysStats, optional): The SysStats object for collecting system stats. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_client/mlops/system_performance" @@ -91,23 +111,40 @@ def report_system_metric(run_id, edge_id, metric_json=None, mqtt_mgr.send_message_json(topic_name, message_json) def stop_job_stats(self): + """ + Stop tracking job performance statistics. + """ + if self.job_stats_event is not None: self.job_stats_event.set() def should_stop_job_stats(self): + """ + Check if job performance statistics tracking should be stopped. + + Returns: + bool: True if job performance statistics tracking should be stopped, otherwise False. + """ if self.job_stats_event is not None and self.job_stats_event.is_set(): return True return False def setup_job_stats_process(self, sys_args): + """ + Set up the process for tracking job performance statistics. + + Args: + sys_args (object): The system arguments. + """ if self.job_stats_process is not None and psutil.pid_exists(self.job_stats_process.pid): return perf_stats = MLOpsJobPerfStats() perf_stats.args = sys_args perf_stats.edge_id = getattr(sys_args, "edge_id", None) - perf_stats.edge_id = getattr(sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id + perf_stats.edge_id = getattr( + sys_args, "client_id", None) if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.edge_id = 0 if perf_stats.edge_id is None else perf_stats.edge_id perf_stats.device_id = getattr(sys_args, "device_id", 0) perf_stats.run_id = getattr(sys_args, "run_id", 0) @@ -122,11 +159,21 @@ def setup_job_stats_process(self, sys_args): self.job_stats_process.start() def report_job_stats(self, sys_args): + """ + Report job performance statistics. + + Args: + sys_args (object): The system arguments. + """ self.setup_job_stats_process(sys_args) def report_job_stats_entry(self, sys_event): - print(f"Report job realtime stats, process id {os.getpid()}") + """ + Report job performance statistics entry point. + Args: + sys_event (multiprocessing.Event): The system event for signaling the process. + """ self.job_stats_event = sys_event mqtt_mgr = MqttManager( self.args.mqtt_config_path["BROKER_HOST"], @@ -134,7 +181,8 @@ def report_job_stats_entry(self, sys_event): self.args.mqtt_config_path["MQTT_USER"], self.args.mqtt_config_path["MQTT_PWD"], 180, - "FedML_Metrics_JobPerf_{}_{}_{}".format(str(self.device_id), str(self.edge_id), str(uuid.uuid4())) + "FedML_Metrics_JobPerf_{}_{}_{}".format( + str(self.device_id), str(self.edge_id), str(uuid.uuid4())) ) mqtt_mgr.connect() mqtt_mgr.loop_start() @@ -144,13 +192,15 @@ def report_job_stats_entry(self, sys_event): for job_id, process_id in self.job_process_id_map.items(): try: if self.job_stats_obj_map.get(job_id, None) is None: - self.job_stats_obj_map[job_id] = SysStats(process_id=process_id) + self.job_stats_obj_map[job_id] = SysStats( + process_id=process_id) MLOpsJobPerfStats.report_system_metric(job_id, self.edge_id, mqtt_mgr=mqtt_mgr, sys_stats_obj=self.job_stats_obj_map[job_id]) except Exception as e: - logging.debug("exception when reporting job pref: {}.".format(traceback.format_exc())) + logging.debug("exception when reporting job pref: {}.".format( + traceback.format_exc())) pass time.sleep(10) diff --git a/python/fedml/core/mlops/mlops_metrics.py b/python/fedml/core/mlops/mlops_metrics.py index 111fc77c01..662c0234e0 100644 --- a/python/fedml/core/mlops/mlops_metrics.py +++ b/python/fedml/core/mlops/mlops_metrics.py @@ -13,6 +13,17 @@ class MLOpsMetrics(object): def __new__(cls, *args, **kw): + """ + Create a singleton instance of MLOpsMetrics. + + Args: + cls: The class. + *args: Variable-length argument list. + **kw: Keyword arguments. + + Returns: + MLOpsMetrics: The MLOpsMetrics instance. + """ if not hasattr(cls, "_instance"): orig = super(MLOpsMetrics, cls) cls._instance = orig.__new__(cls, *args, **kw) @@ -20,9 +31,16 @@ def __new__(cls, *args, **kw): return cls._instance def __init__(self): + """ + Initialize the MLOpsMetrics object. + """ + pass def init(self): + """ + Initialize the MLOpsMetrics object attributes. + """ self.messenger = None self.args = None self.run_id = None @@ -35,6 +53,13 @@ def init(self): self.device_perfs = MLOpsDevicePerfStats() def set_messenger(self, msg_messenger, args=None): + """ + Set the messenger for communication. + + Args: + msg_messenger: The message messenger. + args: The system arguments. + """ self.messenger = msg_messenger if args is not None: self.args = args @@ -62,6 +87,12 @@ def set_messenger(self, msg_messenger, args=None): self.server_agent_id = self.edge_id def comm_sanity_check(self): + """ + Check if communication is set up properly. + + Returns: + bool: True if communication is set up, otherwise False. + """ if self.messenger is None: logging.info("self.messenger is Null") return False @@ -69,6 +100,16 @@ def comm_sanity_check(self): return True def report_client_training_status(self, edge_id, status, running_json=None, is_from_model=False, in_run_id=None): + """ + Report client training status to various components. + + Args: + edge_id: The ID of the edge device. + status: The status of the training. + running_json: The running JSON information. + is_from_model: Whether the report is from the model. + in_run_id: The run ID. + """ run_id = 0 if self.run_id is not None: run_id = self.run_id @@ -84,14 +125,20 @@ def report_client_training_status(self, edge_id, status, running_json=None, is_f if is_from_model: from ...computing.scheduler.model_scheduler.device_client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) else: from ...computing.scheduler.slave.client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) def report_client_device_status_to_web_ui(self, edge_id, status): """ - this is used for notifying the client device status to MLOps Frontend + Report the client device status to MLOps Frontend. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. """ if status == ClientConstants.MSG_MLOPS_CLIENT_STATUS_IDLE: return @@ -100,9 +147,11 @@ def report_client_device_status_to_web_ui(self, edge_id, status): if self.run_id is not None: run_id = self.run_id topic_name = "fl_client/mlops/status" - msg = {"edge_id": edge_id, "run_id": run_id, "status": status, "version": "v1.0"} + msg = {"edge_id": edge_id, "run_id": run_id, + "status": status, "version": "v1.0"} message_json = json.dumps(msg) - logging.info("report_client_device_status. message_json = %s" % message_json) + logging.info( + "report_client_device_status. message_json = %s" % message_json) MLOpsStatus.get_instance().set_client_status(edge_id, status) self.messenger.send_message_json(topic_name, message_json) @@ -111,7 +160,11 @@ def common_report_client_training_status(self, edge_id, status): # logging.info("comm_sanity_check at report_client_training_status.") # return """ - this is used for notifying the client status to MLOps (both FedML CLI and backend can consume it) + Common method for reporting client training status to MLOps. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. """ run_id = 0 if self.run_id is not None: @@ -127,7 +180,12 @@ def broadcast_client_training_status(self, edge_id, status, is_from_model=False) # if not self.comm_sanity_check(): # return """ - this is used for broadcasting the client status to MLOps (backend can consume it) + Broadcast client training status to MLOps. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. + is_from_model (bool): Whether the report is from the model. """ run_id = 0 if self.run_id is not None: @@ -147,7 +205,11 @@ def common_broadcast_client_training_status(self, edge_id, status): # if not self.comm_sanity_check(): # return """ - this is used for broadcasting the client status to MLOps (backend can consume it) + Common method for broadcasting client training status to MLOps. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client device. """ run_id = 0 if self.run_id is not None: @@ -155,22 +217,43 @@ def common_broadcast_client_training_status(self, edge_id, status): topic_name = "fl_run/fl_client/mlops/status" msg = {"edge_id": edge_id, "run_id": run_id, "status": status} message_json = json.dumps(msg) - logging.info("report_client_training_status. message_json = %s" % message_json) + logging.info( + "report_client_training_status. message_json = %s" % message_json) self.messenger.send_message_json(topic_name, message_json) def client_send_exit_train_msg(self, run_id, edge_id, status, msg=None): - topic_exit_train_with_exception = "flserver_agent/" + str(run_id) + "/client_exit_train_with_exception" - msg = {"run_id": run_id, "edge_id": edge_id, "status": status, "msg": msg if msg is not None else ""} + """ + Send an exit train message for a client. + + Args: + run_id (int): The ID of the training run. + edge_id (int): The ID of the edge device. + status (str): The status of the client. + msg (str, optional): Additional message (default is None). + """ + topic_exit_train_with_exception = "flserver_agent/" + \ + str(run_id) + "/client_exit_train_with_exception" + msg = {"run_id": run_id, "edge_id": edge_id, + "status": status, "msg": msg if msg is not None else ""} message_json = json.dumps(msg) logging.info("client_send_exit_train_msg.") - self.messenger.send_message_json(topic_exit_train_with_exception, message_json) + self.messenger.send_message_json( + topic_exit_train_with_exception, message_json) def report_client_id_status(self, run_id, edge_id, status, running_json=None, is_from_model=False, server_id="0"): # if not self.comm_sanity_check(): # return """ - this is used for communication between client agent (FedML cli module) and client + Report client ID status to MLOps. + + Args: + run_id (int): The ID of the training run. + edge_id (int): The ID of the edge device. + status (str): The status of the client. + running_json: JSON information about the running state (default is None). + is_from_model (bool): Whether the report is from the model (default is False). + server_id (str): The ID of the server (default is "0"). """ self.common_report_client_id_status(run_id, edge_id, status, server_id) @@ -178,24 +261,43 @@ def report_client_id_status(self, run_id, edge_id, status, running_json=None, if is_from_model: from ...computing.scheduler.model_scheduler.device_client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) else: from ...computing.scheduler.slave.client_data_interface import FedMLClientDataInterface - FedMLClientDataInterface.get_instance().save_job(run_id, edge_id, status, running_json) + FedMLClientDataInterface.get_instance().save_job( + run_id, edge_id, status, running_json) def common_report_client_id_status(self, run_id, edge_id, status, server_id="0"): # if not self.comm_sanity_check(): # return """ - this is used for communication between client agent (FedML cli module) and client + Common method for reporting client ID status to MLOps. + + Args: + run_id (int): The ID of the training run. + edge_id (int): The ID of the edge device. + status (str): The status of the client device. + server_id (str): The ID of the server (default is "0"). """ topic_name = "fl_client/flclient_agent_" + str(edge_id) + "/status" - msg = {"run_id": run_id, "edge_id": edge_id, "status": status, "server_id": server_id} + msg = {"run_id": run_id, "edge_id": edge_id, + "status": status, "server_id": server_id} message_json = json.dumps(msg) # logging.info("report_client_id_status. message_json = %s" % message_json) self.messenger.send_message_json(topic_name, message_json) def report_server_training_status(self, run_id, status, role=None, running_json=None, is_from_model=False): + """ + Report server training status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + role (str, optional): The role of the server (default is None). + running_json: JSON information about the running state (default is None). + is_from_model (bool): Whether the report is from the model (default is False). + """ # if not self.comm_sanity_check(): # return self.common_report_server_training_status(run_id, status, role) @@ -204,14 +306,21 @@ def report_server_training_status(self, run_id, status, role=None, running_json= if is_from_model: from ...computing.scheduler.model_scheduler.device_server_data_interface import FedMLServerDataInterface - FedMLServerDataInterface.get_instance().save_job(run_id, self.edge_id, status, running_json) + FedMLServerDataInterface.get_instance().save_job( + run_id, self.edge_id, status, running_json) else: from ...computing.scheduler.master.server_data_interface import FedMLServerDataInterface - FedMLServerDataInterface.get_instance().save_job(run_id, self.edge_id, status, running_json) + FedMLServerDataInterface.get_instance().save_job( + run_id, self.edge_id, status, running_json) def report_server_device_status_to_web_ui(self, run_id, status, role=None): """ - this is used for notifying the server device status to MLOps Frontend + Report server device status to MLOps Frontend. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server device. + role (str, optional): The role of the server (default is None). """ if status == ServerConstants.MSG_MLOPS_DEVICE_STATUS_IDLE: return @@ -232,6 +341,14 @@ def report_server_device_status_to_web_ui(self, run_id, status, role=None): self.messenger.send_message_json(topic_name, message_json) def common_report_server_training_status(self, run_id, status, role=None): + """ + Common method for reporting server training status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + role (str, optional): The role of the server (default is None). + """ # if not self.comm_sanity_check(): # return topic_name = "fl_run/fl_server/mlops/status" @@ -250,6 +367,16 @@ def common_report_server_training_status(self, run_id, status, role=None): self.report_server_id_status(run_id, status) def broadcast_server_training_status(self, run_id, status, role=None, is_from_model=False, edge_id=None): + """ + Broadcast server training status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + role (str, optional): The role of the server (default is None). + is_from_model (bool): Whether the report is from the model (default is False). + edge_id (int, optional): The ID of the edge device (default is None). + """ if self.messenger is None: return topic_name = "fl_run/fl_server/mlops/status" @@ -275,37 +402,71 @@ def broadcast_server_training_status(self, run_id, status, role=None, is_from_mo FedMLServerDataInterface.get_instance().save_job(run_id, self.edge_id, status) def report_server_id_status(self, run_id, status, edge_id=None, server_id=None, server_agent_id=None): + """ + Report server ID status to MLOps. + + Args: + run_id (int): The ID of the training run. + status (str): The status of the server. + edge_id (int, optional): The ID of the edge device (default is None). + server_id (str, optional): The ID of the server (default is None). + server_agent_id (int, optional): The ID of the server agent (default is None). + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/flserver_agent_" + str(server_agent_id if server_agent_id is not None else self.server_agent_id) + "/status" - msg = {"run_id": run_id, "edge_id": edge_id if edge_id is not None else self.edge_id, "status": status} + msg = {"run_id": run_id, + "edge_id": edge_id if edge_id is not None else self.edge_id, "status": status} if server_id is not None: msg["server_id"] = server_id message_json = json.dumps(msg) # logging.info("report_server_id_status server id {}".format(server_agent_id)) - logging.info("report_server_id_status. message_json = %s" % message_json) + logging.info("report_server_id_status. message_json = %s" % + message_json) self.messenger.send_message_json(topic_name, message_json) self.report_server_device_status_to_web_ui(run_id, status) def report_client_training_metric(self, metric_json): + """ + Report client training metrics to MLOps. + + Args: + metric_json (dict): JSON containing client training metrics. + """ + # if not self.comm_sanity_check(): # return topic_name = "fl_client/mlops/training_metrics" - logging.info("report_client_training_metric. message_json = %s" % metric_json) + logging.info( + "report_client_training_metric. message_json = %s" % metric_json) message_json = json.dumps(metric_json) self.messenger.send_message_json(topic_name, message_json) def report_server_training_metric(self, metric_json): + """ + Report server training metrics to MLOps. + + Args: + metric_json (dict): JSON containing server training metrics. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/training_progress_and_eval" - logging.info("report_server_training_metric. message_json = %s" % metric_json) + logging.info( + "report_server_training_metric. message_json = %s" % metric_json) message_json = json.dumps(metric_json) self.messenger.send_message_json(topic_name, message_json) def report_server_training_round_info(self, round_info): + """ + Report server training round information to MLOps. + + Args: + round_info (dict): JSON containing server training round information. + """ + # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/training_roundx" @@ -313,6 +474,12 @@ def report_server_training_round_info(self, round_info): self.messenger.send_message_json(topic_name, message_json) def report_client_model_info(self, model_info_json): + """ + Report client model information to MLOps. + + Args: + model_info_json (dict): JSON containing client model information. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/client_model" @@ -320,6 +487,13 @@ def report_client_model_info(self, model_info_json): self.messenger.send_message_json(topic_name, message_json) def report_aggregated_model_info(self, model_info_json): + """ + Report aggregated model information to MLOps. + + Args: + model_info_json (dict): JSON containing aggregated model information. + """ + # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/global_aggregated_model" @@ -327,6 +501,12 @@ def report_aggregated_model_info(self, model_info_json): self.messenger.send_message_json(topic_name, message_json) def report_training_model_net_info(self, model_net_info_json): + """ + Report training model network information to MLOps. + + Args: + model_net_info_json (dict): JSON containing training model network information. + """ # if not self.comm_sanity_check(): # return topic_name = "fl_server/mlops/training_model_net" @@ -334,6 +514,12 @@ def report_training_model_net_info(self, model_net_info_json): self.messenger.send_message_json(topic_name, message_json) def report_llm_record(self, metric_json): + """ + Report low-latency model (LLM) input-output record to MLOps. + + Args: + metric_json (dict): JSON containing low-latency model input-output record. + """ # if not self.comm_sanity_check(): # return topic_name = "model_serving/mlops/llm_input_output_record" @@ -345,8 +531,17 @@ def report_edge_job_computing_cost(self, job_id, edge_id, computing_started_time, computing_ended_time, user_id, api_key): """ - this is used for reporting the computing cost of a job running on an edge to MLOps + Report the computing cost of a job running on an edge to MLOps. + + Args: + job_id (str): The ID of the job. + edge_id (str): The ID of the edge device. + computing_started_time (float): The timestamp when computing started. + computing_ended_time (float): The timestamp when computing ended. + user_id (str): The user ID. + api_key (str): The API key. """ + topic_name = "ml_client/mlops/job_computing_cost" duration = computing_ended_time - computing_started_time if duration < 0: @@ -360,6 +555,12 @@ def report_edge_job_computing_cost(self, job_id, edge_id, # logging.info("report_job_computing_cost. message_json = %s" % message_json) def report_logs_updated(self, run_id): + """ + Report that runtime logs have been updated to MLOps. + + Args: + run_id (int): The ID of the training run. + """ # if not self.comm_sanity_check(): # return topic_name = "mlops/runtime_logs/" + str(run_id) @@ -372,6 +573,20 @@ def report_artifact_info(self, job_id, edge_id, artifact_name, artifact_type, artifact_local_path, artifact_url, artifact_ext_info, artifact_desc, timestamp): + """ + Report artifact information to MLOps. + + Args: + job_id (str): The ID of the job associated with the artifact. + edge_id (str): The ID of the edge device where the artifact is generated. + artifact_name (str): The name of the artifact. + artifact_type (str): The type of the artifact. + artifact_local_path (str): The local path to the artifact. + artifact_url (str): The URL of the artifact. + artifact_ext_info (dict): Additional information about the artifact. + artifact_desc (str): A description of the artifact. + timestamp (float): The timestamp when the artifact was generated. + """ topic_name = "launch_device/mlops/artifacts" artifact_info_json = { "job_id": job_id, @@ -388,31 +603,69 @@ def report_artifact_info(self, job_id, edge_id, artifact_name, artifact_type, self.messenger.send_message_json(topic_name, message_json) def report_sys_perf(self, sys_args, mqtt_config): + """ + Report system performance metrics to MLOps. + + Args: + sys_args (object): System arguments object containing performance metrics. + mqtt_config (str): Path to the MQTT configuration. + """ setattr(sys_args, "mqtt_config_path", mqtt_config) run_id = getattr(sys_args, "run_id", 0) self.fl_job_perf.add_job(run_id, os.getpid()) self.fl_job_perf.report_job_stats(sys_args) def stop_sys_perf(self): + """ + Stop reporting system performance metrics to MLOps. + """ self.fl_job_perf.stop_job_stats() def report_job_perf(self, sys_args, mqtt_config, job_process_id): + """ + Report job performance metrics to MLOps. + + Args: + sys_args (object): System arguments object containing job performance metrics. + mqtt_config (str): Path to the MQTT configuration. + job_process_id (int): The process ID of the job. + """ setattr(sys_args, "mqtt_config_path", mqtt_config) run_id = getattr(sys_args, "run_id", 0) self.job_perfs.add_job(run_id, job_process_id) self.job_perfs.report_job_stats(sys_args) def stop_job_perf(self): + """ + Stop reporting job performance metrics to MLOps. + """ self.job_perfs.stop_job_stats() - def report_device_realtime_perf(self, sys_args, mqtt_config, is_client=True): + def report_device_realtime_perf(self, sys_args, mqtt_config): + """ + Report real-time device performance metrics to MLOps. + + Args: + sys_args (object): System arguments object containing real-time device performance metrics. + mqtt_config (str): Path to the MQTT configuration. + """ setattr(sys_args, "mqtt_config_path", mqtt_config) self.device_perfs.is_client = is_client self.device_perfs.report_device_realtime_stats(sys_args) def stop_device_realtime_perf(self): + """ + Stop reporting real-time device performance metrics to MLOps. + """ + self.device_perfs.stop_device_realtime_stats() def report_json_message(self, topic, payload): - self.messenger.send_message_json(topic, payload) + """ + Report a JSON message to a specified topic. + Args: + topic (str): The MQTT topic to publish the message to. + payload (dict): The JSON payload to be sent. + """ + self.messenger.send_message_json(topic, payload) diff --git a/python/fedml/core/mlops/mlops_profiler_event.py b/python/fedml/core/mlops/mlops_profiler_event.py index 73aa151054..bafdafa7be 100644 --- a/python/fedml/core/mlops/mlops_profiler_event.py +++ b/python/fedml/core/mlops/mlops_profiler_event.py @@ -22,6 +22,12 @@ def __new__(cls, *args, **kwargs): return MLOpsProfilerEvent._instance def __init__(self, args): + """ + Initialize the MLOpsProfilerEvent. + + Args: + args: The system arguments containing configuration settings. + """ self.args = args if args is not None and hasattr(args, "enable_wandb") and args.enable_wandb is not None: self.enable_wandb = args.enable_wandb @@ -37,6 +43,13 @@ def __init__(self, args): self.run_id = 0 def set_messenger(self, msg_messenger, args=None): + """ + Set the messenger for communication. + + Args: + msg_messenger: The messenger for communication. + args: The system arguments containing configuration settings. + """ self.com_manager = msg_messenger if args is None: return @@ -59,19 +72,39 @@ def set_messenger(self, msg_messenger, args=None): @classmethod def enable_wandb_tracking(cls): + """ + Enable W&B (Weights and Biases) tracking. + """ cls._enable_wandb = True @classmethod def enable_sys_perf_profiling(cls): + """ + Enable system performance profiling. + """ cls._sys_perf_profiling = True @classmethod def log_to_wandb(cls, metric): + """ + Log a metric to W&B (Weights and Biases). + + Args: + metric: The metric to log. + """ if cls._enable_wandb: import wandb wandb.log(metric) def log_event_started(self, event_name, event_value=None, event_edge_id=None): + """ + Log the start of an event. + + Args: + event_name: The name of the event. + event_value: The value associated with the event. + event_edge_id: The ID of the edge device associated with the event. + """ if event_value is None: event_value_passed = "" else: @@ -95,6 +128,14 @@ def log_event_started(self, event_name, event_value=None, event_edge_id=None): self.com_manager.send_message_json(event_topic, event_msg_str) def log_event_ended(self, event_name, event_value=None, event_edge_id=None): + """ + Log the end of an event. + + Args: + event_name: The name of the event. + event_value: The value associated with the event. + event_edge_id: The ID of the edge device associated with the event. + """ if event_value is None: event_value_passed = "" else: @@ -120,6 +161,20 @@ def log_event_ended(self, event_name, event_value=None, event_edge_id=None): @staticmethod def __build_event_mqtt_msg(run_id, edge_id, event_type, event_name, event_value): + """ + Build an MQTT message for an event. + + Args: + run_id: The ID of the run. + edge_id: The ID of the edge device. + event_type: The type of the event (started or ended). + event_name: The name of the event. + event_value: The value associated with the event. + + Returns: + event_topic: The MQTT topic for the event. + event_msg: The MQTT message for the event. + """ event_topic = "mlops/events" event_msg = {} if event_type == MLOpsProfilerEvent.EVENT_TYPE_STARTED: diff --git a/python/fedml/core/mlops/mlops_runtime_log.py b/python/fedml/core/mlops/mlops_runtime_log.py index 7ebcc43ade..93b2cea126 100644 --- a/python/fedml/core/mlops/mlops_runtime_log.py +++ b/python/fedml/core/mlops/mlops_runtime_log.py @@ -33,7 +33,8 @@ def handle_exception(exc_type, exc_value, exc_traceback): sys.__excepthook__(exc_type, exc_value, exc_traceback) return - logging.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback)) + logging.error("Uncaught exception", exc_info=( + exc_type, exc_value, exc_traceback)) if MLOpsRuntimeLog._log_sdk_instance is not None and \ hasattr(MLOpsRuntimeLog._log_sdk_instance, "args") and \ @@ -48,6 +49,23 @@ def handle_exception(exc_type, exc_value, exc_traceback): mlops.send_exit_train_msg() def __init__(self, args): + """ + Initialize the MLOpsRuntimeLog. + + Args: + args: Input arguments. + + Attributes: + logger: Logger instance for logging. + args: Input arguments. + should_write_log_file: Boolean indicating whether log files should be written. + log_file_dir: Directory where log files are stored. + log_file: File handle for the log file. + run_id: The ID of the current run. + edge_id: The ID of the edge device (server or client). + origin_log_file_path: Path to the original log file. + + """ self.logger = None self.args = args if hasattr(args, "using_mlops"): @@ -92,13 +110,31 @@ def __init__(self, args): @staticmethod def get_instance(args): + """ + Get an instance of the MLOpsRuntimeLog. + + Args: + args: Input arguments. + + Returns: + MLOpsRuntimeLog: An instance of the log handler. + + """ if MLOpsRuntimeLog._log_sdk_instance is None: MLOpsRuntimeLog._log_sdk_instance = MLOpsRuntimeLog(args) return MLOpsRuntimeLog._log_sdk_instance def init_logs(self, show_stdout_log=True): - log_file_path, program_prefix = MLOpsRuntimeLog.build_log_file_path(self.args) + """ + Initialize logging. + + Args: + show_stdout_log (bool): Flag to control whether to show log messages on stdout. + + """ + log_file_path, program_prefix = MLOpsRuntimeLog.build_log_file_path( + self.args) logging.raiseExceptions = True self.logger = logging.getLogger(log_file_path) @@ -118,7 +154,8 @@ def formatTime(self, record, datefmt=None): if self.ntp_offset is None: self.ntp_offset = 0.0 - log_ntp_time = int((log_time * 1000 + self.ntp_offset) / 1000.0) + log_ntp_time = int( + (log_time * 1000 + self.ntp_offset) / 1000.0) ct = self.converter(log_ntp_time) if datefmt: s = ct.strftime(datefmt) @@ -156,6 +193,17 @@ def formatTime(self, record, datefmt=None): @staticmethod def build_log_file_path(in_args): + """ + Build the log file path based on input arguments. + + Args: + in_args: Input arguments. + + Returns: + str: Log file path. + str: Program prefix. + + """ if in_args.role == "server": if hasattr(in_args, "server_id"): edge_id = in_args.server_id @@ -182,7 +230,8 @@ def build_log_file_path(in_args): edge_id = in_args.edge_id else: edge_id = 0 - program_prefix = "FedML-Client @device-id-{edge}".format(edge=edge_id) + program_prefix = "FedML-Client @device-id-{edge}".format( + edge=edge_id) if not os.path.exists(in_args.log_file_dir): os.makedirs(in_args.log_file_dir, exist_ok=True) @@ -196,7 +245,8 @@ def build_log_file_path(in_args): if __name__ == "__main__": - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--log_file_dir", "-log", help="log file dir") parser.add_argument("--run_id", "-ri", type=str, help='run id') diff --git a/python/fedml/core/mlops/mlops_runtime_log_daemon.py b/python/fedml/core/mlops/mlops_runtime_log_daemon.py index 61f60e7b64..dc4fa1c9ea 100644 --- a/python/fedml/core/mlops/mlops_runtime_log_daemon.py +++ b/python/fedml/core/mlops/mlops_runtime_log_daemon.py @@ -19,13 +19,25 @@ class MLOpsRuntimeLogProcessor: FEDML_RUN_LOG_STATUS_DIR = "run_log_status" def __init__(self, using_mlops, log_run_id, log_device_id, log_file_dir, log_server_url, in_args=None): + """ + Initialize the MLOpsRuntimeLogProcessor. + + Args: + using_mlops: Whether MLOps is being used. + log_run_id: The ID of the log run. + log_device_id: The ID of the log device. + log_file_dir: The directory where log files are stored. + log_server_url: The URL of the log server. + in_args: Input arguments (system configuration). + """ self.args = in_args self.is_log_reporting = False self.log_reporting_status_file = os.path.join(log_file_dir, MLOpsRuntimeLogProcessor.FEDML_RUN_LOG_STATUS_DIR, MLOpsRuntimeLogProcessor.FEDML_LOG_REPORTING_STATUS_FILE_NAME + "-" + str(log_run_id) + ".conf") - os.makedirs(os.path.join(log_file_dir, MLOpsRuntimeLogProcessor.FEDML_RUN_LOG_STATUS_DIR), exist_ok=True) + os.makedirs(os.path.join( + log_file_dir, MLOpsRuntimeLogProcessor.FEDML_RUN_LOG_STATUS_DIR), exist_ok=True) self.logger = None self.should_upload_log_file = using_mlops self.log_file_dir = log_file_dir @@ -53,12 +65,28 @@ def __init__(self, using_mlops, log_run_id, log_device_id, log_file_dir, log_ser self.log_process_event = None def set_log_source(self, source): + """ + Set the source of the log. + + Args: + source: The source of the log. + """ self.log_source = source if source is not None: self.log_source = str(self.log_source).replace(' ', '') @staticmethod def build_log_file_path(in_args): + """ + Build the log file path based on input arguments. + + Args: + in_args: Input arguments (system configuration). + + Returns: + log_file_path: The path to the log file. + program_prefix: The prefix for the program's log. + """ if in_args.rank == 0: if hasattr(in_args, "server_id"): log_device_id = in_args.server_id @@ -67,7 +95,8 @@ def build_log_file_path(in_args): log_device_id = in_args.edge_id else: log_device_id = 0 - program_prefix = "FedML-Server({}) @device-id-{}".format(in_args.rank, log_device_id) + program_prefix = "FedML-Server({}) @device-id-{}".format( + in_args.rank, log_device_id) else: if hasattr(in_args, "client_id"): log_device_id = in_args.client_id @@ -82,7 +111,8 @@ def build_log_file_path(in_args): log_device_id = in_args.edge_id else: log_device_id = 0 - program_prefix = "FedML-Client({}) @device-id-{}".format(in_args.rank, log_device_id) + program_prefix = "FedML-Client({}) @device-id-{}".format( + in_args.rank, log_device_id) if not os.path.exists(in_args.log_file_dir): os.makedirs(in_args.log_file_dir, exist_ok=True) @@ -95,6 +125,13 @@ def build_log_file_path(in_args): return log_file_path, program_prefix def log_upload(self, run_id, device_id): + """ + Upload logs to the log server. + + Args: + run_id: The ID of the run. + device_id: The ID of the device. + """ # read log data from local log file log_lines = self.log_read() if log_lines is None or len(log_lines) <= 0: @@ -131,7 +168,8 @@ def log_upload(self, run_id, device_id): prev_line_prefix_list[2]) if not str(log_lines[index]).startswith('[FedML-'): - log_line = "{} {}".format(prev_line_prefix, log_lines[index]) + log_line = "{} {}".format( + prev_line_prefix, log_lines[index]) log_lines[index] = log_line index += 1 @@ -146,7 +184,8 @@ def log_upload(self, run_id, device_id): for log_index in range(len(upload_lines)): log_line = str(upload_lines[log_index]) if log_line.find(' [ERROR] ') != -1: - err_line_dict = {"errLine": self.log_uploaded_line_index + log_index, "errMsg": log_line} + err_line_dict = { + "errLine": self.log_uploaded_line_index + log_index, "errMsg": log_line} err_list.append(err_line_dict) log_upload_request = { @@ -165,10 +204,12 @@ def log_upload(self, run_id, device_id): if self.log_source is not None and self.log_source != "": log_upload_request["source"] = self.log_source - log_headers = {'Content-Type': 'application/json', 'Connection': 'close'} + log_headers = {'Content-Type': 'application/json', + 'Connection': 'close'} # send log data to the log server - _, cert_path = MLOpsConfigs.get_instance(self.args).get_request_params() + _, cert_path = MLOpsConfigs.get_instance( + self.args).get_request_params() if cert_path is not None: try: requests.session().verify = cert_path @@ -187,7 +228,8 @@ def log_upload(self, run_id, device_id): # logging.info(f"FedMLDebug POST log to server run_id {run_id}, device_id {device_id}. response.status_code: {response.status_code}") else: # logging.info(f"FedMLDebug POST log to server. run_id {run_id}, device_id {device_id}") - response = requests.post(self.log_server_url, headers=log_headers, json=log_upload_request) + response = requests.post( + self.log_server_url, headers=log_headers, json=log_upload_request) # logging.info(f"FedMLDebug POST log to server. run_id {run_id}, device_id {device_id}. response.status_code: {response.status_code}") if response.status_code != 200: pass @@ -201,6 +243,15 @@ def log_upload(self, run_id, device_id): @staticmethod def should_ignore_log_line(log_line): + """ + Determine whether to ignore a log line. + + Args: + log_line: The log line to check. + + Returns: + True if the log line should be ignored, False otherwise. + """ # if str is empty, then continue, will move it later if str(log_line) == '' or str(log_line) == '\n': return True @@ -215,8 +266,12 @@ def should_ignore_log_line(log_line): return False def log_process(self, process_event): - print(f"Log uploading process id {os.getpid()}, run id {self.run_id}, edge id {self.device_id}") + """ + Continuously upload log data to the log server. + Args: + process_event: Event object to control the log processing loop. + """ self.log_process_event = process_event while not self.should_stop(): @@ -230,6 +285,9 @@ def log_process(self, process_event): print("FedDebug log_process STOPPED") def log_relocation(self): + """ + Relocate the log file pointer to the last uploaded log line. + """ log_line_count = self.log_line_index self.log_uploaded_line_index = self.log_line_index while log_line_count > 0: @@ -246,6 +304,9 @@ def log_relocation(self): self.log_line_index = 0 def log_open(self): + """ + Open the log file for reading. + """ try: shutil.copyfile(self.origin_log_file_path, self.log_file_path) if self.log_file is None: @@ -255,6 +316,13 @@ def log_open(self): pass def log_read(self): + """ + Read log data from the log file. + + Returns: + log_lines: A list of log lines read from the file. + """ + self.log_open() if self.log_file is None: @@ -274,6 +342,13 @@ def log_read(self): @staticmethod def __generate_yaml_doc(log_config_object, yaml_file): + """ + Generate a YAML document from a configuration object and save it to a file. + + Args: + log_config_object: The configuration object to serialize. + yaml_file: The path to the YAML file to save. + """ try: file = open(yaml_file, "w", encoding="utf-8") yaml.dump(log_config_object, file) @@ -283,7 +358,15 @@ def __generate_yaml_doc(log_config_object, yaml_file): @staticmethod def __load_yaml_config(yaml_path): - """Helper function to load a yaml config file""" + """ + Load a YAML configuration file. + + Args: + yaml_path: The path to the YAML configuration file. + + Returns: + config_data: The loaded configuration data. + """ with open(yaml_path, "r") as stream: try: return yaml.safe_load(stream) @@ -291,23 +374,50 @@ def __load_yaml_config(yaml_path): raise ValueError("Yaml error - check yaml file") def save_log_config(self): + """ + Save the log configuration to a YAML file, including the log line index. + + This method saves the log line index to the log configuration YAML file + for resuming log processing where it left off. + + Raises: + Exception: If there is an error while saving the configuration. + """ try: - log_config_key = "log_config_{}_{}".format(self.run_id, self.device_id) + log_config_key = "log_config_{}_{}".format( + self.run_id, self.device_id) self.log_config[log_config_key] = dict() self.log_config[log_config_key]["log_line_index"] = self.log_line_index - MLOpsRuntimeLogProcessor.__generate_yaml_doc(self.log_config, self.log_config_file) + MLOpsRuntimeLogProcessor.__generate_yaml_doc( + self.log_config, self.log_config_file) except Exception as e: pass def load_log_config(self): + """ + Load the log configuration from a YAML file. + + This method loads the log configuration, including the log line index, + from the log configuration YAML file. + + Raises: + Exception: If there is an error while loading the configuration. + """ try: - log_config_key = "log_config_{}_{}".format(self.run_id, self.device_id) + log_config_key = "log_config_{}_{}".format( + self.run_id, self.device_id) self.log_config = self.__load_yaml_config(self.log_config_file) self.log_line_index = self.log_config[log_config_key]["log_line_index"] except Exception as e: pass def should_stop(self): + """ + Check if the log processing should stop. + + Returns: + bool: True if the log processing should stop; False otherwise. + """ if self.log_process_event is not None and self.log_process_event.is_set(): return True @@ -326,6 +436,22 @@ def __new__(cls, *args, **kwargs): return MLOpsRuntimeLogDaemon._instance def __init__(self, in_args): + """ + Initialize the MLOpsRuntimeLogDaemon. + + Args: + in_args: Input arguments passed to the daemon. + + Attributes: + args: Input arguments. + edge_id: The ID of the edge device (server or client). + log_server_url: The URL for the log server. + log_file_dir: Directory where log files are stored. + log_child_process_list: List to keep track of child log processing processes. + log_child_process: Reference to the child log processing process. + log_process_event: Event to control log processing. + + """ self.args = in_args if in_args.role == "server": @@ -366,16 +492,43 @@ def __init__(self, in_args): @staticmethod def get_instance(args): + """ + Get an instance of the MLOpsRuntimeLogDaemon. + + Args: + args: Input arguments. + + Returns: + MLOpsRuntimeLogDaemon: An instance of the log daemon. + + """ if MLOpsRuntimeLogDaemon._log_sdk_instance is None: - MLOpsRuntimeLogDaemon._log_sdk_instance = MLOpsRuntimeLogDaemon(args) + MLOpsRuntimeLogDaemon._log_sdk_instance = MLOpsRuntimeLogDaemon( + args) MLOpsRuntimeLogDaemon._log_sdk_instance.log_source = None return MLOpsRuntimeLogDaemon._log_sdk_instance def set_log_source(self, source): + """ + Set the source of log messages. + + Args: + source (str): The source of log messages. + + """ self.log_source = source def start_log_processor(self, log_run_id, log_device_id): + """ + Start a log processor for a specific run and device. + + Args: + log_run_id: The ID of the log run. + log_device_id: The ID of the log device. + + """ + log_processor = MLOpsRuntimeLogProcessor(self.args.using_mlops, log_run_id, log_device_id, self.log_file_dir, self.log_server_url, @@ -391,11 +544,21 @@ def start_log_processor(self, log_run_id, log_device_id): if self.log_child_process is not None: self.log_child_process.start() try: - self.log_child_process_list.index((self.log_child_process, log_run_id, log_device_id)) + self.log_child_process_list.index( + (self.log_child_process, log_run_id, log_device_id)) except ValueError as ex: - self.log_child_process_list.append((self.log_child_process, log_run_id, log_device_id)) + self.log_child_process_list.append( + (self.log_child_process, log_run_id, log_device_id)) def stop_log_processor(self, log_run_id, log_device_id): + """ + Stop a log processor for a specific run and device. + + Args: + log_run_id: The ID of the log run. + log_device_id: The ID of the log device. + + """ if log_run_id is None or log_device_id is None: return @@ -409,6 +572,10 @@ def stop_log_processor(self, log_run_id, log_device_id): break def stop_all_log_processor(self): + """ + Stop all running log processors. + + """ for (log_child_process, _, _) in self.log_child_process_list: if self.log_process_event is not None: self.log_process_event.set() @@ -417,11 +584,13 @@ def stop_all_log_processor(self): if __name__ == "__main__": - parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--log_file_dir", "-log", help="log file dir") parser.add_argument("--rank", "-r", type=str, default="1") parser.add_argument("--client_id_list", "-cil", type=str, default="[]") - parser.add_argument("--log_server_url", "-lsu", type=str, default="http://") + parser.add_argument("--log_server_url", "-lsu", + type=str, default="http://") args = parser.parse_args() setattr(args, "using_mlops", True) @@ -429,7 +598,8 @@ def stop_all_log_processor(self): run_id = 9998 device_id = 1 - MLOpsRuntimeLogDaemon.get_instance(args).start_log_processor(run_id, device_id) + MLOpsRuntimeLogDaemon.get_instance( + args).start_log_processor(run_id, device_id) while True: time.sleep(1) diff --git a/python/fedml/core/mlops/mlops_status.py b/python/fedml/core/mlops/mlops_status.py index 1b166aca91..0146da384c 100644 --- a/python/fedml/core/mlops/mlops_status.py +++ b/python/fedml/core/mlops/mlops_status.py @@ -5,6 +5,21 @@ class MLOpsStatus(Singleton): _status_instance = None def __init__(self): + """ + Initialize an instance of MLOpsStatus. + + This class is a Singleton and should not be instantiated directly. + Use the `get_instance` method to obtain the Singleton instance. + + Attributes: + messenger: Messenger object for communication. + run_id: The ID of the current run. + edge_id: The ID of the edge device. + client_agent_status: A dictionary to store client agent status. + server_agent_status: A dictionary to store server agent status. + client_status: A dictionary to store client status. + server_status: A dictionary to store server status. + """ self.messenger = None self.run_id = None self.edge_id = None @@ -15,31 +30,101 @@ def __init__(self): @staticmethod def get_instance(): + """ + Get the Singleton instance of MLOpsStatus. + + Returns: + MLOpsStatus: The Singleton instance of MLOpsStatus. + """ if MLOpsStatus._status_instance is None: MLOpsStatus._status_instance = MLOpsStatus() return MLOpsStatus._status_instance def set_client_agent_status(self, edge_id, status): + """ + Set the status of a client agent. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client agent. + """ self.client_agent_status[edge_id] = status def set_server_agent_status(self, edge_id, status): + """ + Set the status of a server agent. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the server agent. + """ self.server_agent_status[edge_id] = status def set_client_status(self, edge_id, status): + """ + Set the status of a client. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the client. + """ self.client_status[edge_id] = status def set_server_status(self, edge_id, status): + """ + Set the status of a server. + + Args: + edge_id (int): The ID of the edge device. + status (str): The status of the server. + """ self.server_status[edge_id] = status def get_client_agent_status(self, edge_id): + """ + Get the status of a client agent. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the client agent, or None if not found. + """ return self.client_agent_status.get(edge_id, None) def get_server_agent_status(self, edge_id): + """ + Get the status of a server agent. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the server agent, or None if not found. + """ return self.server_agent_status.get(edge_id, None) def get_client_status(self, edge_id): + """ + Get the status of a client. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the client, or None if not found. + """ return self.client_status.get(edge_id, None) def get_server_status(self, edge_id): + """ + Get the status of a server. + + Args: + edge_id (int): The ID of the edge device. + + Returns: + str or None: The status of the server, or None if not found. + """ return self.server_status.get(edge_id, None) diff --git a/python/fedml/core/mlops/mlops_utils.py b/python/fedml/core/mlops/mlops_utils.py index e8d63088bf..a59d39aa00 100644 --- a/python/fedml/core/mlops/mlops_utils.py +++ b/python/fedml/core/mlops/mlops_utils.py @@ -4,11 +4,23 @@ class MLOpsUtils: + """ + Class for MLOps utilities. + """ _ntp_offset = None BYTES_TO_GB = 1 / (1024 * 1024 * 1024) @staticmethod def calc_ntp_from_config(mlops_config): + """ + Calculate NTP time offset from MLOps configuration. + + Args: + mlops_config (dict): MLOps configuration containing NTP response data. + + Returns: + None: If the necessary NTP response data is missing or invalid. + """ if mlops_config is None: return @@ -25,7 +37,8 @@ def calc_ntp_from_config(mlops_config): return # calculate the time offset(int) - ntp_time = (server_recv_time + server_send_time + device_recv_time - device_send_time) // 2 + ntp_time = (server_recv_time + server_send_time + + device_recv_time - device_send_time) // 2 ntp_offset = ntp_time - device_recv_time # set the time offset @@ -33,20 +46,44 @@ def calc_ntp_from_config(mlops_config): @staticmethod def set_ntp_offset(ntp_offset): + """ + Set the NTP time offset. + + Args: + ntp_offset (int): The NTP time offset. + """ MLOpsUtils._ntp_offset = ntp_offset @staticmethod def get_ntp_time(): + """ + Get the current time adjusted by the NTP offset. + + Returns: + int: The NTP-adjusted current time in milliseconds. + """ if MLOpsUtils._ntp_offset is not None: return int(time.time() * 1000) + MLOpsUtils._ntp_offset return int(time.time() * 1000) @staticmethod def get_ntp_offset(): + """ + Get the current NTP time offset. + + Returns: + int: The NTP time offset. + """ return MLOpsUtils._ntp_offset @staticmethod def write_log_trace(log_trace): + """ + Write a log trace to a file in the "fedml_log" directory. + + Args: + log_trace (str): The log trace to write. + """ log_trace_dir = os.path.join(expanduser("~"), "fedml_log") if not os.path.exists(log_trace_dir): os.makedirs(log_trace_dir, exist_ok=True) diff --git a/python/fedml/core/mlops/stats_impl.py b/python/fedml/core/mlops/stats_impl.py index 51e59e48b9..f1ab974609 100644 --- a/python/fedml/core/mlops/stats_impl.py +++ b/python/fedml/core/mlops/stats_impl.py @@ -28,6 +28,16 @@ def gpu_in_use_by_this_process(gpu_handle: GPUHandle, pid: int) -> bool: + """ + Check if a GPU is in use by a specified process. + + Args: + gpu_handle (GPUHandle): Handle to the GPU to check. + pid (int): The process ID of the target process. + + Returns: + bool: True if the GPU is in use by the specified process; False otherwise. + """ if not psutil: return False @@ -67,6 +77,16 @@ class WandbSystemStats: gpu_count: int def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None: + """ + Initialize the WandbSystemStats instance. + + Args: + settings (SettingsStatic): Settings for system stats tracking. + interface (InterfaceQueue): Interface for publishing stats. + + Raises: + Exception: An exception is raised if GPU initialization fails. + """ try: pynvml.nvmlInit() self.gpu_count = pynvml.nvmlDeviceGetCount() @@ -82,7 +102,8 @@ def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None: self._telem = telemetry.TelemetryRecord() if psutil: net = psutil.net_io_counters() - self.network_init = {"sent": net.bytes_sent, "recv": net.bytes_recv} + self.network_init = { + "sent": net.bytes_sent, "recv": net.bytes_recv} else: wandb.termlog( "psutil not installed, only GPU stats will be reported. Install with pip install psutil" @@ -105,6 +126,9 @@ def __init__(self, settings: SettingsStatic, interface: InterfaceQueue) -> None: wandb.termlog("Error initializing IPUProfiler: " + str(e)) def start(self) -> None: + """ + Start the system stats tracking thread. + """ if self._thread is None: self._shutdown = False self._thread = threading.Thread(target=self._thread_body) @@ -117,23 +141,42 @@ def start(self) -> None: @property def proc(self) -> psutil.Process: + """ + Get the process associated with the current PID. + + Returns: + psutil.Process: A process object for the current PID. + """ return psutil.Process(pid=self._pid) @property def sample_rate_seconds(self) -> float: - """Sample system stats every this many seconds, defaults to 2, min is 0.5""" + """ + Get the system stats sampling rate in seconds. + + Returns: + float: The system stats sampling rate in seconds. + """ sample_rate = self._settings._stats_sample_rate_seconds # TODO: handle self._api.dynamic_settings["system_sample_seconds"] return max(0.5, sample_rate) @property def samples_to_average(self) -> int: - """The number of samples to average before pushing, defaults to 15 valid range (2:30)""" + """ + Get the number of samples to average before pushing. + + Returns: + int: The number of samples to average. + """ samples = self._settings._stats_samples_to_average # TODO: handle self._api.dynamic_settings["system_samples"] return min(30, max(2, samples)) def _thread_body(self) -> None: + """ + Body of the system stats tracking thread. + """ while True: stats = self.stats() for stat, value in stats.items(): @@ -154,6 +197,9 @@ def _thread_body(self) -> None: return def shutdown(self) -> None: + """ + Shutdown the system stats tracking thread. + """ self._shutdown = True try: if self._thread is not None: @@ -164,6 +210,9 @@ def shutdown(self) -> None: self._tpu_profiler.stop() def flush(self) -> None: + """ + Flush and publish system stats. + """ stats = self.stats() for stat, value in stats.items(): # TODO: a bit hacky, we assume all numbers should be averaged. If you want @@ -189,7 +238,8 @@ def stats(self) -> StatsDict: temp = pynvml.nvmlDeviceGetTemperature( handle, pynvml.NVML_TEMPERATURE_GPU ) - in_use_by_us = gpu_in_use_by_this_process(handle, pid=self._pid) + in_use_by_us = gpu_in_use_by_this_process( + handle, pid=self._pid) stats["gpu.{}.{}".format(i, "gpu")] = utilz.gpu stats["gpu.{}.{}".format(i, "memory")] = utilz.memory @@ -200,7 +250,8 @@ def stats(self) -> StatsDict: if in_use_by_us: stats["gpu.process.{}.{}".format(i, "gpu")] = utilz.gpu - stats["gpu.process.{}.{}".format(i, "memory")] = utilz.memory + stats["gpu.process.{}.{}".format( + i, "memory")] = utilz.memory stats["gpu.process.{}.{}".format(i, "memoryAllocated")] = ( memory.used / float(memory.total) ) * 100 @@ -208,17 +259,23 @@ def stats(self) -> StatsDict: # Some GPUs don't provide information about power usage try: - power_watts = pynvml.nvmlDeviceGetPowerUsage(handle) / 1000.0 + power_watts = pynvml.nvmlDeviceGetPowerUsage( + handle) / 1000.0 power_capacity_watts = ( - pynvml.nvmlDeviceGetEnforcedPowerLimit(handle) / 1000.0 + pynvml.nvmlDeviceGetEnforcedPowerLimit( + handle) / 1000.0 ) - power_usage = (power_watts / power_capacity_watts) * 100 + power_usage = ( + power_watts / power_capacity_watts) * 100 - stats["gpu.{}.{}".format(i, "powerWatts")] = power_watts - stats["gpu.{}.{}".format(i, "powerPercent")] = power_usage + stats["gpu.{}.{}".format( + i, "powerWatts")] = power_watts + stats["gpu.{}.{}".format( + i, "powerPercent")] = power_usage if in_use_by_us: - stats["gpu.process.{}.{}".format(i, "powerWatts")] = power_watts + stats["gpu.process.{}.{}".format( + i, "powerWatts")] = power_watts stats[ "gpu.process.{}.{}".format(i, "powerPercent") ] = power_usage @@ -238,9 +295,11 @@ def stats(self) -> StatsDict: and self.gpu_count == 0 ): try: - out = subprocess.check_output([util.apple_gpu_stats_binary(), "--json"]) + out = subprocess.check_output( + [util.apple_gpu_stats_binary(), "--json"]) m1_stats = json.loads(out.split(b"\n")[0]) - stats["gpu.0.memory"] = m1_stats["mem_used"] / float(m1_stats["utilization"]/100) + stats["gpu.0.memory"] = m1_stats["mem_used"] / \ + float(m1_stats["utilization"]/100) stats["gpu.0.gpu"] = m1_stats["utilization"] stats["gpu.0.memoryAllocated"] = m1_stats["mem_used"] stats["gpu.0.temp"] = m1_stats["temperature"] @@ -274,7 +333,8 @@ def stats(self) -> StatsDict: stats["disk"] = psutil.disk_usage("/").percent stats["proc.memory.availableMB"] = sysmem.available / 1048576.0 try: - stats["proc.memory.rssMB"] = self.proc.memory_info().rss / 1048576.0 + stats["proc.memory.rssMB"] = self.proc.memory_info().rss / \ + 1048576.0 stats["proc.memory.percent"] = self.proc.memory_percent() stats["proc.cpu.threads"] = self.proc.num_threads() except psutil.NoSuchProcess: diff --git a/python/fedml/core/mlops/system_stats.py b/python/fedml/core/mlops/system_stats.py index bdbd9e7f55..8e82182b7f 100755 --- a/python/fedml/core/mlops/system_stats.py +++ b/python/fedml/core/mlops/system_stats.py @@ -6,6 +6,28 @@ class SysStats: def __init__(self, process_id=None): + """ + Initialize the SysStats object. + + Args: + process_id (int): Optional process ID. Defaults to None. + + Attributes: + sys_stats_impl (WandbSystemStats): Instance of WandbSystemStats for collecting system statistics. + gpu_time_spent_accessing_memory (float): GPU time spent accessing memory. + gpu_power_usage (float): GPU power usage. + gpu_temp (float): GPU temperature. + gpu_memory_allocated (float): GPU memory allocated. + gpu_utilization (float): GPU utilization. + network_traffic (float): Network traffic. + disk_utilization (float): Disk utilization. + process_cpu_threads_in_use (int): Number of CPU threads in use by the process. + process_memory_available (float): Available process memory. + process_memory_in_use (float): Process memory in use. + process_memory_in_use_size (float): Process memory in use (size). + system_memory_utilization (float): System memory utilization. + cpu_utilization (float): CPU utilization. + """ settings = SettingsStatic(d={"_stats_pid": os.getpid() if process_id is None else process_id}) self.sys_stats_impl = WandbSystemStats(settings=settings, interface=None) self.gpu_time_spent_accessing_memory = 0.0 @@ -23,6 +45,9 @@ def __init__(self, process_id=None): self.cpu_utilization = 0.0 def produce_info(self): + """ + Collect system statistics and update attributes. + """ stats = self.sys_stats_impl.stats() self.cpu_utilization = stats.get("cpu", 0.0) From 048d7a95960e5ac7afc3b6ef6dea563bc5efc377 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sat, 23 Sep 2023 11:57:33 +0530 Subject: [PATCH 67/70] add docstrins --- .../communication/grpc/grpc_server.py | 32 +++ .../communication/grpc/ip_config_utils.py | 9 + .../core/distributed/communication/message.py | 100 +++++++ .../mqtt_thetastore_comm_manager.py | 197 +++++++++++-- .../communication/s3/remote_storage.py | 263 ++++++++++++++---- .../communication/s3/remote_storage_mnn.py | 90 +++++- .../distributed/communication/s3/utils.py | 65 ++++- .../communication/trpc/trpc_comm_manager.py | 100 ++++++- .../communication/trpc/trpc_server.py | 32 +++ .../distributed/communication/trpc/utils.py | 22 +- .../core/distributed/communication/utils.py | 32 +++ .../core/distributed/crypto/crypto_api.py | 37 ++- .../theta_storage/theta_storage.py | 77 +++-- .../web3_storage/web3_storage.py | 35 ++- .../core/distributed/fedml_comm_manager.py | 150 +++++++++- .../core/distributed/flow/fedml_executor.py | 112 +++++++- .../fedml/core/distributed/flow/fedml_flow.py | 198 +++++++++++++ .../core/distributed/flow/test_fedml_flow.py | 72 +++++ .../topology/asymmetric_topology_manager.py | 54 ++++ .../topology/symmetric_topology_manager.py | 53 ++++ 20 files changed, 1589 insertions(+), 141 deletions(-) diff --git a/python/fedml/core/distributed/communication/grpc/grpc_server.py b/python/fedml/core/distributed/communication/grpc/grpc_server.py index de169295aa..67d182cc29 100644 --- a/python/fedml/core/distributed/communication/grpc/grpc_server.py +++ b/python/fedml/core/distributed/communication/grpc/grpc_server.py @@ -10,6 +10,18 @@ class GRPCCOMMServicer(grpc_comm_manager_pb2_grpc.gRPCCommManagerServicer): def __init__(self, host, port, client_num, client_id): + """ + Initializes the gRPC Communication Servicer. + + Args: + host (str): The IP address of the server. + port (int): The port number. + client_num (int): The number of clients. + client_id (int): The client ID. + + Returns: + None + """ # host is the ip address of server self.host = host self.port = port @@ -24,6 +36,16 @@ def __init__(self, host, port, client_num, client_id): self.message_q = queue.Queue() def sendMessage(self, request, context): + """ + Handles the gRPC sendMessage request. + + Args: + request (grpc_comm_manager_pb2.CommRequest): The request message. + context (grpc.ServicerContext): The context of the request. + + Returns: + grpc_comm_manager_pb2.CommResponse: The response message. + """ context_ip = context.peer().split(":")[1] logging.info( "client_{} got something from client_{} from ip address {}".format( @@ -39,4 +61,14 @@ def sendMessage(self, request, context): return response def handleReceiveMessage(self, request, context): + """ + Handles the gRPC handleReceiveMessage request. + + Args: + request (grpc_comm_manager_pb2.CommRequest): The request message. + context (grpc.ServicerContext): The context of the request. + + Returns: + None + """ pass diff --git a/python/fedml/core/distributed/communication/grpc/ip_config_utils.py b/python/fedml/core/distributed/communication/grpc/ip_config_utils.py index 1ebedfd73a..77df94701a 100644 --- a/python/fedml/core/distributed/communication/grpc/ip_config_utils.py +++ b/python/fedml/core/distributed/communication/grpc/ip_config_utils.py @@ -2,6 +2,15 @@ def build_ip_table(path): + """ + Builds an IP table from a CSV file. + + Args: + path (str): The path to the CSV file containing receiver IDs and IP addresses. + + Returns: + dict: A dictionary mapping receiver IDs to IP addresses. + """ ip_config = dict() with open(path, newline="") as csv_file: csv_reader = csv.reader(csv_file) diff --git a/python/fedml/core/distributed/communication/message.py b/python/fedml/core/distributed/communication/message.py index 7d465461e5..df2c2a66a0 100644 --- a/python/fedml/core/distributed/communication/message.py +++ b/python/fedml/core/distributed/communication/message.py @@ -3,6 +3,9 @@ class Message(object): + """ + A class for representing and working with messages in a communication system. + """ MSG_ARG_KEY_OPERATION = "operation" MSG_ARG_KEY_TYPE = "msg_type" @@ -19,6 +22,14 @@ class Message(object): MSG_ARG_KEY_MODEL_PARAMS_KEY = "model_params_key" def __init__(self, type="default", sender_id=0, receiver_id=0): + """ + Initialize a Message instance. + + Args: + type (str): The type of the message. + sender_id (int): The ID of the sender. + receiver_id (int): The ID of the receiver. + """ self.type = str(type) self.sender_id = sender_id self.receiver_id = receiver_id @@ -28,56 +39,145 @@ def __init__(self, type="default", sender_id=0, receiver_id=0): self.msg_params[Message.MSG_ARG_KEY_RECEIVER] = receiver_id def init(self, msg_params): + """ + Initialize the message with the provided message parameters. + + Args: + msg_params (dict): A dictionary of message parameters. + """ self.msg_params = msg_params def init_from_json_string(self, json_string): + """ + Initialize the message from a JSON string. + + Args: + json_string (str): A JSON string representing the message. + """ self.msg_params = json.loads(json_string) self.type = self.msg_params[Message.MSG_ARG_KEY_TYPE] self.sender_id = self.msg_params[Message.MSG_ARG_KEY_SENDER] self.receiver_id = self.msg_params[Message.MSG_ARG_KEY_RECEIVER] def init_from_json_object(self, json_object): + """ + Initialize the message from a JSON object. + + Args: + json_object (dict): A JSON object representing the message. + """ self.msg_params = json_object self.type = self.msg_params[Message.MSG_ARG_KEY_TYPE] self.sender_id = self.msg_params[Message.MSG_ARG_KEY_SENDER] self.receiver_id = self.msg_params[Message.MSG_ARG_KEY_RECEIVER] def get_sender_id(self): + """ + Get the ID of the sender. + + Returns: + int: The sender's ID. + """ return self.sender_id def get_receiver_id(self): + """ + Get the ID of the receiver. + + Returns: + int: The receiver's ID. + """ return self.receiver_id def add_params(self, key, value): + """ + Add a parameter to the message. + + Args: + key (str): The key of the parameter. + value (any): The value of the parameter. + """ self.msg_params[key] = value def get_params(self): + """ + Get all the parameters of the message. + + Returns: + dict: A dictionary of message parameters. + """ return self.msg_params def add(self, key, value): + """ + Add a parameter to the message (alias for add_params). + + Args: + key (str): The key of the parameter. + value (any): The value of the parameter. + """ self.msg_params[key] = value def get(self, key): + """ + Get the value of a parameter by its key. + + Args: + key (str): The key of the parameter. + + Returns: + any: The value of the parameter or None if not found. + """ if key not in self.msg_params.keys(): return None return self.msg_params[key] def get_type(self): + """ + Get the type of the message. + + Returns: + str: The type of the message. + """ return self.msg_params[Message.MSG_ARG_KEY_TYPE] def to_string(self): + """ + Convert the message to a string representation. + + Returns: + dict: A dictionary representing the message. + """ return self.msg_params def to_json(self): + """ + Serialize the message to a JSON string. + + Returns: + str: A JSON string representing the message. + """ json_string = json.dumps(self.msg_params) print("json string size = " + str(sys.getsizeof(json_string))) return json_string def get_content(self): + """ + Get a human-readable representation of the message. + + Returns: + str: A string representing the message content. + """ print_dict = self.msg_params.copy() msg_str = str(self.__to_msg_type_string()) + ": " + str(print_dict) return msg_str def __to_msg_type_string(self): + """ + Get a string representation of the message type. + + Returns: + str: A string representing the message type. + """ type = self.msg_params[Message.MSG_ARG_KEY_TYPE] return type diff --git a/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py b/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py index b1e2cfa3a4..d21a92afda 100755 --- a/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py +++ b/python/fedml/core/distributed/communication/mqtt_thetastore/mqtt_thetastore_comm_manager.py @@ -28,6 +28,20 @@ def __init__( client_num=0, args=None ): + """ + Initializes an MQTT-based ThetaStore Communication Manager. + + Args: + config_path (str): The path to the MQTT configuration file. + thetastore_config_path (str): The path to the ThetaStore configuration file. + topic (str, optional): The MQTT topic. Defaults to "fedml". + client_rank (int, optional): The client rank. Defaults to 0. + client_num (int, optional): The number of clients. Defaults to 0. + args (object, optional): Additional arguments. + + Returns: + None + """ self.broker_port = None self.broker_host = None self.mqtt_user = None @@ -44,7 +58,8 @@ def __init__( self.client_real_ids = [] if args.client_id_list is not None: logging.info( - "MqttThetastoreCommManager args client_id_list: " + str(args.client_id_list) + "MqttThetastoreCommManager args client_id_list: " + + str(args.client_id_list) ) self.client_real_ids = json.loads(args.client_id_list) @@ -91,7 +106,8 @@ def __init__( if args.rank == 0: self.top_active_msg = CommunicationConstants.SERVER_TOP_ACTIVE_MSG self.topic_last_will_msg = CommunicationConstants.SERVER_TOP_LAST_WILL_MSG - self.last_will_msg = json.dumps({"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) + self.last_will_msg = json.dumps( + {"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) self.mqtt_mgr = MqttManager(self.broker_host, self.broker_port, self.mqtt_user, self.mqtt_pwd, self.keepalive_time, self._client_id, self.topic_last_will_msg, @@ -104,6 +120,12 @@ def __init__( @property def client_id(self): + """ + Runs the MQTT message loop forever. + + Returns: + None + """ return self._client_id @property @@ -115,6 +137,14 @@ def run_loop_forever(self): def on_connected(self, mqtt_client_object): """ + Callback function when MQTT client is connected. + + Args: + mqtt_client_object (MqttManager): The MQTT client object. + + Returns: + None + [server] sending message topic (publish): serverID_clientID receiving message topic (subscribe): clientID @@ -135,7 +165,8 @@ def on_connected(self, mqtt_client_object): # logging.info("self.client_real_ids = {}".format(self.client_real_ids)) for client_rank in range(0, self.client_num): - real_topic = self._topic + str(self.client_real_ids[client_rank]) + real_topic = self._topic + \ + str(self.client_real_ids[client_rank]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) # logging.info( @@ -146,7 +177,8 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() else: # client - real_topic = self._topic + str(self.server_id) + "_" + str(self.client_real_ids[0]) + real_topic = self._topic + \ + str(self.server_id) + "_" + str(self.client_real_ids[0]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) self._notify_connection_ready() @@ -158,12 +190,39 @@ def on_connected(self, mqtt_client_object): self.is_connected = True def on_disconnected(self, mqtt_client_object): + """ + Callback function when MQTT client is disconnected. + + Args: + mqtt_client_object (MqttManager): The MQTT client object. + + Returns: + None + """ self.is_connected = False def add_observer(self, observer: Observer): + """ + Adds an observer to the communication manager. + + Args: + observer (Observer): The observer to be added. + + Returns: + None + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Removes an observer from the communication manager. + + Args: + observer (Observer): The observer to be removed. + + Returns: + None + """ self._observers.remove(observer) def _notify_connection_ready(self): @@ -185,7 +244,8 @@ def _on_message_impl(self, msg): payload_obj = json.loads(json_payload) sender_id = payload_obj.get(Message.MSG_ARG_KEY_SENDER, "") receiver_id = payload_obj.get(Message.MSG_ARG_KEY_RECEIVER, "") - thetastore_key_str = payload_obj.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + thetastore_key_str = payload_obj.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") thetastore_key_str = str(thetastore_key_str).strip(" ") if thetastore_key_str != "": @@ -195,10 +255,12 @@ def _on_message_impl(self, msg): model_params = self.theta_storage.read_model(thetastore_key_str) Context().add("received_model_cid", thetastore_key_str) - logging.info("Received model cid {}".format(Context().get("received_model_cid"))) + logging.info("Received model cid {}".format( + Context().get("received_model_cid"))) logging.info( - "mqtt_thetastore.on_message: model params length %d" % len(model_params) + "mqtt_thetastore.on_message: model params length %d" % len( + model_params) ) # replace the thetastore object key with raw model params @@ -213,6 +275,14 @@ def _on_message(self, msg): def send_message(self, msg: Message): """ + Sends a message using MQTT. + + Args: + msg (Message): The message to be sent. + + Returns: + None + [server] sending message topic (publish): fedml_runid_serverID_clientID receiving message topic (subscribe): fedml_runid_clientID @@ -227,16 +297,20 @@ def send_message(self, msg: Message): if self.client_id == 0: # topic = "fedml" + "_" + "run_id" + "_0" + "_" + "client_id" topic = self._topic + str(self.server_id) + "_" + str(receiver_id) - logging.info("mqtt_thetastore.send_message: msg topic = %s" % str(topic)) + logging.info( + "mqtt_thetastore.send_message: msg topic = %s" % str(topic)) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") if model_params_obj != "": # thetastore logging.info("mqtt_thetastore.send_message: to python client.") - message_key = model_url = self.theta_storage.write_model(model_params_obj) + message_key = model_url = self.theta_storage.write_model( + model_params_obj) Context().add("sent_model_cid", model_url) - logging.info("Sent model cid {}".format(Context().get("sent_model_cid"))) + logging.info("Sent model cid {}".format( + Context().get("sent_model_cid"))) logging.info( "mqtt_thetastore.send_message: thetastore+MQTT msg sent, thetastore message key = %s" % message_key @@ -261,12 +335,15 @@ def send_message(self, msg: Message): topic = self._topic + str(msg.get_sender_id()) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") if model_params_obj != "": # thetastore - message_key = model_url = self.theta_storage.write_model(model_params_obj) + message_key = model_url = self.theta_storage.write_model( + model_params_obj) Context().add("sent_model_cid", model_url) - logging.info("Sent model cid {}".format(Context().get("sent_model_cid"))) + logging.info("Sent model cid {}".format( + Context().get("sent_model_cid"))) logging.info( "mqtt_thetastore.send_message: thetastore+MQTT msg sent, message_key = %s" % message_key @@ -286,20 +363,52 @@ def send_message(self, msg: Message): self.mqtt_mgr.send_message(topic, json.dumps(payload)) def send_message_json(self, topic_name, json_message): + """ + Sends a JSON message using MQTT. + + Args: + topic_name (str): The MQTT topic name. + json_message (str): The JSON message to be sent. + + Returns: + None + """ self.mqtt_mgr.send_message_json(topic_name, json_message) def handle_receive_message(self): + """ + Handles the reception of messages. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) self.run_loop_forever() - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) def stop_receive_message(self): + """ + Stops the reception of messages and disconnects the MQTT client. + + Returns: + None + """ logging.info("mqtt_thetastore.stop_receive_message: stopping...") self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() def set_config_from_file(self, config_file_path): + """ + Sets the MQTT configuration from a file. + + Args: + config_file_path (str): The path to the MQTT configuration file. + + Returns: + None + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -315,6 +424,15 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, mqtt_config): + """ + Sets the MQTT configuration from an object. + + Args: + mqtt_config (dict): The MQTT configuration. + + Returns: + None + """ self.broker_host = mqtt_config["BROKER_HOST"] self.broker_port = mqtt_config["BROKER_PORT"] self.mqtt_user = None @@ -325,21 +443,49 @@ def set_config_from_objects(self, mqtt_config): self.mqtt_pwd = mqtt_config["MQTT_PWD"] def callback_client_last_will_msg(self, topic, payload): + """ + Callback function for processing client last will messages. + + Args: + topic (str): The MQTT topic. + payload (str): The message payload. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) if edge_id is not None and status == CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE: if self.client_active_list.get(edge_id, None) is not None: self.client_active_list.pop(edge_id) def callback_client_active_msg(self, topic, payload): + """ + Callback function for processing client active status messages. + + Args: + topic (str): The MQTT topic. + payload (str): The message payload. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) if edge_id is not None: self.client_active_list[edge_id] = status def subscribe_client_status_message(self): + """ + Subscribes to client status messages. + + Returns: + None + """ # Setup MQTT message listener to the last will message form the client. self.mqtt_mgr.add_message_listener(CommunicationConstants.CLIENT_TOP_LAST_WILL_MSG, self.callback_client_last_will_msg) @@ -349,11 +495,26 @@ def subscribe_client_status_message(self): self.callback_client_active_msg) def get_client_status(self, client_id): + """ + Gets the status of a specific client. + + Args: + client_id (int): The client ID. + + Returns: + str: The status of the client. + """ return self.client_active_list.get(client_id, CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) def get_client_list_status(self): + """ + Gets the status of all clients. + + Returns: + dict: A dictionary of client statuses. + """ return self.client_active_list if __name__ == "__main__": - pass \ No newline at end of file + pass diff --git a/python/fedml/core/distributed/communication/s3/remote_storage.py b/python/fedml/core/distributed/communication/s3/remote_storage.py index f9e3416b34..22fc82b780 100644 --- a/python/fedml/core/distributed/communication/s3/remote_storage.py +++ b/python/fedml/core/distributed/communication/s3/remote_storage.py @@ -26,6 +26,15 @@ class S3Storage: def __init__(self, s3_config_path): + """ + Initializes an S3MNNStorage instance with S3 configuration. + + Args: + s3_config_path (str): The path to the S3 configuration file. + + Returns: + None + """ self.bucket_name = None self.cn_region_name = None self.cn_s3_sak = None @@ -49,6 +58,16 @@ def __init__(self, s3_config_path): ) def write_model(self, message_key, model): + """ + Writes a machine learning model to S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The machine learning model to be stored. + + Returns: + str: The URL of the stored model in S3. + """ global aws_s3_client pickle_dump_start_time = time.time() MLOpsProfilerEvent.log_to_wandb( @@ -62,19 +81,23 @@ def write_model(self, message_key, model): model_file_size = len(model_to_send) model_file_transfered = 0 prev_progress = 0 + def upload_model_progress(bytes_transferred): nonlocal model_file_transfered nonlocal model_file_size - nonlocal prev_progress # since the callback is stateless, we need to keep the previous progress + # since the callback is stateless, we need to keep the previous progress + nonlocal prev_progress model_file_transfered += bytes_transferred uploaded_kb = format(model_file_transfered / 1024, '.2f') - progress = (model_file_transfered / model_file_size * 100) if model_file_size != 0 else 0 + progress = (model_file_transfered / model_file_size * + 100) if model_file_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model uploaded to S3 size {} KB, progress {}%".format(uploaded_kb, progress_format_int)) + logging.info("model uploaded to S3 size {} KB, progress {}%".format( + uploaded_kb, progress_format_int)) prev_progress = progress_format_int - + aws_s3_client.upload_fileobj( Fileobj=io.BytesIO(model_to_send), Bucket=self.bucket_name, Key=message_key, Callback=upload_model_progress, @@ -90,6 +113,16 @@ def upload_model_progress(bytes_transferred): return model_url def write_model_net(self, message_key, model, dummy_input_tensor, local_model_cache_path): + """ + Writes a machine learning model to S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The machine learning model to be stored. + + Returns: + str: The URL of the stored model in S3. + """ global aws_s3_client pickle_dump_start_time = time.time() MLOpsProfilerEvent.log_to_wandb( @@ -117,21 +150,25 @@ def write_model_net(self, message_key, model, dummy_input_tensor, local_model_ca model_to_send.seek(0, 0) net_file_transfered = 0 prev_progress = 0 + def upload_model_net_progress(bytes_transferred): nonlocal net_file_transfered nonlocal net_file_size - nonlocal prev_progress # since the callback is stateless, we need to keep the previous progress + # since the callback is stateless, we need to keep the previous progress + nonlocal prev_progress net_file_transfered += bytes_transferred uploaded_kb = format(net_file_transfered / 1024, '.2f') - progress = (net_file_transfered / net_file_size * 100) if net_file_size != 0 else 0 + progress = (net_file_transfered / net_file_size * + 100) if net_file_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model net uploaded to S3 size {} KB, progress {}%".format(uploaded_kb, progress_format_int)) + logging.info("model net uploaded to S3 size {} KB, progress {}%".format( + uploaded_kb, progress_format_int)) prev_progress = progress_format_int aws_s3_client.upload_fileobj( Fileobj=model_to_send, Bucket=self.bucket_name, Key=message_key, - Callback= upload_model_net_progress, + Callback=upload_model_net_progress, ) MLOpsProfilerEvent.log_to_wandb( {"Comm/send_delay": time.time() - s3_upload_start_time} @@ -144,6 +181,18 @@ def upload_model_net_progress(bytes_transferred): return model_url def write_model_input(self, message_key, input_size, input_type, local_model_cache_path): + """ + Writes model input information to S3 storage. + + Args: + message_key (str): The key to identify the stored input information in S3. + input_size: The size of the model input. + input_type: The type of the model input. + local_model_cache_path (str): The local cache path for input information storage. + + Returns: + str: The URL of the stored input information in S3. + """ global aws_s3_client if not os.path.exists(local_model_cache_path): @@ -157,7 +206,8 @@ def write_model_input(self, message_key, input_size, input_type, local_model_cac json.dump(model_input_dict, f) with open(model_input_path, 'rb') as f: - aws_s3_client.upload_fileobj(f, Bucket=self.bucket_name, Key=message_key) + aws_s3_client.upload_fileobj( + f, Bucket=self.bucket_name, Key=message_key) model_input_url = aws_s3_client.generate_presigned_url("get_object", ExpiresIn=60 * 60 * 24 * 5, @@ -165,6 +215,16 @@ def write_model_input(self, message_key, input_size, input_type, local_model_cac return model_input_url def write_model_web(self, message_key, model): + """ + Writes a machine learning model to S3 storage in web format. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The machine learning model to be stored. + + Returns: + str: The URL of the stored model in S3. + """ global aws_s3_client pickle_dump_start_time = time.time() MLOpsProfilerEvent.log_to_wandb( @@ -189,6 +249,15 @@ def write_model_web(self, message_key, model): return model_url def read_model(self, message_key): + """ + Reads a machine learning model from S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + + Returns: + model: The machine learning model retrieved from S3. + """ global aws_s3_client message_handler_start_time = time.time() @@ -200,7 +269,8 @@ def read_model(self, message_key): os.makedirs(cache_dir) except Exception as e: pass - temp_base_file_path = os.path.join(cache_dir, str(os.getpid()) + "@" + str(uuid.uuid4())) + temp_base_file_path = os.path.join( + cache_dir, str(os.getpid()) + "@" + str(uuid.uuid4())) if not os.path.exists(temp_base_file_path): try: os.makedirs(temp_base_file_path) @@ -211,22 +281,25 @@ def read_model(self, message_key): logging.info("temp_file_path = {}".format(temp_file_path)) model_file_transfered = 0 prev_progress = 0 + def read_model_progress(bytes_transferred): nonlocal model_file_transfered nonlocal object_size nonlocal prev_progress model_file_transfered += bytes_transferred readed_kb = format(model_file_transfered / 1024, '.2f') - progress = (model_file_transfered / object_size * 100) if object_size != 0 else 0 + progress = (model_file_transfered / object_size * + 100) if object_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model readed from S3 size {} KB, progress {}%".format(readed_kb, progress_format_int)) + logging.info("model readed from S3 size {} KB, progress {}%".format( + readed_kb, progress_format_int)) prev_progress = progress_format_int with open(temp_file_path, 'wb') as f: aws_s3_client.download_fileobj(Bucket=self.bucket_name, Key=message_key, Fileobj=f, - Callback=read_model_progress) + Callback=read_model_progress) MLOpsProfilerEvent.log_to_wandb( {"Comm/recieve_delay_s3": time.time() - message_handler_start_time} ) @@ -242,6 +315,16 @@ def read_model_progress(bytes_transferred): return model def read_model_net(self, message_key, local_model_cache_path): + """ + Reads a machine learning model in net format from S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + local_model_cache_path (str): The local cache path for model storage. + + Returns: + model: The machine learning model retrieved from S3. + """ global aws_s3_client message_handler_start_time = time.time() @@ -259,21 +342,25 @@ def read_model_net(self, message_key, local_model_cache_path): logging.info("temp_file_path = {}".format(temp_file_path)) model_file_transfered = 0 prev_progress = 0 + def read_model_net_progress(bytes_transferred): nonlocal model_file_transfered nonlocal object_size - nonlocal prev_progress # since the callback is stateless, we need to keep the previous progress + # since the callback is stateless, we need to keep the previous progress + nonlocal prev_progress model_file_transfered += bytes_transferred readed_kb = format(model_file_transfered / 1024, '.2f') - progress = (model_file_transfered / object_size * 100) if object_size != 0 else 0 + progress = (model_file_transfered / object_size * + 100) if object_size != 0 else 0 progress_format_int = int(progress) # print the process every 5% if progress_format_int % 5 == 0 and progress_format_int != prev_progress: - logging.info("model net readed from S3 size {} KB, progress {}%".format(readed_kb, progress_format_int)) + logging.info("model net readed from S3 size {} KB, progress {}%".format( + readed_kb, progress_format_int)) prev_progress = progress_format_int with open(temp_file_path, 'wb') as f: aws_s3_client.download_fileobj(Bucket=self.bucket_name, Key=message_key, Fileobj=f, - Callback=read_model_net_progress) + Callback=read_model_net_progress) MLOpsProfilerEvent.log_to_wandb( {"Comm/recieve_delay_s3": time.time() - message_handler_start_time} ) @@ -291,6 +378,17 @@ def read_model_net_progress(bytes_transferred): return model def read_model_input(self, message_key, local_model_cache_path): + """ + Reads model input information from S3 storage. + + Args: + message_key (str): The key to identify the stored input information in S3. + local_model_cache_path (str): The local cache path for input information storage. + + Returns: + input_size: The size of the model input. + input_type: The type of the model input. + """ global aws_s3_client temp_base_file_path = local_model_cache_path @@ -304,7 +402,8 @@ def read_model_input(self, message_key, local_model_cache_path): os.remove(temp_file_path) logging.info("temp_file_path = {}".format(temp_file_path)) with open(temp_file_path, 'wb') as f: - aws_s3_client.download_fileobj(Bucket=self.bucket_name, Key=message_key, Fileobj=f) + aws_s3_client.download_fileobj( + Bucket=self.bucket_name, Key=message_key, Fileobj=f) with open(temp_file_path, 'r') as f: model_input_dict = json.load(f) @@ -316,9 +415,21 @@ def read_model_input(self, message_key, local_model_cache_path): # TODO: added python torch model to align the Tensorflow parameters from browser def read_model_web(self, message_key, py_model: nn.Module): + """ + Reads a machine learning model in web format from S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + py_model (nn.Module): The PyTorch model to align Tensorflow parameters from the browser. + + Returns: + model: The machine learning model retrieved from S3. + """ + global aws_s3_client message_handler_start_time = time.time() - obj = aws_s3_client.get_object(Bucket=self.bucket_name, Key=message_key) + obj = aws_s3_client.get_object( + Bucket=self.bucket_name, Key=message_key) model_json = obj["Body"].read() if type(model_json) == list: model = load_params_from_tf(py_model, model_json) @@ -368,16 +479,21 @@ def read_model_web(self, message_key, py_model: nn.Module): def upload_file(self, src_local_path, message_key): """ - upload file - :param src_local_path: - :param message_key: - :return: + Uploads a file to S3 storage. + + Args: + src_local_path (str): The local path to the file to be uploaded. + message_key (str): The key to identify the stored file in S3. + + Returns: + str: The URL of the uploaded file. """ try: with open(src_local_path, "rb") as f: global aws_s3_client aws_s3_client.upload_fileobj( - f, self.bucket_name, message_key, ExtraArgs={"ACL": "public-read"} + f, self.bucket_name, message_key, ExtraArgs={ + "ACL": "public-read"} ) model_url = aws_s3_client.generate_presigned_url( @@ -398,10 +514,14 @@ def upload_file(self, src_local_path, message_key): def download_file(self, message_key, path_local): """ - download file - :param message_key: s3 key - :param path_local: local path - :return: + Downloads a file from S3 storage to the local filesystem. + + Args: + message_key (str): The key to identify the file in S3. + path_local (str): The local path where the file should be saved. + + Returns: + None """ retry = 0 while retry < 3: @@ -410,7 +530,8 @@ def download_file(self, message_key, path_local): ) try: global aws_s3_client - aws_s3_client.download_file(self.bucket_name, message_key, path_local) + aws_s3_client.download_file( + self.bucket_name, message_key, path_local) file_size = os.path.getsize(path_local) logging.info( f"Downloading completed. | size: {round(file_size / 1048576, 2)} MB" @@ -425,12 +546,16 @@ def download_file(self, message_key, path_local): def upload_file_with_progress(self, src_local_path, dest_s3_path, out_progress_to_err=True, progress_desc=None): """ - upload file - :param out_progress_to_err: - :param progress_desc: - :param src_local_path: - :param dest_s3_path: - :return: + Uploads a file to S3 storage with progress tracking. + + Args: + src_local_path (str): The local path to the file to be uploaded. + dest_s3_path (str): The key to identify the stored file in S3. + out_progress_to_err (bool): Whether to output progress to stderr. + progress_desc (str): A description for the progress tracking. + + Returns: + str: The URL of the uploaded file. """ file_uploaded_url = "" progress_desc_text = "Uploading Package to AWS S3" @@ -447,8 +572,10 @@ def upload_file_with_progress(self, src_local_path, dest_s3_path, file=sys.stderr if out_progress_to_err else sys.stdout, desc=progress_desc_text) as pbar: aws_s3_client.upload_fileobj( - f, self.bucket_name, dest_s3_path, ExtraArgs={"ACL": "public-read"}, - Callback=lambda bytes_transferred: pbar.update(bytes_transferred), + f, self.bucket_name, dest_s3_path, ExtraArgs={ + "ACL": "public-read"}, + Callback=lambda bytes_transferred: pbar.update( + bytes_transferred), ) file_uploaded_url = aws_s3_client.generate_presigned_url( @@ -469,12 +596,16 @@ def upload_file_with_progress(self, src_local_path, dest_s3_path, def download_file_with_progress(self, path_s3, path_local, out_progress_to_err=True, progress_desc=None): """ - download file - :param out_progress_to_err: - :param progress_desc: - :param path_s3: s3 key - :param path_local: local path - :return: + Downloads a file from S3 storage to the local filesystem with progress tracking. + + Args: + path_s3 (str): The key to identify the file in S3. + path_local (str): The local path where the file should be saved. + out_progress_to_err (bool): Whether to output progress to stderr. + progress_desc (str): A description for the progress tracking. + + Returns: + None """ retry = 0 progress_desc_text = "Downloading Package from AWS S3" @@ -487,7 +618,8 @@ def download_file_with_progress(self, path_s3, path_local, try: global aws_s3_client kwargs = {"Bucket": self.bucket_name, "Key": path_s3} - object_size = aws_s3_client.head_object(**kwargs)["ContentLength"] + object_size = aws_s3_client.head_object( + **kwargs)["ContentLength"] with tqdm.tqdm(total=object_size, unit="B", unit_scale=True, file=sys.stderr if out_progress_to_err else sys.stdout, desc=progress_desc_text) as pbar: @@ -504,10 +636,14 @@ def download_file_with_progress(self, path_s3, path_local, def test_s3_base_cmds(self, message_key, message_body): """ - test_s3_base_cmds - :param file_key: s3 message key - :param file_key: s3 message body - :return: + Tests basic S3 commands by uploading and downloading a message. + + Args: + message_key (str): The key to identify the stored message in S3. + message_body: The message body to be stored and retrieved. + + Returns: + bool: True if the test is successful, False otherwise. """ retry = 0 while retry < 3: @@ -517,7 +653,8 @@ def test_s3_base_cmds(self, message_key, message_body): aws_s3_client.put_object( Body=message_pkl, Bucket=self.bucket_name, Key=message_key, ACL="public-read", ) - obj = aws_s3_client.get_object(Bucket=self.bucket_name, Key=message_key) + obj = aws_s3_client.get_object( + Bucket=self.bucket_name, Key=message_key) message_pkl_downloaded = obj["Body"].read() message_downloaded = pickle.loads(message_pkl_downloaded) if str(message_body) == str(message_downloaded): @@ -534,15 +671,28 @@ def test_s3_base_cmds(self, message_key, message_body): def delete_s3_zip(self, path_s3): """ - delete s3 object - :param path_s3: s3 key - :return: + Deletes an object from S3 storage. + + Args: + path_s3 (str): The key to identify the object in S3. + + Returns: + None """ global aws_s3_client aws_s3_client.delete_object(Bucket=self.bucket_name, Key=path_s3) logging.info(f"Delete s3 file Successful. | path_s3 = {path_s3}") def set_config_from_file(self, config_file_path): + """ + Sets the S3 configuration from a file. + + Args: + config_file_path (str): The path to the configuration file. + + Returns: + None + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -554,6 +704,15 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, s3_config): + """ + Sets the S3 configuration from a dictionary of S3 configuration values. + + Args: + s3_config (dict): A dictionary containing S3 configuration values. + + Returns: + None + """ self.cn_s3_aki = s3_config["CN_S3_AKI"] self.cn_s3_sak = s3_config["CN_S3_SAK"] self.cn_region_name = s3_config["CN_REGION_NAME"] diff --git a/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py b/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py index f6b0b17a9f..9f4068fe68 100644 --- a/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py +++ b/python/fedml/core/distributed/communication/s3/remote_storage_mnn.py @@ -13,39 +13,113 @@ def __init__(self, s3_config_path): def upload_model_file(self, message_key, model_file_path): """ - this is used for Mobile Platform (MNN) - :param message_key: - :param model_file_path: - :return: + Uploads a model file to S3 storage for Mobile Platform (MNN). + + Args: + message_key (str): The key to identify the uploaded model in S3. + model_file_path (str): The local file path of the model to be uploaded. + + Returns: + bool: True if the upload was successful, False otherwise. """ return self.s3_storage.upload_file(model_file_path, message_key) def download_model_file(self, message_key, model_file_path): """ - this is used for Mobile Platform (MNN) - :param message_key: - :param model_file_path: - :return: + Downloads a model file from S3 storage for Mobile Platform (MNN). + + Args: + message_key (str): The key identifying the model to be downloaded from S3. + model_file_path (str): The local file path where the downloaded model will be saved. + + Returns: + None """ self.s3_storage.download_file(message_key, model_file_path) def write_model(self, message_key, model): + """ + Writes a model object to S3 storage. + + Args: + message_key (str): The key to identify the stored model in S3. + model: The model object to be stored. + + Returns: + None + """ self.s3_storage.write_model(message_key, model) def read_model(self, message_key): + """ + Reads a model object from S3 storage. + + Args: + message_key (str): The key identifying the model to be read from S3. + + Returns: + object: The model object read from S3. + """ return self.s3_storage.read_model(message_key) def upload_file(self, src_local_path, dest_s3_path): + """ + Uploads a file from the local system to S3 storage. + + Args: + src_local_path (str): The local file path of the file to be uploaded. + dest_s3_path (str): The S3 destination path for the uploaded file. + + Returns: + bool: True if the upload was successful, False otherwise. + """ return self.s3_storage.upload_file(src_local_path, dest_s3_path) def download_file(self, path_s3, path_local): + """ + Downloads a file from S3 storage to the local system. + + Args: + path_s3 (str): The S3 path of the file to be downloaded. + path_local (str): The local file path where the downloaded file will be saved. + + Returns: + None + """ self.s3_storage.download_file(path_s3, path_local) def delete_s3_zip(self, path_s3): + """ + Deletes a ZIP file from S3 storage. + + Args: + path_s3 (str): The S3 path of the ZIP file to be deleted. + + Returns: + None + """ self.s3_storage.delete_s3_zip(path_s3) def set_config_from_file(self, config_file_path): + """ + Sets the S3 configuration from a configuration file. + + Args: + config_file_path (str): The path to the S3 configuration file. + + Returns: + None + """ self.s3_storage.set_config_from_file(config_file_path) def set_config_from_objects(self, s3_config): + """ + Sets the S3 configuration from configuration objects. + + Args: + s3_config: Configuration objects for S3 storage. + + Returns: + None + """ self.s3_storage.set_config_from_objects(s3_config) diff --git a/python/fedml/core/distributed/communication/s3/utils.py b/python/fedml/core/distributed/communication/s3/utils.py index a92d5b8aaa..00be45480b 100644 --- a/python/fedml/core/distributed/communication/s3/utils.py +++ b/python/fedml/core/distributed/communication/s3/utils.py @@ -5,19 +5,23 @@ def load_params_from_tf(py_model:nn.Module, tf_model:list): """ - Load and update the parameters from tensorflow.js to pytorch nn.Module + Load and update the parameters from TensorFlow.js to PyTorch nn.Module. Args: - py_model: An nn.Moudule network structure from pytorch - tf_module: A list read from JSON file which stored the meta data of tensorflow.js model - (length is number of layers, and has two keys in each layer, 'model' and 'params' respectively) + py_model (nn.Module): A PyTorch neural network structure. + tf_model (list): A list read from a JSON file containing metadata for the TensorFlow.js model. Returns: - An updated nn.Module network structure + nn.Module: An updated PyTorch neural network structure. Raises: - Exception: Certain layer structure is not aligned - KeyError: Model layer is not aligned + Exception: If certain layer structures do not align between PyTorch and TensorFlow.js. + KeyError: If a model layer is not aligned. + + This function loads and updates the parameters from a TensorFlow.js model to a PyTorch nn.Module. + It compares layer names between the two models and assigns the TensorFlow.js parameters to the + corresponding layers in the PyTorch model. + """ state_dict = py_model.state_dict() py_layers = list(state_dict.keys()) @@ -41,6 +45,22 @@ def load_params_from_tf(py_model:nn.Module, tf_model:list): raise TypeError("The model structure of pytorch and tensorflow.js is not aligned! Cannot transfer parameters accordingly.") def process_state_dict(state_dict): + """ + Process a PyTorch state dictionary to convert it into a Python dictionary. + + Args: + state_dict (dict): A PyTorch state dictionary containing model parameters. + + Returns: + dict: A Python dictionary where keys are parameter names and values are + NumPy arrays representing the parameter values. + + This function takes a PyTorch state dictionary, which typically contains the + parameters of a neural network model, and converts it into a Python dictionary. + Each key in the resulting dictionary corresponds to a parameter's name, and the + corresponding value is a NumPy array containing the parameter's values. + + """ lr_py = {} for key, value in state_dict.items(): lr_py[key] = value.cpu().detach().numpy().tolist() @@ -48,12 +68,31 @@ def process_state_dict(state_dict): class LogisticRegression(torch.nn.Module): - def __init__(self, input_dim, output_dim): - super(LogisticRegression, self).__init__() - self.linear = torch.nn.Linear(input_dim, output_dim) - def forward(self, x): - outputs = torch.sigmoid(self.linear(x)) - return outputs + def __init__(self, input_dim, output_dim): + """ + Initialize a logistic regression model. + + Args: + input_dim (int): The input dimension. + output_dim (int): The output dimension. + + """ + super(LogisticRegression, self).__init__() + self.linear = torch.nn.Linear(input_dim, output_dim) + + def forward(self, x): + """ + Forward pass of the logistic regression model. + + Args: + x (Tensor): Input tensor. + + Returns: + Tensor: Output tensor after applying the sigmoid function. + + """ + outputs = torch.sigmoid(self.linear(x)) + return outputs class CNN_WEB(nn.Module): diff --git a/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py b/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py index 0cbe0bccf1..4f91223213 100644 --- a/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py +++ b/python/fedml/core/distributed/communication/trpc/trpc_comm_manager.py @@ -20,6 +20,19 @@ class TRPCCommManager(BaseCommunicationManager): def __init__(self, trpc_master_config_path, process_id=0, world_size=0, args=None): + """ + Initialize a TRPC communication manager. + + Args: + trpc_master_config_path (str): Path to the TRPC master configuration file. + process_id (int): The ID of the current process. + world_size (int): The total number of processes in the world. + args (Optional): Additional arguments. + + Returns: + None + """ + logging.info("using TRPC backend") with open(trpc_master_config_path, newline="") as csv_file: csv_reader = csv.reader(csv_file) @@ -40,19 +53,33 @@ def __init__(self, trpc_master_config_path, process_id=0, world_size=0, args=Non logging.info(f"Worker rank {process_id} initializing RPC") - self.trpc_servicer = TRPCCOMMServicer(master_address, master_port, self.world_size, process_id) + self.trpc_servicer = TRPCCOMMServicer( + master_address, master_port, self.world_size, process_id) logging.info(os.getcwd()) os.environ["MASTER_ADDR"] = self.master_address os.environ["MASTER_PORT"] = self.master_port - self._init_torch_rpc_tp(master_address, master_port, process_id, self.world_size) + self._init_torch_rpc_tp( + master_address, master_port, process_id, self.world_size) self.is_running = True logging.info("server started. master address: " + str(master_address)) def _init_torch_rpc_tp( self, master_addr, master_port, worker_idx, worker_num, ): + """ + Initialize the Torch RPC using TensorPipe backend. + + Args: + master_addr (str): The address of the RPC master. + master_port (str): The port of the RPC master. + worker_idx (int): The index of the current worker. + worker_num (int): The total number of workers. + + Returns: + None + """ # https://github.com/pytorch/pytorch/issues/55615 # [BC-Breaking][RFC] Retire ProcessGroup Backend for RPC #55615 str_init_method = "tcp://" + str(master_addr) + ":" + str(master_port) @@ -73,6 +100,15 @@ def _init_torch_rpc_tp( logging.info("_init_torch_rpc_tp finished.") def send_message(self, msg: Message): + """ + Send a message to the specified receiver. + + Args: + msg (Message): The message to be sent. + + Returns: + None + """ receiver_id = msg.get_receiver_id() logging.info("sending message to {}".format(receiver_id)) @@ -82,21 +118,52 @@ def send_message(self, msg: Message): rpc.rpc_sync( WORKER_NAME.format(receiver_id), TRPCCOMMServicer.sendMessage, args=(self.process_id, msg), ) - MLOpsProfilerEvent.log_to_wandb({"Comm/send_delay": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Comm/send_delay": time.time() - tick}) logging.debug("sent") def add_observer(self, observer: Observer): + """ + Add an observer to the communication manager. + + Args: + observer (Observer): The observer to be added. + + Returns: + None + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer from the communication manager. + + Args: + observer (Observer): The observer to be removed. + + Returns: + None + """ self._observers.remove(observer) def handle_receive_message(self): + """ + Handle receiving messages in a separate thread. + + Returns: + None + """ thread = threading.Thread(target=self.message_handling_subroutine) thread.start() self._notify_connection_ready() def message_handling_subroutine(self): + """ + Subroutine for handling received messages. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) while self.is_running: @@ -105,21 +172,44 @@ def message_handling_subroutine(self): message_handler_start_time = time.time() msg = self.trpc_servicer.message_q.get() self.notify(msg) - MLOpsProfilerEvent.log_to_wandb({"BusyTime": time.time() - message_handler_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"BusyTime": time.time() - message_handler_start_time}) lock.release() - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) return def stop_receive_message(self): + """ + Stop receiving messages and shutdown the communication manager. + + Returns: + None + """ rpc.shutdown() self.is_running = False def notify(self, message: Message): + """ + Notify observers about a received message. + + Args: + message (Message): The received message. + + Returns: + None + """ msg_type = message.get_type() for observer in self._observers: observer.receive_message(msg_type, message) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + + Returns: + None + """ msg_params = Message() msg_params.sender_id = self.rank msg_params.receiver_id = self.rank diff --git a/python/fedml/core/distributed/communication/trpc/trpc_server.py b/python/fedml/core/distributed/communication/trpc/trpc_server.py index 96d0969a38..ef41649f18 100644 --- a/python/fedml/core/distributed/communication/trpc/trpc_server.py +++ b/python/fedml/core/distributed/communication/trpc/trpc_server.py @@ -9,6 +9,18 @@ class TRPCCOMMServicer: _instance = None def __new__(cls, master_address, master_port, client_num, client_id): + """ + Create a new instance of the TRPCCOMMServicer class if it does not exist, otherwise return the existing instance. + + Args: + master_address (str): The address of the RPC master. + master_port (str): The port of the RPC master. + client_num (int): The total number of clients. + client_id (int): The ID of the current client. + + Returns: + TRPCCOMMServicer: An instance of the TRPCCOMMServicer class. + """ cls.master_address = None cls.master_port = None cls.client_num = None @@ -31,6 +43,16 @@ def __new__(cls, master_address, master_port, client_num, client_id): return cls._instance def receiveMessage(self, client_id, message): + """ + Receive a message from another client. + + Args: + client_id (int): The ID of the client sending the message. + message (Message): The received message. + + Returns: + str: A response indicating that the message was received. + """ logging.info( "client_{} got something from client_{}".format( self.client_id, @@ -51,4 +73,14 @@ def receiveMessage(self, client_id, message): @classmethod def sendMessage(cls, clint_id, message): + """ + Send a message to another client. + + Args: + clint_id (int): The ID of the target client. + message (Message): The message to be sent. + + Returns: + None + """ cls._instance.receiveMessage(clint_id, message) \ No newline at end of file diff --git a/python/fedml/core/distributed/communication/trpc/utils.py b/python/fedml/core/distributed/communication/trpc/utils.py index 636750edf2..f8c0b86c85 100644 --- a/python/fedml/core/distributed/communication/trpc/utils.py +++ b/python/fedml/core/distributed/communication/trpc/utils.py @@ -5,8 +5,26 @@ # Generate Device Map for Cuda RPC def set_device_map(options, worker_idx, device_list): + """ + Set the device mapping for PyTorch RPC communication between workers. + + Args: + options (rpc.TensorPipeRpcBackendOptions): The RPC backend options to configure. + worker_idx (int): The index of the current worker. + device_list (list of str): A list of device identifiers for all workers. + + Example: + Suppose you have two workers with GPUs, and `device_list` is ['cuda:0', 'cuda:1']. + If `worker_idx` is 0, this function will set the device mapping for worker 0 as follows: + {WORKER_NAME.format(1): 'cuda:1'} to communicate with worker 1 using 'cuda:1'. + + Returns: + None + """ local_device = device_list[worker_idx] for index, remote_device in enumerate(device_list): - logging.warn(f"Setting device map for client {index} as {remote_device}") + logging.warn( + f"Setting device map for client {index} as {remote_device}") if index != worker_idx: - options.set_device_map(WORKER_NAME.format(index), {local_device: remote_device}) \ No newline at end of file + options.set_device_map(WORKER_NAME.format( + index), {local_device: remote_device}) diff --git a/python/fedml/core/distributed/communication/utils.py b/python/fedml/core/distributed/communication/utils.py index 8bc610309b..521b5cff1d 100755 --- a/python/fedml/core/distributed/communication/utils.py +++ b/python/fedml/core/distributed/communication/utils.py @@ -3,6 +3,14 @@ def log_communication_tick(sender, receiver, timestamp=None): + """ + Log a benchmark tick event from sender to receiver. + + Args: + sender (str): Sender's identifier. + receiver (str): Receiver's identifier. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark tick from {} to {} at {}".format( sender, receiver, timestamp or time() @@ -11,6 +19,14 @@ def log_communication_tick(sender, receiver, timestamp=None): def log_communication_tock(sender, receiver, timestamp=None): + """ + Log a benchmark tock event from sender to receiver. + + Args: + sender (str): Sender's identifier. + receiver (str): Receiver's identifier. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark tock from {} to {} at {}".format( sender, receiver, timestamp or time() @@ -19,6 +35,14 @@ def log_communication_tock(sender, receiver, timestamp=None): def log_round_start(client_idx, round_number, timestamp=None): + """ + Log the start of a benchmark round for a client. + + Args: + client_idx (int): Client's index or identifier. + round_number (int): Round number. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark start round {} for {} at {}".format( round_number, client_idx, timestamp or time() @@ -27,6 +51,14 @@ def log_round_start(client_idx, round_number, timestamp=None): def log_round_end(client_idx, round_number, timestamp=None): + """ + Log the end of a benchmark round for a client. + + Args: + client_idx (int): Client's index or identifier. + round_number (int): Round number. + timestamp (float): Timestamp for the event (default is current time). + """ logging.info( "--Benchmark end round {} for {} at {}".format( round_number, client_idx, timestamp or time() diff --git a/python/fedml/core/distributed/crypto/crypto_api.py b/python/fedml/core/distributed/crypto/crypto_api.py index e7283aca57..284a39921b 100644 --- a/python/fedml/core/distributed/crypto/crypto_api.py +++ b/python/fedml/core/distributed/crypto/crypto_api.py @@ -6,31 +6,42 @@ def export_public_key(private_key_hex: str) -> bytes: - """Export public key for contract join request. + """ + Export the public key for a contract join request. Args: - private_key: hex string representing private key + private_key_hex (str): Hex string representing the private key. Returns: - 32 bytes representing public key + bytes: 32 bytes representing the public key. """ def _hex_to_bytes(hex: str) -> bytes: + """ + Convert a hex string to bytes. + + Args: + hex (str): Hex string. + + Returns: + bytes: Bytes representation of the hex string. + """ return bytes.fromhex(hex[2:] if hex[:2] == "0x" else hex) return bytes(PrivateKey(_hex_to_bytes(private_key_hex)).public_key) def encrypt_nacl(public_key: bytes, data: bytes) -> bytes: - """Encryption function using NaCl box compatible with MetaMask + """ + Encrypt data using NaCl box compatible with MetaMask. For implementation used in MetaMask look into: https://github.com/MetaMask/eth-sig-util Args: - public_key: public key of recipient - data: message data + public_key (bytes): Public key of the recipient. + data (bytes): Message data to be encrypted. Returns: - encrypted data + bytes: Encrypted data. """ emph_key = PrivateKey.generate() enc_box = Box(emph_key, PublicKey(public_key)) @@ -57,7 +68,17 @@ def decrypt_nacl(private_key: bytes, data: bytes) -> bytes: def get_current_secret(secret: bytes, entry_key_turn: int, key_turn: int) -> bytes: - """Calculate shared secret at current state.""" + """ + Calculate the shared secret at the current state. + + Args: + secret (bytes): Initial secret. + entry_key_turn (int): Entry key turn. + key_turn (int): Key turn. + + Returns: + bytes: The calculated shared secret. + """ for _ in range(entry_key_turn, key_turn): secret = hashlib.sha256(secret).digest() return secret diff --git a/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py b/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py index 2add92cd7c..ff894a9e9e 100644 --- a/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py +++ b/python/fedml/core/distributed/distributed_storage/theta_storage/theta_storage.py @@ -14,19 +14,45 @@ class ThetaStorage: - def __init__( - self, thetasotre_config): + def __init__(self, thetasotre_config): + """ + Initialize a ThetaStorage instance. + + Args: + thetasotre_config (dict): Configuration parameters for ThetaStore. + + Attributes: + ipfs_config (dict): ThetaStore configuration dictionary. + store_home_dir (str): Home directory for ThetaStore. + ipfs_upload_uri (str): URI for uploading files to ThetaStore. + ipfs_download_uri (str): URI for downloading files from ThetaStore. + + """ self.ipfs_config = thetasotre_config - self.store_home_dir = thetasotre_config.get("store_home_dir", "~/edge-store-playground") + self.store_home_dir = thetasotre_config.get( + "store_home_dir", "~/edge-store-playground") if str(self.store_home_dir).startswith("~"): home_dir = expanduser("~") - new_store_dir = str(self.store_home_dir).replace('\\', os.sep).replace('/', os.sep) + new_store_dir = str(self.store_home_dir).replace( + '\\', os.sep).replace('/', os.sep) strip_dir = new_store_dir.lstrip('~').lstrip(os.sep) self.store_home_dir = os.path.join(home_dir, strip_dir) - self.ipfs_upload_uri = thetasotre_config.get("upload_uri", "http://localhost:19888/rpc") - self.ipfs_download_uri = thetasotre_config.get("download_uri", "http://localhost:19888/rpc") + self.ipfs_upload_uri = thetasotre_config.get( + "upload_uri", "http://localhost:19888/rpc") + self.ipfs_download_uri = thetasotre_config.get( + "download_uri", "http://localhost:19888/rpc") def write_model(self, model): + """ + Serialize and upload a machine learning model to ThetaStore. + + Args: + model: The machine learning model to be uploaded. + + Returns: + str: The IPFS key where the model is stored. + + """ pickle_dump_start_time = time.time() model_pkl = pickle.dumps(model) secret_key = Context().get("ipfs_secret_key") @@ -43,7 +69,17 @@ def write_model(self, model): ) return model_url - def read_model(self, message_key): + def read_model(self, message_key): + """ + Download and deserialize a machine learning model from ThetaStore. + + Args: + message_key: The ThetaStore key of the model to be retrieved. + + Returns: + model: The deserialized machine learning model. + + """ message_handler_start_time = time.time() model_pkl, _ = self.storage_ipfs_download_file(message_key) secret_key = Context().get("ipfs_secret_key") @@ -61,13 +97,15 @@ def read_model(self, message_key): return model def storage_ipfs_upload_file(self, file_obj): - """Upload file to IPFS using web3.storage. + """ + Upload a file to ThetaStore using Theta's RPC. Args: - file_obj: file-like object in byte mode + file_obj: A file-like object in byte mode. Returns: - Response: (Successful, cid or error message) + tuple: A tuple containing a boolean indicating success, and either the ThetaStore key or an error message. + """ # Request: upload a file # curl -X POST -H 'Content-Type: application/json' --data '{"jsonrpc":"2.0","method":"edgestore.PutFile","params":[{"path": "theta-edge-store-demos/demos/image/data/smiley_explorer.png"}],"id":1}' http://localhost:19888/rpc @@ -89,10 +127,10 @@ def storage_ipfs_upload_file(self, file_obj): with open(file_path, "wb") as file_handle: file_handle.write(file_obj) - request_data = {"jsonrpc":"2.0", - "method":"edgestore.PutFile", - "params":[{"path": file_path}], - "id":1} + request_data = {"jsonrpc": "2.0", + "method": "edgestore.PutFile", + "params": [{"path": file_path}], + "id": 1} res = httpx.post( self.ipfs_upload_uri, headers={"Content-Type": "application/json"}, @@ -133,10 +171,10 @@ def storage_ipfs_download_file(self, ipfs_cid, output_path=None): # } # } - request_data = {"jsonrpc":"2.0", - "method":"edgestore.GetFile", - "params":[{"key": ipfs_cid}], - "id":1} + request_data = {"jsonrpc": "2.0", + "method": "edgestore.GetFile", + "params": [{"key": ipfs_cid}], + "id": 1} res = httpx.post( self.ipfs_download_uri, headers={"Content-Type": "application/json"}, @@ -154,7 +192,8 @@ def storage_ipfs_download_file(self, ipfs_cid, output_path=None): if download_path is None: return False, "Failed to download file(path is none)." else: - download_path = os.path.join(self.store_home_dir, download_path) + download_path = os.path.join( + self.store_home_dir, download_path) output_file_obj = None file_content = None diff --git a/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py b/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py index f5f5b1a299..a86d622fd6 100644 --- a/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py +++ b/python/fedml/core/distributed/distributed_storage/web3_storage/web3_storage.py @@ -10,13 +10,34 @@ class Web3Storage: - def __init__( - self, ipfs_config): + def __init__(self, ipfs_config): + """ + Initialize a Web3Storage instance. + + Args: + ipfs_config (dict): Configuration parameters for IPFS. + + Attributes: + ipfs_config (dict): IPFS configuration dictionary. + ipfs_upload_uri (str): URI for uploading files to IPFS. + ipfs_download_uri (str): URI for downloading files from IPFS. + + """ self.ipfs_config = ipfs_config self.ipfs_upload_uri = ipfs_config.get("upload_uri", "https://api.web3.storage/upload") self.ipfs_download_uri = ipfs_config.get("download_uri", "ipfs.w3s.link2") def write_model(self, model): + """ + Serialize and upload a machine learning model to IPFS. + + Args: + model: The machine learning model to be uploaded. + + Returns: + str: The IPFS URL where the model is stored. + + """ pickle_dump_start_time = time.time() model_pkl = pickle.dumps(model) secret_key = Context().get("ipfs_secret_key") @@ -34,6 +55,16 @@ def write_model(self, model): return model_url def read_model(self, message_key): + """ + Download and deserialize a machine learning model from IPFS. + + Args: + message_key: The IPFS key of the model to be retrieved. + + Returns: + model: The deserialized machine learning model. + + """ message_handler_start_time = time.time() model_pkl, _ = self.storage_ipfs_download_file(message_key) secret_key = Context().get("ipfs_secret_key") diff --git a/python/fedml/core/distributed/fedml_comm_manager.py b/python/fedml/core/distributed/fedml_comm_manager.py index 5959f175ac..9a40f7398b 100644 --- a/python/fedml/core/distributed/fedml_comm_manager.py +++ b/python/fedml/core/distributed/fedml_comm_manager.py @@ -9,7 +9,55 @@ class FedMLCommManager(Observer): + """ + Communication manager for Federated Machine Learning (FedML). + + Args: + args: Command-line arguments. + comm: The communication backend. + rank: The rank of the current node. + size: The total number of nodes in the communication group. + backend: The communication backend used (e.g., "MPI", "MQTT", "MQTT_S3"). + + Attributes: + args: Command-line arguments. + size: The total number of nodes in the communication group. + rank: The rank of the current node. + backend: The communication backend used. + comm: The communication object. + com_manager: The communication manager. + message_handler_dict: A dictionary to register message handlers. + + Methods: + register_comm_manager(comm_manager): Register a communication manager. + run(): Start the communication manager. + get_sender_id(): Get the sender's ID. + receive_message(msg_type, msg_params): Receive a message and handle it. + send_message(message): Send a message. + send_message_json(topic_name, json_message): Send a JSON message. + register_message_receive_handlers(): Register message receive handlers. + register_message_receive_handler(msg_type, handler_callback_func): Register a message receive handler. + finish(): Finish the communication manager. + get_training_mqtt_s3_config(): Get MQTT and S3 configurations for training. + get_training_mqtt_web3_config(): Get MQTT and Web3 configurations for training. + get_training_mqtt_thetastore_config(): Get MQTT and Thetastore configurations for training. + _init_manager(): Initialize the communication manager based on the selected backend. + """ + def __init__(self, args, comm=None, rank=0, size=0, backend="MPI"): + """ + Initialize the FedMLCommManager. + + Args: + args: Command-line arguments. + comm: The communication backend. + rank: The rank of the current node. + size: The total number of nodes in the communication group. + backend: The communication backend used (e.g., "MPI", "MQTT", "MQTT_S3"). + + Returns: + None + """ self.args = args self.size = size self.rank = int(rank) @@ -20,21 +68,54 @@ def __init__(self, args, comm=None, rank=0, size=0, backend="MPI"): self._init_manager() def register_comm_manager(self, comm_manager: BaseCommunicationManager): + """ + Register a communication manager. + + Args: + comm_manager (BaseCommunicationManager): The communication manager to register. + + Returns: + None + """ self.com_manager = comm_manager def run(self): + """ + Start the communication manager. + + Returns: + None + """ self.register_message_receive_handlers() logging.info("running") self.com_manager.handle_receive_message() logging.info("finished...") def get_sender_id(self): + """ + Get the sender's ID. + + Returns: + int: The sender's ID (rank). + + """ return self.rank def receive_message(self, msg_type, msg_params) -> None: + """ + Receive a message and handle it. + + Args: + msg_type (str): The type of the received message. + msg_params: Parameters associated with the received message. + + Returns: + None + """ if msg_params.get_sender_id() == msg_params.get_receiver_id(): - logging.info("communication backend is alive (loop_forever, sender 0 to receiver 0)") + logging.info( + "communication backend is alive (loop_forever, sender 0 to receiver 0)") else: logging.info( "receive_message. msg_type = %s, sender_id = %d, receiver_id = %d" @@ -51,19 +132,64 @@ def receive_message(self, msg_type, msg_params) -> None: ) def send_message(self, message): + """ + Send a message. + + Args: + message: The message to send. + + Returns: + None + """ self.com_manager.send_message(message) def send_message_json(self, topic_name, json_message): + """ + Send a JSON message. + + Args: + topic_name (str): The name of the message topic. + json_message: The JSON message to send. + + Returns: + None + """ self.com_manager.send_message_json(topic_name, json_message) @abstractmethod def register_message_receive_handlers(self) -> None: + """ + Register message receive handlers. + + This method should be implemented in derived classes. + + Returns: + None + """ pass def register_message_receive_handler(self, msg_type, handler_callback_func): + """ + Register a message receive handler. + + Args: + msg_type (str): The type of the message to handle. + handler_callback_func: The callback function to handle the message. + + Returns: + None + """ self.message_handler_dict[msg_type] = handler_callback_func def finish(self): + """ + Finish the communication manager. + + Depending on the backend used, this method may perform specific actions to terminate the communication. + + Returns: + None + """ logging.info("__finish") if self.backend == "MPI": from mpi4py import MPI @@ -81,6 +207,13 @@ def finish(self): self.com_manager.stop_receive_message() def get_training_mqtt_s3_config(self): + """ + Get MQTT and S3 configurations for training. + + Returns: + tuple: A tuple containing MQTT configuration and S3 configuration. + + """ mqtt_config = None s3_config = None if hasattr(self.args, "customized_training_mqtt_config") and self.args.customized_training_mqtt_config != "": @@ -88,7 +221,8 @@ def get_training_mqtt_s3_config(self): if hasattr(self.args, "customized_training_s3_config") and self.args.customized_training_s3_config != "": s3_config = self.args.customized_training_s3_config if mqtt_config is None or s3_config is None: - mqtt_config_from_cloud, s3_config_from_cloud = MLOpsConfigs.get_instance(self.args).fetch_configs() + mqtt_config_from_cloud, s3_config_from_cloud = MLOpsConfigs.get_instance( + self.args).fetch_configs() if mqtt_config is None: mqtt_config = mqtt_config_from_cloud if s3_config is None: @@ -104,7 +238,8 @@ def get_training_mqtt_web3_config(self): if hasattr(self.args, "customized_training_web3_config") and self.args.customized_training_web3_config != "": web3_config = self.args.customized_training_web3_config if mqtt_config is None or web3_config is None: - mqtt_config_from_cloud, web3_config_from_cloud = MLOpsConfigs.get_instance(self.args).fetch_web3_configs() + mqtt_config_from_cloud, web3_config_from_cloud = MLOpsConfigs.get_instance( + self.args).fetch_web3_configs() if mqtt_config is None: mqtt_config = mqtt_config_from_cloud if web3_config is None: @@ -120,7 +255,8 @@ def get_training_mqtt_thetastore_config(self): if hasattr(self.args, "customized_training_thetastore_config") and self.args.customized_training_thetastore_config != "": thetastore_config = self.args.customized_training_thetastore_config if mqtt_config is None or thetastore_config is None: - mqtt_config_from_cloud, thetastore_config_from_cloud = MLOpsConfigs.get_instance(self.args).fetch_thetastore_configs() + mqtt_config_from_cloud, thetastore_config_from_cloud = MLOpsConfigs.get_instance( + self.args).fetch_thetastore_configs() if mqtt_config is None: mqtt_config = mqtt_config_from_cloud if thetastore_config is None: @@ -133,7 +269,8 @@ def _init_manager(self): if self.backend == "MPI": from .communication.mpi.com_manager import MpiCommunicationManager - self.com_manager = MpiCommunicationManager(self.comm, self.rank, self.size) + self.com_manager = MpiCommunicationManager( + self.comm, self.rank, self.size) elif self.backend == "MQTT_S3": from .communication.mqtt_s3.mqtt_s3_multi_clients_comm_manager import MqttS3MultiClientsCommManager @@ -202,7 +339,8 @@ def _init_manager(self): ) else: if self.com_manager is None: - raise Exception("no such backend: {}. Please check the comm_backend spelling.".format(self.backend)) + raise Exception( + "no such backend: {}. Please check the comm_backend spelling.".format(self.backend)) else: logging.info("using self-defined communication backend") diff --git a/python/fedml/core/distributed/flow/fedml_executor.py b/python/fedml/core/distributed/flow/fedml_executor.py index 36b44cb4e2..9b7f45c34b 100644 --- a/python/fedml/core/distributed/flow/fedml_executor.py +++ b/python/fedml/core/distributed/flow/fedml_executor.py @@ -2,32 +2,128 @@ class FedMLExecutor(abc.ABC): + """ + Abstract base class for Federated Machine Learning Executors. + + This class defines the basic structure and methods for a FedML executor. + + Args: + id (str): Identifier for the executor. + neighbor_id_list (List[str]): List of neighbor executor IDs. + + Attributes: + id (str): Identifier for the executor. + neighbor_id_list (List[str]): List of neighbor executor IDs. + params (Any): Parameters associated with the executor. + context (Any): Context or environment information. + + Methods: + get_context() -> Any: + Get the context or environment information associated with the executor. + + set_context(context: Any) -> None: + Set the context or environment information for the executor. + + get_params() -> Any: + Get the parameters associated with the executor. + + set_params(params: Any) -> None: + Set the parameters for the executor. + + set_id(id: str) -> None: + Set the identifier for the executor. + + set_neighbor_id_list(neighbor_id_list: List[str]) -> None: + Set the list of neighbor executor IDs. + + get_id() -> str: + Get the identifier of the executor. + + get_neighbor_id_list() -> List[str]: + Get the list of neighbor executor IDs. + """ + def __init__(self, id, neighbor_id_list): + """ + Initialize a FedMLExecutor. + + Args: + id (str): Identifier for the executor. + neighbor_id_list (List[str]): List of neighbor executor IDs. + """ self.id = id self.neighbor_id_list = neighbor_id_list self.params = None self.context = None - def get_context(self): + def get_context(self) -> Any: + """ + Get the context or environment information associated with the executor. + + Returns: + Any: The context or environment information. + """ return self.context - def set_context(self, context): + def set_context(self, context: Any) -> None: + """ + Set the context or environment information for the executor. + + Args: + context (Any): The context or environment information. + """ self.context = context - def get_params(self): + def get_params(self) -> Any: + """ + Get the parameters associated with the executor. + + Returns: + Any: The parameters. + """ return self.params - def set_params(self, params): + def set_params(self, params: Any) -> None: + """ + Set the parameters for the executor. + + Args: + params (Any): The parameters. + """ self.params = params - def set_id(self, id): + def set_id(self, id: str) -> None: + """ + Set the identifier for the executor. + + Args: + id (str): The identifier. + """ self.id = id - def set_neighbor_id_list(self, neighbor_id_list): + def set_neighbor_id_list(self, neighbor_id_list: List[str]) -> None: + """ + Set the list of neighbor executor IDs. + + Args: + neighbor_id_list (List[str]): List of neighbor executor IDs. + """ self.neighbor_id_list = neighbor_id_list - def get_id(self): + def get_id(self) -> str: + """ + Get the identifier of the executor. + + Returns: + str: The identifier. + """ return self.id - def get_neighbor_id_list(self): + def get_neighbor_id_list(self) -> List[str]: + """ + Get the list of neighbor executor IDs. + + Returns: + List[str]: List of neighbor executor IDs. + """ return self.neighbor_id_list diff --git a/python/fedml/core/distributed/flow/fedml_flow.py b/python/fedml/core/distributed/flow/fedml_flow.py index 0bf7dabb5f..1ab2ba5e5f 100644 --- a/python/fedml/core/distributed/flow/fedml_flow.py +++ b/python/fedml/core/distributed/flow/fedml_flow.py @@ -18,6 +18,46 @@ class FedMLAlgorithmFlow(FedMLCommManager): + """ + Base class for defining the flow of a federated machine learning algorithm. + + Args: + args: Arguments for initializing the algorithm flow. + executor (FedMLExecutor): An instance of a FedMLExecutor class to execute tasks within the flow. + + Attributes: + ONCE (str): Flow tag indicating that the flow should run once. + FINISH (str): Flow tag indicating the end of the flow. + executor (FedMLExecutor): An instance of a FedMLExecutor class. + executor_cls_name (str): Name of the executor class. + flow_index (int): Index to keep track of flow sequences. + flow_sequence_original (list): List to store the original flow sequence. + flow_sequence_current_map (dict): Mapping of current flow sequences. + flow_sequence_next_map (dict): Mapping of next flow sequences. + flow_sequence_executed (list): List to store executed flow sequences. + neighbor_node_online_map (dict): Mapping of neighbor node online status. + is_all_neighbor_connected (bool): Flag to indicate if all neighbor nodes are connected. + + Methods: + register_message_receive_handlers(): Register message receive handlers for different message types. + add_flow(flow_name, executor_task, flow_tag=ONCE): Add a flow to the algorithm. + run(): Start running the algorithm flow. + build(): Build the flow sequence and prepare for execution. + _on_ready_to_run_flow(): Handle when the algorithm is ready to run. + _handle_message_received(msg_params): Handle received messages within the flow. + _execute_flow(flow_params, flow_name, executor_task, executor_task_cls_name, flow_tag): Execute a flow task. + __direct_to_next_flow(flow_name, flow_tag): Get the details of the next flow in the sequence. + _send_msg(flow_name, params): Send a message to other nodes. + _handle_flow_finish(msg_params): Handle the finish of the algorithm flow. + __shutdown(): Shutdown the algorithm flow. + _pass_message_locally(flow_name, params): Pass a message to a locally executed flow. + _handle_connection_ready(msg_params): Handle the readiness of the algorithm to run. + _handle_neighbor_report_node_status(msg_params): Handle neighbor nodes reporting their online status. + _handle_neighbor_check_node_status(msg_params): Handle checking of neighbor node status. + _send_message_to_check_neighbor_node_status(receiver_id): Send a message to check neighbor node status. + _send_message_to_report_node_status(receiver_id): Send a message to report node status. + _get_class_that_defined_method(meth): Get the class that defined a method. + """ ONCE = "FLOW_TAG_ONCE" FINISH = "FLOW_TAG_FINISH" @@ -39,6 +79,14 @@ def __init__(self, args, executor: FedMLExecutor): self.is_all_neighbor_connected = False def register_message_receive_handlers(self) -> None: + """ + Register message receive handlers for various message types. + + This method registers message handlers for messages related to the algorithm flow. + + Returns: + None + """ self.register_message_receive_handler(MSG_TYPE_CONNECTION_IS_READY, self._handle_connection_ready) self.register_message_receive_handler( MSG_TYPE_NEIGHBOR_CHECK_NODE_STATUS, self._handle_neighbor_check_node_status, @@ -64,6 +112,17 @@ def register_message_receive_handlers(self) -> None: self.register_message_receive_handler(flow_name, self._handle_message_received) def add_flow(self, flow_name, executor_task: Callable, flow_tag=ONCE): + """ + Add a flow to the algorithm's flow sequence. + + Args: + flow_name (str): Name of the flow. + executor_task (Callable): Callable function representing the task to be executed in the flow. + flow_tag (str): Tag indicating the type of flow (ONCE or FINISH). + + Returns: + None + """ logging.info("flow_name = {}, executor_task = {}".format(flow_name, executor_task)) executor_task_cls_name = self._get_class_that_defined_method(executor_task) @@ -72,9 +131,21 @@ def add_flow(self, flow_name, executor_task: Callable, flow_tag=ONCE): self.flow_index += 1 def run(self): + """ + Start running the algorithm flow. + + Returns: + None + """ super().run() def build(self): + """ + Build the flow sequence and prepare for execution. + + Returns: + None + """ logging.info("self.flow_sequence = {}".format(self.flow_sequence_original)) (flow_name, executor_task, executor_task_cls_name, flow_tag,) = self.flow_sequence_original[ len(self.flow_sequence_original) - 1 @@ -114,6 +185,12 @@ def build(self): logging.info("self.flow_sequence_next_map = {}".format(self.flow_sequence_next_map)) def _on_ready_to_run_flow(self): + """ + Handle when the algorithm is ready to run. + + Returns: + None + """ logging.info("#######_on_ready_to_run_flow#######") ( flow_name_current, @@ -127,6 +204,15 @@ def _on_ready_to_run_flow(self): ) def _handle_message_received(self, msg_params): + """ + Handle received messages within the flow. + + Args: + msg_params (Params): Parameters received in the message. + + Returns: + None + """ flow_name = msg_params.get_type() flow_params = Params() @@ -141,6 +227,19 @@ def _handle_message_received(self, msg_params): self._execute_flow(flow_params, flow_name_next, executor_task_next, executor_task_cls_name_next, flow_tag_next) def _execute_flow(self, flow_params, flow_name, executor_task, executor_task_cls_name, flow_tag): + """ + Execute a flow task. + + Args: + flow_params (Params): Parameters for the flow. + flow_name (str): Name of the flow. + executor_task (Callable): Callable function representing the task to be executed. + executor_task_cls_name (str): Name of the executor task's class. + flow_tag (str): Tag indicating the type of flow (ONCE or FINISH). + + Returns: + None + """ logging.info( "\n\n###########_execute_flow (START). flow_name = {}, executor_task name = {}() #######".format( flow_name, executor_task.__name__ @@ -183,6 +282,16 @@ def _execute_flow(self, flow_params, flow_name, executor_task, executor_task_cls self._send_msg(flow_name, params) def __direct_to_next_flow(self, flow_name, flow_tag): + """ + Determine the next flow to execute based on the current flow. + + Args: + flow_name (str): Name of the current flow. + flow_tag (str): Tag indicating the type of flow (ONCE or FINISH). + + Returns: + Tuple: A tuple containing the name, executor task, executor task class name, and flow tag of the next flow. + """ ( flow_name_next, executor_task_next, @@ -197,6 +306,16 @@ def __direct_to_next_flow(self, flow_name, flow_tag): ) def _send_msg(self, flow_name, params: Params): + """ + Send a message to one or more receivers. + + Args: + flow_name (str): Name of the flow. + params (Params): Parameters to be included in the message. + + Returns: + None + """ sender_id = params.get(PARAMS_KEY_SENDER_ID) receiver_id = params.get(PARAMS_KEY_RECEIVER_ID) logging.info("sender_id = {}, receiver_id = {}".format(sender_id, receiver_id)) @@ -211,9 +330,24 @@ def _send_msg(self, flow_name, params: Params): self.send_message(message) def _handle_flow_finish(self, msg_params): + """ + Handle the completion of the algorithm flow. + + Args: + msg_params (Params): Parameters received in the completion message. + + Returns: + None + """ self.__shutdown() def __shutdown(self): + """ + Shutdown the algorithm flow and terminate communication. + + Returns: + None + """ for rid in self.executor.get_neighbor_id_list(): message = Message(MSG_TYPE_FLOW_FINISH, self.executor.get_id(), rid) self.send_message(message) @@ -221,6 +355,16 @@ def __shutdown(self): self.finish() def _pass_message_locally(self, flow_name, params: Params): + """ + Pass a message locally to be handled within the algorithm. + + Args: + flow_name (str): Name of the flow. + params (Params): Parameters to be included in the message. + + Returns: + None + """ sender_id = params.get(PARAMS_KEY_SENDER_ID) receiver_id = params.get(PARAMS_KEY_RECEIVER_ID) logging.info("sender_id = {}, receiver_id = {}".format(sender_id, receiver_id)) @@ -235,6 +379,15 @@ def _pass_message_locally(self, flow_name, params: Params): self._handle_message_received(message) def _handle_connection_ready(self, msg_params): + """ + Handle the readiness of connections with neighbors. + + Args: + msg_params (Params): Parameters received indicating connection readiness. + + Returns: + None + """ if self.is_all_neighbor_connected: return logging.info("_handle_connection_ready") @@ -243,6 +396,15 @@ def _handle_connection_ready(self, msg_params): self._send_message_to_report_node_status(receiver_id) def _handle_neighbor_report_node_status(self, msg_params): + """ + Handle the reporting of neighbor node statuses. + + Args: + msg_params (Params): Parameters received with neighbor node status information. + + Returns: + None + """ sender_id = msg_params.get_sender_id() logging.info( "_handle_neighbor_report_node_status. node_id = {}, neighbor_id = {} is online".format( @@ -262,10 +424,28 @@ def _handle_neighbor_report_node_status(self, msg_params): self._on_ready_to_run_flow() def _handle_neighbor_check_node_status(self, msg_params): + """ + Handle a message to check the status of a neighbor node. + + Args: + msg_params (Params): Parameters received in the check node status message. + + Returns: + None + """ sender_id = msg_params.get_sender_id() self._send_message_to_report_node_status(sender_id) def _send_message_to_check_neighbor_node_status(self, receiver_id): + """ + Send a message to check the status of a neighbor node. + + Args: + receiver_id (int): ID of the receiver neighbor node. + + Returns: + None + """ message = Message(MSG_TYPE_NEIGHBOR_CHECK_NODE_STATUS, self.executor.get_id(), receiver_id) logging.info( "_send_message_to_check_neighbor_node_status. node_id = {}, neighbor_id = {} is online".format( @@ -275,10 +455,28 @@ def _send_message_to_check_neighbor_node_status(self, receiver_id): self.send_message(message) def _send_message_to_report_node_status(self, receiver_id): + """ + Send a message to report the node status to a neighbor node. + + Args: + receiver_id (int): ID of the receiver neighbor node. + + Returns: + None + """ message = Message(MSG_TYPE_NEIGHBOR_REPORT_NODE_STATUS, self.executor.get_id(), receiver_id) self.send_message(message) def _get_class_that_defined_method(self, meth): + """ + Get the name of the class that defines a method. + + Args: + meth (method/function): The method or function to determine the defining class. + + Returns: + str: The name of the defining class. + """ if inspect.ismethod(meth): for cls in inspect.getmro(meth.__self__.__class__): if cls.__dict__.get(meth.__name__) is meth: diff --git a/python/fedml/core/distributed/flow/test_fedml_flow.py b/python/fedml/core/distributed/flow/test_fedml_flow.py index f1e0ea86de..8e777a24e9 100644 --- a/python/fedml/core/distributed/flow/test_fedml_flow.py +++ b/python/fedml/core/distributed/flow/test_fedml_flow.py @@ -7,6 +7,16 @@ class Client(FedMLExecutor): def __init__(self, args): + """ + Initialize the Client object. + + Args: + args: Command-line arguments or configuration settings. + + Returns: + None + """ + self.args = args id = args.rank neighbor_id_list = [0] @@ -17,17 +27,40 @@ def __init__(self, args): self.model = None def init(self, device, dataset, model): + """ + Initialize the client with device, dataset, and model. + + Args: + device: The device (e.g., CPU or GPU) for training. + dataset: The dataset used for training. + model: The machine learning model used for training. + + Returns: + None + """ self.device = device self.dataset = dataset self.model = model def local_training(self): + """ + Perform local training on the client. + + Returns: + Params: Parameters containing model updates or other relevant information. + """ logging.info("local_training start") params = self.get_params() model_params = params.get(Params.KEY_MODEL_PARAMS) return params def handle_init_global_model(self): + """ + Handle the initialization of the global model on the client. + + Returns: + Params: Parameters containing the model parameters. + """ received_params = self.get_params() model_params = received_params.get(Params.KEY_MODEL_PARAMS) @@ -38,6 +71,15 @@ def handle_init_global_model(self): class Server(FedMLExecutor): def __init__(self, args): + """ + Initialize the Server object. + + Args: + args: Command-line arguments or configuration settings. + + Returns: + None + """ self.args = args id = args.rank neighbor_id_list = [1, 2] @@ -53,17 +95,41 @@ def __init__(self, args): self.client_num = 2 def init(self, device, dataset, model): + """ + Initialize the server with device, dataset, and model. + + Args: + device: The device (e.g., CPU or GPU) for server operations. + dataset: The dataset used for server operations. + model: The machine learning model used for server operations. + + Returns: + None + """ + self.device = device self.dataset = dataset self.model = model def init_global_model(self): + """ + Initialize the global model on the server. + + Returns: + Params: Parameters containing the initial model parameters. + """ logging.info("init_global_model") params = Params() params.add(Params.KEY_MODEL_PARAMS, self.model.state_dict()) return params def server_aggregate(self): + """ + Perform server-side aggregation of client updates. + + Returns: + Params: Parameters containing the aggregated model updates. + """ logging.info("server_aggregate") params = self.get_params() model_params = params.get(Params.KEY_MODEL_PARAMS) @@ -77,6 +143,12 @@ def server_aggregate(self): return params def final_eval(self): + """ + Perform final evaluation or operations on the server. + + Returns: + None + """ logging.info("final_eval") diff --git a/python/fedml/core/distributed/topology/asymmetric_topology_manager.py b/python/fedml/core/distributed/topology/asymmetric_topology_manager.py index c85737608a..2f0abcfab3 100644 --- a/python/fedml/core/distributed/topology/asymmetric_topology_manager.py +++ b/python/fedml/core/distributed/topology/asymmetric_topology_manager.py @@ -15,12 +15,29 @@ class AsymmetricTopologyManager(BaseTopologyManager): """ def __init__(self, n, undirected_neighbor_num=3, out_directed_neighbor=3): + """ + Initialize the AsymmetricTopologyManager. + + Args: + n (int): Number of nodes in the topology. + undirected_neighbor_num (int): Number of undirected (symmetric) neighbors for each node. + out_directed_neighbor (int): Number of out (asymmetric) neighbors for each node. + + Returns: + None + """ self.n = n self.undirected_neighbor_num = undirected_neighbor_num self.out_directed_neighbor = out_directed_neighbor self.topology = [] def generate_topology(self): + """ + Generate the topology based on the specified parameters. + + Returns: + None + """ # randomly add some links for each node (symmetric) k = self.undirected_neighbor_num # print("neighbors = " + str(k)) @@ -81,6 +98,15 @@ def generate_topology(self): self.topology = topology_ring def get_in_neighbor_weights(self, node_index): + """ + Get the weights of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for incoming neighbors. + """ if node_index >= self.n: return [] in_neighbor_weights = [] @@ -89,11 +115,29 @@ def get_in_neighbor_weights(self, node_index): return in_neighbor_weights def get_out_neighbor_weights(self, node_index): + """ + Get the weights of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for outgoing neighbors. + """ if node_index >= self.n: return [] return self.topology[node_index] def get_in_neighbor_idx_list(self, node_index): + """ + Get the indices of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for incoming neighbors. + """ neighbor_in_idx_list = [] neighbor_weights = self.get_in_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): @@ -102,6 +146,16 @@ def get_in_neighbor_idx_list(self, node_index): return neighbor_in_idx_list def get_out_neighbor_idx_list(self, node_index): + """ + Get the indices of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for outgoing neighbors. + """ + neighbor_out_idx_list = [] neighbor_weights = self.get_out_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): diff --git a/python/fedml/core/distributed/topology/symmetric_topology_manager.py b/python/fedml/core/distributed/topology/symmetric_topology_manager.py index 07d90525e4..9f5326ab08 100644 --- a/python/fedml/core/distributed/topology/symmetric_topology_manager.py +++ b/python/fedml/core/distributed/topology/symmetric_topology_manager.py @@ -14,11 +14,28 @@ class SymmetricTopologyManager(BaseTopologyManager): """ def __init__(self, n, neighbor_num=2): + """ + Initialize the SymmetricTopologyManager. + + Args: + n (int): Number of nodes in the topology. + neighbor_num (int): Number of neighbors for each node. + + Returns: + None + """ self.n = n self.neighbor_num = neighbor_num self.topology = [] def generate_topology(self): + """ + Generate the symmetric topology based on the specified parameters. + + Returns: + None + """ + # first generate a ring topology topology_ring = np.array( nx.to_numpy_matrix(nx.watts_strogatz_graph(self.n, 2, 0)), dtype=np.float32 @@ -56,16 +73,43 @@ def generate_topology(self): self.topology = topology_symmetric def get_in_neighbor_weights(self, node_index): + """ + Get the weights of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for incoming neighbors. + """ if node_index >= self.n: return [] return self.topology[node_index] def get_out_neighbor_weights(self, node_index): + """ + Get the weights of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[float]: List of weights for outgoing neighbors. + """ if node_index >= self.n: return [] return self.topology[node_index] def get_in_neighbor_idx_list(self, node_index): + """ + Get the indices of incoming neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for incoming neighbors. + """ neighbor_in_idx_list = [] neighbor_weights = self.get_in_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): @@ -74,6 +118,15 @@ def get_in_neighbor_idx_list(self, node_index): return neighbor_in_idx_list def get_out_neighbor_idx_list(self, node_index): + """ + Get the indices of outgoing neighbors for a given node. + + Args: + node_index (int): Index of the node. + + Returns: + List[int]: List of indices for outgoing neighbors. + """ neighbor_out_idx_list = [] neighbor_weights = self.get_out_neighbor_weights(node_index) for idx, neighbor_w in enumerate(neighbor_weights): From 1e9c0a06f8e02cea08a7e1314b9ba20593018bac Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Sun, 24 Sep 2023 09:30:36 +0530 Subject: [PATCH 68/70] thread --- .../communication/grpc/grpc_comm_manager.py | 133 ++++++-- .../grpc/grpc_comm_manager_pb2_grpc.py | 90 +++++- .../communication/mpi/com_manager.py | 64 +++- .../communication/mpi/mpi_receive_thread.py | 43 ++- .../communication/mpi/mpi_send_thread.py | 45 ++- .../communication/mqtt/mqtt_manager.py | 48 ++- .../mqtt_s3_multi_clients_comm_manager.py | 288 ++++++++++++++++-- .../mqtt_s3_mnn/mqtt_s3_comm_manager.py | 252 +++++++++++++-- 8 files changed, 875 insertions(+), 88 deletions(-) diff --git a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py index 6eb9fe613e..3931120f49 100644 --- a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py +++ b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager.py @@ -1,3 +1,12 @@ +import csv +import logging +from ...communication.grpc.grpc_server import GRPCCOMMServicer +import time +from fedml.core.mlops.mlops_profiler_event import MLOpsProfilerEvent +from ..constants import CommunicationConstants +from ...communication.observer import Observer +from ...communication.message import Message +from ...communication.base_com_manager import BaseCommunicationManager import os import pickle import threading @@ -10,21 +19,8 @@ lock = threading.Lock() -from ...communication.base_com_manager import BaseCommunicationManager -from ...communication.message import Message -from ...communication.observer import Observer -from ..constants import CommunicationConstants - -from fedml.core.mlops.mlops_profiler_event import MLOpsProfilerEvent - -import time # Check Service or serve? -from ...communication.grpc.grpc_server import GRPCCOMMServicer - -import logging - -import csv class GRPCCommManager(BaseCommunicationManager): @@ -37,6 +33,17 @@ def __init__( client_id=0, client_num=0, ): + """ + Initialize the GRPCCommManager. + + Args: + host (str): The IP address of the server. + port (int): The port number to listen on. + ip_config_path (str): The path to the IP configuration file. + topic (str, optional): The communication topic. Default is "fedml". + client_id (int, optional): The client's ID. Default is 0. + client_num (int, optional): The number of clients. Default is 0. + """ # host is the ip address of server self.host = host self.port = str(port) @@ -61,7 +68,8 @@ def __init__( futures.ThreadPoolExecutor(max_workers=client_num), options=self.opts, ) - self.grpc_servicer = GRPCCOMMServicer(host, port, client_num, client_id) + self.grpc_servicer = GRPCCOMMServicer( + host, port, client_num, client_id) grpc_comm_manager_pb2_grpc.add_gRPCCommManagerServicer_to_server( self.grpc_servicer, self.grpc_server ) @@ -76,13 +84,23 @@ def __init__( logging.info("grpc server started. Listening on port " + str(port)) def send_message(self, msg: Message): + """ + Send a message using gRPC to a specified receiver. + + Args: + msg (Message): The message to send. + + Returns: + None + """ logging.info("msg = {}".format(msg)) # payload = msg.to_json() logging.info("pickle.dumps(msg) START") pickle_dump_start_time = time.time() msg_pkl = pickle.dumps(msg) - MLOpsProfilerEvent.log_to_wandb({"PickleDumpsTime": time.time() - pickle_dump_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"PickleDumpsTime": time.time() - pickle_dump_start_time}) logging.info("pickle.dumps(msg) END") receiver_id = msg.get_receiver_id() @@ -103,27 +121,62 @@ def send_message(self, msg: Message): tick = time.time() stub.sendMessage(request) - MLOpsProfilerEvent.log_to_wandb({"Comm/send_delay": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Comm/send_delay": time.time() - tick}) logging.debug("sent successfully") channel.close() def add_observer(self, observer: Observer): + """ + Add an observer to the communication manager. + + Args: + observer (Observer): The observer to add. + + Returns: + None + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer from the communication manager. + + Args: + observer (Observer): The observer to remove. + + Returns: + None + """ self._observers.remove(observer) def handle_receive_message(self): + """ + Start handling received messages. + + This method initiates the process of receiving and handling messages. + + Returns: + None + """ self._notify_connection_ready() self.message_handling_subroutine() # Cannont run message_handling_subroutine in new thread # Related https://stackoverflow.com/a/70705165 - + # thread = threading.Thread(target=self.message_handling_subroutine) # thread.start() def message_handling_subroutine(self): + """ + Message handling subroutine. + + This method continuously processes received messages. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) while self.is_running: @@ -134,29 +187,58 @@ def message_handling_subroutine(self): logging.info("unpickle START") unpickle_start_time = time.time() msg = pickle.loads(msg_pkl) - MLOpsProfilerEvent.log_to_wandb({"UnpickleTime": time.time() - unpickle_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"UnpickleTime": time.time() - unpickle_start_time}) logging.info("unpickle END") msg_type = msg.get_type() for observer in self._observers: _message_handler_start_time = time.time() observer.receive_message(msg_type, msg) - MLOpsProfilerEvent.log_to_wandb({"MessageHandlerTime": time.time() - _message_handler_start_time}) - MLOpsProfilerEvent.log_to_wandb({"BusyTime": time.time() - busy_time_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"MessageHandlerTime": time.time() - _message_handler_start_time}) + MLOpsProfilerEvent.log_to_wandb( + {"BusyTime": time.time() - busy_time_start_time}) lock.release() time.sleep(0.0001) - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) return def stop_receive_message(self): + """ + Stop receiving and processing messages. + + This method stops the communication manager. + + Returns: + None + """ self.grpc_server.stop(None) self.is_running = False def notify(self, message: Message): + """ + Notify observers with a message. + + Args: + message (Message): The message to notify observers with. + + Returns: + None + """ msg_type = message.get_type() for observer in self._observers: observer.receive_message(msg_type, message) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + + This method notifies observers that the communication connection is ready. + + Returns: + None + """ msg_params = Message() msg_params.sender_id = self.rank msg_params.receiver_id = self.rank @@ -165,6 +247,15 @@ def _notify_connection_ready(self): observer.receive_message(msg_type, msg_params) def _build_ip_table(self, path): + """ + Build an IP configuration table from a CSV file. + + Args: + path (str): The path to the CSV file containing IP configuration data. + + Returns: + dict: A dictionary mapping receiver IDs to their corresponding IP addresses. + """ ip_config = dict() with open(path, newline="") as csv_file: csv_reader = csv.reader(csv_file) diff --git a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py index ec24e39df6..063167a020 100644 --- a/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py +++ b/python/fedml/core/distributed/communication/grpc/grpc_comm_manager_pb2_grpc.py @@ -6,10 +6,15 @@ class gRPCCommManagerStub(object): - """Missing associated documentation comment in .proto file.""" + """ + gRPC Communication Manager Stub. + + This class provides a client-side stub for interacting with the gRPC communication manager service. + """ def __init__(self, channel): - """Constructor. + """ + Initialize the gRPCCommManagerStub. Args: channel: A grpc.Channel. @@ -27,22 +32,53 @@ def __init__(self, channel): class gRPCCommManagerServicer(object): - """Missing associated documentation comment in .proto file.""" + """ + gRPC Communication Manager Servicer. + + This class defines the gRPC service methods for the communication manager. + """ def sendMessage(self, request, context): - """Missing associated documentation comment in .proto file.""" + """ + Handle the sendMessage gRPC service method. + + Args: + request: The request message. + context: The gRPC context. + + Raises: + NotImplementedError: This method is not implemented. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def handleReceiveMessage(self, request, context): - """Missing associated documentation comment in .proto file.""" + """ + Handle the handleReceiveMessage gRPC service method. + + Args: + request: The request message. + context: The gRPC context. + + Raises: + NotImplementedError: This method is not implemented. + """ context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") def add_gRPCCommManagerServicer_to_server(servicer, server): + """ + Add a gRPC Communication Manager Servicer to a gRPC server. + + This function registers the gRPC service methods provided by the servicer to the gRPC server. + + Args: + servicer: The gRPC Communication Manager Servicer instance. + server: The gRPC server instance to which the servicer will be added. + """ rpc_method_handlers = { "sendMessage": grpc.unary_unary_rpc_method_handler( servicer.sendMessage, @@ -63,7 +99,13 @@ def add_gRPCCommManagerServicer_to_server(servicer, server): # This class is part of an EXPERIMENTAL API. class gRPCCommManager(object): - """Missing associated documentation comment in .proto file.""" + """ + gRPC Communication Manager. + + This class provides static methods for making gRPC calls to the Communication Manager service. + + Note: This class is part of an experimental API. + """ @staticmethod def sendMessage( @@ -78,6 +120,24 @@ def sendMessage( timeout=None, metadata=None, ): + """ + Send a gRPC sendMessage request. + + Args: + request: The request message. + target: The target server to send the request. + options: Additional gRPC options. + channel_credentials: Channel credentials. + call_credentials: Call credentials. + insecure: Whether to use an insecure channel. + compression: Compression method to use. + wait_for_ready: Wait for the server to become ready. + timeout: Request timeout. + metadata: Request metadata. + + Returns: + grpc.Call: A gRPC call instance. + """ return grpc.experimental.unary_unary( request, target, @@ -107,6 +167,24 @@ def handleReceiveMessage( timeout=None, metadata=None, ): + """ + Send a gRPC handleReceiveMessage request. + + Args: + request: The request message. + target: The target server to send the request. + options: Additional gRPC options. + channel_credentials: Channel credentials. + call_credentials: Call credentials. + insecure: Whether to use an insecure channel. + compression: Compression method to use. + wait_for_ready: Wait for the server to become ready. + timeout: Request timeout. + metadata: Request metadata. + + Returns: + grpc.Call: A gRPC call instance. + """ return grpc.experimental.unary_unary( request, target, diff --git a/python/fedml/core/distributed/communication/mpi/com_manager.py b/python/fedml/core/distributed/communication/mpi/com_manager.py index 030b8793ad..e02666ea65 100644 --- a/python/fedml/core/distributed/communication/mpi/com_manager.py +++ b/python/fedml/core/distributed/communication/mpi/com_manager.py @@ -12,7 +12,21 @@ class MpiCommunicationManager(BaseCommunicationManager): + """ + MPI Communication Manager. + + This class manages communication using MPI (Message Passing Interface) for federated learning. + """ + def __init__(self, comm, rank, size): + """ + Initialize the MPI Communication Manager. + + Args: + comm: The MPI communicator. + rank: The rank of the current process. + size: The total number of processes in the communicator. + """ self.comm = comm self.rank = rank self.size = size @@ -39,6 +53,12 @@ def __init__(self, comm, rank, size): # assert False def init_server_communication(self): + """ + Initialize server-side communication components. + + Returns: + Tuple: A tuple containing server send and receive queues. + """ server_send_queue = queue.Queue(0) # self.server_send_thread = MPISendThread( # self.comm, self.rank, self.size, "ServerSendThread", server_send_queue @@ -54,6 +74,12 @@ def init_server_communication(self): return server_send_queue, server_receive_queue def init_client_communication(self): + """ + Initialize client-side communication components. + + Returns: + Tuple: A tuple containing client send and receive queues. + """ # SEND client_send_queue = queue.Queue(0) # self.client_send_thread = MPISendThread( @@ -75,19 +101,43 @@ def init_client_communication(self): # self.q_sender.put(msg) def send_message(self, msg: Message): + """ + Send a message using MPI. + + Args: + msg: The message to be sent. + """ # self.q_sender.put(msg) dest_id = msg.get(Message.MSG_ARG_KEY_RECEIVER) tick = time.time() self.comm.send(msg, dest=dest_id) - MLOpsProfilerEvent.log_to_wandb({"Comm/send_delay": time.time() - tick}) + MLOpsProfilerEvent.log_to_wandb( + {"Comm/send_delay": time.time() - tick}) def add_observer(self, observer: Observer): + """ + Add an observer to the list of observers. + + Args: + observer: The observer to be added. + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer from the list of observers. + + Args: + observer: The observer to be removed. + """ self._observers.remove(observer) def handle_receive_message(self): + """ + Handle receiving messages using MPI. + + This function continuously listens for incoming messages and notifies observers when a message is received. + """ self.is_running = True # the first message after connection, aligned the protocol with MQTT + S3 self._notify_connection_ready() @@ -108,6 +158,9 @@ def handle_receive_message(self): logging.info("!!!!!!handle_receive_message stopped!!!") def stop_receive_message(self): + """ + Stop receiving messages and threads. + """ self.is_running = False # self.__stop_thread(self.server_send_thread) self.__stop_thread(self.server_receive_thread) @@ -117,11 +170,20 @@ def stop_receive_message(self): self.__stop_thread(self.client_collective_thread) def notify(self, msg_params): + """ + Notify observers with the received message. + + Args: + msg_params: The received message. + """ msg_type = msg_params.get_type() for observer in self._observers: observer.receive_message(msg_type, msg_params) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + """ msg_params = Message() msg_params.sender_id = self.rank msg_params.receiver_id = self.rank diff --git a/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py b/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py index b10f4d52ff..63ca19fbe0 100644 --- a/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py +++ b/python/fedml/core/distributed/communication/mpi/mpi_receive_thread.py @@ -7,7 +7,23 @@ class MPIReceiveThread(threading.Thread): + """ + MPI Receive Thread. + + This thread is responsible for receiving messages using MPI. + """ + def __init__(self, comm, rank, size, name, q): + """ + Initialize the MPI Receive Thread. + + Args: + comm: The MPI communicator. + rank: The rank of the current process. + size: The total number of processes in the communicator. + name: The name of the thread. + q: The message queue to store received messages. + """ super(MPIReceiveThread, self).__init__() self._stop_event = threading.Event() self.comm = comm @@ -17,28 +33,44 @@ def __init__(self, comm, rank, size, name, q): self.q = q def run(self): + """ + Run the MPI Receive Thread. + + This method continuously listens for incoming messages and puts them into the message queue. + """ logging.debug( "Starting Thread:" + self.name + ". Process ID = " + str(self.rank) ) while True: try: msg = self.comm.recv() - # Ugly delete comments - # msg_str = self.comm.recv() - # msg = Message() - # msg.init(msg_str) self.q.put(msg) except Exception: traceback.print_exc() raise Exception("MPI failed!") def stop(self): + """ + Stop the MPI Receive Thread. + """ self._stop_event.set() def stopped(self): + """ + Check if the MPI Receive Thread is stopped. + + Returns: + bool: True if the thread is stopped, False otherwise. + """ return self._stop_event.is_set() def get_id(self): + """ + Get the ID of the thread. + + Returns: + int: The ID of the thread. + """ # returns id of the respective thread if hasattr(self, "_thread_id"): return self._thread_id @@ -47,6 +79,9 @@ def get_id(self): return id def raise_exception(self): + """ + Raise an exception in the MPI Receive Thread to stop it. + """ thread_id = self.get_id() res = ctypes.pythonapi.PyThreadState_SetAsyncExc( thread_id, ctypes.py_object(SystemExit) diff --git a/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py b/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py index 39ebb5599a..6d67938fd4 100644 --- a/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py +++ b/python/fedml/core/distributed/communication/mpi/mpi_send_thread.py @@ -1,6 +1,3 @@ -# Ugly delete file - - import ctypes import logging import threading @@ -11,7 +8,23 @@ class MPISendThread(threading.Thread): + """ + MPI Send Thread. + + This thread is responsible for sending messages using MPI. + """ + def __init__(self, comm, rank, size, name, q): + """ + Initialize the MPI Send Thread. + + Args: + comm: The MPI communicator. + rank: The rank of the current process. + size: The total number of processes in the communicator. + name: The name of the thread. + q: The message queue to get messages to send. + """ super(MPISendThread, self).__init__() self._stop_event = threading.Event() self.comm = comm @@ -21,7 +34,13 @@ def __init__(self, comm, rank, size, name, q): self.q = q def run(self): - logging.debug("Starting " + self.name + ". Process ID = " + str(self.rank)) + """ + Run the MPI Send Thread. + + This method continuously checks the message queue and sends messages to the specified destination. + """ + logging.debug("Starting " + self.name + + ". Process ID = " + str(self.rank)) while True: try: if not self.q.empty(): @@ -35,12 +54,27 @@ def run(self): raise Exception("MPI failed!") def stop(self): + """ + Stop the MPI Send Thread. + """ self._stop_event.set() def stopped(self): + """ + Check if the MPI Send Thread is stopped. + + Returns: + bool: True if the thread is stopped, False otherwise. + """ return self._stop_event.is_set() def get_id(self): + """ + Get the ID of the thread. + + Returns: + int: The ID of the thread. + """ # returns id of the respective thread if hasattr(self, "_thread_id"): return self._thread_id @@ -49,6 +83,9 @@ def get_id(self): return id def raise_exception(self): + """ + Raise an exception in the MPI Send Thread to stop it. + """ thread_id = self.get_id() res = ctypes.pythonapi.PyThreadState_SetAsyncExc( thread_id, ctypes.py_object(SystemExit) diff --git a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py index 93a8463fbd..18c1064165 100644 --- a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py +++ b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py @@ -13,6 +13,19 @@ class MqttManager(object): def __init__(self, host, port, user, pwd, keepalive_time, client_id, last_will_topic=None, last_will_msg=None): + """ + MQTT Manager for handling MQTT connections, sending, and receiving messages. + + Args: + host (str): MQTT broker host. + port (int): MQTT broker port. + user (str): MQTT username. + pwd (str): MQTT password. + keepalive_time (int): Keepalive time for the MQTT connection. + client_id (str): Client ID for the MQTT client. + last_will_topic (str, optional): Last will topic for the MQTT client. + last_will_msg (str, optional): Last will message for the MQTT client. + """ self._client = None self.mqtt_connection_id = None self._host = host @@ -44,6 +57,7 @@ def __del__(self): self._client = None def init_connect(self): + self.mqtt_connection_id = "{}_{}".format(self._client_id, "ID") self._client = mqtt.Client(client_id=self.mqtt_connection_id, clean_session=False) self._client.connected_flag = False @@ -90,9 +104,20 @@ def loop_forever(self): self._client.loop_forever(retry_first_connection=True) def send_message(self, topic, message, publish_single_message=False): - # logging.info( - # f"FedMLDebug - Send: topic ({topic}), message ({message})" - # ) + """ + Send an MQTT message. + + Args: + topic (str): The MQTT topic to which the message will be sent. + message (str): The message to send. + publish_single_message (bool, optional): If True, publish as a single message; otherwise, use MQTT publish. + + Returns: + bool: True if the message was successfully sent, False otherwise. + """ + logging.info( + f"FedMLDebug - Send: topic ({topic}), message ({message})" + ) self.check_connection() mqtt_send_start_time = time.time() @@ -110,9 +135,20 @@ def send_message(self, topic, message, publish_single_message=False): return True def send_message_json(self, topic, message, publish_single_message=False): - # logging.info( - # f"FedMLDebug - Send: topic ({topic}), message ({message})" - # ) + """ + Send an MQTT message as JSON. + + Args: + topic (str): The MQTT topic to which the message will be sent. + message (str): The message to send as JSON. + publish_single_message (bool, optional): If True, publish as a single message; otherwise, use MQTT publish. + + Returns: + bool: True if the message was successfully sent, False otherwise. + """ + logging.info( + f"FedMLDebug - Send: topic ({topic}), message ({message})" + ) self.check_connection() if publish_single_message: diff --git a/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py b/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py index 50f0908ad2..454fef71c8 100755 --- a/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py +++ b/python/fedml/core/distributed/communication/mqtt_s3/mqtt_s3_multi_clients_comm_manager.py @@ -18,6 +18,44 @@ class MqttS3MultiClientsCommManager(BaseCommunicationManager): + """ + MQTT communication manager for multi-client federated learning. + + This class provides an MQTT-based communication manager for multi-client federated learning scenarios. + It supports communication between a central server and multiple client devices. + + Args: + config_path (str): Path to the MQTT configuration file. + s3_config_path (str): Path to the S3 storage configuration file. + topic (str): The MQTT topic prefix. + client_rank (int): The rank or ID of the client. + client_num (int): The total number of clients. + args (object): Additional configuration arguments. + + Attributes: + client_id (str): The unique ID of the MQTT client. + topic (str): The MQTT topic. + is_connected (bool): Indicates if the MQTT client is connected to the broker. + client_active_list (dict): A dictionary to store the status of connected clients. + + Methods: + run_loop_forever(): Run the MQTT loop forever to handle incoming messages. + on_connected(mqtt_client_object): MQTT on_connected callback. + on_disconnected(mqtt_client_object): MQTT on_disconnected callback. + add_observer(observer): Add an observer to receive messages. + remove_observer(observer): Remove an observer. + send_message(msg, wait_for_publish=False): Send a message using MQTT. + send_message_json(topic_name, json_message): Send a JSON message using MQTT. + handle_receive_message(): Start handling received messages by running the MQTT loop. + stop_receive_message(): Stop receiving messages and disconnect from MQTT. + set_config_from_file(config_file_path): Load MQTT configuration from a file. + set_config_from_objects(mqtt_config): Set MQTT configuration from objects. + callback_client_last_will_msg(topic, payload): Callback for client last will message. + callback_client_active_msg(topic, payload): Callback for client active message. + subscribe_client_status_message(): Subscribe to client status messages. + get_client_status(client_id): Get the status of a specific client. + get_client_list_status(): Get the status of all clients. + """ def __init__( self, @@ -28,6 +66,17 @@ def __init__( client_num=0, args=None ): + """ + Initialize the MQTT communication manager. + + Args: + config_path (str): Path to the MQTT configuration file. + s3_config_path (str): Path to the S3 storage configuration file. + topic (str): The MQTT topic prefix. + client_rank (int): The rank or ID of the client. + client_num (int): The total number of clients. + args (object): Additional configuration arguments. + """ self.args = args self.broker_port = None self.broker_host = None @@ -51,7 +100,8 @@ def __init__( self.client_real_ids = [] if args.client_id_list is not None: logging.info( - "MqttS3CommManager args client_id_list: " + str(args.client_id_list) + "MqttS3CommManager args client_id_list: " + + str(args.client_id_list) ) self.client_real_ids = json.loads(args.client_id_list) @@ -82,7 +132,8 @@ def __init__( self._observers: List[Observer] = [] - self._client_id = "FedML_CS_{}_{}_{}".format(str(args.run_id), str(self.edge_id), str(uuid.uuid4())) + self._client_id = "FedML_CS_{}_{}_{}".format( + str(args.run_id), str(self.edge_id), str(uuid.uuid4())) self.client_num = client_num logging.info("mqtt_s3.init: client_num = %d" % client_num) @@ -95,7 +146,8 @@ def __init__( if args.rank == 0: self.top_active_msg = CommunicationConstants.SERVER_TOP_ACTIVE_MSG self.topic_last_will_msg = CommunicationConstants.SERVER_TOP_LAST_WILL_MSG - self.last_will_msg = json.dumps({"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) + self.last_will_msg = json.dumps( + {"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) self.mqtt_mgr = MqttManager( config_path["BROKER_HOST"], config_path["BROKER_PORT"], @@ -114,25 +166,39 @@ def __init__( @property def client_id(self): + """ + Get the client ID. + + Returns: + str: The client ID. + """ return self._client_id @property def topic(self): + """ + Get the MQTT topic. + + Returns: + str: The MQTT topic. + """ return self._topic def run_loop_forever(self): + """ + Run the MQTT loop forever to handle incoming messages. + """ self.mqtt_mgr.loop_forever() def on_connected(self, mqtt_client_object): """ - [server] - sending message topic (publish): serverID_clientID - receiving message topic (subscribe): clientID + MQTT on_connected callback. - [client] - sending message topic (publish): clientID - receiving message topic (subscribe): serverID_clientID + This method is called when the MQTT client is connected to the broker. It handles + subscription to topics based on whether the current instance is a server or client. + Args: + mqtt_client_object: The MQTT client object. """ self.mqtt_mgr.add_message_passthrough_listener(self._on_message) @@ -143,7 +209,8 @@ def on_connected(self, mqtt_client_object): # logging.info("self.client_real_ids = {}".format(self.client_real_ids)) for client_rank in range(0, self.client_num): - real_topic = self._topic + str(self.client_real_ids[client_rank]) + real_topic = self._topic + \ + str(self.client_real_ids[client_rank]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) # logging.info( @@ -155,7 +222,8 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() else: # client - real_topic = self._topic + str(self.server_id) + "_" + str(self.client_real_ids[0]) + real_topic = self._topic + \ + str(self.server_id) + "_" + str(self.client_real_ids[0]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) self._notify_connection_ready() @@ -167,21 +235,56 @@ def on_connected(self, mqtt_client_object): self.is_connected = True def on_disconnected(self, mqtt_client_object): + """ + MQTT on_disconnected callback. + + This method is called when the MQTT client is disconnected from the broker. + + Args: + mqtt_client_object: The MQTT client object. + """ self.is_connected = False def add_observer(self, observer: Observer): + """ + Add an observer to receive messages. + + Args: + observer (Observer): The observer to add. + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer. + + Args: + observer (Observer): The observer to remove. + """ self._observers.remove(observer) def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + """ msg_params = Message() msg_type = CommunicationConstants.MSG_TYPE_CONNECTION_IS_READY for observer in self._observers: observer.receive_message(msg_type, msg_params) def _notify(self, msg_obj): + """ + Notify registered observers with a received message. + + This method parses the incoming message, extracts its type, and notifies all registered + observers with the message type and parameters. + + Args: + msg_obj (dict): The received message object. + + Returns: + None + """ msg_params = Message() msg_params.init_from_json_object(msg_obj) msg_type = msg_params.get_type() @@ -190,6 +293,18 @@ def _notify(self, msg_obj): observer.receive_message(msg_type, msg_params) def _on_message_impl(self, msg): + """ + Handle incoming MQTT messages. + + This method is called when an MQTT message is received. It parses the message payload, + processes it, and notifies observers with the received message. + + Args: + msg (paho.mqtt.client.MQTTMessage): The received MQTT message. + + Returns: + None + """ json_payload = str(msg.payload, encoding="utf-8") payload_obj = json.loads(json_payload) logging.info( @@ -218,16 +333,19 @@ def _on_message_impl(self, msg): elif self.dataSetType == 'cifar10': py_model = CNN_WEB() - model_params = self.s3_storage.read_model_web(s3_key_str, py_model) + model_params = self.s3_storage.read_model_web( + s3_key_str, py_model) else: model_params = self.s3_storage.read_model(s3_key_str) if not hasattr(self.args, "fa_task"): logging.info( - "mqtt_s3.on_message: model params length %d" % len(model_params) + "mqtt_s3.on_message: model params length %d" % len( + model_params) ) - model_url = payload_obj.get(Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") + model_url = payload_obj.get( + Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") logging.info("mqtt_s3.on_message: model url {}".format(model_url)) # replace the S3 object key with raw model params @@ -239,6 +357,19 @@ def _on_message_impl(self, msg): self._notify(payload_obj) def _on_message(self, msg): + """ + Send a message using MQTT. + + This method sends a message to the specified recipient using MQTT. The topic for publishing + the message is determined based on whether the current instance is a server or client. + + Args: + msg (Message): The message to be sent. + wait_for_publish (bool): Whether to wait for the message to be published. + + Returns: + bool: True if the message was sent successfully, False otherwise. + """ self._on_message_impl(msg) def send_message(self, msg: Message, wait_for_publish=False): @@ -259,7 +390,8 @@ def send_message(self, msg: Message, wait_for_publish=False): logging.info("mqtt_s3.send_message: msg topic = %s" % str(topic)) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") model_url = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") model_key = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_KEY, "") if model_params_obj != "": @@ -267,9 +399,11 @@ def send_message(self, msg: Message, wait_for_publish=False): if model_url == "": model_key = topic + "_" + str(uuid.uuid4()) if self.isBrowser: - model_url = self.s3_storage.write_model_web(model_key, model_params_obj) + model_url = self.s3_storage.write_model_web( + model_key, model_params_obj) else: - model_url = self.s3_storage.write_model(model_key, model_params_obj) + model_url = self.s3_storage.write_model( + model_key, model_params_obj) logging.info( "mqtt_s3.send_message: S3+MQTT msg sent, s3 message key = %s" @@ -290,14 +424,16 @@ def send_message(self, msg: Message, wait_for_publish=False): message_key = topic + "_" + str(uuid.uuid4()) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") if model_params_obj != "": # S3 logging.info( "mqtt_s3.send_message: S3+MQTT msg sent, message_key = %s" % message_key ) - model_url = self.s3_storage.write_model(message_key, model_params_obj) + model_url = self.s3_storage.write_model( + message_key, model_params_obj) model_params_key_url = { "key": message_key, "url": model_url, @@ -319,20 +455,60 @@ def send_message(self, msg: Message, wait_for_publish=False): return True def send_message_json(self, topic_name, json_message): + """ + Send a JSON message to a specified MQTT topic. + + Args: + topic_name (str): The MQTT topic to which the message will be sent. + json_message (str): The JSON-formatted message to send. + + Returns: + bool: True if the message was sent successfully, False otherwise. + """ return self.mqtt_mgr.send_message_json(topic_name, json_message) def handle_receive_message(self): + """ + Start listening for incoming MQTT messages and handle them. + + This method initiates the process of receiving and handling MQTT messages. + It runs a loop to continuously listen for messages and processes them until stopped. + + Returns: + None + """ start_listening_time = time.time() MLOpsProfilerEvent.log_to_wandb({"ListenStart": start_listening_time}) self.run_loop_forever() - MLOpsProfilerEvent.log_to_wandb({"TotalTime": time.time() - start_listening_time}) + MLOpsProfilerEvent.log_to_wandb( + {"TotalTime": time.time() - start_listening_time}) def stop_receive_message(self): + """ + Stop listening for incoming MQTT messages and disconnect from the MQTT broker. + + This method stops the MQTT message listening loop and disconnects from the MQTT broker. + + Returns: + None + """ logging.info("mqtt_s3.stop_receive_message: stopping...") self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() def set_config_from_file(self, config_file_path): + """ + Load MQTT configuration settings from a YAML file. + + This method reads MQTT configuration settings, including the broker host, port, username, and + password, from a YAML file and updates the instance variables accordingly. + + Args: + config_file_path (str): The path to the YAML configuration file. + + Returns: + None + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -348,6 +524,18 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, mqtt_config): + """ + Set MQTT configuration settings from a dictionary. + + This method sets the MQTT configuration settings, including the broker host, port, username, + and password, from a dictionary object. + + Args: + mqtt_config (dict): A dictionary containing MQTT configuration settings. + + Returns: + None + """ self.broker_host = mqtt_config["BROKER_HOST"] self.broker_port = mqtt_config["BROKER_PORT"] self.mqtt_user = None @@ -358,21 +546,58 @@ def set_config_from_objects(self, mqtt_config): self.mqtt_pwd = mqtt_config["MQTT_PWD"] def callback_client_last_will_msg(self, topic, payload): + """ + Handle the last will message from a client. + + This method processes the last will message received from a client and updates the client's + status accordingly. + + Args: + topic (str): The MQTT topic on which the last will message was received. + payload (str): The payload of the last will message. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) if edge_id is not None and status == CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE: if self.client_active_list.get(edge_id, None) is not None: self.client_active_list.pop(edge_id) def callback_client_active_msg(self, topic, payload): + """ + Handle the active status message from a client. + + This method processes the active status message received from a client and updates the client's + status in the active list. + + Args: + topic (str): The MQTT topic on which the active status message was received. + payload (str): The payload of the active status message. + + Returns: + None + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) if edge_id is not None: self.client_active_list[edge_id] = status def subscribe_client_status_message(self): + """ + Subscribe to client status messages. + + This method sets up MQTT message listeners to handle both last will messages and active status + messages from clients. + + Returns: + None + """ # Setup MQTT message listener to the last will message form the client. self.mqtt_mgr.add_message_listener(self.topic_last_will_msg, self.callback_client_last_will_msg) @@ -382,7 +607,26 @@ def subscribe_client_status_message(self): self.callback_client_active_msg) def get_client_status(self, client_id): + """ + Get the status of a specific client. + + This method retrieves the status of a client based on its ID from the client active list. + + Args: + client_id (str): The ID of the client. + + Returns: + str: The status of the client, e.g., 'offline' or 'idle'. + """ return self.client_active_list.get(client_id, CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) def get_client_list_status(self): + """ + Get the status of all connected clients. + + This method returns the entire client active list, containing the statuses of all connected clients. + + Returns: + dict: A dictionary mapping client IDs to their statuses. + """ return self.client_active_list diff --git a/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py b/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py index 9361bb2bf5..6c5aa9b188 100755 --- a/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py +++ b/python/fedml/core/distributed/communication/mqtt_s3_mnn/mqtt_s3_comm_manager.py @@ -17,6 +17,67 @@ class MqttS3MNNCommManager(BaseCommunicationManager): + """ + MQTT-S3-based Communication Manager for Federated Learning. + + This communication manager uses MQTT-S3 for message communication and S3 for model storage. + + Args: + config_path (str): Path to the configuration file. + s3_config_path (str): Path to the S3 configuration file. + topic (str, optional): MQTT topic. Default is "fedml". + client_id (int, optional): Client ID. Default is 0. + client_num (int, optional): Number of clients. Default is 0. + args (Namespace, optional): Command-line arguments. + bind_port (int, optional): Port to bind. Default is 0. + + Attributes: + mqtt_pwd (str): MQTT password. + mqtt_user (str): MQTT username. + broker_port (int): MQTT broker port. + broker_host (str): MQTT broker host. + keepalive_time (int): MQTT keepalive time. + args (Namespace): Command-line arguments. + rank (int): Client rank. + _topic (str): MQTT topic. + s3_storage (S3MNNStorage): S3 storage. + client_real_ids (list): List of real client IDs. + group_server_id_list (str): Group server ID list. + edge_id (int): Edge ID. + server_id (int): Server ID. + _observers (list): List of observers. + _client_id (str): Client ID. + client_num (int): Number of clients. + client_active_list (dict): Dictionary to track client activity status. + top_active_msg (str): Top-level active message topic. + topic_last_will_msg (str): Topic for last will message. + last_will_msg (str): Last will message. + mqtt_mgr (MqttManager): MQTT manager. + + Methods: + run_loop_forever(self): Run the MQTT loop indefinitely. + __del__(self): Destructor to stop the MQTT loop and disconnect. + on_connected(self, mqtt_client_object): MQTT on_connected callback. + on_disconnected(self, mqtt_client_object): MQTT on_disconnected callback. + add_observer(self, observer: Observer): Add an observer to receive messages. + remove_observer(self, observer: Observer): Remove an observer. + _notify(self, msg_obj): Notify observers with a message. + _on_message_impl(self, msg): Handle incoming MQTT messages. + _on_message(self, msg): Wrapper for handling incoming MQTT messages. + send_message(self, msg: Message): Send a message using MQTT. + send_message_json(self, topic_name, json_message): Send a JSON message using MQTT. + handle_receive_message(self): Start handling received messages. + stop_receive_message(self): Stop receiving messages and disconnect from MQTT. + set_config_from_file(self, config_file_path): Load MQTT configuration from a file. + set_config_from_objects(self, mqtt_config): Set MQTT configuration from objects. + _notify_connection_ready(self): Notify observers that the connection is ready. + callback_client_last_will_msg(self, topic, payload): Callback for client last will message. + callback_client_active_msg(self, topic, payload): Callback for client active message. + subscribe_client_status_message(self): Subscribe to client status messages. + get_client_status(self, client_id): Get the status of a specific client. + get_client_list_status(self): Get the status of all clients. + """ + def __init__( self, config_path, @@ -27,6 +88,18 @@ def __init__( args=None, bind_port=0, ): + """ + Initialize the MqttS3MNNCommManager. + + Args: + config_path (str): Path to the configuration file. + s3_config_path (str): Path to the S3 configuration file. + topic (str, optional): MQTT topic. Default is "fedml". + client_id (int, optional): Client ID. Default is 0. + client_num (int, optional): Number of clients. Default is 0. + args (Namespace, optional): Command-line arguments. + bind_port (int, optional): Port to bind. Default is 0. + """ self.mqtt_pwd = None self.mqtt_user = None self.broker_port = None @@ -39,7 +112,8 @@ def __init__( self.s3_storage = S3MNNStorage(s3_config_path) self.client_real_ids = [] logging.info( - "MqttS3CommManager args client_id_list: " + str(args.client_id_list) + "MqttS3CommManager args client_id_list: " + + str(args.client_id_list) ) if args.client_id_list is not None: self.client_real_ids = json.loads(args.client_id_list) @@ -70,7 +144,8 @@ def __init__( self.edge_id = 0 self._observers: List[Observer] = [] - self._client_id = "FedML_CS_{}_{}_{}".format(str(args.run_id), str(self.edge_id), str(uuid.uuid4())) + self._client_id = "FedML_CS_{}_{}_{}".format( + str(args.run_id), str(self.edge_id), str(uuid.uuid4())) self.client_num = client_num logging.info("mqtt_s3.init: client_num = %d" % client_num) @@ -83,7 +158,8 @@ def __init__( if args.rank == 0: self.top_active_msg = CommunicationConstants.SERVER_TOP_ACTIVE_MSG self.topic_last_will_msg = CommunicationConstants.SERVER_TOP_LAST_WILL_MSG - self.last_will_msg = json.dumps({"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) + self.last_will_msg = json.dumps( + {"ID": self.edge_id, "status": CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE}) self.mqtt_mgr = MqttManager( config_path["BROKER_HOST"], config_path["BROKER_PORT"], @@ -99,30 +175,47 @@ def __init__( self.mqtt_mgr.connect() def run_loop_forever(self): + """ + Run the MQTT loop forever to handle incoming messages. + """ self.mqtt_mgr.loop_forever() def __del__(self): + """ + Destructor to stop the MQTT loop and disconnect from the broker. + """ self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() @property def client_id(self): + """ + Get the client ID. + + Returns: + str: The client ID. + """ return self._client_id @property def topic(self): + """ + Get the MQTT topic. + + Returns: + str: The MQTT topic. + """ return self._topic def on_connected(self, mqtt_client_object): """ - [server] - sending message topic (publish): serverID_clientID - receiving message topic (subscribe): clientID + MQTT on_connected callback. - [client] - sending message topic (publish): clientID - receiving message topic (subscribe): serverID_clientID + This method is called when the MQTT client is connected to the broker. It handles + subscription to topics based on whether the current instance is a server or client. + Args: + mqtt_client_object: The MQTT client object. """ self.mqtt_mgr.add_message_passthrough_listener(self._on_message) @@ -132,7 +225,8 @@ def on_connected(self, mqtt_client_object): self.subscribe_client_status_message() for client_ID in range(1, self.client_num + 1): - real_topic = self._topic + str(self.client_real_ids[client_ID - 1]) + real_topic = self._topic + \ + str(self.client_real_ids[client_ID - 1]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) logging.info( @@ -143,7 +237,8 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() else: # client - real_topic = self._topic + str(self.server_id) + "_" + str(self.client_real_ids[0]) + real_topic = self._topic + \ + str(self.server_id) + "_" + str(self.client_real_ids[0]) result, mid = mqtt_client_object.subscribe(real_topic, qos=2) logging.info( @@ -153,15 +248,42 @@ def on_connected(self, mqtt_client_object): self._notify_connection_ready() def on_disconnected(self, mqtt_client_object): + """ + MQTT on_connected callback. + + This method is called when the MQTT client is connected to the broker. It handles + subscription to topics based on whether the current instance is a server or client. + + Args: + mqtt_client_object: The MQTT client object. + """ pass def add_observer(self, observer: Observer): + """ + Add an observer to receive messages. + + Args: + observer (Observer): The observer to add. + """ self._observers.append(observer) def remove_observer(self, observer: Observer): + """ + Remove an observer. + + Args: + observer (Observer): The observer to remove. + """ self._observers.remove(observer) def _notify(self, msg_obj): + """ + Notify observers with a message object. + + Args: + msg_obj: The message object to notify observers with. + """ msg_params = Message() msg_params.init_from_json_object(msg_obj) msg_type = msg_params.get_type() @@ -170,6 +292,15 @@ def _notify(self, msg_obj): observer.receive_message(msg_type, msg_params) def _on_message_impl(self, msg): + """ + Handle incoming MQTT messages. + + This method processes incoming MQTT messages, including downloading model files from S3 + if needed. + + Args: + msg: The incoming MQTT message. + """ json_payload = str(msg.payload, encoding="utf-8") payload_obj = json.loads(json_payload) logging.info("mqtt_s3.on_message: payload_obj %s" % payload_obj) @@ -182,7 +313,8 @@ def _on_message_impl(self, msg): model_file_path = self.args.model_file_cache_folder + "/" + s3_key_str self.s3_storage.download_model_file(s3_key_str, model_file_path) - logging.info("mqtt_s3.on_message: downloaded model file {}".format(model_file_path)) + logging.info( + "mqtt_s3.on_message: downloaded model file {}".format(model_file_path)) # replace the S3 object key with raw model params payload_obj[Message.MSG_ARG_KEY_MODEL_PARAMS] = model_file_path @@ -193,22 +325,30 @@ def _on_message_impl(self, msg): self._notify(payload_obj) def _on_message(self, msg): + """ + Wrapper for handling incoming MQTT messages. + + This method wraps the _on_message_impl method and handles exceptions. + + Args: + msg: The incoming MQTT message. + """ try: self._on_message_impl(msg) except Exception as e: - logging.error("mqtt_s3.on_message exception: {}".format(traceback.format_exc())) + logging.error("mqtt_s3.on_message exception: {}".format( + traceback.format_exc())) def send_message(self, msg: Message): """ - [server] - sending message topic (publish): fedml_runid_serverID_clientID - receiving message topic (subscribe): fedml_runid_clientID + Send a message using MQTT. - [client] - sending message topic (publish): fedml_runid_clientID - receiving message topic (subscribe): fedml_runid_serverID_clientID + This method sends a message using MQTT, including handling S3 storage if required. + Args: + msg (Message): The message to send. """ + if self.rank == 0: # server receiver_id = msg.get_receiver_id() @@ -218,14 +358,16 @@ def send_message(self, msg: Message): logging.info("mqtt_s3.send_message: msg topic = %s" % str(topic)) payload = msg.get_params() - model_params_obj = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS, "") + model_params_obj = payload.get( + Message.MSG_ARG_KEY_MODEL_PARAMS, "") model_url = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_URL, "") model_key = payload.get(Message.MSG_ARG_KEY_MODEL_PARAMS_KEY, "") if model_params_obj != "": # S3 if model_url == "": model_key = topic + "_" + str(uuid.uuid4()) - model_url = self.s3_storage.upload_model_file(model_key, model_params_obj) + model_url = self.s3_storage.upload_model_file( + model_key, model_params_obj) logging.info( "mqtt_s3.send_message: S3+MQTT msg sent, s3 message key = %s" @@ -244,17 +386,36 @@ def send_message(self, msg: Message): raise Exception("This is only used for the server") def send_message_json(self, topic_name, json_message): + """ + Send a JSON message using MQTT. + + Args: + topic_name (str): The topic to send the message to. + json_message: The JSON message to send. + """ self.mqtt_mgr.send_message_json(topic_name, json_message) def handle_receive_message(self): + """ + Start handling received messages by running the MQTT loop. + """ self.run_loop_forever() def stop_receive_message(self): + """ + Stop receiving messages and disconnect from MQTT. + """ logging.info("mqtt_s3.stop_receive_message: stopping...") self.mqtt_mgr.loop_stop() self.mqtt_mgr.disconnect() def set_config_from_file(self, config_file_path): + """ + Load MQTT configuration from a file. + + Args: + config_file_path (str): Path to the configuration file. + """ try: with open(config_file_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) @@ -270,6 +431,12 @@ def set_config_from_file(self, config_file_path): pass def set_config_from_objects(self, mqtt_config): + """ + Set MQTT configuration from objects. + + Args: + mqtt_config: MQTT configuration object. + """ self.broker_host = mqtt_config["BROKER_HOST"] self.broker_port = mqtt_config["BROKER_PORT"] self.mqtt_user = None @@ -280,27 +447,49 @@ def set_config_from_objects(self, mqtt_config): self.mqtt_pwd = mqtt_config["MQTT_PWD"] def _notify_connection_ready(self): + """ + Notify observers that the connection is ready. + """ msg_params = Message() msg_type = CommunicationConstants.MSG_TYPE_CONNECTION_IS_READY for observer in self._observers: observer.receive_message(msg_type, msg_params) def callback_client_last_will_msg(self, topic, payload): + """ + Callback for client last will message. + + Args: + topic (str): MQTT topic. + payload: The payload of the message. + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) if edge_id is not None and status == CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE: if self.client_active_list.get(edge_id, None) is not None: self.client_active_list.pop(edge_id) def callback_client_active_msg(self, topic, payload): + """ + Callback for client active message. + + Args: + topic (str): MQTT topic. + payload: The payload of the message. + """ msg = json.loads(payload) edge_id = msg.get("ID", None) - status = msg.get("status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) + status = msg.get( + "status", CommunicationConstants.MSG_CLIENT_STATUS_IDLE) if edge_id is not None: self.client_active_list[edge_id] = status def subscribe_client_status_message(self): + """ + Subscribe to client status messages. + """ # Setup MQTT message listener to the last will message form the client. self.mqtt_mgr.add_message_listener(CommunicationConstants.CLIENT_TOP_LAST_WILL_MSG, self.callback_client_last_will_msg) @@ -310,7 +499,22 @@ def subscribe_client_status_message(self): self.callback_client_active_msg) def get_client_status(self, client_id): + """ + Get the status of a specific client. + + Args: + client_id: The ID of the client. + + Returns: + str: The status of the client. + """ return self.client_active_list.get(client_id, CommunicationConstants.MSG_CLIENT_STATUS_OFFLINE) def get_client_list_status(self): - return self.client_active_list \ No newline at end of file + """ + Get the status of all clients. + + Returns: + dict: A dictionary containing the status of all clients. + """ + return self.client_active_list From feaaab92af0bccaa64361581be7f4bfbd313a942 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Tue, 26 Sep 2023 22:03:23 +0530 Subject: [PATCH 69/70] Update mqtt_manager.py --- .../communication/mqtt/mqtt_manager.py | 188 ++++++++++++++++++ 1 file changed, 188 insertions(+) diff --git a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py index 18c1064165..c4e1492e95 100644 --- a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py +++ b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py @@ -164,6 +164,15 @@ def send_message_json(self, topic, message, publish_single_message=False): return True def on_connect(self, client, userdata, flags, rc): + """ + Callback function for the MQTT on_connect event. + + Args: + client: The MQTT client instance. + userdata: User data. + flags: Connection flags. + rc: Return code from the MQTT broker. + """ if rc == 0: client.connected_flag = True client.bad_conn_flag = False @@ -200,16 +209,43 @@ def on_connect(self, client, userdata, flags, rc): self.mqtt_connection_id, rc)) def is_connected(self): + """ + Check if the MQTT client is connected. + + Returns: + bool: True if the client is connected, False otherwise. + """ return self._client.is_connected() def subscribe_will_set_msg(self, client): + """ + Subscribe to the last will message topic and set a callback. + + Args: + client: The MQTT client instance. + """ self.add_message_listener(self.last_will_topic, self.callback_will_set_msg) client.subscribe(self.last_will_topic, qos=2) def callback_will_set_msg(self, topic, payload): + """ + Callback function for handling the last will message. + + Args: + topic (str): The MQTT topic. + payload (str): The message payload. + """ logging.info(f"MQTT client will be disconnected, id: {self._client_id}, topic: {topic}, payload: {payload}") def on_message(self, client, userdata, msg): + """ + Callback function for the MQTT on_message event. + + Args: + client: The MQTT client instance. + userdata: User data. + msg: The received MQTT message. + """ # logging.info("on_message: msg.topic {}, msg.retain {}".format(msg.topic, msg.retain)) if msg.retain: @@ -230,98 +266,250 @@ def on_message(self, client, userdata, msg): MLOpsProfilerEvent.log_to_wandb({"BusyTime": time.time() - message_handler_start_time}) def on_publish(self, client, obj, mid): + """ + Callback function for the MQTT on_publish event. + + Args: + client: The MQTT client instance. + obj: Object. + mid: Message ID. + """ self.callback_published_listener(client) def on_disconnect(self, client, userdata, rc): + """ + Callback function for the MQTT on_disconnect event. + + Args: + client: The MQTT client instance. + userdata: User data. + rc: Return code from the MQTT broker. + """ client.connected_flag = False client.bad_conn_flag = True self.callback_disconnected_listener(client) def _on_subscribe(self, client, userdata, mid, granted_qos): + """ + Callback function for the MQTT on_subscribe event. + + Args: + client: The MQTT client instance. + userdata: User data. + mid: Message ID. + granted_qos: Granted QoS levels. + """ self.callback_subscribed_listener(client) def _on_log(self, client, userdata, level, buf): + """ + Callback function for MQTT logging. + + Args: + client: The MQTT client instance. + userdata: User data. + level: Logging level. + buf: Log message buffer. + """ logging.info("mqtt log {}, client id {}.".format(buf, self.mqtt_connection_id)) def add_message_listener(self, topic, listener): + """ + Add a message listener to handle messages received on a specific topic. + + Args: + topic (str): The MQTT topic to listen to. + listener (callable): The callback function to handle the received messages. + """ self._listeners[topic] = listener def remove_message_listener(self, topic): + """ + Remove a message listener for a specific topic. + + Args: + topic (str): The MQTT topic to remove the listener from. + """ try: del self._listeners[topic] except Exception as e: pass def add_message_passthrough_listener(self, listener): + """ + Add a message passthrough listener to handle all incoming messages. + + Args: + listener (callable): The callback function to handle incoming messages. + """ + # if not callable(listener): + # raise Exception("listener must be callable!") + # self.__message_passthrough_listener = listener self.remove_message_passthrough_listener(listener) self._passthrough_listeners.append(listener) def remove_message_passthrough_listener(self, listener): + """ + Remove a message passthrough listener. + + Args: + listener (callable): The passthrough listener to remove. + """ + # if hasattr(self,'__message_passthrough_listener') and \ + # self.__message_passthrough_listener is not None: + # self.__message_passthrough_listener = None + # if isinstance(listener,(list)): + # for l in listener: + # self._passthrough_listeners.remove(l) + # else: + # self._passthrough_listeners.remove(listener) try: self._passthrough_listeners.remove(listener) except Exception as e: pass def add_connected_listener(self, listener): + """ + Add a listener to handle the MQTT client's connection event. + + Args: + listener (callable): The callback function to handle the connection event. + """ self._connected_listeners.append(listener) def remove_connected_listener(self, listener): + """ + Remove a connected listener. + + Args: + listener (callable): The connected listener to remove. + """ try: self._connected_listeners.remove(listener) except Exception as e: pass def callback_connected_listener(self, client): + """ + Callback function for handling connected listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._connected_listeners: if listener is not None and callable(listener): listener(client) def add_disconnected_listener(self, listener): + """ + Add a listener to handle the MQTT client's disconnection event. + + Args: + listener (callable): The callback function to handle the disconnection event. + """ self._disconnected_listeners.append(listener) def remove_disconnected_listener(self, listener): + """ + Remove a disconnected listener. + + Args: + listener (callable): The disconnected listener to remove. + """ try: self._disconnected_listeners.remove(listener) except Exception as e: pass def callback_disconnected_listener(self, client): + """ + Callback function for handling disconnected listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._disconnected_listeners: if listener is not None and callable(listener): listener(client) def add_subscribed_listener(self, listener): + """ + Add a listener to handle the MQTT client's subscription event. + + Args: + listener (callable): The callback function to handle the subscription event. + """ self._subscribed_listeners.append(listener) def remove_subscribed_listener(self, listener): + """ + Remove a subscribed listener. + + Args: + listener (callable): The subscribed listener to remove. + """ try: self._subscribed_listeners.remove(listener) except Exception as e: pass def callback_subscribed_listener(self, client): + """ + Callback function for handling subscribed listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._subscribed_listeners: if listener is not None and callable(listener): listener(client) def add_published_listener(self, listener): + """ + Add a listener to handle the MQTT client's message publishing event. + + Args: + listener (callable): The callback function to handle the publishing event. + """ self._published_listeners.append(listener) def remove_published_listener(self, listener): + """ + Remove a published listener. + + Args: + listener (callable): The published listener to remove. + """ try: self._published_listeners.remove(listener) except Exception as e: pass def callback_published_listener(self, client): + """ + Callback function for handling published listeners. + + Args: + client: The MQTT client instance. + """ for listener in self._published_listeners: if listener is not None and callable(listener): listener(client) def subscribe_msg(self, topic): + """ + Subscribe to an MQTT topic with a QoS level of 2. + + Args: + topic (str): The MQTT topic to subscribe to. + """ self._client.subscribe(topic, qos=2) def check_connection(self): + """ + Check the MQTT client's connection status and wait for a connection if not connected. + Raises an exception if the connection fails. + """ count = 0 while not self._client.connected_flag and self._client.bad_conn_flag: if count >= 30: From b96cc6708f8176bde014f416946687d6e0cb1585 Mon Sep 17 00:00:00 2001 From: rajveer43 Date: Wed, 27 Sep 2023 12:01:31 +0530 Subject: [PATCH 70/70] Update mqtt_manager.py --- .../distributed/communication/mqtt/mqtt_manager.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py index c4e1492e95..d092389bd4 100644 --- a/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py +++ b/python/fedml/core/distributed/communication/mqtt/mqtt_manager.py @@ -530,12 +530,26 @@ def check_connection(self): def test_msg_callback(topic, payload): + """ + Callback function to handle received MQTT messages for testing purposes. + + Args: + topic (str): The MQTT topic on which the message was received. + payload (str): The payload of the received message. + """ global received_msg_count received_msg_count += 1 logging.info("Received the topic: {}, message: {}, count {}.".format(topic, payload, received_msg_count)) def test_last_will_callback(topic, payload): + """ + Callback function to handle last will messages received for testing purposes. + + Args: + topic (str): The MQTT topic on which the last will message was received. + payload (str): The payload of the received last will message. + """ logging.info("Received the topic: {}, message: {}.".format(topic, payload))