-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathREADME.Rmd
More file actions
214 lines (160 loc) · 6.45 KB
/
README.Rmd
File metadata and controls
214 lines (160 loc) · 6.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
---
output: github_document
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(
echo = TRUE, message = FALSE, warning = FALSE,
fig.path = "man/figures/README-", fig.width = 7, fig.height = 4.5,
collapse = TRUE, comment = "#>"
)
library(forestBalance)
library(grf)
library(Matrix)
```
# forestBalance
**Forest Kernel Energy Balancing for Causal Inference**
`forestBalance` estimates average treatment effects (ATE) by combining
multivariate random forests with kernel energy balancing. A joint forest model
of covariates, treatment, and outcome defines a proximity kernel that characterizes
the confounding structure and emphasizes similarity of observations in terms of confounding.
The kernel is then used in kernel energy balancing, which minimizes the distributional distance between treatment and control groups. By using the joint forest kernel, this balancing approach emphasizes balance of variables that impact both treatment and outcome.
By construction, these balancing weights aim to balance the joint
distribution of confounders specifically.
The method is described in:
> De, S. and Huling, J.D. (2025). *Data adaptive covariate balancing for causal effect
> estimation for high dimensional data.* arXiv:2512.18069.
## Installation
```{r install, eval=FALSE}
# Install from GitHub
devtools::install_github("jaredhuling/forestBalance")
```
## Quick start
```{r quickstart}
# library(forestBalance)
# Simulate observational data with nonlinear confounding (true ATE = 0)
set.seed(123)
dat <- simulate_data(n = 500, p = 10, ate = 0)
# Estimate ATE with forest kernel energy balancing
fit <- forest_balance(dat$X, dat$A, dat$Y)
fit
```
## How it works
The method proceeds in three steps:
1. **Joint forest model:** A `grf::multi_regression_forest` is fit on
covariates $X$ with the bivariate response $(A, Y)$. Because the forest
splits on both treatment and outcome, the resulting tree structure captures
confounding relationships.
2. **Proximity kernel:** The $n \times n$ kernel matrix $K(i,j)$ is defined as
the proportion of trees where observations $i$ and $j$ share a leaf node.
This is computed efficiently via a single sparse matrix cross-product.
3. **Kernel energy balancing:** Balancing weights are obtained in closed form
by solving a linear system derived from the kernel energy distance objective.
The weights make the treated and control distributions similar with respect
to the forest-defined similarity measure.
## Detailed example
### Simulating data
`simulate_data()` generates observational data with nonlinear confounding
through a Beta density link:
```{r simulate}
set.seed(123)
dat <- simulate_data(n = 800, p = 10, ate = 0)
# Naive (unweighted) estimate is biased
naive_ate <- mean(dat$Y[dat$A == 1]) - mean(dat$Y[dat$A == 0])
c("Naive ATE" = round(naive_ate, 4), "True ATE" = 0)
```
### Fitting the model
```{r fit}
fit <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 1000)
```
### Print and summary
`print()` gives a concise overview:
```{r print}
fit
```
`summary()` provides a full covariate balance comparison (unweighted vs
weighted) with flagged imbalances:
```{r summary}
summary(fit)
```
### Balance on nonlinear transformations
Since confounding operates through nonlinear functions of $X_1$ and $X_5$, we
can check balance on transformations of the covariates:
```{r summary-trans}
X <- dat$X
X.nl <- cbind(
X[,1]^2, X[,2]^2, X[,5]^2,
X[,1] * X[,2], X[,1] * X[,5],
dbeta(X[,1], 2, 4), dbeta(X[,5], 2, 4)
)
colnames(X.nl) <- c("X1^2", "X2^2", "X5^2", "X1*X2", "X1*X5",
"Beta(X1)", "Beta(X5)")
summary(fit, X.trans = X.nl)
```
### Standalone balance diagnostics
`compute_balance()` can be used independently with any set of weights:
```{r standalone-balance}
# Inverse propensity weights (using true propensity scores)
ipw <- ifelse(dat$A == 1, 1 / dat$propensity, 1 / (1 - dat$propensity))
bal_forest <- compute_balance(dat$X, dat$A, fit$weights)
bal_ipw <- compute_balance(dat$X, dat$A, ipw)
c("Forest balance" = round(bal_forest$max_smd, 4),
"IPW" = round(bal_ipw$max_smd, 4))
```
### Lower-level interface
For more control, the pipeline can be run step by step:
```{r lower-level}
library(grf)
# 1. Fit the joint forest
forest <- multi_regression_forest(dat$X, scale(cbind(dat$A, dat$Y)),
num.trees = 500)
# 2. Extract leaf node matrix and build kernel
leaf_mat <- get_leaf_node_matrix(forest, dat$X)
K <- leaf_node_kernel(leaf_mat)
c("observations" = nrow(leaf_mat), "trees" = ncol(leaf_mat))
c("kernel % nonzero" = round(100 * length(K@x) / prod(dim(K)), 1))
# 3. Compute balancing weights
bal <- kernel_balance(dat$A, K)
ate <- weighted.mean(dat$Y[dat$A == 1], bal$weights[dat$A == 1]) -
weighted.mean(dat$Y[dat$A == 0], bal$weights[dat$A == 0])
c("ATE estimate" = round(ate, 4))
```
## Simulation study
A small simulation comparing the forest balance estimator against the naive
(unadjusted) difference in means:
```{r simulation, cache=TRUE}
set.seed(1)
nreps <- 100
results <- matrix(NA, nreps, 2, dimnames = list(NULL, c("Naive", "Forest")))
for (r in seq_len(nreps)) {
dat <- simulate_data(n = 500, p = 10, ate = 0)
fit <- forest_balance(dat$X, dat$A, dat$Y, num.trees = 500)
results[r, "Naive"] <- mean(dat$Y[dat$A == 1]) - mean(dat$Y[dat$A == 0])
results[r, "Forest"] <- fit$ate
}
```
```{r sim-results, echo=FALSE}
stbl <- data.frame(
Method = colnames(results),
Bias = round(colMeans(results), 4),
SD = round(apply(results, 2, sd), 4),
RMSE = round(sqrt(colMeans(results)^2 + apply(results, 2, var)), 4)
)
knitr::kable(stbl, row.names = FALSE)
```
```{r sim-plot, echo=FALSE, fig.height=3.5}
par(mar = c(4, 0.5, 2, 0.5))
boxplot(results, horizontal = TRUE, col = c("firebrick", "steelblue"),
main = "ATE estimates across 100 replications (true ATE = 0)",
xlab = "Estimated ATE")
abline(v = 0, lty = 2, col = "grey40")
```
## Key functions
| Function | Description |
|---|---|
| `forest_balance()` | High-level: fit forest, build kernel, compute weights, return ATE |
| `simulate_data()` | Simulate observational data with nonlinear confounding |
| `compute_balance()` | Covariate balance diagnostics (SMD, ESS, energy distance) |
| `get_leaf_node_matrix()` | Fast vectorized leaf node extraction from grf forests |
| `leaf_node_kernel()` | Sparse proximity kernel from leaf node matrix |
| `forest_kernel()` | Convenience: forest object to kernel in one call |
| `kernel_balance()` | Closed-form kernel energy balancing weights |