diff --git a/README.md b/README.md index 6ee255d..6efaf40 100644 --- a/README.md +++ b/README.md @@ -81,11 +81,13 @@ For other algorithms and datasets, users can refer to `argzoo/` or customize you **FedBN:** [FedBN: Federated Learning on Non-IID Features via Local Batch Normalization](https://arxiv.org/pdf/2102.07623.pdf) +**FedAMP:** [Personalized Cross-Silo Federated Learning on Non-IID Data](https://ojs.aaai.org/index.php/AAAI/article/view/16960) + **FedRoD:** [On Bridging Generic and Personalized Federated Learning for Image Classification](https://openreview.net/pdf?id=I1hQbx10Kxn) **pFedSD:** [Personalized Edge Intelligence via Federated Self-Knowledge Distillation](https://ieeexplore.ieee.org/abstract/document/9964434) -**FedCAC:** [Bold but Cautious: Unlocking the Potential of Personalized Federated Learning through Cautiously Aggressive Collaboration]() +**FedCAC:** [Bold but Cautious: Unlocking the Potential of Personalized Federated Learning through Cautiously Aggressive Collaboration](https://openaccess.thecvf.com/content/ICCV2023/papers/Wu_Bold_but_Cautious_Unlocking_the_Potential_of_Personalized_Federated_Learning_ICCV_2023_paper.pdf) ## Feedback and Contribution diff --git a/argzoo/cifar10/cifar10_fedamp_resnet_config.py b/argzoo/cifar10/cifar10_fedamp_resnet_config.py new file mode 100644 index 0000000..63fde50 --- /dev/null +++ b/argzoo/cifar10/cifar10_fedamp_resnet_config.py @@ -0,0 +1,51 @@ +from easydict import EasyDict + +lamda = 1 +alphaK = 10000 +sigma = 100 +decay_rate = 0.1 +decay_frequency = 30 +noniid = 'dirichlet' +alpha = 0.1 +seed = 2 + + +exp_args = dict( + data=dict( + dataset='cifar10', + data_path='./data/CIFAR10', + sample_method=dict(name=noniid, alpha=alpha, train_num=500, test_num=100) + ), + learn=dict( + device='cuda:0', + local_eps=5, + global_eps=300, + batch_size=100, + optimizer=dict(name='sgd', lr=0.1, momentum=0.9), + finetune_parameters=dict(name='all'), + test_place=['after_aggregation'], + lamda=lamda, # regularization weight for FedAMP + alphaK=alphaK, # lambda/sqrt(GLOABL-ITRATION) according to the paper + sigma=sigma, # hyperparameter in function A + decay_rate=decay_rate, # decay rate of alphaK in FedAMP + decay_frequency=decay_frequency, # decay frequency of alphaK in FedAMP + ), + model=dict( + name='resnet8', + input_channel=3, + class_number=10, + ), + client=dict(name='fedamp_client', client_num=40), + server=dict(name='base_server'), + group=dict( + name='fedamp_group', + ), + other=dict(test_freq=1, logging_path=f'./logging/cifar10_fedamp_resnet/{noniid}_{alpha}/{lamda}_{alphaK}_{sigma}/{seed}') +) + +exp_args = EasyDict(exp_args) + +if __name__ == '__main__': + from fling.pipeline import personalized_model_pipeline + + personalized_model_pipeline(exp_args, seed=seed) diff --git a/argzoo/default_config.py b/argzoo/default_config.py index 77da2d2..ece1722 100644 --- a/argzoo/default_config.py +++ b/argzoo/default_config.py @@ -37,8 +37,6 @@ name='sgd', # Learning rate of the optimizer. lr=0.02, - # Momentum of the SGD optimizer. - momentum=0.9 ), # Learning rate scheduler. For each global epoch, use a dynamic learning rate. scheduler=dict( diff --git a/fling/component/client/__init__.py b/fling/component/client/__init__.py index a7e0736..125f94d 100644 --- a/fling/component/client/__init__.py +++ b/fling/component/client/__init__.py @@ -5,3 +5,4 @@ from .fedcac_client import FedCACClient from .fedrod_client import FedRoDClient from .fedprox_client import FedProxClient +from .fedamp_client import FedAMPClient diff --git a/fling/component/client/fedamp_client.py b/fling/component/client/fedamp_client.py new file mode 100644 index 0000000..29b4e29 --- /dev/null +++ b/fling/component/client/fedamp_client.py @@ -0,0 +1,54 @@ +import copy +import numpy as np +from typing import Tuple + +import torch +import torch.nn as nn + +from fling.utils.registry_utils import CLIENT_REGISTRY +from .base_client import BaseClient +from fling.utils.utils import weight_flatten + +@CLIENT_REGISTRY.register('fedamp_client') +class FedAMPClient(BaseClient): + """ + Overview: + This class is the base implementation of client in 'Bold but Cautious: Unlocking the Potential of Personalized + Federated Learning through Cautiously Aggressive Collaboration' (FedCAC). + """ + + def __init__(self, args, client_id, train_dataset, test_dataset=None): + """ + Initializing train dataset, test dataset(for personalized settings). + """ + super(FedAMPClient, self).__init__(args, client_id, train_dataset, test_dataset) + self.client_u = copy.deepcopy(self.model) + + def FedAMP_Loss_client(self): + params = weight_flatten(self.model) + params_ = weight_flatten(self.client_u) + sub = params - params_ + result = self.args.learn.lamda / (2 * self.args.learn.alphaK) * torch.dot(sub, sub) + return result + + def train_step(self, batch_data, criterion, monitor, optimizer): + self.client_u.to(self.device) + + batch_x, batch_y = batch_data['x'], batch_data['y'] + o = self.model(batch_x) + loss = criterion(o, batch_y) + loss = loss + self.FedAMP_Loss_client() + y_pred = torch.argmax(o, dim=-1) + + monitor.append( + { + 'train_acc': torch.mean((y_pred == batch_y).float()).item(), + 'train_loss': loss.item() + }, + weight=batch_y.shape[0] + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + self.client_u.to('cpu') \ No newline at end of file diff --git a/fling/component/group/__init__.py b/fling/component/group/__init__.py index 46c1d4b..6a71d11 100644 --- a/fling/component/group/__init__.py +++ b/fling/component/group/__init__.py @@ -1,3 +1,4 @@ from .base_group import ParameterServerGroup from .build_group import get_group from .fedcac_group import FedCACServerGroup +from .fedamp_group import FedAMPServerGroup \ No newline at end of file diff --git a/fling/component/group/fedamp_group.py b/fling/component/group/fedamp_group.py new file mode 100644 index 0000000..e2a5f4a --- /dev/null +++ b/fling/component/group/fedamp_group.py @@ -0,0 +1,105 @@ +import time +import copy +import torch + +from fling.utils.compress_utils import * +from fling.utils.registry_utils import GROUP_REGISTRY +from fling.utils import Logger +from fling.component.group import ParameterServerGroup +from fling.utils.utils import weight_flatten + +@GROUP_REGISTRY.register('fedamp_group') +class FedAMPServerGroup(ParameterServerGroup): + r""" + Overview: + Implementation of the group in FedAMP. + """ + + def __init__(self, args: dict, logger: Logger): + super(FedAMPServerGroup, self).__init__(args, logger) + # FedAMP auguments + self.alphaK = args.learn.alphaK + self.sigma = args.learn.sigma + self.lamda = args.learn.lamda + self.client_ws = [None for i in range(self.args.client.client_num)] # maintain all clients' personalized models + self.client_us = [None for i in range(self.args.client.client_num)] # aggregated model for each client + + def sync(self) -> None: + r""" + Overview: + Send customized global models to each client + """ + if self.client_us[0] is None: + super().sync() # Called during system initialization + + else: + for idx, client in enumerate(self.clients): + client.client_u = copy.deepcopy(self.client_us[idx]) + + def initialize(self) -> None: + super().initialize() + self.client_ws = [copy.deepcopy(self.clients[i].model) for i in range(self.args.client.client_num)] + self.client_us = [copy.deepcopy(self.clients[i].model) for i in range(self.args.client.client_num)] + + + def receive_models(self): + r""" + Overview: + Receive personalized models from each client + """ + for idx, client in enumerate(self.clients): + self.client_ws[idx] = copy.deepcopy(client.model) + + def aggregate(self, train_round: int) -> int: + r""" + Overview: + Aggregate customized global models and send to each client + """ + # recieve models from clients + self.receive_models() + for i in range(self.args.client.client_num): + self.client_ws[i].to(self.args.learn.device) + self.client_us[i].to(self.args.learn.device) + + # aggregate models + weights = [weight_flatten(mw) for mw in self.client_ws] + for i, mu in enumerate(self.client_us): # calculate u for each client + for param in mu.parameters(): # set zero for each parameter + param.data = torch.zeros_like(param.data) + + coef = torch.zeros(self.args.client.client_num) + for j, mw in enumerate(self.client_ws): + if i == j: continue + sub = weights[i] - weights[j] + sub = torch.dot(sub, sub) + coef[j] = self.args.learn.alphaK * self.e(sub) + coef[i] = 1 - torch.sum(coef) + + for j, mw in enumerate(self.client_ws): + for param, param_j in zip(mu.parameters(), mw.parameters()): + param.data += coef[j] * param_j + + for i in range(self.args.client.client_num): + self.client_ws[i].to('cpu') + self.client_us[i].to('cpu') + + # send to all clients + self.sync() + + # perform alphaK decay + if train_round % self.args.learn.decay_frequency == 0 and train_round != 0: + self.args.learn.alphaK *= self.args.learn.decay_rate + + # calculate communication cost + trans_cost = 0 + state_dict = self.clients[0].model.state_dict() + for k in self.clients[0].fed_keys: + trans_cost += self.args.client.client_num * state_dict[k].numel() + # 1B = 32bit + return 4 * trans_cost + def e(self, x): + r""" + Overview: + The derivative of attention-inducing function in FedAMP + """ + return torch.exp(-x/self.sigma)/self.sigma diff --git a/fling/utils/utils.py b/fling/utils/utils.py index 01c7351..1a0417f 100644 --- a/fling/utils/utils.py +++ b/fling/utils/utils.py @@ -1,6 +1,7 @@ import numpy as np import os import time +import torch from typing import Iterable, Dict, List from prettytable import PrettyTable @@ -14,6 +15,13 @@ def client_sampling(client_ids: Iterable, sample_rate: float) -> List: ) return participated_clients +def weight_flatten(model) -> torch.Tensor: + params = [] + for u in model.parameters(): + params.append(u.view(-1)) + params = torch.cat(params) + + return params class Logger(SummaryWriter):