from torchvision.datasets import MNIST, CIFAR10, CIFAR100
import numpy as np
from tqdm import tqdm
import torchvision
from sklearn.model_selection import train_test_split
from .data_manager import DataManager
[docs]class BasicDataManager(DataManager):
r"""A basic data manager for partitioning the data. Currecntly three
rules of partitioning are supported:
- iid:
same label distribution among clients. sample balance determines
quota of each client samples from a lognorm distribution.
- dir:
Dirichlete distribution with concentration parameter given by
label_balance determines label balance of each client.
sample balance determines quota of each client samples from a
lognorm distribution.
- exclusive:
samples corresponding to each label are randomly splitted to
k clients where k = total_sample_size * label_balance.
sample_balance determines the way this split happens (quota).
This rule also is know as "shards splitting".
Args:
root (str): root dir of the dataset to partition
dataset (str): name of the dataset
num_clients (int): number of partitions or clients
rule (str): rule of partitioning
sample_balance (float): balance of number of samples among clients
label_balance (float): balance of the labels on each clietns
local_test_portion (float): portion of local test set from trian
seed (int): random seed of partitioning
save_path (str, optional): path to save partitioned indices.
"""
def __init__(
self,
root,
dataset='mnist',
num_partitions=500,
rule='iid',
sample_balance=0.,
label_balance=1.,
local_test_portion=0.,
seed=10,
save_path=None,
*args,
**kwargs,
):
"""A basic data manager for partitioning the data. Currecntly three
rules of partitioning are supported:
- iid:
same label distribution among clients. sample balance determines
quota of each client samples from a lognorm distribution.
- dir:
Dirichlete distribution with concentration parameter given by
label_balance determines label balance of each client.
sample balance determines quota of each client samples from a
lognorm distribution.
- exclusive:
samples corresponding to each label are randomly splitted to
k clients where k = total_sample_size * label_balance.
sample_balance determines the way this split happens (quota).
This rule also is know as "shards splitting".
Args:
root (str): root dir of the dataset to partition
dataset (str): name of the dataset
num_clients (int): number of partitions or clients
rule (str): rule of partitioning
sample_balance (float): balance of number of samples among clients
label_balance (float): balance of the labels on each clietns
local_test_portion (float): portion of local test set from trian
seed (int): random seed of partitioning
save_path (str, optional): path to save partitioned indices.
"""
self.dataset_name = dataset
self.num_partitions = num_partitions
self.rule = rule
self.sample_balance = sample_balance
self.label_balance = label_balance
self.local_test_portion = local_test_portion
# super should be called at the end because abstract classes are called in its __init__
super(BasicDataManager, self).__init__(
root,
seed,
save_path=save_path,
)
[docs] def make_datasets(self, root, global_transforms=None):
if self.dataset_name == 'mnist':
local_dset = MNIST(root, download=True, train=True, transform=None)
global_dset = MNIST(root,
download=True,
train=True,
transform=None)
elif self.dataset_name == 'cifar10' or self.dataset_name == 'cifar100':
dst_class = CIFAR10 if self.dataset_name == 'cifar10' else CIFAR100
local_dset = dst_class(root=root,
download=True,
train=True,
transform=None)
global_dset = dst_class(root=root,
download=True,
train=False,
transform=global_transforms['test'])
else:
raise NotImplementedError
return local_dset, global_dset
[docs] def partition_local_data(self, dataset):
n = self.num_partitions
targets = np.array(dataset.targets)
all_sample_count = len(targets)
num_classes = len(np.unique(targets))
# the special case of exclusive rule:
if self.rule == 'exclusive':
# TODO: implement this
raise NotImplementedError
# # get number of samples per label
# label_counts = [(targets==i).sum() for i in range(num_classes)]
# for label, label_count in enumerate(label_counts):
# # randomly select k clients
# # determine the quota for each client from a lognorm
# # reassign the
# *********************************************************
# determine sample quota for each client
sample_per_client = all_sample_count // n
if self.sample_balance != 0:
# Draw from lognormal distribution
client_quota = (np.random.lognormal(mean=np.log(sample_per_client),
sigma=self.sample_balance,
size=n))
quota_sum = np.sum(client_quota)
client_quota = (client_quota / quota_sum *
all_sample_count).astype(int)
diff = quota_sum - all_sample_count
# Add/Sub the excess number starting from first client
if diff != 0:
for clnt_i in range(n):
if client_quota[clnt_i] > diff:
client_quota[clnt_i] -= diff
break
else:
client_quota = np.ones(n, dtype=int) * sample_per_client
indices = [
np.zeros(client_quota[client], dtype=int) \
for client in range(n)
]
# *********************************************************
if self.rule == 'dir':
# Dirichlet partitioning rule
cls_priors = np.random.dirichlet(alpha=[self.label_balance] *
num_classes,
size=n)
prior_cumsum = np.cumsum(cls_priors, axis=1)
idx_list = [np.where(targets == i)[0] for i in range(num_classes)]
cls_amount = np.array(
[len(idx_list[i]) for i in range(num_classes)])
print('partitionig')
pbar = tqdm(total=np.sum(client_quota))
while np.sum(client_quota) != 0:
curr_clnt = np.random.randint(n)
# If current node is full resample a client
if client_quota[curr_clnt] <= 0:
continue
client_quota[curr_clnt] -= 1
curr_prior = prior_cumsum[curr_clnt]
# exclude the classes that have ran out of examples
curr_prior[cls_amount <= 0] = -1
# scale the prior up so the positive values sum to
# 1 again
cpp = curr_prior[curr_prior > 0]
cpp /= cpp.sum()
curr_prior[curr_prior > 0] = cpp
while True:
if (curr_prior > 0).sum() < 1:
raise Exception("choose another seed")
if (curr_prior > 0).sum() == 1:
cls_label = curr_prior.argmax()
else:
uu = np.random.uniform()
cls_label = np.argmax(uu <= curr_prior)
# Redraw class label if out of that class
if cls_amount[cls_label] <= 0:
continue
cls_amount[cls_label] -= 1
indices[curr_clnt][client_quota[curr_clnt]] = idx_list[
cls_label][cls_amount[cls_label]]
break
pbar.update(1)
pbar.close()
# *********************************************************
elif self.rule == 'iid':
clnt_quota_cum_sum = np.concatenate(([0], np.cumsum(client_quota)))
for client_index in range(n):
indices[client_index] = np.arange(
clnt_quota_cum_sum[client_index],
clnt_quota_cum_sum[client_index + 1])
else:
raise NotImplementedError
ts_portion = self.local_test_portion
if ts_portion > 0:
new_indices = dict(train=[], test=[])
for client_indices in indices:
train_idxs, test_idxs = train_test_split(client_indices,
test_size=ts_portion)
new_indices['train'].append(train_idxs)
new_indices['test'].append(test_idxs)
else:
new_indices = dict(train=indices)
return new_indices
[docs] def get_identifiers(self):
identifiers = [self.dataset_name, str(self.num_partitions), self.rule]
if self.rule == 'dir':
identifiers.append(str(self.label_balance))
if self.sample_balance == 0:
identifiers.append('balanced')
else:
identifiers.append('unbalanced')
if self.local_test_portion > 0:
identifiers.append('ts_{}'.format(self.local_test_portion))
return identifiers