Discover how to adapt random forests for causal inference to estimate heterogeneous treatment effects—understanding who benefits most from a treatment and by how much.
Random forests are powerful prediction machines, but naive application to treatment effect estimation fails badly. Causal Forests adapt the random forest algorithm to directly target conditional average treatment effects (CATE).
Instead of asking "what will Y be?", we ask "how would Y change if we changed treatment?"—and critically, "does this effect vary across individuals?"
Standard regression trees (CART) split on features to minimize prediction error. Causal trees split on features to maximize treatment effect heterogeneity.
Using the same data for both constructing the tree (choosing splits) and estimating leaf values leads to overfitting and biased estimates. Honesty solves this.
Honest Tree Construction:
Just as random forests improve on single trees, causal forests aggregate many honest causal trees to get better estimates with lower variance.
Causal Forest Algorithm:
For a new observation x, we find which leaf it falls into in each tree, then average treatment effects from those leaves:
GRF extends causal forests to a unified framework for local moment estimation, enabling estimation of various causal quantities beyond simple treatment effects.
GRF Framework (Athey, Tibshirani, Wager 2019):
Instead of targeting a specific estimand, GRF solves local versions of moment equations:
Beyond estimating τ(x), we often want to know which features drive treatment effect heterogeneity. Variable importance measures help prioritize subgroup analyses.
How often a variable is used for splitting, weighted by improvement in heterogeneity.
Randomly permute variable and measure decrease in forest's ability to detect heterogeneity.
Decompose individual-level CATE predictions into feature contributions.
Project τ̂(X) onto individual features to quantify linear relationships.
Let's implement causal forests using the grf package in R and causalml in Python.
| Method | Pros | Cons |
|---|---|---|
| Causal Forest | • Non-parametric, no functional form • Valid CIs • Variable importance • Handles high-dim X | • Slower than meta-learners • Black box (less interpretable) • Needs larger samples |
| DML | • Flexible ML for nuisances • Valid inference • Efficient estimates | • Assumes parametric τ(x) • Requires correct model for θ |
| Meta-learners | • Simple to implement • Fast • Works with any base learner | • No built-in inference • Can be biased (esp. S/T-learner) |
| Matching | • Interpretable • No parametric assumptions | • Curse of dimensionality • Inefficient with many features |
Context: You're a data scientist at DoorDash. The company currently offers a blanket 20% discount to all customers who haven't ordered in 60 days. Marketing wants to optimize this strategy: instead of giving everyone the same discount, can we personalize discount amounts based on who benefits most?
Data: Historical A/B test with 100K lapsed customers:
Business question: Should we give 20% off to all 2M lapsed customers (costs $X million)? Or can we target only high-CATE customers—those who respond strongly to discounts—to maximize ROI?
Your task: Use Causal Forests to estimate heterogeneous treatment effects τ(x), identify which customer segments benefit most from discounts, and design a personalized targeting policy that maximizes incremental orders per dollar spent.
The Heterogeneity Hypothesis:
Not all customers respond equally to discounts. Likely variation by:
Why Causal Forests are Perfect for This:
Alternative Approaches (and why they're worse):
Step 2a: Data Preparation
import pandas as pd
import numpy as np
from econml.dml import CausalForestDML
# Load A/B test data
df = pd.read_csv('lapsed_customers_experiment.csv')
# Treatment: binary (1 = got discount, 0 = control)
W = df['discount_received'].values
# Outcome: binary (1 = ordered, 0 = didn't order)
Y = df['ordered_within_30days'].values
# Covariates: 40 customer features
X = df[['tenure_days', 'total_orders', 'avg_order_value', 'days_since_last_order',
'favorite_cuisine_diversity', 'urban', 'app_user', 'support_tickets',
'previous_discounts_used', ... # 40 total
]].values
print(f"Sample size: {len(Y)}")
print(f"Treatment: {W.sum()}/{len(W)} ({W.mean()*100:.1f}% treated)")
print(f"Outcome (control): {Y[W==0].mean()*100:.1f}%")
print(f"Outcome (treated): {Y[W==1].mean()*100:.1f}%")
print(f"Naive ATE: {(Y[W==1].mean() - Y[W==0].mean())*100:.1f} pp")
# Example output:
# Sample size: 100000
# Treatment: 50000/100000 (50.0% treated)
# Outcome (control): 5.2%
# Outcome (treated): 13.4%
# Naive ATE: 8.2 pp
Step 2b: Fit Causal Forest
from grf import CausalForest
# Fit causal forest with honest splitting
cf = CausalForest(
n_estimators=4000, # More trees = more stable
min_samples_leaf=50, # Min obs per leaf (avoid overfitting)
max_depth=None, # Let tree grow naturally
honest=True, # CRITICAL: use honest trees
honesty_fraction=0.5, # 50% for splits, 50% for estimates
inference=True, # Compute standard errors
random_state=42
)
# Fit on data
cf.fit(X, Y, W)
# Estimate CATE for each observation
tau_hat = cf.predict(X) # Individual treatment effects τ(x_i)
tau_stderr = cf.predict_stderr(X) # Standard errors (from inference)
print(f"\nCaTE Statistics:")
print(f" Mean CATE: {tau_hat.mean()*100:.2f} pp")
print(f" Median CATE: {np.median(tau_hat)*100:.2f} pp")
print(f" Min CATE: {tau_hat.min()*100:.2f} pp")
print(f" Max CATE: {tau_hat.max()*100:.2f} pp")
print(f" Std CATE: {tau_hat.std()*100:.2f} pp")
# Example:
# Mean CATE: 8.1 pp (≈ ATE, good sign)
# Median CATE: 7.3 pp
# Min CATE: -2.1 pp (some customers negatively affected!)
# Max CATE: 24.5 pp (high responders)
# Std CATE: 5.8 pp (substantial heterogeneity!)
Key Parameters Explained:
Step 2c: Validate Causal Forest Fit
# Check 1: Mean CATE ≈ ATE from naive comparison
ate_naive = Y[W==1].mean() - Y[W==0].mean()
ate_cf = tau_hat.mean()
print(f"ATE (naive): {ate_naive*100:.2f} pp")
print(f"ATE (causal forest): {ate_cf*100:.2f} pp")
print(f"Difference: {abs(ate_naive - ate_cf)*100:.2f} pp")
# Should be close (< 1pp difference)
# Check 2: Out-of-bag prediction quality
# (GRF package automatically computes OOB MSE)
print(f"\nOOB MSE: {cf.oob_prediction_error_:.4f}")
# Lower is better; compare to baseline (constant effect model)
# Check 3: Variable importance
var_importance = cf.feature_importances_
top_features = np.argsort(var_importance)[-10:][::-1]
print("\nTop 10 features driving heterogeneity:")
for idx in top_features:
print(f" {feature_names[idx]}: {var_importance[idx]:.4f}")
# Example output:
# 1. days_since_last_order: 0.18
# 2. total_orders: 0.14
# 3. avg_order_value: 0.11
# 4. tenure_days: 0.09
# 5. urban: 0.07
Analysis 1: Quantile Analysis—Distribution of Treatment Effects
# Split customers into quintiles by CATE
cate_quintiles = pd.qcut(tau_hat, q=5, labels=['Q1 (Lowest)', 'Q2', 'Q3', 'Q4', 'Q5 (Highest)'])
for q in ['Q1 (Lowest)', 'Q2', 'Q3', 'Q4', 'Q5 (Highest)']:
mask = (cate_quintiles == q)
print(f"\n{q}:")
print(f" CATE range: [{tau_hat[mask].min()*100:.1f}pp, {tau_hat[mask].max()*100:.1f}pp]")
print(f" Avg CATE: {tau_hat[mask].mean()*100:.1f}pp")
print(f" Share of total lift: {tau_hat[mask].sum()/tau_hat.sum()*100:.1f}%")
# Example output:
# Q5 (Highest): CATE range: [15.2pp, 24.5pp], Avg: 18.9pp
# → Share of total lift: 42% from top 20% of customers!
# Q1 (Lowest): CATE range: [-2.1pp, 2.8pp], Avg: 0.9pp
# → Almost no response—giving discounts is wasteful here
Key insight: Top quintile drives 42% of total incremental orders while being only 20% of customers → massive targeting opportunity!
Analysis 2: Best Linear Projection (BLP) Test for Heterogeneity
# Regress actual treatment effect on predicted CATE
# If significant heterogeneity exists, slope should be significantly different from 0
from statsmodels.api import OLS, add_constant
# Actual treatment effect proxy (for treated units)
y_treated = Y[W==1]
tau_treated = tau_hat[W==1]
# BLP regression: Y ~ intercept + CATE
model = OLS(y_treated, add_constant(tau_treated)).fit()
print("\nBest Linear Projection Test:")
print(f" Intercept: {model.params[0]:.4f} (p={model.pvalues[0]:.4f})")
print(f" Slope: {model.params[1]:.4f} (p={model.pvalues[1]:.4f})")
print(f" R²: {model.rsquared:.4f}")
# Interpretation:
# Slope ≈ 1 → CATE predictions are well-calibrated
# p-value < 0.05 → Significant heterogeneity detected
# R² > 0 → CATE explains variation in outcomes
# Example:
# Slope: 0.89 (p < 0.001) → Significant heterogeneity!
# R²: 0.12 → CATE explains 12% of outcome variation
Analysis 3: Partial Dependence Plots—Which Features Drive Heterogeneity?
from sklearn.inspection import partial_dependence, plot_partial_dependence
# Partial dependence: How does CATE vary with each feature?
features_to_plot = ['days_since_last_order', 'total_orders',
'avg_order_value', 'tenure_days']
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
for idx, feature in enumerate(features_to_plot):
ax = axes[idx // 2, idx % 2]
# Compute PD
pd_result = partial_dependence(cf, X, features=[feature_idx[feature]])
ax.plot(pd_result['values'][0], pd_result['average'][0])
ax.set_xlabel(feature)
ax.set_ylabel('CATE (pp)')
ax.set_title(f'CATE vs {feature}')
ax.grid(alpha=0.3)
plt.tight_layout()
# Example findings from PD plots:
# - days_since_last_order: CATE peaks at 60-90 days, drops for >120 days
# → These customers are "retrievable," very long-churned customers less so
# - total_orders: CATE highest for moderate users (10-30 orders),
# lower for very light (<5) and heavy (>50) users
# - avg_order_value: Negative relationship—high AOV customers don't need discount
# - tenure_days: U-shaped—new customers and very old customers respond less
Analysis 4: Policy Tree—Interpretable Segmentation
from sklearn.tree import DecisionTreeRegressor, plot_tree
# Fit decision tree on CATE to create interpretable rules
policy_tree = DecisionTreeRegressor(max_depth=3, min_samples_leaf=1000)
policy_tree.fit(X, tau_hat)
# Visualize
plt.figure(figsize=(20, 10))
plot_tree(policy_tree, feature_names=feature_names, filled=True)
plt.title('Policy Tree: Which customers have high CATE?')
# Example tree:
# Root: days_since_last_order < 90?
# Yes → total_orders < 20?
# Yes → CATE = 4.2pp (low responders)
# No → CATE = 16.8pp (HIGH responders)
# No → CATE = 2.1pp (long-churned, low response)
# Create targeting segments from tree
segments = policy_tree.apply(X)
for seg in np.unique(segments):
mask = (segments == seg)
print(f"\nSegment {seg}: n={mask.sum()}")
print(f" Avg CATE: {tau_hat[mask].mean()*100:.1f}pp")
print(f" Rule: {get_tree_path(policy_tree, seg)}")
Targeting Strategy: Threshold-Based
# Simulate different targeting policies: Give discount to customers with CATE > threshold
discount_cost = 5 # $5 cost per discount given
revenue_per_order = 8 # $8 revenue per incremental order
thresholds = np.linspace(0, 0.20, 21) # 0% to 20% CATE thresholds
results = []
for threshold in thresholds:
# Target customers with CATE > threshold
target_mask = (tau_hat > threshold)
n_targeted = target_mask.sum()
# Expected incremental orders
incremental_orders = tau_hat[target_mask].sum()
# Cost vs benefit
total_cost = n_targeted * discount_cost
total_revenue = incremental_orders * revenue_per_order
net_value = total_revenue - total_cost
roi = net_value / total_cost if total_cost > 0 else 0
results.append({
'threshold': threshold * 100,
'n_targeted': n_targeted,
'pct_targeted': n_targeted / len(tau_hat) * 100,
'incremental_orders': incremental_orders,
'total_cost': total_cost,
'total_revenue': total_revenue,
'net_value': net_value,
'roi': roi * 100
})
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))
Example Results:
| Threshold | % Targeted | Net Value | ROI |
| 0% | 100% | $310K | 31% |
| 5% | 65% | $420K | 65% |
| 10% | 35% | $490K ⭐ | 140% |
| 15% | 15% | $380K | 250% |
| 20% | 5% | $180K | 360% |
Optimal policy: Target customers with CATE > 10% → 35% of customers, $490K net value (58% more than blanket policy!)
Recommendation to Leadership:
Current Policy (Blanket 20% off):
Proposed Policy (Personalized Targeting at CATE > 10%):
✓ Save $6.5M in discount costs
✓ Increase net value by $180K (58% improvement)
✓ Improve ROI from 31% → 140%
Validation 1: Rank-Weighted Average Treatment Effect (RATE)
# RATE test: Do high-CATE customers actually have higher treatment effects?
# Regress outcomes on treatment, weighted by CATE rank
from scipy.stats import rankdata
ranks = rankdata(tau_hat)
weights = ranks / ranks.sum()
# Separate regressions for treated/control, weighted by CATE rank
y_treated_weighted = (Y[W==1] * weights[W==1]).sum()
y_control_weighted = (Y[W==0] * weights[W==0]).sum()
rate = y_treated_weighted - y_control_weighted
print(f"RATE (rank-weighted ATE): {rate*100:.2f} pp")
print(f"ATE (unweighted): {tau_hat.mean()*100:.2f} pp")
# If RATE > ATE → high-CATE customers have higher actual effects ✓
# If RATE ≈ ATE → no heterogeneity (CATE is noise)
# Example: RATE = 12.4pp vs ATE = 8.1pp → Validation success!
Validation 2: Out-of-Sample Policy Evaluation
# Hold out 20% of data for validation
from sklearn.model_selection import train_test_split
X_train, X_val, Y_train, Y_val, W_train, W_val = train_test_split(
X, Y, W, test_size=0.2, random_state=42
)
# Fit CF on train set
cf.fit(X_train, Y_train, W_train)
# Predict CATE on validation set (out-of-sample!)
tau_hat_val = cf.predict(X_val)
# Evaluate policy on validation set
threshold = 0.10
target_mask_val = (tau_hat_val > threshold)
# Actual outcomes in validation set
ate_val_targeted = (Y_val[W_val==1][target_mask_val[W_val==1]].mean() -
Y_val[W_val==0][target_mask_val[W_val==0]].mean())
print(f"\nOut-of-Sample Validation:")
print(f" Predicted CATE (targeted group): {tau_hat_val[target_mask_val].mean()*100:.2f} pp")
print(f" Actual ATE (targeted group): {ate_val_targeted*100:.2f} pp")
print(f" Difference: {abs(tau_hat_val[target_mask_val].mean() - ate_val_targeted)*100:.2f} pp")
# Small difference (<2pp) → good out-of-sample performance ✓
Validation 3: A/B Test Confirmation (Recommended)
Before rolling out personalized targeting to all 2M customers, run a pilot A/B test:
If pilot confirms CATE-based targeting outperforms blanket policy → roll out to full 2M customer base.
Q1: How do you prevent overfitting with so many features (40)?
Answer:
Q2: What if customers with high CATE would have ordered anyway (without discount)?
Answer:
This is a critical concern—we want CATE (treatment effect), not baseline propensity to order.
Why Causal Forest solves this:
Sanity check: Plot baseline outcome (control group) vs. CATE. If uncorrelated → good. If correlated → may be confusing propensity with treatment effect (check model specification).
Q3: Can we use CATE estimates for continuous treatments (e.g., discount amount)?
Answer:
Yes! Causal forests generalize to continuous treatments. Instead of binary (discount/no discount), estimate dose-response function τ(x, d) where d is discount level (0%, 10%, 20%, 30%).
Approach:
This enables fully personalized discounts (e.g., 8% for customer A, 22% for customer B).
Q4: How do you communicate CATE findings to non-technical stakeholders?
Answer: