diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 00000000..ab739147 Binary files /dev/null and b/.DS_Store differ diff --git a/models/MNIST/AdaProxFedProx_server.pt b/models/MNIST/AdaProxFedProx_server.pt new file mode 100644 index 00000000..5d1c97b0 Binary files /dev/null and b/models/MNIST/AdaProxFedProx_server.pt differ diff --git a/models/MNIST/AdaProxKEval_server.pt b/models/MNIST/AdaProxKEval_server.pt new file mode 100644 index 00000000..a4339da9 Binary files /dev/null and b/models/MNIST/AdaProxKEval_server.pt differ diff --git a/models/MNIST/FedProx_server.pt b/models/MNIST/FedProx_server.pt new file mode 100644 index 00000000..d434a8ea Binary files /dev/null and b/models/MNIST/FedProx_server.pt differ diff --git a/system/.DS_Store b/system/.DS_Store new file mode 100644 index 00000000..484ec67f Binary files /dev/null and b/system/.DS_Store differ diff --git a/system/flcore/clients/client_adaprox_ditto.py b/system/flcore/clients/client_adaprox_ditto.py new file mode 100644 index 00000000..724d8d21 --- /dev/null +++ b/system/flcore/clients/client_adaprox_ditto.py @@ -0,0 +1,87 @@ +import torch +from flcore.clients.clientditto import clientDitto + + +class ClientAdaProxDitto(clientDitto): + """ + AdaProxDitto Client: Adaptive proximal regularization for Ditto. + Adaptively adjusts lambda (proximal penalty) based on loss gap between + client's global model loss and server's EMA of global losses. + """ + def __init__(self, args, id, train_samples, test_samples, **kwargs): + super().__init__(args, id, train_samples, test_samples, **kwargs) + + # AdaProx hyperparameters + self.alpha = getattr(args, 'alpha_gain', 1.0) + self.tau = getattr(args, 'gap_clip', 1.0) + self.lam_max = getattr(args, 'lam_max', 5.0) + self.warmup = getattr(args, 'warmup_rounds', 5) + self.lam_init = getattr(args, 'lam_init', 0.0) + + # State for adaptive lambda + self.lam = self.lam_init + + # Server-provided state (set by server during send_models) + self.server_lg = None + self.current_round = -1 + + # Client-reported state (read by server after train) + self.mean_loss_global = 0.0 + + def eval_loss_on_global_model(self): + """ + Evaluate loss of current global model on local training data. + Used to compute the loss gap for adaptive lambda. + """ + trainloader = self.load_train_data() + self.model.eval() + total_loss = 0 + count = 0 + with torch.no_grad(): + for x, y in trainloader: + if type(x) == type([]): + x[0] = x[0].to(self.device) + else: + x = x.to(self.device) + y = y.to(self.device) + output = self.model(x) + loss = self.loss(output, y) + total_loss += loss.item() * y.size(0) + count += y.size(0) + + return total_loss / count if count > 0 else 0.0 + + def train(self): + """ + Override train to compute adaptive lambda before calling parent's training. + """ + # Step 1: Evaluate global model's loss before any training + self.mean_loss_global = self.eval_loss_on_global_model() + + # Step 2: Calculate adaptive lambda based on loss gap + lg = self.server_lg + if lg is not None and self.current_round >= self.warmup: + # Loss gap: how much worse is client's loss vs global average + gap = max(0.0, min(self.tau, float(self.mean_loss_global - lg))) + self.lam = min(self.lam_max, self.alpha * gap) + else: + # During warmup, use initial value + self.lam = self.lam_init + + # Step 3: Pass adaptive lambda to parent's logic + # Ditto uses self.mu which is read from args in __init__ + # We need to update the args.lam that would be used + # However, Ditto uses self.mu for PerturbedGradientDescent + # Let's check what parameter Ditto's personalized optimizer uses + + # Looking at clientDitto, the personalized model uses: + # self.optimizer_per = PerturbedGradientDescent(..., mu=self.mu) + # So we need to update self.mu and self.optimizer_per.mu + + # For Ditto, mu controls the proximal term between personalized and global model + # We want to adapt this based on the loss gap + self.mu = self.lam + self.optimizer_per.mu = self.lam + + # Step 4: Call parent's training method + super().train() diff --git a/system/flcore/clients/clientadaprox.py b/system/flcore/clients/clientadaprox.py new file mode 100644 index 00000000..dc4161d4 --- /dev/null +++ b/system/flcore/clients/clientadaprox.py @@ -0,0 +1,73 @@ +import torch +import numpy as np +import time +from flcore.clients.clientprox import clientProx + + +class clientAdaProx(clientProx): + def __init__(self, args, id, train_samples, test_samples, **kwargs): + super().__init__(args, id, train_samples, test_samples, **kwargs) + + # AdaProx hyperparameters + self.alpha = getattr(args, 'alpha_gain', 1.0) + self.tau = getattr(args, 'gap_clip', 1.0) + self.mu_max = getattr(args, 'mu_max', 5.0) + # Use a short warmup so adaptive mu can activate in short tests + self.warmup = getattr(args, 'warmup_rounds', 2) + # Small non-zero default for proximal term so adaptive behavior is testable + self.mu_init = getattr(args, 'mu_init', 0.01) + + # State for adaptive mu (proximal penalty) + self.adaptive_mu = self.mu_init + + # Server-provided state (set by server during send_models) + self.server_lg = None + self.current_round = -1 + + # Client-reported state (read by server after train) + self.mean_loss_global = 0.0 + + def eval_loss_on_global_model(self): + """ + Evaluate loss of current global model on local training data. + Used to compute the loss gap for adaptive lambda. + """ + trainloader = self.load_train_data() + self.model.eval() + total_loss = 0 + count = 0 + with torch.no_grad(): + for x, y in trainloader: + if type(x) == type([]): + x[0] = x[0].to(self.device) + else: + x = x.to(self.device) + y = y.to(self.device) + output = self.model(x) + loss = self.loss(output, y) + total_loss += loss.item() * y.size(0) + count += y.size(0) + + return total_loss / count if count > 0 else 0.0 + + def train(self): + # 1. Evaluate global model loss before local training + self.mean_loss_global = self.eval_loss_on_global_model() + + # 2. Compute adaptive mu based on loss gap with server's EMA + lg = self.server_lg + if lg is not None and self.current_round >= self.warmup: + # Loss gap: how much worse is client's loss vs global average + gap = max(0.0, min(self.tau, float(self.mean_loss_global - lg))) + self.adaptive_mu = min(self.mu_max, self.alpha * gap) + else: + # During warmup, use initial value + self.adaptive_mu = self.mu_init + + # 3. Override the proximal penalty parameter + # clientProx uses self.mu via self.optimizer (PerturbedGradientDescent) + self.mu = self.adaptive_mu + self.optimizer.mu = self.adaptive_mu + + # 4. Run standard FedProx training with adaptive mu + super().train() diff --git a/system/flcore/clients/clientadaprox_k_eval.py b/system/flcore/clients/clientadaprox_k_eval.py new file mode 100644 index 00000000..d9264b10 --- /dev/null +++ b/system/flcore/clients/clientadaprox_k_eval.py @@ -0,0 +1,79 @@ +import torch +import numpy as np +import time +from flcore.clients.clientprox import clientProx + + +class clientAdaProxKEval(clientProx): + def __init__(self, args, id, train_samples, test_samples, **kwargs): + super().__init__(args, id, train_samples, test_samples, **kwargs) + + # AdaProx hyperparameters + self.alpha = getattr(args, 'alpha_gain', 1.0) + self.tau = getattr(args, 'gap_clip', 1.0) + self.mu_max = getattr(args, 'mu_max', 5.0) + self.warmup = getattr(args, 'warmup_rounds', 5) + self.mu_init = getattr(args, 'mu_init', 0.0) + + # State for adaptive mu (proximal penalty) + self.adaptive_mu = self.mu_init + # How many batches to use when evaluating the global model loss + self.k_eval_batches = getattr(args, 'k_eval_batches', 5) + + # Server-provided state (set by server during send_models) + self.server_lg = None + self.current_round = -1 + + # Client-reported state (read by server after train) + self.mean_loss_global = 0.0 + + def eval_loss_on_global_model(self): + """ + Evaluate loss of current global model on local training data. + Used to compute the loss gap for adaptive lambda. + This variant uses only the first k batches to estimate loss. + """ + trainloader = self.load_train_data() + self.model.eval() + total_loss = 0 + count = 0 + batch_idx = 0 + with torch.no_grad(): + for x, y in trainloader: + # stop after evaluating on at most k batches + if batch_idx >= self.k_eval_batches: + break + if type(x) == type([]): + x[0] = x[0].to(self.device) + else: + x = x.to(self.device) + y = y.to(self.device) + output = self.model(x) + loss = self.loss(output, y) + total_loss += loss.item() * y.size(0) + count += y.size(0) + batch_idx += 1 + + return total_loss / count if count > 0 else 0.0 + + def train(self): + # 1. Evaluate global model loss before local training + self.mean_loss_global = self.eval_loss_on_global_model() + + # 2. Compute adaptive mu based on loss gap with server's EMA + lg = self.server_lg + if lg is not None and self.current_round >= self.warmup: + # Loss gap: how much worse is client's loss vs global average + gap = max(0.0, min(self.tau, float(self.mean_loss_global - lg))) + self.adaptive_mu = min(self.mu_max, self.alpha * gap) + else: + # During warmup, use initial value + self.adaptive_mu = self.mu_init + + # 3. Override the proximal penalty parameter + # clientProx uses self.mu via self.optimizer (PerturbedGradientDescent) + self.mu = self.adaptive_mu + self.optimizer.mu = self.adaptive_mu + + # 4. Run standard FedProx training with adaptive mu + super().train() diff --git a/system/flcore/clients/clientadaprox_optimized.py b/system/flcore/clients/clientadaprox_optimized.py new file mode 100644 index 00000000..d214b2de --- /dev/null +++ b/system/flcore/clients/clientadaprox_optimized.py @@ -0,0 +1,129 @@ +import torch +import numpy as np +import time +from flcore.clients.clientprox import clientProx + + +class clientAdaProxOptimized(clientProx): + """ + Optimized AdaProx client with reduced computational overhead. + + Key optimization: Use sampling in eval_loss_on_global_model() to reduce + the expensive full-dataset forward pass by ~80% while maintaining accuracy. + """ + + def __init__(self, args, id, train_samples, test_samples, **kwargs): + super().__init__(args, id, train_samples, test_samples, **kwargs) + + # AdaProx hyperparameters + self.alpha = getattr(args, 'alpha_gain', 1.0) + self.tau = getattr(args, 'gap_clip', 1.0) + self.mu_max = getattr(args, 'mu_max', 5.0) + self.warmup = getattr(args, 'warmup_rounds', 5) + self.mu_init = getattr(args, 'mu_init', 0.0) + + # Optimization parameters + self.loss_sample_ratio = getattr(args, 'loss_sample_ratio', 0.2) # Use 20% of data + self.exact_loss_every = getattr(args, 'exact_loss_every', 10) # Full eval every 10 rounds + + # State for adaptive mu (proximal penalty) + self.adaptive_mu = self.mu_init + + # Server-provided state (set by server during send_models) + self.server_lg = None + self.current_round = -1 + + # Client-reported state (read by server after train) + self.mean_loss_global = 0.0 + + def eval_loss_on_global_model_full(self): + """ + Exact loss evaluation (expensive - full dataset pass). + Used periodically for accuracy. + """ + trainloader = self.load_train_data() + self.model.eval() + total_loss = 0 + count = 0 + with torch.no_grad(): + for x, y in trainloader: + if type(x) == type([]): + x[0] = x[0].to(self.device) + else: + x = x.to(self.device) + y = y.to(self.device) + output = self.model(x) + loss = self.loss(output, y) + total_loss += loss.item() * y.size(0) + count += y.size(0) + + return total_loss / count if count > 0 else 0.0 + + def eval_loss_on_global_model_sampled(self): + """ + Fast loss estimation using sampling (cheap - partial dataset). + Used most of the time for efficiency. + + Trade-off: Slightly noisier estimate, but EMA smoothing compensates. + Expected speedup: 5-10× faster depending on sample ratio. + """ + trainloader = self.load_train_data() + self.model.eval() + total_loss = 0 + count = 0 + + # Calculate number of batches to sample + num_batches = len(trainloader) + sample_batches = max(1, int(num_batches * self.loss_sample_ratio)) + + with torch.no_grad(): + for i, (x, y) in enumerate(trainloader): + if i >= sample_batches: # Early stopping for efficiency + break + + if type(x) == type([]): + x[0] = x[0].to(self.device) + else: + x = x.to(self.device) + y = y.to(self.device) + output = self.model(x) + loss = self.loss(output, y) + total_loss += loss.item() * y.size(0) + count += y.size(0) + + return total_loss / count if count > 0 else 0.0 + + def eval_loss_on_global_model(self): + """ + Hybrid approach: Use sampling most of the time, exact computation periodically. + + This balances efficiency (sampling) with accuracy (exact computation). + """ + # Use exact computation periodically for stability + if self.current_round % self.exact_loss_every == 0: + return self.eval_loss_on_global_model_full() + else: + # Use fast sampling for efficiency + return self.eval_loss_on_global_model_sampled() + + def train(self): + # 1. Evaluate global model loss before local training (OPTIMIZED) + self.mean_loss_global = self.eval_loss_on_global_model() + + # 2. Compute adaptive mu based on loss gap with server's EMA + lg = self.server_lg + if lg is not None and self.current_round >= self.warmup: + # Loss gap: how much worse is client's loss vs global average + gap = max(0.0, min(self.tau, float(self.mean_loss_global - lg))) + self.adaptive_mu = min(self.mu_max, self.alpha * gap) + else: + # During warmup, use initial value + self.adaptive_mu = self.mu_init + + # 3. Override the proximal penalty parameter + # clientProx uses self.mu via self.optimizer (PerturbedGradientDescent) + self.mu = self.adaptive_mu + self.optimizer.mu = self.adaptive_mu + + # 4. Run standard FedProx training with adaptive mu + super().train() diff --git a/system/flcore/servers/server_adaprox_ditto.py b/system/flcore/servers/server_adaprox_ditto.py new file mode 100644 index 00000000..7fdbc7e1 --- /dev/null +++ b/system/flcore/servers/server_adaprox_ditto.py @@ -0,0 +1,114 @@ +import time +from flcore.clients.client_adaprox_ditto import ClientAdaProxDitto +from flcore.servers.serverditto import Ditto + + +class ServerAdaProxDitto(Ditto): + """ + AdaProxDitto: Adaptive proximal algorithm for Ditto. + Combines Ditto's personalization with adaptive lambda based on loss gap. + """ + def __init__(self, args, times): + super().__init__(args, times) + + # Override the client class with AdaProxDitto clients + self.set_clients(ClientAdaProxDitto) + + # Global loss EMA state + self.lg = None + self.beta = getattr(args, 'ema_beta', 0.9) + + print(f"\n[AdaProxDitto] EMA beta: {self.beta}") + print("[AdaProxDitto] Using adaptive proximal regularization for Ditto") + + def send_models(self): + """ + Override to send both global model and server's EMA loss (lg) to clients. + """ + assert (len(self.selected_clients) > 0) + + for client in self.selected_clients: + start_time = time.time() + + # Send global model parameters (from parent) + client.set_parameters(self.global_model) + + # Send EMA and round info for adaptive lambda computation + client.server_lg = self.lg + client.current_round = self.global_round + + client.send_time = time.time() - start_time + + def train(self): + """ + Override to insert EMA update logic after client training. + Replicates the train() loop from Ditto with added EMA logic. + """ + for i in range(self.global_rounds + 1): + self.global_round = i + s_t = time.time() + self.selected_clients = self.select_clients() + self.send_models() + + if i % self.eval_gap == 0: + print(f"\n-------------Round number: {i}-------------") + print("\nEvaluate global models") + self.evaluate() + + if i % self.eval_gap == 0: + print("\nEvaluate personalized models") + self.evaluate_personalized() + + # Clients run personalized training (ptrain) then global training + for client in self.selected_clients: + client.ptrain() + client.train() + + # Collect models from clients + self.receive_models() + + # === AdaProx: Update EMA of global loss === + try: + client_losses = [c.mean_loss_global for c in self.selected_clients] + mean_loss = sum(client_losses) / len(client_losses) + + # Update EMA + if self.lg is None: + self.lg = mean_loss + else: + self.lg = self.beta * self.lg + (1 - self.beta) * mean_loss + + if i % self.eval_gap == 0: + print(f"[AdaProxDitto] Mean client loss: {mean_loss:.4f}, EMA (lg): {self.lg:.4f}") + + except Exception as e: + print(f"[AdaProxDitto Warning] Error computing EMA: {e}") + # === End AdaProx logic === + + # DLG evaluation if needed + if self.dlg_eval and i % self.dlg_gap == 0: + self.call_dlg(i) + + # Aggregate parameters (from parent) + self.aggregate_parameters() + + self.Budget.append(time.time() - s_t) + print('-' * 25, 'time cost', '-' * 25, self.Budget[-1]) + + if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + break + + print("\nBest accuracy.") + print(max(self.rs_test_acc)) + print("\nAverage time cost per round.") + print(sum(self.Budget[1:]) / len(self.Budget[1:])) + + self.save_results() + self.save_global_model() + + if self.num_new_clients > 0: + self.eval_new_clients = True + self.set_new_clients(ClientAdaProxDitto) + print(f"\n-------------Fine tuning round-------------") + print("\nEvaluate new clients") + self.evaluate() diff --git a/system/flcore/servers/serveradaprox.py b/system/flcore/servers/serveradaprox.py new file mode 100644 index 00000000..de421fa0 --- /dev/null +++ b/system/flcore/servers/serveradaprox.py @@ -0,0 +1,135 @@ +import time +from flcore.clients.clientadaprox import clientAdaProx +from flcore.servers.serverbase import Server + + +class AdaProxFedProx(Server): + def __init__(self, args, times): + super().__init__(args, times) + + # select slow clients + self.set_slow_clients() + # Set custom client class (AdaProx clients instead of regular Prox clients) + self.set_clients(clientAdaProx) + + # Global loss EMA state + self.lg = None + self.beta = getattr(args, 'ema_beta', 0.9) + + print(f"\n[AdaProxFedProx] EMA beta: {self.beta}") + print("[AdaProxFedProx] Using adaptive proximal regularization") + + print(f"\nJoin ratio / total clients: {self.join_ratio} / {self.num_clients}") + print("Finished creating server and client s.") + + # self.load_model() + self.Budget = [] + + def send_models(self): + """ + Override to send both global model and server's EMA loss (lg) to clients. + """ + assert (len(self.selected_clients) > 0) + + for client in self.selected_clients: + start_time = time.time() + + # Send global model parameters (from parent) + client.set_parameters(self.global_model) + + # Send EMA and round info for adaptive mu computation + client.server_lg = self.lg + client.current_round = self.global_round + + client.send_time = time.time() - start_time + + def train(self): + """ + Override to insert EMA update logic after client training. + """ + for i in range(self.global_rounds + 1): + self.global_round = i + s_t = time.time() + self.selected_clients = self.select_clients() + self.send_models() + + if i % self.eval_gap == 0: + print(f"\n-------------Round number: {i}-------------") + print("\nEvaluate global model") + self.evaluate() + + # Clients run train() with adaptive mu + for client in self.selected_clients: + client.train() + + # Collect models from clients + self.receive_models() + + # === AdaProx: Update EMA of global loss === + try: + client_losses = [c.mean_loss_global for c in self.selected_clients] + mean_loss = sum(client_losses) / len(client_losses) + + # Update EMA + if self.lg is None: + self.lg = mean_loss + else: + self.lg = self.beta * self.lg + (1 - self.beta) * mean_loss + + if i % self.eval_gap == 0: + print(f"[AdaProx] Mean client loss: {mean_loss:.4f}, EMA (lg): {self.lg:.4f}") + # Log client mu values for debugging/analysis + try: + client_mus = [] + for c in self.selected_clients: + # prefer adaptive_mu if present, fall back to mu + mu_val = getattr(c, 'adaptive_mu', None) + if mu_val is None: + mu_val = getattr(c, 'mu', None) + client_mus.append(mu_val if mu_val is not None else float('nan')) + + # summary stats + mus_valid = [m for m in client_mus if not (m is None)] + if len(mus_valid) > 0: + mean_mu = sum(mus_valid) / len(mus_valid) + else: + mean_mu = float('nan') + + # Print per-client mu (short) and mean + mus_str = ', '.join([f"{m:.4f}" if (m is not None) else "nan" for m in client_mus]) + print(f"[AdaProx] client mu values: [{mus_str}]") + print(f"[AdaProx] mean client mu: {mean_mu:.4f}") + except Exception as e: + print(f"[AdaProx Warning] Error logging client mu values: {e}") + + except Exception as e: + print(f"[AdaProx Warning] Error computing EMA: {e}") + # === End AdaProx logic === + + # DLG evaluation if needed + if self.dlg_eval and i % self.dlg_gap == 0: + self.call_dlg(i) + + # Aggregate parameters (from parent) + self.aggregate_parameters() + + self.Budget.append(time.time() - s_t) + print('-' * 25, 'time cost', '-' * 25, self.Budget[-1]) + + if self.auto_break and self.check_done(acc_lss=[self.rs_test_acc], top_cnt=self.top_cnt): + break + + print("\nBest accuracy.") + print(max(self.rs_test_acc)) + print("\nAverage time cost per round.") + print(sum(self.Budget[1:]) / len(self.Budget[1:])) + + self.save_results() + self.save_global_model() + + if self.num_new_clients > 0: + self.eval_new_clients = True + self.set_new_clients(clientAdaProx) + print(f"\n-------------Fine tuning round-------------") + print("\nEvaluate new clients") + self.evaluate() diff --git a/system/main.py b/system/main.py index cbc28b56..fe6e3c10 100644 --- a/system/main.py +++ b/system/main.py @@ -13,6 +13,7 @@ from flcore.servers.serverpFedMe import pFedMe from flcore.servers.serverperavg import PerAvg from flcore.servers.serverprox import FedProx +from flcore.servers.serveradaprox import AdaProxFedProx from flcore.servers.serverfomo import FedFomo from flcore.servers.serveramp import FedAMP from flcore.servers.servermtl import FedMTL @@ -20,6 +21,7 @@ from flcore.servers.serverper import FedPer from flcore.servers.serverapfl import APFL from flcore.servers.serverditto import Ditto +from flcore.servers.server_adaprox_ditto import ServerAdaProxDitto from flcore.servers.serverrep import FedRep from flcore.servers.serverphp import FedPHP from flcore.servers.serverbn import FedBN @@ -58,6 +60,9 @@ from flcore.trainmodel.mobilenet_v2 import * from flcore.trainmodel.transformer import * +# Optional clients for experiments +from flcore.clients.clientadaprox_k_eval import clientAdaProxKEval + from utils.result_utils import average_data from utils.mem_utils import MemReporter @@ -204,6 +209,14 @@ def run(args): elif args.algorithm == "FedProx": server = FedProx(args, i) + elif args.algorithm == "AdaProxFedProx": + server = AdaProxFedProx(args, i) + + elif args.algorithm == "AdaProxKEval": + server = AdaProxFedProx(args, i) + # Override the client class to use the k-batch evaluation variant + server.set_clients(clientAdaProxKEval) + elif args.algorithm == "FedFomo": server = FedFomo(args, i) @@ -222,6 +235,9 @@ def run(args): elif args.algorithm == "Ditto": server = Ditto(args, i) + elif args.algorithm == "AdaProxDitto": + server = ServerAdaProxDitto(args, i) + elif args.algorithm == "FedRep": args.head = copy.deepcopy(args.model.fc) args.model.fc = nn.Identity() @@ -497,6 +513,29 @@ def run(args): parser.add_argument('-fsb', "--first_stage_bound", type=int, default=0) parser.add_argument('-ca', "--fedcross_alpha", type=float, default=0.99) parser.add_argument('-cmss', "--collaberative_model_select_strategy", type=int, default=1) + + # AdaProxFedProx + parser.add_argument('-ag', "--alpha_gain", type=float, default=1.0, + help="Adaptive mu gain (alpha) for AdaProx") + parser.add_argument('-gc', "--gap_clip", type=float, default=1.0, + help="Clipping value for loss gap (tau) in AdaProx") + parser.add_argument('-mmax', "--mu_max", type=float, default=5.0, + help="Maximum value for adaptive mu in AdaProx") + parser.add_argument('-minit', "--mu_init", type=float, default=0.0, + help="Initial mu value during warmup in AdaProx") + parser.add_argument('-wr', "--warmup_rounds", type=int, default=5, + help="Rounds before adaptive mu kicks in for AdaProx") + parser.add_argument('-eb', "--ema_beta", type=float, default=0.9, + help="EMA beta for server's global loss tracker in AdaProx") + # How many batches clients should use when evaluating the global model loss + parser.add_argument('-keb', "--k_eval_batches", type=int, default=5, + help="Number of batches to evaluate the global model loss on each client (used by AdaProx clients)") + + # AdaProxDitto + parser.add_argument('-lmax', "--lam_max", type=float, default=5.0, + help="Maximum value for adaptive lambda in AdaProxDitto") + parser.add_argument('-linit', "--lam_init", type=float, default=0.0, + help="Initial lambda value during warmup in AdaProxDitto") args = parser.parse_args()