-
Notifications
You must be signed in to change notification settings - Fork 0
/
FedELMY.py
236 lines (220 loc) · 9.64 KB
/
FedELMY.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import copy
import datetime
import json
import os
import random
import shutil
import sys
import time
import warnings
from hps import *
import math
import torchvision.models as models
import numpy as np
from tqdm import tqdm
import pdb
from helpers.datasets import partition_data
from helpers.utils import get_dataset, mean_average_weights, DatasetSplit, KLDiv, setup_seed, test, \
federated_average_weights
from loop_df_fl import get_model, LocalUpdate, Ensemble
from models.nets import CNNCifar, CNNMnist, CNNCifar100, CNNPACS
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from torch.utils.data.dataset import random_split
from models.resnet import resnet18
from models.vit import deit_tiny_patch16_224
from warmup_config import warmup_config
# import wandb
from commandline_config import Config
warnings.filterwarnings('ignore')
upsample = torch.nn.Upsample(mode='nearest', scale_factor=7)
preset_config = {
"warmup_epochs": -1, # When to start split learning by different hyperparameters
"shots": 1, # Number of shots
"num_users": 10, # Number of users: K
"num_classes": 10, # Number of classes
"num_models": 5, # Number of models per user for model pool
"frac": 1, # The fraction of clients: C
"local_ep": -1, # The number of local epochs: E
"max_hp_count": 9999, # The number of local epochs: E
"local_bs": 128, # Local batch size: B
"lr": 0.01, # Learning rate
"image_size": -1,
"validation_ratio": 0.1, # Validation dataset ratio
"momentum": 0.9, # SGD momentum (default: 0.5)
"weight_decay": 1e-4, # SGD weight decay (default: 1e-4)
"optimizer": "-1",
"record_distances": 0,
"note": "",
"dataset": "cifar10", # Name of dataset
"random_position": "inside", # Position of random
"iid": 0, # Default set to IID. Set to 0 for non-IID.
"mu": 1, # \mu for fedprox
"optimization_method": "none",
"alpha": 1, # alpha for the regularization term
"beta": 1, # beta for the regularization term
"order": 1, # Order of domain shift tasks
"save_every_model": 0,
"adv": 1, # Scaling factor for adv loss
"bn": 1, # Scaling factor for BN regularization
"oh": 1, # Scaling factor for one hot loss (cross entropy)
"act": 0, # Scaling factor for activation loss used in DAFL
"save_dir": "run/synthesis",
"partition": "dirichlet", # Partition type
"betas": 0.3, # Split distribution, If betas is set to a smaller value, then the partition is more unbalanced
"lr_g": 1e-3, # Initial learning rate for generation
"T": 20,
"g_steps": 30, # Number of iterations for generation
"batch_size": 256, # Number of total iterations in each epoch
"nz": 256, # Number of total iterations in each epoch
"synthesis_batch_size": 256,
"seed": 1, # Seed for initializing training
"epochs": 50, # Total number of training epochs
"type": "pretrain",
"model": "cnn", # Model name
"other": "",
"device": "cuda:0", # GPU Device ID
"id": "0" # Experiment ID
}
if __name__ == '__main__':
config = Config(preset_config, name='Federated Learning Experiments')
print(config)
if not torch.cuda.is_available():
config.device = "cpu"
print("CUDA is not available, use CPU.")
setup_seed(config.seed)
# pdb.set_trace()
# BUILD MODEL
start_time = time.time()
global_model = get_model(config)
init_weights = copy.deepcopy(global_model.state_dict())
bst_acc = -1
description = "inference acc={:.4f}% loss={:.2f}, best_acc = {:.2f}%"
global_model.train()
fedavg_accs = []
client_accs = []
if config.id == "0":
id = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
else:
id = config.id
print("id: {}".format(id))
time.sleep(3)
if config.shots == 1:
hps = hyperparameters_one_shot[config.dataset]
else:
hps = hyperparameters[config.dataset]
fedavg_model_weights = []
saved_model_weights_pool = []
# ===============================================
model_weights_pool = []
for i in range(config.shots):
local_weights = []
user_avg_weights = []
users = []
saved_datasets = []
acc_list = []
max_accs = []
best_model_weights = []
val_accs = []
client_losses = []
if i == 0:
# Client 0 warmup
if config.warmup_epochs != -1:
warmup_epochs = config.warmup_epochs
else:
warmup_epochs = warmup_config[config.dataset][config.model][0]
hyperparameter = hps[0]
train_dataset, val_dataset, test_dataset, user_groups, val_user_groups, training_data_cls_counts = partition_data(
config.dataset, config.partition, beta=config.betas, num_users=config.num_users,
transform=hyperparameter["transform"], order=config.order)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256,
shuffle=False, num_workers=4)
local_model = LocalUpdate(args=config, dataset=train_dataset, val_dataset=val_dataset,
idxs=user_groups[0], val_idxs=val_user_groups[0], test_loader=test_loader)
training_set, valid_set = local_model.get_datasets()
global_model.load_state_dict(init_weights)
print("Start Warm Up")
warmup_weights, local_acc_list, best_epoch, max_val_acc, local_loss_list = local_model.update_weights(
copy.deepcopy(global_model), config.device, hyperparameter, local_ep=warmup_epochs, optimize=False, args=config)
t_warmup_weights = copy.deepcopy(warmup_weights)
model_weights_pool.append(t_warmup_weights)
for idx in range(config.num_users):
train_dataset, val_dataset, test_dataset, user_groups, val_user_groups, training_data_cls_counts = partition_data(
config.dataset, config.partition, beta=config.betas, num_users=config.num_users,
transform=hyperparameter["transform"], order=config.order)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256,
shuffle=False, num_workers=4)
for m in range(config.num_models):
print("Now training model {} for user {}".format(m, idx))
local_model = LocalUpdate(args=config, dataset=train_dataset, val_dataset=val_dataset,
idxs=user_groups[idx], val_idxs=val_user_groups[idx], test_loader=test_loader)
model_weights_pool, local_acc_list, best_epoch, max_val_acc, local_loss_list = local_model.update_weights_model_pool(
copy.deepcopy(global_model), config.device, hyperparameter, model_weights_pool, random_position=config.random_position, args=config)
saved_model_weights_pool.extend(model_weights_pool)
model_weights_pool = [mean_average_weights(model_weights_pool)]
if not os.path.exists('checkpoints'):
os.makedirs('checkpoints')
torch.save(saved_model_weights_pool,
'checkpoints/{}_{}clients_{}_{}_{}_{}.pkl'.format(config.dataset, config.num_users, config.betas,
config.partition, config.model, id))
global_weights = mean_average_weights(model_weights_pool)
global_model.load_state_dict(global_weights)
print("One-Shot MeanAvg Accuracy:")
meanavg_test_acc, meanavg_test_loss = test(global_model, test_loader, config.device)
model_list = []
for i in range(len(model_weights_pool)):
net = copy.deepcopy(global_model)
net.load_state_dict(model_weights_pool[i])
model_list.append(net)
ensemble_model = Ensemble(model_list)
print("Ensemble Accuracy:")
ensemble_test_acc, ensemble_test_loss = test(ensemble_model, test_loader, config.device)
max_acc = []
for sub in acc_list:
max_acc.append(max(sub))
last_acc = []
for sub in acc_list:
last_acc.append(sub[-1])
output = {
"id": id,
"seed": config.seed,
"shots": config.shots,
"time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"dataset": config.dataset,
"model": config.model,
"num_users": config.num_users,
"betas": config.betas,
"partition": config.partition,
"warmup_config": warmup_config[config.dataset],
"file": os.path.basename(__file__),
"meanavg_test_acc": meanavg_test_acc,
"meanavg_test_loss": meanavg_test_loss,
"ensemble_test_acc": ensemble_test_acc,
"ensemble_test_loss": ensemble_test_loss,
"hyperparameters": hps,
"data_cls_counts": str(training_data_cls_counts),
"args": str(config),
"local_bs": config.local_bs,
"validation_ratio": config.validation_ratio,
"note": config.note,
"order": config.order,
"client_losses": client_losses,
"val_accs": val_accs,
"alpha": config.alpha,
"beta": config.beta,
}
if not os.path.exists('results'):
os.makedirs('results')
json.dump(output, open(
'results/{}_{}clients_{}_{}_{}.json'.format(config.dataset, config.num_users, config.betas, config.model, id), 'w'))
output = json.dumps(output, indent=4)
print(output)
print("One-Shot MeanAvg Accuracy:")
meanavg_test_acc, meanavg_test_loss = test(global_model, test_loader, config.device)
end_time = time.time()
print("Total Time Cost: {:.2f}s".format(end_time - start_time))
# ===============================================