Pytorch implementations of some federated learning methods based on sharpness-aware minimization.
I built this repository based on PFLlib and FL-Simulator. Thanks for their superior and understandable code architecture.
I did some optimization for time-saving and GPU-saving. For example, if you run FedAvg with 10% active clients per round of total 100 clients on CIFAR10 dataset with one NVIDIA 4090 GPU:
- Using two-layers CNN: about only 2.72s runtime cost every round and 0.93GB GPU memory cost.
- Using Resnet18: about only 20s runtime cost every round and 1.86GB GPU memory cost.
-
FedAvg — Communication-Efficient Learning of Deep Networks from Decentralized Data AISTATS 2017
-
FedDyn — Federated Learning Based on Dynamic Regularization ICLR 2021
-
FedSAM—Generalized Federated Learning via Sharpness Aware Minimization ICML2022
-
FedSpeed—FedSpeed: Larger Local Interval, Less Communication Round, and Higher Generalization Accuracy ICLR 2023
-
FedSMOO—Dynamic Regularized Sharpness Aware Minimization in Federated Learning: Approaching Global Consistency and Smooth Landscape ICML 2023
-
FedLESAM—Locally Estimated Global Perturbations are Better than Local Perturbations for Federated Sharpness-aware Minimization ICML 2024
-
FedGMT— One Arrow, Two Hawks: Sharpness-aware Minimization for Federated Learning via Global Model Trajectory ICML 2025
We show some results of the CIFAR-10 dataset with 10% active clients per round of total 100 clients after 500 rounds. The corresponding hyperparameters are stated in the following.
CIFAR-10 | ||||||||||
CNN | ResNet18 | |||||||||
IID | Dir-1.0 | Dir-0.1 | Dir-0.01 | Time / round | IID | Dir-1.0 | Dir-0.1 | Dir-0.01 | Time / round | |
FedAvg | 77.71 | 75.96 | 71.68 | 63.27 | 2.72s | 76.74 | 73.73 | 64.34 | 50.41 | 20.10s |
FedDyn | 77.94 | 78.08 | 76.71 | 73.06 | 3.01s | 78.88 | 77.89 | 74.66 | 69.41 | 20.29s |
FedSAM | 80.68 | 78.33 | 72.27 | 63.80 | 7.64s | 77.71 | 74.62 | 64.38 | 48.42 | 40.38s |
FedSpeed | 81.63 | 81.66 | 77.90 | 74.24 | 8.48s | 79.58 | 79.54 | 75.66 | 69.31 | 42.11s |
FedSMOO | 81.24 | 80.98 | 78.28 | 75.28 | 9.39s | 79.50 | 79.35 | 75.22 | 69.60 | 43.07s |
FedLESAM-D | 77.71 | 78.69 | 76.85 | 72.71 | 3.25s | 78.45 | 79.56 | 74.82 | 69.34 | 22.32s |
FedGMT | 81.62 | 81.92 | 79.45 | 76.36 | 3.21s | 80.99 | 80.10 | 75.89 | 70.28 | 23.72s |
Common Training hyparameters
In the above experiments, we employ SGD with a learning rate of 0.01, momentum of 0.9, weight decay of 1e-5, batch size of 50, local epoch of 5.
Some key hyparameters selection
SAM perturbation | penalty coefficient | others | |
FedAvg | - | - | - |
FedDyn | - | 10 | - |
FedSAM | {0.001,0.01,0.1} | - | - |
FedSpeed | {0.001,0.01,0.1} | 10 | - |
FedSMOO | {0.001,0.01,0.1} | 10 | - |
FedLESAM-D | {0.001,0.01,0.1} | 10 | - |
FedGMT | - | 10 | EMA coefficient α: {0.95, 0.995,0.998} Sharpness strength γ: {0.5,1.0,2.0} |
Example codes to run FedGMT on CIFAR10 is given here.
Please install the required packages. The code is compiled with Python 3.7 dependencies in a virtual environment via
pip install -r requirements.txt
./dataset
:utils
: code for heterogeneous partition strategy.generate_dataset.py
: generate client's local datasets .
./system
:main.py
: configurations of methods../flcore
:./clients/clientxxx.py
: the code on the client../servers/serverxxx.py
: the code on the server../trainmodel/models.py
: the code for backbones.
./utils
:mem_utils.py
: the code to record the GPU memory usage.data_utils.py
: the code to read the dataset.xxx_utils.py
: the code for specific algorithm.
if you want to generate Pathological non-iid data:
cd ./dataset
python generate_dataset.py --shard -data cifar10 -nc 100 -shard_per_user 2 #Path(2)
if you want to generate Dirichlet non-iid data:
cd ./dataset
python generate_dataset.py --LDA -data cifar10 -nc 100 -noniid 0.1 -if 0.5 #Dir(0.1) with long tail
cd ./system
python main.py -algo FedAvg -data cifar10 -dev cuda --seed 1 -lr 0.01 -gr 500 -lbs 50 -le 5 -jr 0.1 -nc 100
cd ./system
python main.py -algo FedGMT -data cifar10 -dev cuda --seed 1 -lr 0.01 -gr 500 -lbs 50 -le 5 -jr 0.1 -nc 100 -ga 1.0 -al 0.95 -tau 3.0 -be 10
To add a new algorithm, extend the base classes Server and Client, which are defined in ./system/flcore/servers/serverbase.py
and ./system/flcore/clients/clientbase.py
, respectively.
If this codebase can help you, please cite our papers:
@inproceedings{li2025one,
title={One Arrow, Two Hawks: Sharpness-aware Minimization for Federated Learning via Global Model Trajectory},
author={Li, Yuhang and Liu, Tong and Cui, Yangguang and Hu, Ming and Li, Xiaoqiang},
booktitle={International Conference on Machine Learning (ICML)},
year={2025}
}