make_partially_linear_simple_dataset
extensions.synthetic_data.make_partially_linear_simple_dataset(n_obs=1000, n_covariates=10, n_confounders=5, dim_heterogeneity=2, binary_treatment=True, seed=None)
Generate a partially linear data generating process with simple 1 or 2 dimensional CATE function. The outcome is continuous and the treatment can be binary or continuous. The dataset is generated using themake_heterogeneous_data
function from the doubleml
package.
The general form of the data generating process is, in the case of dim_heterogeneity=1:
\[ Y_i= \tau (x_0) D_i + g(\mathbf{X_i})+\epsilon_i \] \[ D_i=f(\mathbf{X_i})+\eta_i \]
or, in the case of dim_heterogeneity=2:
\[ Y_i= \tau (x_0,x_1) D_i + g(\mathbf{X_i})+\epsilon_i \] \[ D_i=f(\mathbf{X_i})+\eta_i \]
where \(Y_i\) is the outcome, \(D_i\) is the treatment, \(\mathbf{X_i}\) are the covariates, \(\epsilon_i\) and \(\eta_i\) are the error terms, \(\tau\) is the CATE function, \(g\) is the outcome function, and \(f\) is the treatment function.
Parameters
Name | Type | Description | Default |
---|---|---|---|
n_obs |
int | The number of observations to generate. Default is 1000. | 1000 |
n_covariates |
int | The number of covariates to generate. Default is 10. | 10 |
n_confounders |
int | The number of covariates that are confounders. Default is 5. | 5 |
dim_heterogeneity |
int | The dimension of the heterogeneity. Default is 2. Can only be 1 or 2. | 2 |
binary_treatment |
bool | Whether the treatment is binary or continuous. Default is True. | True |
seed |
int | None | The seed to use for the random number generator. Default is None. | None |
Returns
Type | Description |
---|---|
pd.DataFrame | The generated dataset where y is the outcome, d is the treatment, and X are the covariates. |
np.ndarray | The true conditional average treatment effects. |
float | The true average treatment effect. |
Examples
>>> df, true_cates, true_ate = make_partially_linear_simple_dataset(n_obs=1000, n_covariates=10, n_confounders=5, dim_heterogeneity=2, binary_treatment=True, seed=1)