From 19939ac6b5b55f394358946cd1346cc1130b5d60 Mon Sep 17 00:00:00 2001 From: fanlai0990 Date: Thu, 11 Aug 2022 18:10:34 -0500 Subject: [PATCH 01/12] Fix async --- README.md | 2 +- benchmark/configs/async_fl/async_fl.yml | 25 +++-- examples/async_fl/async_aggregator.py | 124 +++++++++++++----------- examples/async_fl/async_client.py | 64 ++++++++++++ examples/async_fl/async_executor.py | 80 ++++++++++++--- fedscale/core/aggregation/aggregator.py | 5 +- fedscale/core/execution/client.py | 3 +- fedscale/core/execution/executor.py | 12 +-- fedscale/core/resource_manager.py | 5 + fedscale/dataloaders/divide_data.py | 5 +- 10 files changed, 227 insertions(+), 98 deletions(-) create mode 100644 examples/async_fl/async_client.py diff --git a/README.md b/README.md index 69459fad..f8f03f29 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ Now that you have FedScale installed, you can start exploring FedScale following ***We are adding more datasets! Please contribute!*** -FedScale consists of 20+ large-scale, heterogeneous FL datasets covering computer vision (CV), natural language processing (NLP), and miscellaneous tasks. +FedScale consists of 20+ large-scale, heterogeneous FL datasets and 70+ various [models](./fedscale/utils/models/cv_models/README.md), covering computer vision (CV), natural language processing (NLP), and miscellaneous tasks. Each one is associated with its training, validation, and testing datasets. Please go to the `./benchmark/dataset` directory and follow the dataset [README](./benchmark/dataset/README.md) for more details. diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 6c625d7e..195b90fc 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -2,13 +2,14 @@ # ========== Cluster configuration ========== # ip address of the parameter server (need 1 GPU process) -ps_ip: localhost +ps_ip: 10.0.0.1 # ip address of each worker:# of available gpus process on each gpu in this node # Note that if we collocate ps and worker on same GPU, then we need to decrease this number of available processes on that GPU by 1 # E.g., master node has 4 available processes, then 1 for the ps, and worker should be set to: worker:3 worker_ips: - - localhost:[2] + - 10.0.0.1:[4] + - 10.0.0.2:[4] exp_path: $FEDSCALE_HOME/fedscale/core @@ -31,27 +32,23 @@ setup_commands: job_conf: - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp - log_path: $FEDSCALE_HOME/benchmark # Path of log files - - num_participants: 800 # Number of participants per round, we use K=100 in our paper, large K will be much slower + - num_participants: 20 # Number of participants per round, we use K=100 in our paper, large K will be much slower - data_set: femnist # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/femnist # Path of the dataset - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided - - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace + #- device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace + #- device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - model: shufflenet_v2_x2_0 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2 - - eval_interval: 20 # How many rounds to run a testing on the testing set - - rounds: 500 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds + - eval_interval: 5 # How many rounds to run a testing on the testing set + - rounds: 3000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples - num_loaders: 2 - local_steps: 20 - learning_rate: 0.05 - batch_size: 20 - test_bsz: 20 - - use_cuda: False + - use_cuda: True - decay_round: 50 - overcommitment: 1.0 - - async_buffer: 10 - - arrival_interval: 3 - - - - + - async_buffer: 20 + - arrival_interval: 3 \ No newline at end of file diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index db442ab4..3b97d6a4 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -28,7 +28,9 @@ def __init__(self, args): self.round_stamp = [0] self.client_model_version = {} self.virtual_client_clock = {} - self.round_lock = threading.Lock() + self.weight_tensor_type = {} + # We need to keep the test model for specific round to avoid async mismatch + self.test_model = None def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): @@ -81,11 +83,19 @@ def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): return (sampled_clients, sampled_clients, completed_client_clock, 1, completionTimes) + def save_last_param(self): + """ Save the last model parameters + """ + self.last_gradient_weights = [ + p.data.clone() for p in self.model.parameters()] + self.model_weights = copy.deepcopy(self.model.state_dict()) + self.weight_tensor_type = {p: self.model_weights[p].data.dtype \ + for p in self.model_weights} + def aggregate_client_weights(self, results): """May aggregate client updates on the fly""" """ - [FedAvg] "Communication-Efficient Learning of Deep Networks from Decentralized Data". - H. Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, Blaise Aguera y Arcas. AISTATS, 2017 + "PAPAYA: PRACTICAL, PRIVATE, AND SCALABLE FEDERATED LEARNING". MLSys, 2022 """ # Start to take the average of updates, and we do not keep updates to save memory # Importance of each update is 1/#_of_participants * staleness @@ -95,54 +105,29 @@ def aggregate_client_weights(self, results): importance = 1. / math.sqrt(1 + client_staleness) for p in results['update_weight']: + # Different to core/executor, update_weight here is (train_model_weight - untrained) param_weight = results['update_weight'][p] + if isinstance(param_weight, list): param_weight = np.asarray(param_weight, dtype=np.float32) param_weight = torch.from_numpy( param_weight).to(device=self.device) - if self.model_in_update == 1: - self.model_weights[p].data = param_weight * importance - else: + if self.model_weights[p].data.dtype in ( + torch.float, torch.double, torch.half, + torch.bfloat16, torch.chalf, torch.cfloat, torch.cdouble + ): + # Only assign importance to floats (trainable variables) self.model_weights[p].data += param_weight * importance + else: + # Non-floats (e.g., batches), no need to aggregate but need to track + self.model_weights[p].data += param_weight if self.model_in_update == self.async_buffer_size: + logging.info("Calibrating tensor type") for p in self.model_weights: - d_type = self.model_weights[p].data.dtype - - self.model_weights[p].data = ( - self.model_weights[p] / float(self.async_buffer_size)).to(dtype=d_type) - - def aggregate_client_group_weights(self, results): - """Streaming weight aggregation. Similar to aggregate_client_weights, - but each key corresponds to a group of weights (e.g., for Tensorflow)""" - - client_staleness = self.round - \ - self.client_model_version[results['clientId']] - importance = 1. / math.sqrt(1 + client_staleness) - - for p_g in results['update_weight']: - param_weights = results['update_weight'][p_g] - for idx, param_weight in enumerate(param_weights): - if isinstance(param_weight, list): - param_weight = np.asarray(param_weight, dtype=np.float32) - param_weight = torch.from_numpy( - param_weight).to(device=self.device) - - if self.model_in_update == 1: - self.model_weights[p_g][idx].data = param_weight * importance - else: - self.model_weights[p_g][idx].data += param_weight * importance - - if self.model_in_update == self.async_buffer_size: - for p in self.model_weights: - for idx in range(len(self.model_weights[p])): - d_type = self.model_weights[p][idx].data.dtype - - self.model_weights[p][idx].data = ( - self.model_weights[p][idx].data / - float(self.async_buffer_size) - ).to(dtype=d_type) + d_type = self.weight_tensor_type[p] + self.model_weights[p].data = (self.model_weights[p].data/float(self.async_buffer_size)).to(dtype=d_type) def round_completion_handler(self): self.global_virtual_clock = self.round_stamp[-1] @@ -173,7 +158,7 @@ def round_completion_handler(self): self.sampled_participants, len(self.sampled_participants)) logging.info(f"{len(clientsToRun)} clients with constant arrival following the order: {clientsToRun}") - + logging.info(f"====Register {len(clientsToRun)} to queue") # Issue requests to the resource manager; Tasks ordered by the completion time self.resource_manager.register_tasks(clientsToRun) self.virtual_client_clock.update(virtual_client_clock) @@ -192,10 +177,12 @@ def round_completion_handler(self): self.test_result_accumulator = [] self.stats_util_accumulator = [] self.client_training_results = [] + self.loss_accumulator = [] if self.round >= self.args.rounds: self.broadcast_aggregator_events(commons.SHUT_DOWN) elif self.round % self.args.eval_interval == 0: + self.test_model = copy.deepcopy(self.model) self.broadcast_aggregator_events(commons.UPDATE_MODEL) self.broadcast_aggregator_events(commons.MODEL_TEST) else: @@ -206,7 +193,30 @@ def find_latest_model(self, start_time): for i, time_stamp in enumerate(reversed(self.round_stamp)): if start_time >= time_stamp: return len(self.round_stamp) - i - return None + return 1 + + def get_test_config(self, client_id): + """FL model testing on clients, developers can further define personalized client config here. + + Args: + client_id (int): The client id. + + Returns: + dictionary: The testing config for new task. + + """ + # Get the straggler round-id + client_tasks = self.resource_manager.client_run_queue + current_pending_length = min( + self.resource_manager.client_run_queue_idx, len(client_tasks)-1) + + current_pending_clients = client_tasks[current_pending_length:] + straggler_round = 1e10 + for client in current_pending_clients: + straggler_round = min( + self.find_latest_model(self.client_start_time[client]), straggler_round) + + return {'client_id': client_id, 'straggler_round': straggler_round, 'test_model': self.test_model} def get_client_conf(self, clientId): """Training configurations that will be applied on clients""" @@ -214,7 +224,7 @@ def get_client_conf(self, clientId): model_id = self.find_latest_model(start_time) self.client_model_version[clientId] = model_id end_time = self.client_round_duration[clientId] + start_time - logging.info(f"Client {clientId} train on model {model_id} during {start_time}-{end_time}") + logging.info(f"Client {clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") conf = { 'learning_rate': self.args.learning_rate, @@ -227,17 +237,17 @@ def create_client_task(self, executorId): next_clientId = self.resource_manager.get_next_task(executorId) train_config = None - # NOTE: model = None then the executor will load the global model broadcasted in UPDATE_MODEL - model = None + model_version = None if next_clientId != None: config = self.get_client_conf(next_clientId) + model_version = self.find_latest_model(self.client_start_time[next_clientId]) train_config = {'client_id': next_clientId, 'task_config': config} - return train_config, model + return train_config, model_version def CLIENT_EXECUTE_COMPLETION(self, request, context): """FL clients complete the execution task. - + Args: request (CompleteRequest): Complete request info from executor. @@ -249,19 +259,12 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): executor_id, client_id, event = request.executor_id, request.client_id, request.event execution_status, execution_msg = request.status, request.msg meta_result, data_result = request.meta_result, request.data_result - + if event == commons.CLIENT_TRAIN: # Training results may be uploaded in CLIENT_EXECUTE_RESULT request later, # so we need to specify whether to ask client to do so (in case of straggler/timeout in real FL). if execution_status is False: logging.error(f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}") - - if self.resource_manager.has_next_task(executor_id): - # NOTE: we do not pop the train immediately in simulation mode, - # since the executor may run multiple clients - if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: - self.individual_client_events[executor_id].append( - commons.CLIENT_TRAIN) elif event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): self.add_event_handler( @@ -269,6 +272,13 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): else: logging.error(f"Received undefined event {event} from client {client_id}") + if self.resource_manager.has_next_task(executor_id): + # NOTE: we do not pop the train immediately in simulation mode, + # since the executor may run multiple clients + if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: + self.individual_client_events[executor_id].append( + commons.CLIENT_TRAIN) + return self.CLIENT_PING(request, context) def log_train_result(self, avg_loss): @@ -304,7 +314,7 @@ def event_monitor(self): if current_event == commons.UPLOAD_MODEL: self.client_completion_handler( self.deserialize_response(data)) - if len(self.stats_util_accumulator) == self.async_buffer_size: + if self.model_in_update == self.async_buffer_size: clientID = self.deserialize_response(data)['clientId'] self.round_stamp.append( self.client_round_duration[clientID] + self.client_start_time[clientID]) diff --git a/examples/async_fl/async_client.py b/examples/async_fl/async_client.py new file mode 100644 index 00000000..016704b9 --- /dev/null +++ b/examples/async_fl/async_client.py @@ -0,0 +1,64 @@ +import copy +import logging +import math + +import torch +from torch.autograd import Variable + +from fedscale.core.execution.client import Client +from fedscale.core.execution.optimizers import ClientOptimizer +from fedscale.dataloaders.nlp import mask_tokens + + +class Client(Client): + """Basic client component in Federated Learning""" + + def train(self, client_data, model, conf): + + clientId = conf.clientId + logging.info(f"Start to train (CLIENT: {clientId}) ...") + tokenizer, device = conf.tokenizer, conf.device + + model = model.to(device=device) + model.train() + + trained_unique_samples = min( + len(client_data.dataset), conf.local_steps * conf.batch_size) + + self.global_model = None + if conf.gradient_policy == 'fed-prox': + # could be move to optimizer + self.global_model = [param.data.clone() for param in model.parameters()] + + prev_model_dict = copy.deepcopy(model.state_dict()) + optimizer = self.get_optimizer(model, conf) + criterion = self.get_criterion(conf) + error_type = None + + # TODO: One may hope to run fixed number of epochs, instead of iterations + while self.completed_steps < conf.local_steps: + try: + self.train_step(client_data, conf, model, optimizer, criterion) + except Exception as ex: + error_type = ex + break + + state_dicts = model.state_dict() + # In async, we need the delta_weight only + model_param = {p: (state_dicts[p] - prev_model_dict[p]).data.cpu().numpy() + for p in state_dicts} + results = {'clientId': clientId, 'moving_loss': self.epoch_train_loss, + 'trained_size': self.completed_steps*conf.batch_size, + 'success': self.completed_steps == conf.batch_size} + results['utility'] = math.sqrt( + self.loss_squre)*float(trained_unique_samples) + + if error_type is None: + logging.info(f"Training of (CLIENT: {clientId}) completes, {results}") + else: + logging.info(f"Training of (CLIENT: {clientId}) failed as {error_type}") + + results['update_weight'] = model_param + results['wall_duration'] = 0 + + return results diff --git a/examples/async_fl/async_executor.py b/examples/async_fl/async_executor.py index cfb2b1e2..03165a5f 100644 --- a/examples/async_fl/async_executor.py +++ b/examples/async_fl/async_executor.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import copy import pickle import fedscale.core.channels.job_api_pb2 as job_api_pb2 @@ -7,6 +8,9 @@ from fedscale.core.logger.execution import * from fedscale.core import commons +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from async_client import Client as CustomizedClient + class AsyncExecutor(Executor): """Each executor takes certain resource to run real training. Each run simulates the execution of an individual client""" @@ -14,35 +18,35 @@ class AsyncExecutor(Executor): def __init__(self, args): super().__init__(args) self.temp_model_path_version = lambda round: os.path.join( - logDir, 'model_' + str(round) + '.pth.tar') + logDir, f'model_{self.this_rank}_{round}.pth.tar') def update_model_handler(self, model): """Update the model copy on this executor""" - self.model = model self.round += 1 # Dump latest model to disk with open(self.temp_model_path_version(self.round), 'wb') as model_out: - logging.info(f"Received latest model saved at {self.temp_model_path_version(self.round)}") - pickle.dump(self.model, model_out) + logging.info( + f"Received latest model saved at {self.temp_model_path_version(self.round)}" + ) + pickle.dump(model, model_out) def load_global_model(self, round=None): # load last global model - if round == -1: - with open(self.temp_model_path, 'rb') as model_in: - model = pickle.load(model_in) - else: - round = min(round, self.round) if round is not None else self.round - with open(self.temp_model_path_version(round), 'rb') as model_in: - model = pickle.load(model_in) + logging.info(f"====Load global model with version {round}") + round = min(round, self.round) if round is not None else self.round + with open(self.temp_model_path_version(round), 'rb') as model_in: + model = pickle.load(model_in) return model + def get_client_trainer(self, conf): + return CustomizedClient(conf) + def training_handler(self, clientId, conf, model=None): """Train model given client ids""" # Here model is model_id - client_model = self.load_global_model(-1) if model is None \ - else self.load_global_model(model) + client_model = self.load_global_model(model) conf.clientId, conf.device = clientId, self.device conf.tokenizer = tokenizer @@ -63,8 +67,50 @@ def training_handler(self, clientId, conf, model=None): return train_res + def testing_handler(self, args, config=None): + + evalStart = time.time() + device = self.device + model = config['test_model'] + if self.task == 'rl': + client = RLClient(args) + test_res = client.test(args, self.this_rank, model, device=device) + _, _, _, testResults = test_res + else: + data_loader = select_dataset(self.this_rank, self.testing_sets, + batch_size=args.test_bsz, args=args, + isTest=True, collate_fn=self.collate_fn + ) + + if self.task == 'voice': + criterion = CTCLoss(reduction='mean').to(device=device) + else: + criterion = torch.nn.CrossEntropyLoss().to(device=device) + + if self.args.engine == commons.PYTORCH: + test_res = test_model(self.this_rank, model, data_loader, + device=device, criterion=criterion, tokenizer=tokenizer) + else: + raise Exception(f"Need customized implementation for model testing in {self.args.engine} engine") + + test_loss, acc, acc_5, testResults = test_res + logging.info("After aggregation round {}, CumulTime {}, eval_time {}, test_loss {}, test_accuracy {:.2f}%, test_5_accuracy {:.2f}% \n" + .format(self.round, round(time.time() - self.start_run_time, 4), round(time.time() - evalStart, 4), test_loss, acc*100., acc_5*100.)) + + gc.collect() + + return testResults + def check_model_version(self, model_id): - return os.path.exists(self.temp_model_path_version(round)) + return os.path.exists(self.temp_model_path_version(model_id)) + + def remove_stale_models(self, straggler_round): + """Remove useless models kept for async execution in the past""" + logging.info(f"Current straggler round is {straggler_round}") + for r in range(min(straggler_round-1, self.round)): + if self.check_model_version(r): + logging.info(f"Executor {self.this_rank} removes stale model version {r}") + os.remove(self.temp_model_path_version(r)) def event_monitor(self): """Activate event handler once receiving new message @@ -99,7 +145,9 @@ def event_monitor(self): future_call.add_done_callback(lambda _response: self.dispatch_worker_events(_response.result())) elif current_event == commons.MODEL_TEST: - self.Test(self.deserialize_response(request.meta)) + test_configs = self.deserialize_response(request.meta) + self.remove_stale_models(test_configs['straggler_round']) + self.Test(test_configs) elif current_event == commons.UPDATE_MODEL: broadcast_config = self.deserialize_response(request.data) @@ -112,7 +160,7 @@ def event_monitor(self): elif current_event == commons.DUMMY_EVENT: pass else: - time.sleep(10) + time.sleep(1) self.client_ping() if __name__ == "__main__": diff --git a/fedscale/core/aggregation/aggregator.py b/fedscale/core/aggregation/aggregator.py index 1593e350..900b5c58 100755 --- a/fedscale/core/aggregation/aggregator.py +++ b/fedscale/core/aggregation/aggregator.py @@ -467,9 +467,11 @@ def save_last_param(self): if self.args.engine == commons.TENSORFLOW: self.last_gradient_weights = [ layer.get_weights() for layer in self.model.layers] + self.model_weights = copy.deepcopy(self.model.state_dict()) else: self.last_gradient_weights = [ p.data.clone() for p in self.model.parameters()] + self.model_weights = copy.deepcopy(self.model.state_dict()) def round_weight_handler(self, last_model): """Update model when the round completes @@ -552,6 +554,7 @@ def round_completion_handler(self): self.test_result_accumulator = [] self.stats_util_accumulator = [] self.client_training_results = [] + self.loss_accumulator = [] if self.round >= self.args.rounds: self.broadcast_aggregator_events(commons.SHUT_DOWN) @@ -785,7 +788,7 @@ def CLIENT_PING(self, request, context): # while multiple client_id may use the same executor_id (VMs) in simulations executor_id, client_id = request.executor_id, request.client_id response_data = response_msg = commons.DUMMY_RESPONSE - + if len(self.individual_client_events[executor_id]) == 0: # send dummy response current_event = commons.DUMMY_EVENT diff --git a/fedscale/core/execution/client.py b/fedscale/core/execution/client.py index 986d28a0..03575974 100644 --- a/fedscale/core/execution/client.py +++ b/fedscale/core/execution/client.py @@ -60,7 +60,8 @@ def train(self, client_data, model, conf): model_param = {p: state_dicts[p].data.cpu().numpy() for p in state_dicts} results = {'clientId': clientId, 'moving_loss': self.epoch_train_loss, - 'trained_size': self.completed_steps*conf.batch_size, 'success': self.completed_steps > 0} + 'trained_size': self.completed_steps*conf.batch_size, + 'success': self.completed_steps == conf.batch_size} results['utility'] = math.sqrt( self.loss_squre)*float(trained_unique_samples) diff --git a/fedscale/core/execution/executor.py b/fedscale/core/execution/executor.py index 66272d67..d77bfc12 100755 --- a/fedscale/core/execution/executor.py +++ b/fedscale/core/execution/executor.py @@ -33,7 +33,7 @@ def __init__(self, args): self.executor_id = str(self.this_rank) # ======== model and data ======== - self.model = self.training_sets = self.test_dataset = None + self.training_sets = self.test_dataset = None self.temp_model_path = os.path.join( logDir, 'model_'+str(args.this_rank)+'.pth.tar') @@ -134,7 +134,6 @@ def run(self): """Start running the executor by setting up execution and communication environment, and monitoring the grpc message. """ self.setup_env() - self.model = self.init_model() self.training_sets, self.testing_sets = self.init_data() self.setup_communication() self.event_monitor() @@ -220,7 +219,7 @@ def Test(self, config): config (dictionary): The client testing config. """ - test_res = self.testing_handler(args=self.args) + test_res = self.testing_handler(args=self.args, config=config) test_res = {'executorId': self.this_rank, 'results': test_res} # Report execution completion information @@ -255,12 +254,11 @@ def update_model_handler(self, model): config (PyTorch or TensorFlow model): The broadcasted global model """ - self.model = model self.round += 1 # Dump latest model to disk with open(self.temp_model_path, 'wb') as model_out: - pickle.dump(self.model, model_out) + pickle.dump(model, model_out) def load_global_model(self): """ Load last global model @@ -335,12 +333,12 @@ def training_handler(self, clientId, conf, model=None): return train_res - def testing_handler(self, args): + def testing_handler(self, args, config=None): """Test model Args: args (dictionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py - + config (dictionary): Variable arguments from coordinator. Returns: dictionary: The test result diff --git a/fedscale/core/resource_manager.py b/fedscale/core/resource_manager.py index 5cb92b4c..ef25f465 100644 --- a/fedscale/core/resource_manager.py +++ b/fedscale/core/resource_manager.py @@ -18,6 +18,11 @@ def register_tasks(self, clientsToRun): self.client_run_queue = clientsToRun.copy() self.client_run_queue_idx = 0 + def get_remaining(self) -> int: + """Number of tasks left in the queue + """ + return self.get_task_length() + def get_task_length(self) -> int: """Number of tasks left in the queue diff --git a/fedscale/dataloaders/divide_data.py b/fedscale/dataloaders/divide_data.py index fc633418..43c90c0b 100755 --- a/fedscale/dataloaders/divide_data.py +++ b/fedscale/dataloaders/divide_data.py @@ -132,7 +132,10 @@ def select_dataset(rank, partition, batch_size, args, isTest=False, collate_fn=N """Load data given client Id""" partition = partition.use(rank - 1, isTest) dropLast = False if isTest else True - num_loaders = min(int(len(partition)/args.batch_size/2), args.num_loaders) + if isTest: + num_loaders = 0 + else: + num_loaders = min(int(len(partition)/args.batch_size/2), args.num_loaders) if num_loaders == 0: time_out = 0 else: From dea2ceda1c1551fc3341f908292ca619f4e5d029 Mon Sep 17 00:00:00 2001 From: fanlai0990 Date: Wed, 17 Aug 2022 22:52:54 -0500 Subject: [PATCH 02/12] Async and Model Zoo --- benchmark/configs/async_fl/async_fl.yml | 18 ++-- benchmark/configs/cifar_cpu/cifar_cpu.yml | 5 +- benchmark/configs/femnist/conf.yml | 20 ++-- examples/async_fl/async_aggregator.py | 112 ++++++++++------------ examples/async_fl/async_client.py | 1 + examples/async_fl/async_executor.py | 10 +- examples/async_fl/resource_manager.py | 12 +-- fedscale/core/aggregation/aggregator.py | 28 +++--- fedscale/core/config_parser.py | 5 +- fedscale/core/execution/client.py | 2 +- fedscale/core/execution/executor.py | 4 +- fedscale/core/fllibs.py | 3 +- fedscale/core/resource_manager.py | 5 - fedscale/dataloaders/utils_data.py | 2 +- fedscale/utils/models/cv_models/README.md | 6 +- 15 files changed, 113 insertions(+), 120 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 195b90fc..73bbc9fb 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -29,16 +29,23 @@ setup_commands: # ========== Additional job configuration ========== # Default parameters are specified in config_parser.py, wherein more description of the parameter can be found +# NOTE: We are supporting and improving the following implementation (Async FL) in FedScale: + # - "PAPAYA: Practical, Private, and Scalable Federated Learning", MLSys, 2022 + # - "Federated Learning with Buffered Asynchronous Aggregation", AISTATS, 2022 + +# We appreciate you to contribute and/or report bugs. Thank you! + job_conf: - - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp - - log_path: $FEDSCALE_HOME/benchmark # Path of log files - - num_participants: 20 # Number of participants per round, we use K=100 in our paper, large K will be much slower - - data_set: femnist # Dataset: openImg, google_speech, stackoverflow + - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp + - log_path: $FEDSCALE_HOME/benchmark # Path of log files + - async_buffer: 50 # Number of updates need to be aggregated before generating new model version + - data_set: femnist # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/femnist # Path of the dataset - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided #- device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace #- device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - - model: shufflenet_v2_x2_0 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2 + - model: shufflenet_v2_x2_0 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs + # - model_zoo: fedscale-zoo - eval_interval: 5 # How many rounds to run a testing on the testing set - rounds: 3000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples @@ -50,5 +57,4 @@ job_conf: - use_cuda: True - decay_round: 50 - overcommitment: 1.0 - - async_buffer: 20 - arrival_interval: 3 \ No newline at end of file diff --git a/benchmark/configs/cifar_cpu/cifar_cpu.yml b/benchmark/configs/cifar_cpu/cifar_cpu.yml index efa3fbbd..5d1c4c4a 100644 --- a/benchmark/configs/cifar_cpu/cifar_cpu.yml +++ b/benchmark/configs/cifar_cpu/cifar_cpu.yml @@ -34,13 +34,14 @@ job_conf: - num_participants: 4 # Number of participants per round, we use K=100 in our paper, large K will be much slower - data_set: cifar10 # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/ # Path of the dataset - - model: shufflenet_v2_x2_0 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2# - gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default + - model: resnet56_cifar10 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs + - model_zoo: fedscale-zoo # Default zoo (torchcv) uses the pytorchvision zoo, which can not support small images well - eval_interval: 5 # How many rounds to run a testing on the testing set - rounds: 600 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 0 # Remove clients w/ less than 21 samples - num_loaders: 2 - local_steps: 20 - - learning_rate: 0.001 + - learning_rate: 0.05 - batch_size: 32 - test_bsz: 32 - use_cuda: False diff --git a/benchmark/configs/femnist/conf.yml b/benchmark/configs/femnist/conf.yml index e5960c3b..ad974c78 100644 --- a/benchmark/configs/femnist/conf.yml +++ b/benchmark/configs/femnist/conf.yml @@ -2,13 +2,14 @@ # ========== Cluster configuration ========== # ip address of the parameter server (need 1 GPU process) -ps_ip: localhost +ps_ip: 10.0.0.1 # ip address of each worker:# of available gpus process on each gpu in this node # Note that if we collocate ps and worker on same GPU, then we need to decrease this number of available processes on that GPU by 1 # E.g., master node has 4 available processes, then 1 for the ps, and worker should be set to: worker:3 -worker_ips: - - localhost:[2] +worker_ips: + - 10.0.0.1:[4] + - 10.0.0.2:[4] exp_path: $FEDSCALE_HOME/fedscale/core @@ -31,24 +32,21 @@ setup_commands: job_conf: - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp - log_path: $FEDSCALE_HOME/benchmark # Path of log files - - num_participants: 20 # Number of participants per round, we use K=100 in our paper, large K will be much slower + - num_participants: 20 # Number of participants per round, we use K=100 in our paper, large K will be much slower - data_set: femnist # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/femnist # Path of the dataset - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - - model: shufflenet_v2_x2_0 # Models: e.g., shufflenet_v2_x2_0, mobilenet_v2, resnet34, albert-base-v2 - - gradient_policy: yogi # {"fed-yogi", "fed-prox", "fed-avg"}, "fed-avg" by default - - eval_interval: 30 # How many rounds to run a testing on the testing set + - model: resnet56_cifar10 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs + - model_zoo: fedscale-zoo + - eval_interval: 10 # How many rounds to run a testing on the testing set - rounds: 5000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples - num_loaders: 2 - - yogi_eta: 3e-3 - - yogi_tau: 1e-8 - local_steps: 20 - learning_rate: 0.05 - batch_size: 20 - test_bsz: 20 - - malicious_factor: 4 - - use_cuda: False + - use_cuda: True diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index 3b97d6a4..8cdb09f1 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -15,6 +15,11 @@ MAX_MESSAGE_LENGTH = 1 * 1024 * 1024 * 1024 # 1GB +# NOTE: We are supporting and improving the following implementation (Async FL) in FedScale: + # - "PAPAYA: Practical, Private, and Scalable Federated Learning", MLSys, 2022 + # - "Federated Learning with Buffered Asynchronous Aggregation", AISTATS, 2022 + +# We appreciate you to contribute and/or report bugs. Thank you! class AsyncAggregator(Aggregator): """This centralized aggregator collects training/testing feedbacks from executors""" @@ -29,8 +34,11 @@ def __init__(self, args): self.client_model_version = {} self.virtual_client_clock = {} self.weight_tensor_type = {} + # We need to keep the test model for specific round to avoid async mismatch self.test_model = None + self.aggregate_update = {} + self.importance_sum = 0 def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): @@ -98,11 +106,14 @@ def aggregate_client_weights(self, results): "PAPAYA: PRACTICAL, PRIVATE, AND SCALABLE FEDERATED LEARNING". MLSys, 2022 """ # Start to take the average of updates, and we do not keep updates to save memory - # Importance of each update is 1/#_of_participants * staleness - # importance = 1./self.tasks_round * staleness - client_staleness = self.round - \ - self.client_model_version[results['clientId']] - importance = 1. / math.sqrt(1 + client_staleness) + # Importance of each update is 1/staleness + client_staleness = self.round - self.client_model_version[results['clientId']] + importance = 1. #/ (math.sqrt(1 + client_staleness)) + + new_round_aggregation = (self.model_in_update == 1) + if new_round_aggregation: + self.importance_sum = 0 + self.importance_sum += importance for p in results['update_weight']: # Different to core/executor, update_weight here is (train_model_weight - untrained) @@ -113,21 +124,27 @@ def aggregate_client_weights(self, results): param_weight = torch.from_numpy( param_weight).to(device=self.device) - if self.model_weights[p].data.dtype in ( - torch.float, torch.double, torch.half, - torch.bfloat16, torch.chalf, torch.cfloat, torch.cdouble - ): + # if self.model_weights[p].data.dtype in ( + # torch.float, torch.double, torch.half, + # torch.bfloat16, torch.chalf, torch.cfloat, torch.cdouble + # ): # Only assign importance to floats (trainable variables) - self.model_weights[p].data += param_weight * importance + if new_round_aggregation: + self.aggregate_update[p] = param_weight * importance else: - # Non-floats (e.g., batches), no need to aggregate but need to track - self.model_weights[p].data += param_weight + self.aggregate_update[p] += param_weight * importance + + # self.model_weights[p].data += param_weight * importance + # else: + # # Non-floats (e.g., num_batches), no need to aggregate but need to track + # self.aggregate_update[p] = param_weight if self.model_in_update == self.async_buffer_size: - logging.info("Calibrating tensor type") for p in self.model_weights: d_type = self.weight_tensor_type[p] - self.model_weights[p].data = (self.model_weights[p].data/float(self.async_buffer_size)).to(dtype=d_type) + self.model_weights[p].data = ( + self.model_weights[p].data + self.aggregate_update[p]/self.importance_sum + ).to(dtype=d_type) def round_completion_handler(self): self.global_virtual_clock = self.round_stamp[-1] @@ -142,18 +159,19 @@ def round_completion_handler(self): avg_loss = sum(self.loss_accumulator) / \ max(1, len(self.loss_accumulator)) - logging.info(f"Wall clock: {round(self.global_virtual_clock)} s, round: {self.round}, Remaining participants: " + - f"{self.resource_manager.get_remaining()}, Succeed participants: " + - f"{len(self.stats_util_accumulator)}, Training loss: {avg_loss}") + logging.info(f"Wall clock: {round(self.global_virtual_clock)} s, round: {self.round}, asyn running participants: " + + f"{self.resource_manager.get_task_length()}, aggregating {len(self.stats_util_accumulator)} participants, " + + f"training loss: {avg_loss}") # dump round completion information to tensorboard if len(self.loss_accumulator): self.log_train_result(avg_loss) # update select participants + # NOTE: we simulate async, while have to sync every 20 rounds to avoid large division to trace if self.resource_manager.get_task_length() < self.async_buffer_size: self.sampled_participants = self.select_participants( - select_num_participants=self.args.num_participants, overcommitment=self.args.overcommitment) + select_num_participants=self.async_buffer_size*2, overcommitment=self.args.overcommitment) (clientsToRun, clientsStartTime, virtual_client_clock) = self.tictak_client_tasks( self.sampled_participants, len(self.sampled_participants)) @@ -216,19 +234,14 @@ def get_test_config(self, client_id): straggler_round = min( self.find_latest_model(self.client_start_time[client]), straggler_round) - return {'client_id': client_id, 'straggler_round': straggler_round, 'test_model': self.test_model} + return {'client_id': client_id, + 'straggler_round': straggler_round, + 'test_model': self.test_model} def get_client_conf(self, clientId): """Training configurations that will be applied on clients""" - start_time = self.client_start_time[clientId] - model_id = self.find_latest_model(start_time) - self.client_model_version[clientId] = model_id - end_time = self.client_round_duration[clientId] + start_time - logging.info(f"Client {clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") - conf = { 'learning_rate': self.args.learning_rate, - 'model': model_id # none indicates we are using the global model } return conf @@ -237,49 +250,22 @@ def create_client_task(self, executorId): next_clientId = self.resource_manager.get_next_task(executorId) train_config = None - model_version = None + model = None if next_clientId != None: config = self.get_client_conf(next_clientId) - model_version = self.find_latest_model(self.client_start_time[next_clientId]) - train_config = {'client_id': next_clientId, 'task_config': config} - return train_config, model_version - - def CLIENT_EXECUTE_COMPLETION(self, request, context): - """FL clients complete the execution task. - - Args: - request (CompleteRequest): Complete request info from executor. - - Returns: - ServerResponse: Server response to job completion request + start_time = self.client_start_time[next_clientId] + model_id = self.find_latest_model(start_time) + self.client_model_version[next_clientId] = model_id + end_time = self.client_round_duration[next_clientId] + start_time - """ - - executor_id, client_id, event = request.executor_id, request.client_id, request.event - execution_status, execution_msg = request.status, request.msg - meta_result, data_result = request.meta_result, request.data_result + # The executor has already received the model, thus transfering id is enough + model = model_id + train_config = {'client_id': next_clientId, 'task_config': config} - if event == commons.CLIENT_TRAIN: - # Training results may be uploaded in CLIENT_EXECUTE_RESULT request later, - # so we need to specify whether to ask client to do so (in case of straggler/timeout in real FL). - if execution_status is False: - logging.error(f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}") + logging.info(f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") - elif event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): - self.add_event_handler( - executor_id, event, meta_result, data_result) - else: - logging.error(f"Received undefined event {event} from client {client_id}") - - if self.resource_manager.has_next_task(executor_id): - # NOTE: we do not pop the train immediately in simulation mode, - # since the executor may run multiple clients - if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: - self.individual_client_events[executor_id].append( - commons.CLIENT_TRAIN) - - return self.CLIENT_PING(request, context) + return train_config, model def log_train_result(self, avg_loss): """Result will be post on TensorBoard""" diff --git a/examples/async_fl/async_client.py b/examples/async_fl/async_client.py index 016704b9..cdf5212f 100644 --- a/examples/async_fl/async_client.py +++ b/examples/async_fl/async_client.py @@ -1,6 +1,7 @@ import copy import logging import math +import pickle import torch from torch.autograd import Variable diff --git a/examples/async_fl/async_executor.py b/examples/async_fl/async_executor.py index 03165a5f..2b0ece57 100644 --- a/examples/async_fl/async_executor.py +++ b/examples/async_fl/async_executor.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -import copy import pickle import fedscale.core.channels.job_api_pb2 as job_api_pb2 @@ -107,10 +106,11 @@ def check_model_version(self, model_id): def remove_stale_models(self, straggler_round): """Remove useless models kept for async execution in the past""" logging.info(f"Current straggler round is {straggler_round}") - for r in range(min(straggler_round-1, self.round)): - if self.check_model_version(r): - logging.info(f"Executor {self.this_rank} removes stale model version {r}") - os.remove(self.temp_model_path_version(r)) + stale_version = straggler_round-1 + while self.check_model_version(stale_version): + logging.info(f"Executor {self.this_rank} removes stale model version {stale_version}") + os.remove(self.temp_model_path_version(stale_version)) + stale_version -= 1 def event_monitor(self): """Activate event handler once receiving new message diff --git a/examples/async_fl/resource_manager.py b/examples/async_fl/resource_manager.py index 4fbf2776..b2f8d258 100644 --- a/examples/async_fl/resource_manager.py +++ b/examples/async_fl/resource_manager.py @@ -12,17 +12,15 @@ def __init__(self, experiment_mode): self.experiment_mode = experiment_mode self.update_lock = threading.Lock() - def get_remaining(self): - return len(self.client_run_queue) + def get_task_length(self): + self.update_lock.acquire() + remaining_task_num: int = len(self.client_run_queue) + self.update_lock.release() + return remaining_task_num def register_tasks(self, clientsToRun): self.client_run_queue += clientsToRun.copy() - def remove_client_task(self, client_id): - assert (client_id in self.client_run_queue, - f"client task {client_id} is not in task queue") - pass - def has_next_task(self, client_id=None): exist_next_task = False if self.experiment_mode == commons.SIMULATION_MODE: diff --git a/fedscale/core/aggregation/aggregator.py b/fedscale/core/aggregation/aggregator.py index 900b5c58..3fc39fdc 100755 --- a/fedscale/core/aggregation/aggregator.py +++ b/fedscale/core/aggregation/aggregator.py @@ -473,6 +473,13 @@ def save_last_param(self): p.data.clone() for p in self.model.parameters()] self.model_weights = copy.deepcopy(self.model.state_dict()) + def update_default_task_config(self): + """Update the default task configuration after each round + """ + if self.round % self.args.decay_round == 0: + self.args.learning_rate = max( + self.args.learning_rate*self.args.decay_factor, self.args.min_learning_rate) + def round_weight_handler(self, last_model): """Update model when the round completes @@ -499,10 +506,6 @@ def round_completion_handler(self): self.global_virtual_clock += self.round_duration self.round += 1 - if self.round % self.args.decay_round == 0: - self.args.learning_rate = max( - self.args.learning_rate*self.args.decay_factor, self.args.min_learning_rate) - # handle the global update w/ current and last self.round_weight_handler(self.last_gradient_weights) @@ -555,7 +558,8 @@ def round_completion_handler(self): self.stats_util_accumulator = [] self.client_training_results = [] self.loss_accumulator = [] - + self.update_default_task_config() + if self.round >= self.args.rounds: self.broadcast_aggregator_events(commons.SHUT_DOWN) elif self.round % self.args.eval_interval == 0: @@ -678,7 +682,6 @@ def get_client_conf(self, clientId): """ conf = { 'learning_rate': self.args.learning_rate, - 'model': None # none indicates we are using the global model } return conf @@ -841,17 +844,20 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): # so we need to specify whether to ask client to do so (in case of straggler/timeout in real FL). if execution_status is False: logging.error(f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}") - if self.resource_manager.has_next_task(executor_id): - # NOTE: we do not pop the train immediately in simulation mode, - # since the executor may run multiple clients - self.individual_client_events[executor_id].append( - commons.CLIENT_TRAIN) elif event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): self.add_event_handler( executor_id, event, meta_result, data_result) else: logging.error(f"Received undefined event {event} from client {client_id}") + + if self.resource_manager.has_next_task(executor_id): + # NOTE: we do not pop the train immediately in simulation mode, + # since the executor may run multiple clients + if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: + self.individual_client_events[executor_id].append( + commons.CLIENT_TRAIN) + return self.CLIENT_PING(request, context) def event_monitor(self): diff --git a/fedscale/core/config_parser.py b/fedscale/core/config_parser.py index 70cfe9e2..7d8516e8 100644 --- a/fedscale/core/config_parser.py +++ b/fedscale/core/config_parser.py @@ -174,9 +174,8 @@ parser.add_argument("--n_states", type=int, default=4, help="state number") -# for speech parser.add_argument("--num_classes", type=int, default=35, - help="For number of classes in speech") + help="For number of classes of the dataset") # for voice @@ -231,7 +230,7 @@ 'resnet': 0.135/0.0554, } -args.num_class = datasetCategories.get(args.data_set, 10) +args.num_class = datasetCategories.get(args.data_set, args.num_classes) for model_name in model_factor: if model_name in args.model: args.clock_factor = args.clock_factor * model_factor[model_name] diff --git a/fedscale/core/execution/client.py b/fedscale/core/execution/client.py index 37784f96..f965e2dc 100644 --- a/fedscale/core/execution/client.py +++ b/fedscale/core/execution/client.py @@ -1,6 +1,6 @@ import logging import math - +import pickle import torch from torch.autograd import Variable diff --git a/fedscale/core/execution/executor.py b/fedscale/core/execution/executor.py index d77bfc12..ab51a573 100755 --- a/fedscale/core/execution/executor.py +++ b/fedscale/core/execution/executor.py @@ -193,8 +193,8 @@ def Train(self, config): client_id, train_config = config['client_id'], config['task_config'] model = None - if 'model' in train_config and train_config['model'] is not None: - model = train_config['model'] + if 'model' in config and config['model'] is not None: + model = config['model'] client_conf = self.override_conf(train_config) train_res = self.training_handler( diff --git a/fedscale/core/fllibs.py b/fedscale/core/fllibs.py index 6dcf6e3b..d0b90ad9 100644 --- a/fedscale/core/fllibs.py +++ b/fedscale/core/fllibs.py @@ -209,7 +209,8 @@ def init_model(): else: if args.model_zoo == "fedscale-zoo": if args.task == "cv": - model = get_cv_model() + model = get_cv_model(name=args.model, + num_classes=outputClass[args.data_set]) else: raise NameError(f"Model zoo {args.model_zoo} does not exist") elif args.model_zoo == "torchcv": diff --git a/fedscale/core/resource_manager.py b/fedscale/core/resource_manager.py index ef25f465..5cb92b4c 100644 --- a/fedscale/core/resource_manager.py +++ b/fedscale/core/resource_manager.py @@ -18,11 +18,6 @@ def register_tasks(self, clientsToRun): self.client_run_queue = clientsToRun.copy() self.client_run_queue_idx = 0 - def get_remaining(self) -> int: - """Number of tasks left in the queue - """ - return self.get_task_length() - def get_task_length(self) -> int: """Number of tasks left in the queue diff --git a/fedscale/dataloaders/utils_data.py b/fedscale/dataloaders/utils_data.py index c356912c..12f3fddf 100755 --- a/fedscale/dataloaders/utils_data.py +++ b/fedscale/dataloaders/utils_data.py @@ -157,4 +157,4 @@ def get_data_transform(data: str): print('Data must be {} or {} !'.format('mnist', 'cifar')) sys.exit(-1) - return train_transform, test_transform + return train_transform, test_transform \ No newline at end of file diff --git a/fedscale/utils/models/cv_models/README.md b/fedscale/utils/models/cv_models/README.md index bd198322..4b89307e 100644 --- a/fedscale/utils/models/cv_models/README.md +++ b/fedscale/utils/models/cv_models/README.md @@ -1,6 +1,8 @@ # Computer vision models -This folder contains 70+ computer vision models from [Imgclsmob](https://github.com/osmr/imgclsmob/blob/master/pytorch/README.md). We borrow their implementations, and change model APIs (e.g., num_classes), and integrate them into FedScale benchmarking. +This folder contains 70+ computer vision models. Some are from [Imgclsmob](https://github.com/osmr/imgclsmob/blob/master/pytorch/README.md). We reimplement some of them, add new APIs (e.g., num_classes), and support them in FedScale benchmarking. **Please acknowledge to [Imgclsmob](https://github.com/osmr/imgclsmob) if you use any of the model herein**. -The full list of supported models are available [here](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/utils/models/model_provider.py#L179). Note that for small images (e.g., FMNIST), we suggest using models with ```-cifar``` suffix, as they have smaller kernels and strides. \ No newline at end of file +The full list of supported models are available [here](https://github.com/SymbioticLab/FedScale/blob/master/fedscale/utils/models/model_provider.py#L179). Note that for small images (e.g., FMNIST), we suggest using models with ```-cifar``` suffix, as they have smaller kernels and strides. Meanwhile, please ignore the suffix ``-cifar10`` or ``-cifar100`` as their model num_classes will be automatically overrided by the ``--num_classes`` setting of the dataset. + +**We are adding new models, and appreciate if you can consider contributing yours! Please feel free to report bugs.** \ No newline at end of file From d3fc5a75ebce532f2d2de2f1dcc7a961591dda4f Mon Sep 17 00:00:00 2001 From: fanlai0990 Date: Wed, 17 Aug 2022 22:57:04 -0500 Subject: [PATCH 03/12] merge branch --- benchmark/configs/async_fl/async_fl.yml | 10 ++--- examples/async_fl/async_aggregator.py | 4 +- fedscale/dataloaders/utils_data.py | 58 ++++++------------------- 3 files changed, 21 insertions(+), 51 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 73bbc9fb..0e267679 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -38,14 +38,14 @@ setup_commands: job_conf: - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp - log_path: $FEDSCALE_HOME/benchmark # Path of log files - - async_buffer: 50 # Number of updates need to be aggregated before generating new model version + - async_buffer: 20 # Number of updates need to be aggregated before generating new model version - data_set: femnist # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/femnist # Path of the dataset - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided - #- device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - #- device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - - model: shufflenet_v2_x2_0 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs - # - model_zoo: fedscale-zoo + - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace + - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace + - model: resnet56_cifar100 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs + - model_zoo: fedscale-zoo - eval_interval: 5 # How many rounds to run a testing on the testing set - rounds: 3000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index 8cdb09f1..6e27916e 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -108,7 +108,7 @@ def aggregate_client_weights(self, results): # Start to take the average of updates, and we do not keep updates to save memory # Importance of each update is 1/staleness client_staleness = self.round - self.client_model_version[results['clientId']] - importance = 1. #/ (math.sqrt(1 + client_staleness)) + importance = 1./(math.sqrt(1 + client_staleness)) new_round_aggregation = (self.model_in_update == 1) if new_round_aggregation: @@ -171,7 +171,7 @@ def round_completion_handler(self): # NOTE: we simulate async, while have to sync every 20 rounds to avoid large division to trace if self.resource_manager.get_task_length() < self.async_buffer_size: self.sampled_participants = self.select_participants( - select_num_participants=self.async_buffer_size*2, overcommitment=self.args.overcommitment) + select_num_participants=self.async_buffer_size*20, overcommitment=self.args.overcommitment) (clientsToRun, clientsStartTime, virtual_client_clock) = self.tictak_client_tasks( self.sampled_participants, len(self.sampled_participants)) diff --git a/fedscale/dataloaders/utils_data.py b/fedscale/dataloaders/utils_data.py index 12f3fddf..6de3978a 100755 --- a/fedscale/dataloaders/utils_data.py +++ b/fedscale/dataloaders/utils_data.py @@ -1,57 +1,48 @@ # -*- coding: utf-8 -*- import sys - +import logging from torchvision import transforms def get_data_transform(data: str): if data == 'mnist': train_transform = transforms.Compose([ - # transforms.Grayscale(num_output_channels=1), - transforms.Resize((28, 28)), + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) + transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) ]) test_transform = transforms.Compose([ - # transforms.Grayscale(num_output_channels=1), - transforms.Resize((28, 28)), + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) + transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) ]) elif data == 'cifar': train_transform = transforms.Compose([ - # input arguments: length&width of a figure transforms.RandomCrop(32, padding=4), - # transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # convert PIL image or numpy.ndarray to tensor - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.ToTensor(), - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) elif data == 'imagenet': normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ - # transforms.RandomCrop(32, padding=4), # input arguments: length&width of a figure transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - normalize, # convert PIL image or numpy.ndarray to tensor - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + normalize, ]) test_transform = transforms.Compose([ @@ -59,48 +50,33 @@ def get_data_transform(data: str): transforms.RandomResizedCrop(224), transforms.ToTensor(), normalize, - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif data == 'emnist': train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), - transforms.RandomGrayscale(), transforms.ToTensor(), - # transforms.Resize(224), # input arguments: length&width of a figure - # transforms.RandomResizedCrop(224), - # transforms.RandomHorizontalFlip(), - # transforms.ToTensor(), # convert PIL image or numpy.ndarray to tensor - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) ]) test_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), - transforms.RandomGrayscale(), transforms.ToTensor(), - # transforms.Resize(224), - # transforms.ToTensor(), - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) ]) elif data == 'openImg': train_transform = transforms.Compose([ - # transforms.RandomResizedCrop(224), transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.Resize((256, 256)), - # transforms.RandomResizedCrop((128,128)), - # transforms.CenterCrop(224), transforms.ToTensor(), - #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) @@ -111,13 +87,8 @@ def get_data_transform(data: str): transforms.Resize((299, 299)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - # transforms.Resize(224), # input arguments: length&width of a figure - # transforms.RandomResizedCrop(224), - # transforms.RandomHorizontalFlip(), - # transforms.ToTensor(), # convert PIL image or numpy.ndarray to tensor transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) test_transform = transforms.Compose([ @@ -128,7 +99,6 @@ def get_data_transform(data: str): # transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), - # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif data == 'inaturalist': train_transform = transforms.Compose([ @@ -154,7 +124,7 @@ def get_data_transform(data: str): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) else: - print('Data must be {} or {} !'.format('mnist', 'cifar')) + logging.info(f"Does not find data transform for {data}") sys.exit(-1) - return train_transform, test_transform \ No newline at end of file + return train_transform, test_transform From 6bda3c32792d87f22e5eb2e239d0e0754bf46f0d Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Sun, 21 Aug 2022 09:35:28 -0400 Subject: [PATCH 04/12] debug async timestamp; job scheduling; start ronudgit push --- benchmark/configs/async_fl/async_fl.yml | 12 +-- examples/async_fl/async_aggregator.py | 119 +++++++++++++++++++++--- examples/async_fl/async_client.py | 6 +- examples/async_fl/async_executor.py | 5 +- fedscale/core/config_parser.py | 1 + 5 files changed, 120 insertions(+), 23 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 0e267679..69016aff 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -8,8 +8,7 @@ ps_ip: 10.0.0.1 # Note that if we collocate ps and worker on same GPU, then we need to decrease this number of available processes on that GPU by 1 # E.g., master node has 4 available processes, then 1 for the ps, and worker should be set to: worker:3 worker_ips: - - 10.0.0.1:[4] - - 10.0.0.2:[4] + - 10.0.0.1:[5] exp_path: $FEDSCALE_HOME/fedscale/core @@ -38,7 +37,6 @@ setup_commands: job_conf: - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp - log_path: $FEDSCALE_HOME/benchmark # Path of log files - - async_buffer: 20 # Number of updates need to be aggregated before generating new model version - data_set: femnist # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/femnist # Path of the dataset - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided @@ -47,14 +45,16 @@ job_conf: - model: resnet56_cifar100 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs - model_zoo: fedscale-zoo - eval_interval: 5 # How many rounds to run a testing on the testing set - - rounds: 3000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds + - rounds: 1000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples - num_loaders: 2 - - local_steps: 20 + - local_steps: 5 - learning_rate: 0.05 - batch_size: 20 - test_bsz: 20 - use_cuda: True - decay_round: 50 - overcommitment: 1.0 - - arrival_interval: 3 \ No newline at end of file + - arrival_interval: 2 + - max_concurrency: 50 + - async_buffer: 20 # Number of updates need to be aggregated before generating new model version diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index 6e27916e..97340fb1 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import collections import os import sys @@ -29,7 +29,7 @@ def __init__(self, args): self.resource_manager = ResourceManager(self.experiment_mode) self.async_buffer_size = args.async_buffer self.client_round_duration = {} - self.client_start_time = {} + self.client_start_time = collections.defaultdict(list) self.round_stamp = [0] self.client_model_version = {} self.virtual_client_clock = {} @@ -39,6 +39,7 @@ def __init__(self, args): self.test_model = None self.aggregate_update = {} self.importance_sum = 0 + self.client_end = [] def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): @@ -69,7 +70,7 @@ def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): sampledClientsReal.append(client_to_run) completed_client_clock[client_to_run] = exe_cost startTimes.append(start_time) - self.client_start_time[client_to_run] = start_time + self.client_start_time[client_to_run].append(start_time) self.client_round_duration[client_to_run] = roundDuration endTimes.append(roundDuration + start_time) @@ -147,7 +148,9 @@ def aggregate_client_weights(self, results): ).to(dtype=d_type) def round_completion_handler(self): + self.global_virtual_clock = self.round_stamp[-1] + self.round += 1 if self.round % self.args.decay_round == 0: @@ -169,7 +172,8 @@ def round_completion_handler(self): # update select participants # NOTE: we simulate async, while have to sync every 20 rounds to avoid large division to trace - if self.resource_manager.get_task_length() < self.async_buffer_size: + if self.resource_manager.get_task_length() < self.async_buffer_size*2: + self.sampled_participants = self.select_participants( select_num_participants=self.async_buffer_size*20, overcommitment=self.args.overcommitment) (clientsToRun, clientsStartTime, virtual_client_clock) = self.tictak_client_tasks( @@ -196,6 +200,7 @@ def round_completion_handler(self): self.stats_util_accumulator = [] self.client_training_results = [] self.loss_accumulator = [] + # self.round_stamp.append(self.global_virtual_clock) if self.round >= self.args.rounds: self.broadcast_aggregator_events(commons.SHUT_DOWN) @@ -232,7 +237,7 @@ def get_test_config(self, client_id): straggler_round = 1e10 for client in current_pending_clients: straggler_round = min( - self.find_latest_model(self.client_start_time[client]), straggler_round) + self.find_latest_model(self.client_start_time[client][0]), straggler_round) return {'client_id': client_id, 'straggler_round': straggler_round, @@ -254,15 +259,14 @@ def create_client_task(self, executorId): if next_clientId != None: config = self.get_client_conf(next_clientId) - start_time = self.client_start_time[next_clientId] + start_time = self.client_start_time[next_clientId][0] model_id = self.find_latest_model(start_time) self.client_model_version[next_clientId] = model_id end_time = self.client_round_duration[next_clientId] + start_time - # The executor has already received the model, thus transfering id is enough + # The executor has already received the model, thus transferring id is enough model = model_id - train_config = {'client_id': next_clientId, 'task_config': config} - + train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} logging.info(f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") return train_config, model @@ -275,6 +279,91 @@ def log_train_result(self, avg_loss): self.log_writer.add_scalar( 'FAR/round_duration (min)', self.round_duration / 60., self.round) + def client_completion_handler(self, results): + """We may need to keep all updates from clients, + if so, we need to append results to the cache + + Args: + results (dictionary): client's training result + + """ + # Format: + # -results = {'clientId':clientId, 'update_weight': model_param, 'moving_loss': round_train_loss, + # 'trained_size': count, 'wall_duration': time_cost, 'success': is_success 'utility': utility} + if self.client_round_duration[results['clientId']] + self.client_start_time[results['clientId']][0] < self.round_stamp[-1]: + # Ignore tasks that are issued earlier but finish late + self.client_start_time[results['clientId']].pop(0) + logging.info(f"Warning: Ignore late-response client {results['clientId']}") + return + + # [ASYNC] New checkin clients ID would overlap with previous unfinished clients + logging.info(f"Client {results['clientId']} completes from {self.client_start_time[results['clientId']][0]} to {self.client_start_time[results['clientId']][0]+self.client_round_duration[results['clientId']]}") + + self.client_end.append( self.client_round_duration[results['clientId']] + self.client_start_time[results['clientId']].pop(0) ) + + if self.args.gradient_policy in ['q-fedavg']: + self.client_training_results.append(results) + # Feed metrics to client sampler + self.stats_util_accumulator.append(results['utility']) + self.loss_accumulator.append(results['moving_loss']) + + self.client_manager.register_feedback(results['clientId'], results['utility'], + auxi=math.sqrt( + results['moving_loss']), + time_stamp=self.round, + duration=self.virtual_client_clock[results['clientId']]['computation'] + + self.virtual_client_clock[results['clientId']]['communication'] + ) + + # ================== Aggregate weights ====================== + self.update_lock.acquire() + + self.model_in_update += 1 + if self.using_group_params == True: + self.aggregate_client_group_weights(results) + else: + self.aggregate_client_weights(results) + + self.update_lock.release() + + def CLIENT_EXECUTE_COMPLETION(self, request, context): + """FL clients complete the execution task. + + Args: + request (CompleteRequest): Complete request info from executor. + + Returns: + ServerResponse: Server response to job completion request + + """ + + executor_id, client_id, event = request.executor_id, request.client_id, request.event + execution_status, execution_msg = request.status, request.msg + meta_result, data_result = request.meta_result, request.data_result + # logging.info(f"$$$$$$$$ ({executor_id}) CLIENT_EXECUTE_COMPLETION client {client_id} with event {event}") + + if event == commons.CLIENT_TRAIN: + # Training results may be uploaded in CLIENT_EXECUTE_RESULT request later, + # so we need to specify whether to ask client to do so (in case of straggler/timeout in real FL). + if execution_status is False: + logging.error(f"Executor {executor_id} fails to run client {client_id}, due to {execution_msg}") + + elif event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): + self.add_event_handler( + executor_id, event, meta_result, data_result) + else: + logging.error(f"Received undefined event {event} from client {client_id}") + + # [ASYNC] Different from sync, only schedule tasks once previous training finish + if self.resource_manager.has_next_task(executor_id): + # NOTE: we do not pop the train immediately in simulation mode, + # since the executor may run multiple clients + if event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): + self.individual_client_events[executor_id].append( + commons.CLIENT_TRAIN) + + return self.CLIENT_PING(request, context) + def event_monitor(self): logging.info("Start monitoring events ...") @@ -287,7 +376,9 @@ def event_monitor(self): self.dispatch_client_events(current_event) elif current_event == commons.START_ROUND: - self.dispatch_client_events(commons.CLIENT_TRAIN) + # [ASYNC] Only dispatch CLIENT_TRAIN on the first round + if self.round == 1: + self.dispatch_client_events(commons.CLIENT_TRAIN) elif current_event == commons.SHUT_DOWN: self.dispatch_client_events(commons.SHUT_DOWN) @@ -300,10 +391,14 @@ def event_monitor(self): if current_event == commons.UPLOAD_MODEL: self.client_completion_handler( self.deserialize_response(data)) + if self.model_in_update == self.async_buffer_size: clientID = self.deserialize_response(data)['clientId'] - self.round_stamp.append( - self.client_round_duration[clientID] + self.client_start_time[clientID]) + logging.info( + f"last client {clientID} at round {self.round} ") + + self.round_stamp.append(max(self.client_end)) + self.client_end = [] self.round_completion_handler() elif current_event == commons.MODEL_TEST: diff --git a/examples/async_fl/async_client.py b/examples/async_fl/async_client.py index cdf5212f..b41fd8f1 100644 --- a/examples/async_fl/async_client.py +++ b/examples/async_fl/async_client.py @@ -54,9 +54,9 @@ def train(self, client_data, model, conf): results['utility'] = math.sqrt( self.loss_squre)*float(trained_unique_samples) - if error_type is None: - logging.info(f"Training of (CLIENT: {clientId}) completes, {results}") - else: + if error_type is not None: + # logging.info(f"Training of (CLIENT: {clientId}) completes, {results}") + # else: logging.info(f"Training of (CLIENT: {clientId}) failed as {error_type}") results['update_weight'] = model_param diff --git a/examples/async_fl/async_executor.py b/examples/async_fl/async_executor.py index 2b0ece57..854119cf 100644 --- a/examples/async_fl/async_executor.py +++ b/examples/async_fl/async_executor.py @@ -32,7 +32,7 @@ def update_model_handler(self, model): def load_global_model(self, round=None): # load last global model - logging.info(f"====Load global model with version {round}") + # logging.info(f"====Load global model with version {round}") round = min(round, self.round) if round is not None else self.round with open(self.temp_model_path_version(round), 'rb') as model_in: model = pickle.load(model_in) @@ -70,7 +70,7 @@ def testing_handler(self, args, config=None): evalStart = time.time() device = self.device - model = config['test_model'] + model = self.load_global_model()# config['test_model'] if self.task == 'rl': client = RLClient(args) test_res = client.test(args, self.this_rank, model, device=device) @@ -129,6 +129,7 @@ def event_monitor(self): if train_model is not None and not self.check_model_version(train_model): # The executor may have not received the model due to async grpc self.event_queue.append(request) + logging.error(f"Warning: Not receive model {train_model} for client {train_config['client_id'] }") time.sleep(1) continue diff --git a/fedscale/core/config_parser.py b/fedscale/core/config_parser.py index 7d8516e8..7a0cd765 100644 --- a/fedscale/core/config_parser.py +++ b/fedscale/core/config_parser.py @@ -101,6 +101,7 @@ parser.add_argument('--malicious_factor', type=int, default=1e15) # for asynchronous FL buffer size +parser.add_argument('--max_concurrency', type=int, default=100) parser.add_argument('--async_buffer', type=int, default=10) parser.add_argument( '--checkin_period', type=int, default=50, help='number of rounds to sample async clients' From b682992879c7e296487575f3a0d4eed4ba852665 Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Tue, 23 Aug 2022 09:44:29 -0400 Subject: [PATCH 05/12] rollbackdata loader+concurrency --- benchmark/configs/async_fl/async_fl.yml | 15 ++--- benchmark/configs/cifar_cpu/cifar_cpu.yml | 4 +- benchmark/configs/femnist/conf.yml | 4 +- examples/async_fl/async_aggregator.py | 69 ++++++++++++++--------- examples/async_fl/async_executor.py | 5 +- fedscale/core/aggregation/aggregator.py | 6 +- fedscale/core/config_parser.py | 1 + fedscale/dataloaders/utils_data.py | 58 ++++++++++++++----- 8 files changed, 104 insertions(+), 58 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 69016aff..3c618b7d 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -2,13 +2,13 @@ # ========== Cluster configuration ========== # ip address of the parameter server (need 1 GPU process) -ps_ip: 10.0.0.1 +ps_ip: localhost # ip address of each worker:# of available gpus process on each gpu in this node # Note that if we collocate ps and worker on same GPU, then we need to decrease this number of available processes on that GPU by 1 # E.g., master node has 4 available processes, then 1 for the ps, and worker should be set to: worker:3 worker_ips: - - 10.0.0.1:[5] + - localhost:[4] exp_path: $FEDSCALE_HOME/fedscale/core @@ -42,8 +42,8 @@ job_conf: - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - - model: resnet56_cifar100 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs - - model_zoo: fedscale-zoo + - model: resnet18 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs +# - model_zoo: fedscale-zoo - eval_interval: 5 # How many rounds to run a testing on the testing set - rounds: 1000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples @@ -52,9 +52,10 @@ job_conf: - learning_rate: 0.05 - batch_size: 20 - test_bsz: 20 + - ps_port: 12345 - use_cuda: True - - decay_round: 50 - overcommitment: 1.0 - - arrival_interval: 2 + - arrival_interval: 10 + - max_staleness: 0 - max_concurrency: 50 - - async_buffer: 20 # Number of updates need to be aggregated before generating new model version + - async_buffer: 50 # Number of updates need to be aggregated before generating new model version diff --git a/benchmark/configs/cifar_cpu/cifar_cpu.yml b/benchmark/configs/cifar_cpu/cifar_cpu.yml index 5d1c4c4a..7c31ce06 100644 --- a/benchmark/configs/cifar_cpu/cifar_cpu.yml +++ b/benchmark/configs/cifar_cpu/cifar_cpu.yml @@ -34,8 +34,8 @@ job_conf: - num_participants: 4 # Number of participants per round, we use K=100 in our paper, large K will be much slower - data_set: cifar10 # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/ # Path of the dataset - - model: resnet56_cifar10 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs - - model_zoo: fedscale-zoo # Default zoo (torchcv) uses the pytorchvision zoo, which can not support small images well + - model: shufflenet_v2_x2_0 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs +# - model_zoo: fedscale-zoo # Default zoo (torchcv) uses the pytorchvision zoo, which can not support small images well - eval_interval: 5 # How many rounds to run a testing on the testing set - rounds: 600 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 0 # Remove clients w/ less than 21 samples diff --git a/benchmark/configs/femnist/conf.yml b/benchmark/configs/femnist/conf.yml index ad974c78..9f374acc 100644 --- a/benchmark/configs/femnist/conf.yml +++ b/benchmark/configs/femnist/conf.yml @@ -38,8 +38,8 @@ job_conf: - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided - device_conf_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_device_capacity # Path of the client trace - device_avail_file: $FEDSCALE_HOME/benchmark/dataset/data/device_info/client_behave_trace - - model: resnet56_cifar10 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs - - model_zoo: fedscale-zoo + - model: resnet18 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs +# - model_zoo: fedscale-zoo - eval_interval: 10 # How many rounds to run a testing on the testing set - rounds: 5000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index 97340fb1..e495ed0d 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -28,10 +28,11 @@ def __init__(self, args): Aggregator.__init__(self, args) self.resource_manager = ResourceManager(self.experiment_mode) self.async_buffer_size = args.async_buffer + self.max_concurrency = args.max_concurrency self.client_round_duration = {} self.client_start_time = collections.defaultdict(list) self.round_stamp = [0] - self.client_model_version = {} + self.client_model_version = collections.defaultdict(list) self.virtual_client_clock = {} self.weight_tensor_type = {} @@ -40,6 +41,8 @@ def __init__(self, args): self.aggregate_update = {} self.importance_sum = 0 self.client_end = [] + self.round_staleness = [] + self.model_concurrency = collections.defaultdict(int) def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): @@ -108,8 +111,10 @@ def aggregate_client_weights(self, results): """ # Start to take the average of updates, and we do not keep updates to save memory # Importance of each update is 1/staleness - client_staleness = self.round - self.client_model_version[results['clientId']] + client_staleness = self.round - self.client_model_version[results['clientId']].pop(0) + importance = 1./(math.sqrt(1 + client_staleness)) + self.round_staleness.append(client_staleness) new_round_aggregation = (self.model_in_update == 1) if new_round_aggregation: @@ -134,25 +139,21 @@ def aggregate_client_weights(self, results): self.aggregate_update[p] = param_weight * importance else: self.aggregate_update[p] += param_weight * importance - - # self.model_weights[p].data += param_weight * importance - # else: - # # Non-floats (e.g., num_batches), no need to aggregate but need to track - # self.aggregate_update[p] = param_weight if self.model_in_update == self.async_buffer_size: for p in self.model_weights: d_type = self.weight_tensor_type[p] self.model_weights[p].data = ( - self.model_weights[p].data + self.aggregate_update[p]/self.importance_sum + self.model_weights[p].data + self.aggregate_update[p] / float(self.importance_sum) # self.model_in_update ).to(dtype=d_type) def round_completion_handler(self): + self.round += 1 + logging.info(f"Round {self.round} average staleness {np.mean(self.round_staleness)}") + self.round_staleness = [] self.global_virtual_clock = self.round_stamp[-1] - self.round += 1 - if self.round % self.args.decay_round == 0: self.args.learning_rate = max( self.args.learning_rate * self.args.decay_factor, self.args.min_learning_rate) @@ -172,10 +173,10 @@ def round_completion_handler(self): # update select participants # NOTE: we simulate async, while have to sync every 20 rounds to avoid large division to trace - if self.resource_manager.get_task_length() < self.async_buffer_size*2: + if self.resource_manager.get_task_length() < self.async_buffer_size: self.sampled_participants = self.select_participants( - select_num_participants=self.async_buffer_size*20, overcommitment=self.args.overcommitment) + select_num_participants=self.async_buffer_size*5, overcommitment=self.args.overcommitment) (clientsToRun, clientsStartTime, virtual_client_clock) = self.tictak_client_tasks( self.sampled_participants, len(self.sampled_participants)) @@ -253,21 +254,30 @@ def get_client_conf(self, clientId): def create_client_task(self, executorId): """Issue a new client training task to the executor""" - next_clientId = self.resource_manager.get_next_task(executorId) train_config = None model = None - - if next_clientId != None: - config = self.get_client_conf(next_clientId) - start_time = self.client_start_time[next_clientId][0] - model_id = self.find_latest_model(start_time) - self.client_model_version[next_clientId] = model_id - end_time = self.client_round_duration[next_clientId] + start_time - - # The executor has already received the model, thus transferring id is enough - model = model_id - train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} - logging.info(f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") + while True: + next_clientId = self.resource_manager.get_next_task(executorId) + if next_clientId != None: + config = self.get_client_conf(next_clientId) + start_time = self.client_start_time[next_clientId][0] + end_time = self.client_round_duration[next_clientId] + start_time + model_id = self.find_latest_model(start_time) + if end_time < self.round_stamp[-1] or self.model_concurrency[model_id] > self.max_concurrency + self.async_buffer_size: + self.client_start_time[next_clientId].pop(0) + continue + + self.client_model_version[next_clientId].append(model_id) + + # The executor has already received the model, thus transferring id is enough + model = model_id + train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} + logging.info( + f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") + self.model_concurrency[model_id] += 1 + break + else: + break return train_config, model @@ -290,11 +300,17 @@ def client_completion_handler(self, results): # Format: # -results = {'clientId':clientId, 'update_weight': model_param, 'moving_loss': round_train_loss, # 'trained_size': count, 'wall_duration': time_cost, 'success': is_success 'utility': utility} + + # [Async] some clients are scheduled earlier, which should be aggregated in previous round but receive the result late if self.client_round_duration[results['clientId']] + self.client_start_time[results['clientId']][0] < self.round_stamp[-1]: # Ignore tasks that are issued earlier but finish late self.client_start_time[results['clientId']].pop(0) logging.info(f"Warning: Ignore late-response client {results['clientId']}") return + if self.round - self.client_model_version[results['clientId']][0] > self.args.max_staleness: + logging.info(f"Warning: Ignore stale client {results['clientId']} with {self.round - self.client_model_version[results['clientId']][0]}") + self.client_model_version[results['clientId']].pop(0) + return # [ASYNC] New checkin clients ID would overlap with previous unfinished clients logging.info(f"Client {results['clientId']} completes from {self.client_start_time[results['clientId']][0]} to {self.client_start_time[results['clientId']][0]+self.client_round_duration[results['clientId']]}") @@ -340,7 +356,6 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): executor_id, client_id, event = request.executor_id, request.client_id, request.event execution_status, execution_msg = request.status, request.msg meta_result, data_result = request.meta_result, request.data_result - # logging.info(f"$$$$$$$$ ({executor_id}) CLIENT_EXECUTE_COMPLETION client {client_id} with event {event}") if event == commons.CLIENT_TRAIN: # Training results may be uploaded in CLIENT_EXECUTE_RESULT request later, @@ -396,7 +411,7 @@ def event_monitor(self): clientID = self.deserialize_response(data)['clientId'] logging.info( f"last client {clientID} at round {self.round} ") - + # [ASYNC] handle different completion order self.round_stamp.append(max(self.client_end)) self.client_end = [] self.round_completion_handler() diff --git a/examples/async_fl/async_executor.py b/examples/async_fl/async_executor.py index 854119cf..c472ef72 100644 --- a/examples/async_fl/async_executor.py +++ b/examples/async_fl/async_executor.py @@ -70,7 +70,7 @@ def testing_handler(self, args, config=None): evalStart = time.time() device = self.device - model = self.load_global_model()# config['test_model'] + model = self.load_global_model() # config['test_model'] if self.task == 'rl': client = RLClient(args) test_res = client.test(args, self.this_rank, model, device=device) @@ -128,6 +128,7 @@ def event_monitor(self): train_model = self.deserialize_response(request.data) if train_model is not None and not self.check_model_version(train_model): # The executor may have not received the model due to async grpc + # TODO: server will lose track of scheduled but not executed task and remove the model self.event_queue.append(request) logging.error(f"Warning: Not receive model {train_model} for client {train_config['client_id'] }") time.sleep(1) @@ -147,7 +148,7 @@ def event_monitor(self): elif current_event == commons.MODEL_TEST: test_configs = self.deserialize_response(request.meta) - self.remove_stale_models(test_configs['straggler_round']) + # self.remove_stale_models(test_configs['straggler_round']) self.Test(test_configs) elif current_event == commons.UPDATE_MODEL: diff --git a/fedscale/core/aggregation/aggregator.py b/fedscale/core/aggregation/aggregator.py index 3fc39fdc..26b07d15 100755 --- a/fedscale/core/aggregation/aggregator.py +++ b/fedscale/core/aggregation/aggregator.py @@ -384,8 +384,7 @@ def client_completion_handler(self, results): results['moving_loss']), time_stamp=self.round, duration=self.virtual_client_clock[results['clientId']]['computation'] + - self.virtual_client_clock[results['clientId'] - ]['communication'] + self.virtual_client_clock[results['clientId']]['communication'] ) # ================== Aggregate weights ====================== @@ -850,7 +849,8 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): executor_id, event, meta_result, data_result) else: logging.error(f"Received undefined event {event} from client {client_id}") - + + # TODO: whether we should schedule tasks when client_ping or client_complete if self.resource_manager.has_next_task(executor_id): # NOTE: we do not pop the train immediately in simulation mode, # since the executor may run multiple clients diff --git a/fedscale/core/config_parser.py b/fedscale/core/config_parser.py index 7a0cd765..2c69b3d6 100644 --- a/fedscale/core/config_parser.py +++ b/fedscale/core/config_parser.py @@ -103,6 +103,7 @@ # for asynchronous FL buffer size parser.add_argument('--max_concurrency', type=int, default=100) parser.add_argument('--async_buffer', type=int, default=10) +parser.add_argument('--max_staleness', type=int, default=5) parser.add_argument( '--checkin_period', type=int, default=50, help='number of rounds to sample async clients' ) diff --git a/fedscale/dataloaders/utils_data.py b/fedscale/dataloaders/utils_data.py index 6de3978a..7f7db93f 100755 --- a/fedscale/dataloaders/utils_data.py +++ b/fedscale/dataloaders/utils_data.py @@ -1,48 +1,55 @@ # -*- coding: utf-8 -*- import sys -import logging from torchvision import transforms - def get_data_transform(data: str): if data == 'mnist': train_transform = transforms.Compose([ - transforms.RandomCrop(32, padding=4), + # transforms.Grayscale(num_output_channels=1), + transforms.Resize((28, 28)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) + transforms.Normalize((0.1307,), (0.3081,)) ]) test_transform = transforms.Compose([ - transforms.RandomCrop(32, padding=4), + # transforms.Grayscale(num_output_channels=1), + transforms.Resize((28, 28)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) + transforms.Normalize((0.1307,), (0.3081,)) ]) elif data == 'cifar': train_transform = transforms.Compose([ + # input arguments: length&width of a figure transforms.RandomCrop(32, padding=4), + # transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), # convert PIL image or numpy.ndarray to tensor - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) test_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif data == 'imagenet': normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ + # transforms.RandomCrop(32, padding=4), # input arguments: length&width of a figure transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - normalize, + normalize, # convert PIL image or numpy.ndarray to tensor + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) test_transform = transforms.Compose([ @@ -50,33 +57,48 @@ def get_data_transform(data: str): transforms.RandomResizedCrop(224), transforms.ToTensor(), normalize, + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif data == 'emnist': train_transform = transforms.Compose([ - transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.RandomGrayscale(), transforms.ToTensor(), - transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) + # transforms.Resize(224), # input arguments: length&width of a figure + # transforms.RandomResizedCrop(224), + # transforms.RandomHorizontalFlip(), + # transforms.ToTensor(), # convert PIL image or numpy.ndarray to tensor + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) test_transform = transforms.Compose([ - transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.RandomGrayscale(), transforms.ToTensor(), - transforms.Normalize((0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)) + # transforms.Resize(224), + # transforms.ToTensor(), + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif data == 'openImg': train_transform = transforms.Compose([ + # transforms.RandomResizedCrop(224), transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) test_transform = transforms.Compose([ transforms.Resize((256, 256)), + # transforms.RandomResizedCrop((128,128)), + # transforms.CenterCrop(224), transforms.ToTensor(), + #transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) @@ -87,8 +109,13 @@ def get_data_transform(data: str): transforms.Resize((299, 299)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), + # transforms.Resize(224), # input arguments: length&width of a figure + # transforms.RandomResizedCrop(224), + # transforms.RandomHorizontalFlip(), + # transforms.ToTensor(), # convert PIL image or numpy.ndarray to tensor transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) test_transform = transforms.Compose([ @@ -99,6 +126,7 @@ def get_data_transform(data: str): # transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) elif data == 'inaturalist': train_transform = transforms.Compose([ @@ -124,7 +152,7 @@ def get_data_transform(data: str): transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]) else: - logging.info(f"Does not find data transform for {data}") + print('Data must be {} or {} !'.format('mnist', 'cifar')) sys.exit(-1) - return train_transform, test_transform + return train_transform, test_transform \ No newline at end of file From 4dae934ce6496d6e83c1f4cf230941628e69e880 Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Wed, 24 Aug 2022 11:55:07 -0400 Subject: [PATCH 06/12] tune concurrency --- benchmark/configs/async_fl/async_fl.yml | 2 +- examples/async_fl/async_aggregator.py | 30 +++++++++++++++++++------ examples/async_fl/async_executor.py | 5 +++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 3c618b7d..49b347e2 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -56,6 +56,6 @@ job_conf: - use_cuda: True - overcommitment: 1.0 - arrival_interval: 10 - - max_staleness: 0 + - max_staleness: 3 - max_concurrency: 50 - async_buffer: 50 # Number of updates need to be aggregated before generating new model version diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index e495ed0d..b8f842b5 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -42,7 +42,7 @@ def __init__(self, args): self.importance_sum = 0 self.client_end = [] self.round_staleness = [] - self.model_concurrency = collections.defaultdict(int) + # self.model_concurrency = collections.defaultdict(int) def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): @@ -57,6 +57,10 @@ def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): start_time = self.global_virtual_clock constant_checkin_period = self.args.arrival_interval # 1. remove dummy clients that are not available to the end of training + concurreny_count = 0 + + end_list = [] + end_j = 0 for client_to_run in sampled_clients: client_cfg = self.client_conf.get(client_to_run, self.args) exe_cost = self.client_manager.getCompletionTime(client_to_run, @@ -69,13 +73,23 @@ def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): exe_cost['communication'] # if the client is not active by the time of collection, we consider it is lost in this round start_time += constant_checkin_period - if self.client_manager.isClientActive(client_to_run, roundDuration + start_time): + end_time = roundDuration + start_time + end_list.append( end_time ) + while start_time > end_list[end_j] : + concurreny_count -= 1 + end_j += 1 + if concurreny_count > self.max_concurrency: + end_list.pop() + continue + + if self.client_manager.isClientActive(client_to_run, end_time): + concurreny_count += 1 sampledClientsReal.append(client_to_run) completed_client_clock[client_to_run] = exe_cost startTimes.append(start_time) self.client_start_time[client_to_run].append(start_time) self.client_round_duration[client_to_run] = roundDuration - endTimes.append(roundDuration + start_time) + endTimes.append(end_time) num_clients_to_collect = min( num_clients_to_collect, len(sampledClientsReal)) @@ -173,10 +187,10 @@ def round_completion_handler(self): # update select participants # NOTE: we simulate async, while have to sync every 20 rounds to avoid large division to trace - if self.resource_manager.get_task_length() < self.async_buffer_size: + if self.resource_manager.get_task_length() < self.async_buffer_size*2: self.sampled_participants = self.select_participants( - select_num_participants=self.async_buffer_size*5, overcommitment=self.args.overcommitment) + select_num_participants=self.async_buffer_size*10, overcommitment=self.args.overcommitment) (clientsToRun, clientsStartTime, virtual_client_clock) = self.tictak_client_tasks( self.sampled_participants, len(self.sampled_participants)) @@ -263,7 +277,7 @@ def create_client_task(self, executorId): start_time = self.client_start_time[next_clientId][0] end_time = self.client_round_duration[next_clientId] + start_time model_id = self.find_latest_model(start_time) - if end_time < self.round_stamp[-1] or self.model_concurrency[model_id] > self.max_concurrency + self.async_buffer_size: + if end_time < self.round_stamp[-1]: # or self.model_concurrency[model_id] > self.max_concurrency + self.async_buffer_size: self.client_start_time[next_clientId].pop(0) continue @@ -274,7 +288,7 @@ def create_client_task(self, executorId): train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} logging.info( f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") - self.model_concurrency[model_id] += 1 + #self.model_concurrency[model_id] += 1 break else: break @@ -305,11 +319,13 @@ def client_completion_handler(self, results): if self.client_round_duration[results['clientId']] + self.client_start_time[results['clientId']][0] < self.round_stamp[-1]: # Ignore tasks that are issued earlier but finish late self.client_start_time[results['clientId']].pop(0) + self.client_model_version[results['clientId']].pop(0) logging.info(f"Warning: Ignore late-response client {results['clientId']}") return if self.round - self.client_model_version[results['clientId']][0] > self.args.max_staleness: logging.info(f"Warning: Ignore stale client {results['clientId']} with {self.round - self.client_model_version[results['clientId']][0]}") self.client_model_version[results['clientId']].pop(0) + self.client_start_time[results['clientId']].pop(0) return # [ASYNC] New checkin clients ID would overlap with previous unfinished clients diff --git a/examples/async_fl/async_executor.py b/examples/async_fl/async_executor.py index c472ef72..07d12b7d 100644 --- a/examples/async_fl/async_executor.py +++ b/examples/async_fl/async_executor.py @@ -129,8 +129,9 @@ def event_monitor(self): if train_model is not None and not self.check_model_version(train_model): # The executor may have not received the model due to async grpc # TODO: server will lose track of scheduled but not executed task and remove the model - self.event_queue.append(request) logging.error(f"Warning: Not receive model {train_model} for client {train_config['client_id'] }") + if self.round - train_model <= self.args.max_staleness: + self.event_queue.append(request) time.sleep(1) continue @@ -148,7 +149,7 @@ def event_monitor(self): elif current_event == commons.MODEL_TEST: test_configs = self.deserialize_response(request.meta) - # self.remove_stale_models(test_configs['straggler_round']) + self.remove_stale_models(test_configs['straggler_round']) self.Test(test_configs) elif current_event == commons.UPDATE_MODEL: From e8bfa5cdb2662bd752e0053196cad1a65000cf2a Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Wed, 24 Aug 2022 11:55:55 -0400 Subject: [PATCH 07/12] tweak --- examples/async_fl/async_executor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/async_fl/async_executor.py b/examples/async_fl/async_executor.py index 07d12b7d..40578890 100644 --- a/examples/async_fl/async_executor.py +++ b/examples/async_fl/async_executor.py @@ -70,7 +70,7 @@ def testing_handler(self, args, config=None): evalStart = time.time() device = self.device - model = self.load_global_model() # config['test_model'] + model = config['test_model'] if self.task == 'rl': client = RLClient(args) test_res = client.test(args, self.this_rank, model, device=device) From 1086e0edfac5cea8e754cb85a85860e64640e95c Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Sat, 27 Aug 2022 22:22:12 -0400 Subject: [PATCH 08/12] aggregate late-response result --- benchmark/configs/async_fl/async_fl.yml | 6 +++--- benchmark/configs/femnist/conf.yml | 7 +++---- examples/async_fl/async_aggregator.py | 7 ------- 3 files changed, 6 insertions(+), 14 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 49b347e2..7d6eda50 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -35,7 +35,7 @@ setup_commands: # We appreciate you to contribute and/or report bugs. Thank you! job_conf: - - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp + - job_name: async_femnist # Generate logs under this folder: log_path/job_name/time_stamp - log_path: $FEDSCALE_HOME/benchmark # Path of log files - data_set: femnist # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/femnist # Path of the dataset @@ -55,7 +55,7 @@ job_conf: - ps_port: 12345 - use_cuda: True - overcommitment: 1.0 - - arrival_interval: 10 + - arrival_interval: 1 - max_staleness: 3 - - max_concurrency: 50 + - max_concurrency: 100 - async_buffer: 50 # Number of updates need to be aggregated before generating new model version diff --git a/benchmark/configs/femnist/conf.yml b/benchmark/configs/femnist/conf.yml index 9f374acc..fb3cf746 100644 --- a/benchmark/configs/femnist/conf.yml +++ b/benchmark/configs/femnist/conf.yml @@ -9,7 +9,6 @@ ps_ip: 10.0.0.1 # E.g., master node has 4 available processes, then 1 for the ps, and worker should be set to: worker:3 worker_ips: - 10.0.0.1:[4] - - 10.0.0.2:[4] exp_path: $FEDSCALE_HOME/fedscale/core @@ -32,7 +31,7 @@ setup_commands: job_conf: - job_name: femnist # Generate logs under this folder: log_path/job_name/time_stamp - log_path: $FEDSCALE_HOME/benchmark # Path of log files - - num_participants: 20 # Number of participants per round, we use K=100 in our paper, large K will be much slower + - num_participants: 50 # Number of participants per round, we use K=100 in our paper, large K will be much slower - data_set: femnist # Dataset: openImg, google_speech, stackoverflow - data_dir: $FEDSCALE_HOME/benchmark/dataset/data/femnist # Path of the dataset - data_map_file: $FEDSCALE_HOME/benchmark/dataset/data/femnist/client_data_mapping/train.csv # Allocation of data to each client, turn to iid setting if not provided @@ -41,10 +40,10 @@ job_conf: - model: resnet18 # NOTE: Please refer to our model zoo README and use models for these small image (e.g., 32x32x3) inputs # - model_zoo: fedscale-zoo - eval_interval: 10 # How many rounds to run a testing on the testing set - - rounds: 5000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds + - rounds: 1000 # Number of rounds to run this training. We use 1000 in our paper, while it may converge w/ ~400 rounds - filter_less: 21 # Remove clients w/ less than 21 samples - num_loaders: 2 - - local_steps: 20 + - local_steps: 5 - learning_rate: 0.05 - batch_size: 20 - test_bsz: 20 diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index b8f842b5..3a292850 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -315,13 +315,6 @@ def client_completion_handler(self, results): # -results = {'clientId':clientId, 'update_weight': model_param, 'moving_loss': round_train_loss, # 'trained_size': count, 'wall_duration': time_cost, 'success': is_success 'utility': utility} - # [Async] some clients are scheduled earlier, which should be aggregated in previous round but receive the result late - if self.client_round_duration[results['clientId']] + self.client_start_time[results['clientId']][0] < self.round_stamp[-1]: - # Ignore tasks that are issued earlier but finish late - self.client_start_time[results['clientId']].pop(0) - self.client_model_version[results['clientId']].pop(0) - logging.info(f"Warning: Ignore late-response client {results['clientId']}") - return if self.round - self.client_model_version[results['clientId']][0] > self.args.max_staleness: logging.info(f"Warning: Ignore stale client {results['clientId']} with {self.round - self.client_model_version[results['clientId']][0]}") self.client_model_version[results['clientId']].pop(0) From 61df1af7917a00d1e318d0cad377658aa997a0bb Mon Sep 17 00:00:00 2001 From: fanlai Date: Thu, 1 Sep 2022 23:24:48 -0400 Subject: [PATCH 09/12] [Async] Correcting Batch Execution --- examples/async_fl/async_aggregator.py | 104 +++++++++++++----------- examples/async_fl/async_client.py | 15 ++-- examples/async_fl/async_executor.py | 2 +- fedscale/core/aggregation/aggregator.py | 3 +- fedscale/core/execution/client.py | 10 +-- 5 files changed, 71 insertions(+), 63 deletions(-) diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index 3a292850..9082b10c 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -42,6 +42,7 @@ def __init__(self, args): self.importance_sum = 0 self.client_end = [] self.round_staleness = [] + self.round_tasks_issued = 0 # self.model_concurrency = collections.defaultdict(int) def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): @@ -64,18 +65,16 @@ def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): for client_to_run in sampled_clients: client_cfg = self.client_conf.get(client_to_run, self.args) exe_cost = self.client_manager.getCompletionTime(client_to_run, - batch_size=client_cfg.batch_size, - upload_step=client_cfg.local_steps, - upload_size=self.model_update_size, - download_size=self.model_update_size) + batch_size=client_cfg.batch_size, upload_step=client_cfg.local_steps, + upload_size=self.model_update_size, download_size=self.model_update_size) roundDuration = exe_cost['computation'] + \ exe_cost['communication'] # if the client is not active by the time of collection, we consider it is lost in this round start_time += constant_checkin_period end_time = roundDuration + start_time - end_list.append( end_time ) - while start_time > end_list[end_j] : + end_list.append(end_time) + while start_time > end_list[end_j]: concurreny_count -= 1 end_j += 1 if concurreny_count > self.max_concurrency: @@ -144,11 +143,6 @@ def aggregate_client_weights(self, results): param_weight = torch.from_numpy( param_weight).to(device=self.device) - # if self.model_weights[p].data.dtype in ( - # torch.float, torch.double, torch.half, - # torch.bfloat16, torch.chalf, torch.cfloat, torch.cdouble - # ): - # Only assign importance to floats (trainable variables) if new_round_aggregation: self.aggregate_update[p] = param_weight * importance else: @@ -186,8 +180,8 @@ def round_completion_handler(self): self.log_train_result(avg_loss) # update select participants - # NOTE: we simulate async, while have to sync every 20 rounds to avoid large division to trace - if self.resource_manager.get_task_length() < self.async_buffer_size*2: + # NOTE: we simulate async, while have to sync every 10 rounds to avoid large division to trace + if self.resource_manager.get_task_length() < self.async_buffer_size * 2: self.sampled_participants = self.select_participants( select_num_participants=self.async_buffer_size*10, overcommitment=self.args.overcommitment) @@ -195,7 +189,7 @@ def round_completion_handler(self): self.sampled_participants, len(self.sampled_participants)) logging.info(f"{len(clientsToRun)} clients with constant arrival following the order: {clientsToRun}") - logging.info(f"====Register {len(clientsToRun)} to queue") + # Issue requests to the resource manager; Tasks ordered by the completion time self.resource_manager.register_tasks(clientsToRun) self.virtual_client_clock.update(virtual_client_clock) @@ -270,28 +264,38 @@ def create_client_task(self, executorId): train_config = None model = None - while True: - next_clientId = self.resource_manager.get_next_task(executorId) - if next_clientId != None: - config = self.get_client_conf(next_clientId) - start_time = self.client_start_time[next_clientId][0] - end_time = self.client_round_duration[next_clientId] + start_time - model_id = self.find_latest_model(start_time) - if end_time < self.round_stamp[-1]: # or self.model_concurrency[model_id] > self.max_concurrency + self.async_buffer_size: - self.client_start_time[next_clientId].pop(0) - continue - self.client_model_version[next_clientId].append(model_id) - - # The executor has already received the model, thus transferring id is enough - model = model_id - train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} - logging.info( - f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") - #self.model_concurrency[model_id] += 1 - break + # NOTE: in batch execution simulation (i.e., multiple executors), we need to stall task scheduling + # to ensure clients in current async_buffer_size completes ahead of other tasks + with self.update_lock: + logging.info(f"====self.round_tasks_issued is {self.round_tasks_issued}, {self.async_buffer_size}") + if self.round_tasks_issued < self.async_buffer_size: + while True: + next_clientId = self.resource_manager.get_next_task(executorId) + if next_clientId != None: + config = self.get_client_conf(next_clientId) + start_time = self.client_start_time[next_clientId][0] + end_time = self.client_round_duration[next_clientId] + start_time + model_id = self.find_latest_model(start_time) + if end_time < self.round_stamp[-1]: # or self.model_concurrency[model_id] > self.max_concurrency + self.async_buffer_size: + self.client_start_time[next_clientId].pop(0) + continue + + self.client_model_version[next_clientId].append(model_id) + + # The executor has already received the model, thus sending id is enough + model = model_id + train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} + logging.info( + f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") + #self.model_concurrency[model_id] += 1 + self.round_tasks_issued += 1 + break + else: + break else: - break + # We should insert the train request back, since we pop it earlier + self.individual_client_events[executorId].append(commons.CLIENT_TRAIN) return train_config, model @@ -319,12 +323,15 @@ def client_completion_handler(self, results): logging.info(f"Warning: Ignore stale client {results['clientId']} with {self.round - self.client_model_version[results['clientId']][0]}") self.client_model_version[results['clientId']].pop(0) self.client_start_time[results['clientId']].pop(0) + with self.update_lock: self.round_tasks_issued -= 1 return # [ASYNC] New checkin clients ID would overlap with previous unfinished clients - logging.info(f"Client {results['clientId']} completes from {self.client_start_time[results['clientId']][0]} to {self.client_start_time[results['clientId']][0]+self.client_round_duration[results['clientId']]}") + logging.info( + f"Client {results['clientId']} completes from {self.client_start_time[results['clientId']][0]} " + + f"to {self.client_start_time[results['clientId']][0]+self.client_round_duration[results['clientId']]}") - self.client_end.append( self.client_round_duration[results['clientId']] + self.client_start_time[results['clientId']].pop(0) ) + self.client_end.append(self.client_round_duration[results['clientId']] + self.client_start_time[results['clientId']].pop(0)) if self.args.gradient_policy in ['q-fedavg']: self.client_training_results.append(results) @@ -341,15 +348,12 @@ def client_completion_handler(self, results): ) # ================== Aggregate weights ====================== - self.update_lock.acquire() - - self.model_in_update += 1 - if self.using_group_params == True: - self.aggregate_client_group_weights(results) - else: - self.aggregate_client_weights(results) - - self.update_lock.release() + with self.update_lock: + self.model_in_update += 1 + if self.using_group_params == True: + self.aggregate_client_group_weights(results) + else: + self.aggregate_client_weights(results) def CLIENT_EXECUTE_COMPLETION(self, request, context): """FL clients complete the execution task. @@ -378,11 +382,12 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): else: logging.error(f"Received undefined event {event} from client {client_id}") - # [ASYNC] Different from sync, only schedule tasks once previous training finish + # # [ASYNC] Different from sync that only schedule tasks once previous training finish if self.resource_manager.has_next_task(executor_id): # NOTE: we do not pop the train immediately in simulation mode, # since the executor may run multiple clients - if event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): + if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: + # if event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): self.individual_client_events[executor_id].append( commons.CLIENT_TRAIN) @@ -400,7 +405,7 @@ def event_monitor(self): self.dispatch_client_events(current_event) elif current_event == commons.START_ROUND: - # [ASYNC] Only dispatch CLIENT_TRAIN on the first round + # [ASYNC] Only dispatch CLIENT_TRAIN in the first round if self.round == 1: self.dispatch_client_events(commons.CLIENT_TRAIN) @@ -419,11 +424,12 @@ def event_monitor(self): if self.model_in_update == self.async_buffer_size: clientID = self.deserialize_response(data)['clientId'] logging.info( - f"last client {clientID} at round {self.round} ") + f"last client {clientID} in round {self.round} ") # [ASYNC] handle different completion order self.round_stamp.append(max(self.client_end)) self.client_end = [] self.round_completion_handler() + with self.update_lock: self.round_tasks_issued = 0 elif current_event == commons.MODEL_TEST: self.testing_completion_handler( diff --git a/examples/async_fl/async_client.py b/examples/async_fl/async_client.py index b41fd8f1..67787dd9 100644 --- a/examples/async_fl/async_client.py +++ b/examples/async_fl/async_client.py @@ -36,7 +36,8 @@ def train(self, client_data, model, conf): criterion = self.get_criterion(conf) error_type = None - # TODO: One may hope to run fixed number of epochs, instead of iterations + # NOTE: One may hope to run fixed number of epochs, instead of iterations + # then replace the following with "while self.completed_steps < conf.local_steps * len(client_data)" while self.completed_steps < conf.local_steps: try: self.train_step(client_data, conf, model, optimizer, criterion) @@ -50,15 +51,15 @@ def train(self, client_data, model, conf): for p in state_dicts} results = {'clientId': clientId, 'moving_loss': self.epoch_train_loss, 'trained_size': self.completed_steps*conf.batch_size, - 'success': self.completed_steps == conf.batch_size} - results['utility'] = math.sqrt( - self.loss_squre)*float(trained_unique_samples) + 'success': self.completed_steps == conf.local_steps} - if error_type is not None: - # logging.info(f"Training of (CLIENT: {clientId}) completes, {results}") - # else: + if error_type is None: + logging.info(f"Training of (CLIENT: {clientId}) completes, {results}") + else: logging.info(f"Training of (CLIENT: {clientId}) failed as {error_type}") + results['utility'] = math.sqrt( + self.loss_squre)*float(trained_unique_samples) results['update_weight'] = model_param results['wall_duration'] = 0 diff --git a/examples/async_fl/async_executor.py b/examples/async_fl/async_executor.py index 40578890..991881af 100644 --- a/examples/async_fl/async_executor.py +++ b/examples/async_fl/async_executor.py @@ -123,6 +123,7 @@ def event_monitor(self): request = self.event_queue.popleft() current_event = request.event + logging.info(f"====Poping event {current_event}") if current_event == commons.CLIENT_TRAIN: train_config = self.deserialize_response(request.meta) train_model = self.deserialize_response(request.data) @@ -155,7 +156,6 @@ def event_monitor(self): elif current_event == commons.UPDATE_MODEL: broadcast_config = self.deserialize_response(request.data) self.UpdateModel(broadcast_config) - time.sleep(5) elif current_event == commons.SHUT_DOWN: self.Stop() diff --git a/fedscale/core/aggregation/aggregator.py b/fedscale/core/aggregation/aggregator.py index 26b07d15..631af74c 100755 --- a/fedscale/core/aggregation/aggregator.py +++ b/fedscale/core/aggregation/aggregator.py @@ -796,11 +796,12 @@ def CLIENT_PING(self, request, context): current_event = commons.DUMMY_EVENT response_data = response_msg = commons.DUMMY_RESPONSE else: + logging.info(f"====event queue {executor_id}, {self.individual_client_events[executor_id]}") current_event = self.individual_client_events[executor_id].popleft( ) if current_event == commons.CLIENT_TRAIN: response_msg, response_data = self.create_client_task( - client_id) + executor_id) if response_msg is None: current_event = commons.DUMMY_EVENT if self.experiment_mode != commons.SIMULATION_MODE: diff --git a/fedscale/core/execution/client.py b/fedscale/core/execution/client.py index f965e2dc..43f6ac13 100644 --- a/fedscale/core/execution/client.py +++ b/fedscale/core/execution/client.py @@ -47,9 +47,9 @@ def train(self, client_data, model, conf): criterion = self.get_criterion(conf) error_type = None - # NOTE: If one may hope to run fixed number of epochs, instead of iterations, use `while self.completed_steps < conf.local_steps * len(client_data)` instead + # NOTE: If one may hope to run fixed number of epochs, instead of iterations, + # use `while self.completed_steps < conf.local_steps * len(client_data)` instead while self.completed_steps < conf.local_steps: - try: self.train_step(client_data, conf, model, optimizer, criterion) except Exception as ex: @@ -61,15 +61,15 @@ def train(self, client_data, model, conf): for p in state_dicts} results = {'clientId': clientId, 'moving_loss': self.epoch_train_loss, 'trained_size': self.completed_steps*conf.batch_size, - 'success': self.completed_steps == conf.batch_size} - results['utility'] = math.sqrt( - self.loss_squre)*float(trained_unique_samples) + 'success': self.completed_steps == conf.local_steps} if error_type is None: logging.info(f"Training of (CLIENT: {clientId}) completes, {results}") else: logging.info(f"Training of (CLIENT: {clientId}) failed as {error_type}") + results['utility'] = math.sqrt( + self.loss_squre)*float(trained_unique_samples) results['update_weight'] = model_param results['wall_duration'] = 0 From 7e52e762001e70527bab551e1abda668e0f30fa1 Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Fri, 2 Sep 2022 15:32:42 -0400 Subject: [PATCH 10/12] create client task; make sync simulation --- benchmark/configs/async_fl/async_fl.yml | 6 +-- examples/async_fl/async_aggregator.py | 49 +++++++++---------------- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 7d6eda50..b05fc462 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -8,7 +8,7 @@ ps_ip: localhost # Note that if we collocate ps and worker on same GPU, then we need to decrease this number of available processes on that GPU by 1 # E.g., master node has 4 available processes, then 1 for the ps, and worker should be set to: worker:3 worker_ips: - - localhost:[4] + - localhost:[2,2,2,2] exp_path: $FEDSCALE_HOME/fedscale/core @@ -52,10 +52,10 @@ job_conf: - learning_rate: 0.05 - batch_size: 20 - test_bsz: 20 - - ps_port: 12345 + - ps_port: 12342 - use_cuda: True - overcommitment: 1.0 - arrival_interval: 1 - - max_staleness: 3 + - max_staleness: 2 - max_concurrency: 100 - async_buffer: 50 # Number of updates need to be aggregated before generating new model version diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index 9082b10c..079a990c 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -268,34 +268,23 @@ def create_client_task(self, executorId): # NOTE: in batch execution simulation (i.e., multiple executors), we need to stall task scheduling # to ensure clients in current async_buffer_size completes ahead of other tasks with self.update_lock: - logging.info(f"====self.round_tasks_issued is {self.round_tasks_issued}, {self.async_buffer_size}") + # logging.info(f"====self.round_tasks_issued ({executorId}) is {self.round_tasks_issued}, {self.async_buffer_size}") if self.round_tasks_issued < self.async_buffer_size: - while True: - next_clientId = self.resource_manager.get_next_task(executorId) - if next_clientId != None: - config = self.get_client_conf(next_clientId) - start_time = self.client_start_time[next_clientId][0] - end_time = self.client_round_duration[next_clientId] + start_time - model_id = self.find_latest_model(start_time) - if end_time < self.round_stamp[-1]: # or self.model_concurrency[model_id] > self.max_concurrency + self.async_buffer_size: - self.client_start_time[next_clientId].pop(0) - continue - - self.client_model_version[next_clientId].append(model_id) - - # The executor has already received the model, thus sending id is enough - model = model_id - train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} - logging.info( - f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") - #self.model_concurrency[model_id] += 1 - self.round_tasks_issued += 1 - break - else: - break - else: - # We should insert the train request back, since we pop it earlier - self.individual_client_events[executorId].append(commons.CLIENT_TRAIN) + next_clientId = self.resource_manager.get_next_task(executorId) + config = self.get_client_conf(next_clientId) + start_time = self.client_start_time[next_clientId][0] + end_time = self.client_round_duration[next_clientId] + start_time + model_id = self.find_latest_model(start_time) + + self.client_model_version[next_clientId].append(model_id) + + # The executor has already received the model, thus sending id is enough + model = model_id + train_config = {'client_id': next_clientId, 'task_config': config, 'end_time': end_time} + logging.info( + f"Client {next_clientId} train on model {model_id} during {int(start_time)}-{int(end_time)}") + + self.round_tasks_issued += 1 return train_config, model @@ -382,7 +371,7 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): else: logging.error(f"Received undefined event {event} from client {client_id}") - # # [ASYNC] Different from sync that only schedule tasks once previous training finish + # [ASYNC] Different from sync that only schedule tasks once previous training finish if self.resource_manager.has_next_task(executor_id): # NOTE: we do not pop the train immediately in simulation mode, # since the executor may run multiple clients @@ -405,9 +394,7 @@ def event_monitor(self): self.dispatch_client_events(current_event) elif current_event == commons.START_ROUND: - # [ASYNC] Only dispatch CLIENT_TRAIN in the first round - if self.round == 1: - self.dispatch_client_events(commons.CLIENT_TRAIN) + self.dispatch_client_events(commons.CLIENT_TRAIN) elif current_event == commons.SHUT_DOWN: self.dispatch_client_events(commons.SHUT_DOWN) From 94b225847e0fffdfae3e93ac30057e405ffa5b7a Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Fri, 2 Sep 2022 16:10:58 -0400 Subject: [PATCH 11/12] create client task; make sync simulation --- benchmark/configs/async_fl/async_fl.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index b05fc462..79be4434 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -55,7 +55,7 @@ job_conf: - ps_port: 12342 - use_cuda: True - overcommitment: 1.0 - - arrival_interval: 1 + - arrival_interval: 5 - max_staleness: 2 - max_concurrency: 100 - async_buffer: 50 # Number of updates need to be aggregated before generating new model version From 0448ede3b9dddf3f144c91e3e3cf7065a74dbf30 Mon Sep 17 00:00:00 2001 From: AmberLJC Date: Sat, 3 Sep 2022 19:39:33 -0400 Subject: [PATCH 12/12] stale result --- benchmark/configs/async_fl/async_fl.yml | 2 +- examples/async_fl/async_aggregator.py | 107 ++++++++++++++++++++---- fedscale/core/aggregation/aggregator.py | 4 +- 3 files changed, 94 insertions(+), 19 deletions(-) diff --git a/benchmark/configs/async_fl/async_fl.yml b/benchmark/configs/async_fl/async_fl.yml index 79be4434..e81d5eb0 100644 --- a/benchmark/configs/async_fl/async_fl.yml +++ b/benchmark/configs/async_fl/async_fl.yml @@ -56,6 +56,6 @@ job_conf: - use_cuda: True - overcommitment: 1.0 - arrival_interval: 5 - - max_staleness: 2 + - max_staleness: 5 - max_concurrency: 100 - async_buffer: 50 # Number of updates need to be aggregated before generating new model version diff --git a/examples/async_fl/async_aggregator.py b/examples/async_fl/async_aggregator.py index 079a990c..c6273bb3 100644 --- a/examples/async_fl/async_aggregator.py +++ b/examples/async_fl/async_aggregator.py @@ -45,6 +45,25 @@ def __init__(self, args): self.round_tasks_issued = 0 # self.model_concurrency = collections.defaultdict(int) + def run(self): + """Start running the aggregator server by setting up execution + and communication environment, and monitoring the grpc message. + """ + self.setup_env() + self.init_control_communication() + self.queue_lock = [threading.Lock() for _ in range(len(self.executors))] + self.init_data_communication() + + self.init_model() + self.save_last_param() + self.model_update_size = sys.getsizeof( + pickle.dumps(self.model)) / 1024.0 * 8. # kbits + self.client_profiles = self.load_client_profile( + file_path=self.args.device_conf_file) + + self.event_monitor() + + def tictak_client_tasks(self, sampled_clients, num_clients_to_collect): if self.experiment_mode == commons.SIMULATION_MODE: @@ -229,13 +248,13 @@ def find_latest_model(self, start_time): def get_test_config(self, client_id): """FL model testing on clients, developers can further define personalized client config here. - + Args: client_id (int): The client id. - + Returns: dictionary: The testing config for new task. - + """ # Get the straggler round-id client_tasks = self.resource_manager.client_run_queue @@ -248,8 +267,8 @@ def get_test_config(self, client_id): straggler_round = min( self.find_latest_model(self.client_start_time[client][0]), straggler_round) - return {'client_id': client_id, - 'straggler_round': straggler_round, + return {'client_id': client_id, + 'straggler_round': straggler_round, 'test_model': self.test_model} def get_client_conf(self, clientId): @@ -265,10 +284,10 @@ def create_client_task(self, executorId): train_config = None model = None - # NOTE: in batch execution simulation (i.e., multiple executors), we need to stall task scheduling + # NOTE: in batch execution simulation (i.e., multiple executors), we need to stall task scheduling # to ensure clients in current async_buffer_size completes ahead of other tasks with self.update_lock: - # logging.info(f"====self.round_tasks_issued ({executorId}) is {self.round_tasks_issued}, {self.async_buffer_size}") + logging.info(f"====self.round_tasks_issued ({executorId}) is {self.round_tasks_issued}, {self.async_buffer_size}") if self.round_tasks_issued < self.async_buffer_size: next_clientId = self.resource_manager.get_next_task(executorId) config = self.get_client_conf(next_clientId) @@ -286,6 +305,7 @@ def create_client_task(self, executorId): self.round_tasks_issued += 1 + return train_config, model def log_train_result(self, avg_loss): @@ -312,8 +332,10 @@ def client_completion_handler(self, results): logging.info(f"Warning: Ignore stale client {results['clientId']} with {self.round - self.client_model_version[results['clientId']][0]}") self.client_model_version[results['clientId']].pop(0) self.client_start_time[results['clientId']].pop(0) - with self.update_lock: self.round_tasks_issued -= 1 - return + with self.update_lock: + self.round_tasks_issued -= 1 + # self.individual_client_events['1'].append( commons.CLIENT_TRAIN) + return -1 # [ASYNC] New checkin clients ID would overlap with previous unfinished clients logging.info( @@ -344,6 +366,8 @@ def client_completion_handler(self, results): else: self.aggregate_client_weights(results) + return 0 + def CLIENT_EXECUTE_COMPLETION(self, request, context): """FL clients complete the execution task. @@ -372,16 +396,64 @@ def CLIENT_EXECUTE_COMPLETION(self, request, context): logging.error(f"Received undefined event {event} from client {client_id}") # [ASYNC] Different from sync that only schedule tasks once previous training finish - if self.resource_manager.has_next_task(executor_id): + if self.resource_manager.has_next_task(executor_id) and self.round_tasks_issued < self.async_buffer_size: # NOTE: we do not pop the train immediately in simulation mode, # since the executor may run multiple clients - if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id]: + if commons.CLIENT_TRAIN not in self.individual_client_events[executor_id] : # if event in (commons.MODEL_TEST, commons.UPLOAD_MODEL): self.individual_client_events[executor_id].append( commons.CLIENT_TRAIN) return self.CLIENT_PING(request, context) + def CLIENT_PING(self, request, context): + """Handle client ping requests + + Args: + request (PingRequest): Ping request info from executor. + + Returns: + ServerResponse: Server response to ping request + + """ + # NOTE: client_id = executor_id in deployment, + # while multiple client_id may use the same executor_id (VMs) in simulations + executor_id, client_id = request.executor_id, request.client_id + response_data = response_msg = commons.DUMMY_RESPONSE + with self.queue_lock[int(executor_id)-1]: + if len(self.individual_client_events[executor_id]) == 0: + # send dummy response + current_event = commons.DUMMY_EVENT + response_data = response_msg = commons.DUMMY_RESPONSE + else: + logging.info(f"====event queue {executor_id}, {self.individual_client_events[executor_id]}") + current_event = self.individual_client_events[executor_id].popleft() + if current_event == commons.CLIENT_TRAIN: + response_msg, response_data = self.create_client_task( + executor_id) + if response_msg is None: + current_event = commons.DUMMY_EVENT + if self.experiment_mode != commons.SIMULATION_MODE: + self.individual_client_events[executor_id].append( + commons.CLIENT_TRAIN) + elif current_event == commons.MODEL_TEST: + response_msg = self.get_test_config(client_id) + elif current_event == commons.UPDATE_MODEL: + response_data = self.get_global_model() + elif current_event == commons.SHUT_DOWN: + response_msg = self.get_shutdown_config(executor_id) + + response_msg, response_data = self.serialize_response( + response_msg), self.serialize_response(response_data) + # NOTE: in simulation mode, response data is pickle for faster (de)serialization + response = job_api_pb2.ServerResponse(event=current_event, + meta=response_msg, data=response_data) + if current_event != commons.DUMMY_EVENT: + logging.info(f"Issue EVENT ({current_event}) to EXECUTOR ({executor_id})") + + return response + + def event_monitor(self): logging.info("Start monitoring events ...") @@ -405,13 +477,16 @@ def event_monitor(self): client_id, current_event, meta, data = self.sever_events_queue.popleft() if current_event == commons.UPLOAD_MODEL: - self.client_completion_handler( + state = self.client_completion_handler( self.deserialize_response(data)) + logging.info( + f"Executor ({client_id}) finish client {self.deserialize_response(data)['clientId']} in round {self.round} [{self.model_in_update}/{ self.async_buffer_size}] ") + if state == -1 : + self.individual_client_events[client_id].append(commons.CLIENT_TRAIN) + + elif self.model_in_update == self.async_buffer_size: + # clientID = self.deserialize_response(data)['clientId'] - if self.model_in_update == self.async_buffer_size: - clientID = self.deserialize_response(data)['clientId'] - logging.info( - f"last client {clientID} in round {self.round} ") # [ASYNC] handle different completion order self.round_stamp.append(max(self.client_end)) self.client_end = [] diff --git a/fedscale/core/aggregation/aggregator.py b/fedscale/core/aggregation/aggregator.py index 631af74c..56414ea4 100755 --- a/fedscale/core/aggregation/aggregator.py +++ b/fedscale/core/aggregation/aggregator.py @@ -797,8 +797,7 @@ def CLIENT_PING(self, request, context): response_data = response_msg = commons.DUMMY_RESPONSE else: logging.info(f"====event queue {executor_id}, {self.individual_client_events[executor_id]}") - current_event = self.individual_client_events[executor_id].popleft( - ) + current_event = self.individual_client_events[executor_id].popleft() if current_event == commons.CLIENT_TRAIN: response_msg, response_data = self.create_client_task( executor_id) @@ -875,6 +874,7 @@ def event_monitor(self): self.dispatch_client_events(current_event) elif current_event == commons.START_ROUND: + self.dispatch_client_events(commons.CLIENT_TRAIN) elif current_event == commons.SHUT_DOWN: