-
Notifications
You must be signed in to change notification settings - Fork 16
Feature(wxh): Add FedAMP algo and fix bugs. #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| from easydict import EasyDict | ||
|
|
||
| lamda = 1 | ||
| alphaK = 10000 | ||
| sigma = 100 | ||
| decay_rate = 0.1 | ||
| decay_frequency = 30 | ||
| noniid = 'dirichlet' | ||
| alpha = 0.1 | ||
| seed = 2 | ||
|
|
||
|
|
||
| exp_args = dict( | ||
| data=dict( | ||
| dataset='cifar10', | ||
| data_path='./data/CIFAR10', | ||
| sample_method=dict(name=noniid, alpha=alpha, train_num=500, test_num=100) | ||
| ), | ||
| learn=dict( | ||
| device='cuda:0', | ||
| local_eps=5, | ||
| global_eps=300, | ||
| batch_size=100, | ||
| optimizer=dict(name='sgd', lr=0.1, momentum=0.9), | ||
| finetune_parameters=dict(name='all'), | ||
| test_place=['after_aggregation'], | ||
| lamda=lamda, # regularization weight for FedAMP | ||
| alphaK=alphaK, # lambda/sqrt(GLOABL-ITRATION) according to the paper | ||
| sigma=sigma, # hyperparameter in function A | ||
| decay_rate=decay_rate, # decay rate of alphaK in FedAMP | ||
| decay_frequency=decay_frequency, # decay frequency of alphaK in FedAMP | ||
| ), | ||
| model=dict( | ||
| name='resnet8', | ||
| input_channel=3, | ||
| class_number=10, | ||
| ), | ||
| client=dict(name='fedamp_client', client_num=40), | ||
| server=dict(name='base_server'), | ||
| group=dict( | ||
| name='fedamp_group', | ||
| ), | ||
| other=dict(test_freq=1, logging_path=f'./logging/cifar10_fedamp_resnet/{noniid}_{alpha}/{lamda}_{alphaK}_{sigma}/{seed}') | ||
| ) | ||
|
|
||
| exp_args = EasyDict(exp_args) | ||
|
|
||
| if __name__ == '__main__': | ||
| from fling.pipeline import personalized_model_pipeline | ||
|
|
||
| personalized_model_pipeline(exp_args, seed=seed) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| import copy | ||
| import numpy as np | ||
| from typing import Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from fling.utils.registry_utils import CLIENT_REGISTRY | ||
| from .base_client import BaseClient | ||
| from fling.utils.utils import weight_flatten | ||
|
|
||
| @CLIENT_REGISTRY.register('fedamp_client') | ||
| class FedAMPClient(BaseClient): | ||
| """ | ||
| Overview: | ||
| This class is the base implementation of client in 'Bold but Cautious: Unlocking the Potential of Personalized | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct this document, not FedCAC. |
||
| Federated Learning through Cautiously Aggressive Collaboration' (FedCAC). | ||
| """ | ||
|
|
||
| def __init__(self, args, client_id, train_dataset, test_dataset=None): | ||
| """ | ||
| Initializing train dataset, test dataset(for personalized settings). | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct this document, the purpose is to get a copy of local model. |
||
| """ | ||
| super(FedAMPClient, self).__init__(args, client_id, train_dataset, test_dataset) | ||
| self.client_u = copy.deepcopy(self.model) | ||
|
|
||
| def FedAMP_Loss_client(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use the |
||
| params = weight_flatten(self.model) | ||
| params_ = weight_flatten(self.client_u) | ||
| sub = params - params_ | ||
| result = self.args.learn.lamda / (2 * self.args.learn.alphaK) * torch.dot(sub, sub) | ||
| return result | ||
|
|
||
| def train_step(self, batch_data, criterion, monitor, optimizer): | ||
| self.client_u.to(self.device) | ||
|
|
||
| batch_x, batch_y = batch_data['x'], batch_data['y'] | ||
| o = self.model(batch_x) | ||
| loss = criterion(o, batch_y) | ||
| loss = loss + self.FedAMP_Loss_client() | ||
| y_pred = torch.argmax(o, dim=-1) | ||
|
|
||
| monitor.append( | ||
| { | ||
| 'train_acc': torch.mean((y_pred == batch_y).float()).item(), | ||
| 'train_loss': loss.item() | ||
| }, | ||
| weight=batch_y.shape[0] | ||
| ) | ||
| optimizer.zero_grad() | ||
| loss.backward() | ||
| optimizer.step() | ||
|
|
||
| self.client_u.to('cpu') | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| from .base_group import ParameterServerGroup | ||
| from .build_group import get_group | ||
| from .fedcac_group import FedCACServerGroup | ||
| from .fedamp_group import FedAMPServerGroup |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,105 @@ | ||
| import time | ||
| import copy | ||
| import torch | ||
|
|
||
| from fling.utils.compress_utils import * | ||
| from fling.utils.registry_utils import GROUP_REGISTRY | ||
| from fling.utils import Logger | ||
| from fling.component.group import ParameterServerGroup | ||
| from fling.utils.utils import weight_flatten | ||
|
|
||
| @GROUP_REGISTRY.register('fedamp_group') | ||
| class FedAMPServerGroup(ParameterServerGroup): | ||
| r""" | ||
| Overview: | ||
| Implementation of the group in FedAMP. | ||
| """ | ||
|
|
||
| def __init__(self, args: dict, logger: Logger): | ||
| super(FedAMPServerGroup, self).__init__(args, logger) | ||
| # FedAMP auguments | ||
| self.alphaK = args.learn.alphaK | ||
| self.sigma = args.learn.sigma | ||
| self.lamda = args.learn.lamda | ||
| self.client_ws = [None for i in range(self.args.client.client_num)] # maintain all clients' personalized models | ||
| self.client_us = [None for i in range(self.args.client.client_num)] # aggregated model for each client | ||
|
|
||
| def sync(self) -> None: | ||
| r""" | ||
| Overview: | ||
| Send customized global models to each client | ||
| """ | ||
| if self.client_us[0] is None: | ||
| super().sync() # Called during system initialization | ||
|
|
||
| else: | ||
| for idx, client in enumerate(self.clients): | ||
| client.client_u = copy.deepcopy(self.client_us[idx]) | ||
|
|
||
| def initialize(self) -> None: | ||
| super().initialize() | ||
| self.client_ws = [copy.deepcopy(self.clients[i].model) for i in range(self.args.client.client_num)] | ||
| self.client_us = [copy.deepcopy(self.clients[i].model) for i in range(self.args.client.client_num)] | ||
|
|
||
|
|
||
| def receive_models(self): | ||
| r""" | ||
| Overview: | ||
| Receive personalized models from each client | ||
| """ | ||
| for idx, client in enumerate(self.clients): | ||
| self.client_ws[idx] = copy.deepcopy(client.model) | ||
|
|
||
| def aggregate(self, train_round: int) -> int: | ||
| r""" | ||
| Overview: | ||
| Aggregate customized global models and send to each client | ||
| """ | ||
| # recieve models from clients | ||
| self.receive_models() | ||
| for i in range(self.args.client.client_num): | ||
| self.client_ws[i].to(self.args.learn.device) | ||
| self.client_us[i].to(self.args.learn.device) | ||
|
|
||
| # aggregate models | ||
| weights = [weight_flatten(mw) for mw in self.client_ws] | ||
| for i, mu in enumerate(self.client_us): # calculate u for each client | ||
| for param in mu.parameters(): # set zero for each parameter | ||
| param.data = torch.zeros_like(param.data) | ||
|
|
||
| coef = torch.zeros(self.args.client.client_num) | ||
| for j, mw in enumerate(self.client_ws): | ||
| if i == j: continue | ||
| sub = weights[i] - weights[j] | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rewrite it using fling.utils.get_model_difference |
||
| sub = torch.dot(sub, sub) | ||
| coef[j] = self.args.learn.alphaK * self.e(sub) | ||
| coef[i] = 1 - torch.sum(coef) | ||
|
|
||
| for j, mw in enumerate(self.client_ws): | ||
| for param, param_j in zip(mu.parameters(), mw.parameters()): | ||
| param.data += coef[j] * param_j | ||
|
|
||
| for i in range(self.args.client.client_num): | ||
| self.client_ws[i].to('cpu') | ||
| self.client_us[i].to('cpu') | ||
|
|
||
| # send to all clients | ||
| self.sync() | ||
|
|
||
| # perform alphaK decay | ||
| if train_round % self.args.learn.decay_frequency == 0 and train_round != 0: | ||
| self.args.learn.alphaK *= self.args.learn.decay_rate | ||
|
|
||
| # calculate communication cost | ||
| trans_cost = 0 | ||
| state_dict = self.clients[0].model.state_dict() | ||
|
kxzxvbk marked this conversation as resolved.
|
||
| for k in self.clients[0].fed_keys: | ||
| trans_cost += self.args.client.client_num * state_dict[k].numel() | ||
| # 1B = 32bit | ||
| return 4 * trans_cost | ||
| def e(self, x): | ||
| r""" | ||
| Overview: | ||
| The derivative of attention-inducing function in FedAMP | ||
| """ | ||
| return torch.exp(-x/self.sigma)/self.sigma | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| import numpy as np | ||
| import os | ||
| import time | ||
| import torch | ||
| from typing import Iterable, Dict, List | ||
| from prettytable import PrettyTable | ||
|
|
||
|
|
@@ -14,6 +15,13 @@ def client_sampling(client_ids: Iterable, sample_rate: float) -> List: | |
| ) | ||
| return participated_clients | ||
|
|
||
| def weight_flatten(model) -> torch.Tensor: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose that this function can be removed. |
||
| params = [] | ||
| for u in model.parameters(): | ||
| params.append(u.view(-1)) | ||
| params = torch.cat(params) | ||
|
|
||
| return params | ||
|
|
||
| class Logger(SummaryWriter): | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering whether this client is identical to
FedProxClient? What's the differences?