"""Load Adult dataset and manage cross validation."""
import folktables
from folktables import ACSDataSource
import pandas as pd
from sklearn.model_selection import StratifiedKFold
import numpy as np
[docs]
def load(k=0,sensitive=[]):
"""Download if necessary folktables adult. Split and return train and test.
:param k: (Optinal default=0) Corss validation step in {0,1,2,3,4}.
:type k: int
:param sensitive: (Optional default=[]) List of sensitive attributes to include in the features. The sensitive attribute are "sex" and "race".
:type sensitive: list of str
:return: Train and test split dataframes in a dictionary.
:rtype: Doctionary
"""
#states = ['AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA', 'HI','ID', 'IL', 'IN', 'IA', 'KS', 'KY', 'LA', 'ME', 'MD', 'MA', 'MI', 'MN', 'MS', 'MO', 'MT', 'NE', 'NV', 'NH', 'NJ', 'NM', 'NY', 'NC', 'ND', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT','VT', 'VA', 'WA', 'WV', 'WI', 'WY', 'PR']
data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
ca_data = data_source.get_data(states=["CA"],download=True)
features=[
'AGEP',
'COW',
'SCHL',
'MAR',
'OCCP',
'POBP',
'RELP',
'WKHP',
'PINCP'
]
human_to_census = {"race":"RAC1P","sex":"SEX"}
for s in sensitive:
features += [human_to_census[s]]
ACSIncome = folktables.BasicProblem(features=features,
target='PINCP',
target_transform=lambda x: x > 50000,
group=['SEX','RAC1P'],
preprocess=folktables.adult_filter,
postprocess=lambda x: np.nan_to_num(x, -1),
)
ca_features, ca_labels, ca_attrib = ACSIncome.df_to_pandas(ca_data)
skf = StratifiedKFold(random_state=1234,shuffle=True)
for i,(tmp_train,tmp_test) in enumerate(skf.split(ca_features,ca_labels)):
if i==k:
train = tmp_train
test = tmp_test
data = {"train":ca_features.iloc[train],
"test":ca_features.iloc[test]}
return data