Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,13 @@ For other algorithms and datasets, users can refer to `argzoo/` or customize you

**FedBN:** [FedBN: Federated Learning on Non-IID Features via Local Batch Normalization](https://arxiv.org/pdf/2102.07623.pdf)

**FedAMP:** [Personalized Cross-Silo Federated Learning on Non-IID Data](https://ojs.aaai.org/index.php/AAAI/article/view/16960)

**FedRoD:** [On Bridging Generic and Personalized Federated Learning for Image Classification](https://openreview.net/pdf?id=I1hQbx10Kxn)

**pFedSD:** [Personalized Edge Intelligence via Federated Self-Knowledge Distillation](https://ieeexplore.ieee.org/abstract/document/9964434)

**FedCAC:** [Bold but Cautious: Unlocking the Potential of Personalized Federated Learning through Cautiously Aggressive Collaboration]()
**FedCAC:** [Bold but Cautious: Unlocking the Potential of Personalized Federated Learning through Cautiously Aggressive Collaboration](https://openaccess.thecvf.com/content/ICCV2023/papers/Wu_Bold_but_Cautious_Unlocking_the_Potential_of_Personalized_Federated_Learning_ICCV_2023_paper.pdf)

## Feedback and Contribution

Expand Down
51 changes: 51 additions & 0 deletions argzoo/cifar10/cifar10_fedamp_resnet_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from easydict import EasyDict

lamda = 1
alphaK = 10000
sigma = 100
decay_rate = 0.1
decay_frequency = 30
noniid = 'dirichlet'
alpha = 0.1
seed = 2


exp_args = dict(
data=dict(
dataset='cifar10',
data_path='./data/CIFAR10',
sample_method=dict(name=noniid, alpha=alpha, train_num=500, test_num=100)
),
learn=dict(
device='cuda:0',
local_eps=5,
global_eps=300,
batch_size=100,
optimizer=dict(name='sgd', lr=0.1, momentum=0.9),
finetune_parameters=dict(name='all'),
test_place=['after_aggregation'],
lamda=lamda, # regularization weight for FedAMP
alphaK=alphaK, # lambda/sqrt(GLOABL-ITRATION) according to the paper
sigma=sigma, # hyperparameter in function A
decay_rate=decay_rate, # decay rate of alphaK in FedAMP
decay_frequency=decay_frequency, # decay frequency of alphaK in FedAMP
),
model=dict(
name='resnet8',
input_channel=3,
class_number=10,
),
client=dict(name='fedamp_client', client_num=40),
server=dict(name='base_server'),
group=dict(
name='fedamp_group',
),
other=dict(test_freq=1, logging_path=f'./logging/cifar10_fedamp_resnet/{noniid}_{alpha}/{lamda}_{alphaK}_{sigma}/{seed}')
)

exp_args = EasyDict(exp_args)

if __name__ == '__main__':
from fling.pipeline import personalized_model_pipeline

personalized_model_pipeline(exp_args, seed=seed)
2 changes: 0 additions & 2 deletions argzoo/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
name='sgd',
# Learning rate of the optimizer.
lr=0.02,
# Momentum of the SGD optimizer.
momentum=0.9
),
# Learning rate scheduler. For each global epoch, use a dynamic learning rate.
scheduler=dict(
Expand Down
1 change: 1 addition & 0 deletions fling/component/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .fedcac_client import FedCACClient
from .fedrod_client import FedRoDClient
from .fedprox_client import FedProxClient
from .fedamp_client import FedAMPClient
54 changes: 54 additions & 0 deletions fling/component/client/fedamp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import copy
import numpy as np
from typing import Tuple

import torch
import torch.nn as nn

from fling.utils.registry_utils import CLIENT_REGISTRY
from .base_client import BaseClient
from fling.utils.utils import weight_flatten

@CLIENT_REGISTRY.register('fedamp_client')
class FedAMPClient(BaseClient):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether this client is identical to FedProxClient? What's the differences?

"""
Overview:
This class is the base implementation of client in 'Bold but Cautious: Unlocking the Potential of Personalized

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct this document, not FedCAC.

Federated Learning through Cautiously Aggressive Collaboration' (FedCAC).
"""

def __init__(self, args, client_id, train_dataset, test_dataset=None):
"""
Initializing train dataset, test dataset(for personalized settings).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct this document, the purpose is to get a copy of local model.

"""
super(FedAMPClient, self).__init__(args, client_id, train_dataset, test_dataset)
self.client_u = copy.deepcopy(self.model)

def FedAMP_Loss_client(self):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use the get_model_difference function (defined in fling/utils/torch_utils.py ) for simplification.

params = weight_flatten(self.model)
params_ = weight_flatten(self.client_u)
sub = params - params_
result = self.args.learn.lamda / (2 * self.args.learn.alphaK) * torch.dot(sub, sub)
return result

def train_step(self, batch_data, criterion, monitor, optimizer):
self.client_u.to(self.device)

batch_x, batch_y = batch_data['x'], batch_data['y']
o = self.model(batch_x)
loss = criterion(o, batch_y)
loss = loss + self.FedAMP_Loss_client()
y_pred = torch.argmax(o, dim=-1)

monitor.append(
{
'train_acc': torch.mean((y_pred == batch_y).float()).item(),
'train_loss': loss.item()
},
weight=batch_y.shape[0]
)
optimizer.zero_grad()
loss.backward()
optimizer.step()

self.client_u.to('cpu')
1 change: 1 addition & 0 deletions fling/component/group/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .base_group import ParameterServerGroup
from .build_group import get_group
from .fedcac_group import FedCACServerGroup
from .fedamp_group import FedAMPServerGroup
105 changes: 105 additions & 0 deletions fling/component/group/fedamp_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import time
import copy
import torch

from fling.utils.compress_utils import *
from fling.utils.registry_utils import GROUP_REGISTRY
from fling.utils import Logger
from fling.component.group import ParameterServerGroup
from fling.utils.utils import weight_flatten

@GROUP_REGISTRY.register('fedamp_group')
class FedAMPServerGroup(ParameterServerGroup):
r"""
Overview:
Implementation of the group in FedAMP.
"""

def __init__(self, args: dict, logger: Logger):
super(FedAMPServerGroup, self).__init__(args, logger)
# FedAMP auguments
self.alphaK = args.learn.alphaK
self.sigma = args.learn.sigma
self.lamda = args.learn.lamda
self.client_ws = [None for i in range(self.args.client.client_num)] # maintain all clients' personalized models
self.client_us = [None for i in range(self.args.client.client_num)] # aggregated model for each client

def sync(self) -> None:
r"""
Overview:
Send customized global models to each client
"""
if self.client_us[0] is None:
super().sync() # Called during system initialization

else:
for idx, client in enumerate(self.clients):
client.client_u = copy.deepcopy(self.client_us[idx])

def initialize(self) -> None:
super().initialize()
self.client_ws = [copy.deepcopy(self.clients[i].model) for i in range(self.args.client.client_num)]
self.client_us = [copy.deepcopy(self.clients[i].model) for i in range(self.args.client.client_num)]


def receive_models(self):
r"""
Overview:
Receive personalized models from each client
"""
for idx, client in enumerate(self.clients):
self.client_ws[idx] = copy.deepcopy(client.model)

def aggregate(self, train_round: int) -> int:
r"""
Overview:
Aggregate customized global models and send to each client
"""
# recieve models from clients
self.receive_models()
for i in range(self.args.client.client_num):
self.client_ws[i].to(self.args.learn.device)
self.client_us[i].to(self.args.learn.device)

# aggregate models
weights = [weight_flatten(mw) for mw in self.client_ws]
for i, mu in enumerate(self.client_us): # calculate u for each client
for param in mu.parameters(): # set zero for each parameter
param.data = torch.zeros_like(param.data)

coef = torch.zeros(self.args.client.client_num)
for j, mw in enumerate(self.client_ws):
if i == j: continue
sub = weights[i] - weights[j]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite it using fling.utils.get_model_difference

sub = torch.dot(sub, sub)
coef[j] = self.args.learn.alphaK * self.e(sub)
coef[i] = 1 - torch.sum(coef)

for j, mw in enumerate(self.client_ws):
for param, param_j in zip(mu.parameters(), mw.parameters()):
param.data += coef[j] * param_j

for i in range(self.args.client.client_num):
self.client_ws[i].to('cpu')
self.client_us[i].to('cpu')

# send to all clients
self.sync()

# perform alphaK decay
if train_round % self.args.learn.decay_frequency == 0 and train_round != 0:
self.args.learn.alphaK *= self.args.learn.decay_rate

# calculate communication cost
trans_cost = 0
state_dict = self.clients[0].model.state_dict()
Comment thread
kxzxvbk marked this conversation as resolved.
for k in self.clients[0].fed_keys:
trans_cost += self.args.client.client_num * state_dict[k].numel()
# 1B = 32bit
return 4 * trans_cost
def e(self, x):
r"""
Overview:
The derivative of attention-inducing function in FedAMP
"""
return torch.exp(-x/self.sigma)/self.sigma
8 changes: 8 additions & 0 deletions fling/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import os
import time
import torch
from typing import Iterable, Dict, List
from prettytable import PrettyTable

Expand All @@ -14,6 +15,13 @@ def client_sampling(client_ids: Iterable, sample_rate: float) -> List:
)
return participated_clients

def weight_flatten(model) -> torch.Tensor:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose that this function can be removed.

params = []
for u in model.parameters():
params.append(u.view(-1))
params = torch.cat(params)

return params

class Logger(SummaryWriter):

Expand Down