set.seed(42)
library(bartcs)
The bartcs package finds confounders and treatment effect with Bayesian Additive Regression Trees (BART).
This tutorial will use The Infant Health and Development Program (IHDP) dataset. The dataset includes 6 continuous and 19 binary covariates with simulated outcome which is a cognitive test score. This dataset was first used by Hill (2011). My version of dataset is the first realization generated by Louizos et al. (2017) and you can find other versions in his github.
data(ihdp, package = "bartcs")
<- single_bart(
fit Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 10,
num_chain = 4,
num_post_sample = 100,
num_burn_in = 100,
verbose = FALSE
)
fit#> `bartcs` fit by `single_bart()`
#>
#> mean 2.5% 97.5%
#> ATE 3.989180 3.757940 4.194852
#> Y1 6.414511 6.209326 6.594922
#> Y0 2.425331 2.352195 2.512516
You can get mean and 95% credible interval of average treatment effect (ATE) and possible outcome Y1 and Y0.
Both separate_bart()
and single_bart()
fits
multiple MCMC chains. summary()
provides result and
Gelman-Rubin statistic to check convergence.
summary(fit)
#> `bartcs` fit by `single_bart()`
#>
#> Treatment Value
#> Treated group : 1
#> Control group : 0
#>
#> Tree Parameters
#> Number of Tree : 10 Value of alpha : 0.95
#> Prob. of Grow : 0.28 Value of beta : 2
#> Prob. of Prune : 0.28 Value of nu : 3
#> Prob. of Change : 0.44 Value of q : 0.95
#>
#> Chain Parameters
#> Number of Chains : 4 Number of burn-in : 100
#> Number of Iter : 200 Number of thinning : 1
#> Number of Sample : 100
#>
#> Outcome Diagnostics
#> Gelman-Rubin : 0.9964517
#>
#> Outcome
#> estimand chain 2.5% 1Q mean median 3Q 97.5%
#> ATE 1 3.773543 3.911107 3.972227 3.970111 4.056683 4.151239
#> ATE 2 3.772706 3.941568 3.999966 4.008911 4.055016 4.193598
#> ATE 3 3.741889 3.859122 3.943162 3.959975 4.021592 4.118731
#> ATE 4 3.819920 3.971550 4.041364 4.045832 4.111067 4.252098
#> ATE agg 3.757940 3.914043 3.989180 3.990748 4.061248 4.194852
#> Y1 1 6.215452 6.328724 6.396984 6.404919 6.460931 6.543268
#> Y1 2 6.261225 6.359134 6.426559 6.434194 6.486718 6.590550
#> Y1 3 6.184771 6.320276 6.377308 6.382648 6.450492 6.545511
#> Y1 4 6.245013 6.393345 6.457192 6.465348 6.513756 6.672324
#> Y1 agg 6.209326 6.347214 6.414511 6.423418 6.480329 6.594922
#> Y0 1 2.342513 2.399295 2.424757 2.423876 2.451448 2.506433
#> Y0 2 2.355629 2.398670 2.426593 2.426521 2.448976 2.495499
#> Y0 3 2.357711 2.401557 2.434146 2.431961 2.460607 2.517344
#> Y0 4 2.346422 2.389331 2.415827 2.412804 2.441819 2.511036
#> Y0 agg 2.352195 2.397022 2.425331 2.424030 2.450444 2.512516
You can get posterior inclusion probability for each variables.
plot(fit, method = "pip")
Since inclusion_plot()
is a wrapper function of
ggcharts::bar_chart()
, you can use its arguments for better
plot.
plot(fit, method = "pip", top_n = 10)
plot(fit, method = "pip", threshold = 0.5)
With trace_plot()
, you can visually check trace of
effects or other parameters.
plot(fit, method = "trace")
plot(fit, method = "trace", "alpha")
count_omp_thread()
#> [1] 8
Check whether OpenMP is supported. You need more than 1 thread for multi-threading. Due to overhead of multi-threading, using parallelization will be not effective with small and moderate datasets. I recommend parallelization for data with size of at least 10,000.
For comparison purpose, I will create dataset with 40,000 rows by bootstrapping from IHDP dataset. Then, for fast computation, I will set most parameters to 1.
<- sample(nrow(ihdp), 4e4, TRUE)
idx <- ihdp[idx, ]
ihdp
::microbenchmark(
microbenchmarksimple = single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = FALSE
),parallel = single_bart(
Y = ihdp$y_factual,
trt = ihdp$treatment,
X = ihdp[, 6:30],
num_tree = 1,
num_chain = 1,
num_post_sample = 10,
num_burn_in = 0,
verbose = FALSE,
parallel = TRUE
),times = 50
)#> Warning in microbenchmark::microbenchmark(simple = single_bart(Y =
#> ihdp$y_factual, : less accurate nanosecond times to avoid potential integer
#> overflows
#> Unit: milliseconds
#> expr min lq mean median uq max neval
#> simple 48.22818 57.33953 63.56068 62.56124 66.49462 142.85273 50
#> parallel 45.68097 51.33405 55.00709 53.98286 56.76028 82.09377 50
Result show that parallelization gives better result.