1. Introduction: Why Deep Learning for Causality?
So far, we've covered classical methods (matching, regression, IV) and tree-based causal ML (causal forests, meta-learners). But what if:
- Your covariates are high-dimensional (images, text, complex behavioral data)?
- Treatment effects are highly nonlinear in the covariates?
- You want to explicitly balance treatment groups in a learned representation?
Deep learning offers a powerful framework for learning flexible, data-driven representations that can capture complex patterns while enforcing balance between treatment groups.
Core Insight:
Neural networks can learn shared representations of covariates that are informative for outcomes but balanced across treatment groups, reducing selection bias while maintaining predictive power.
2. Representation Learning for Causal Inference
2.1 Key Idea: Balanced Representations
The goal is to learn a low-dimensional representation Φ(X) of high-dimensional covariates X such that:
- Predictive: Φ(X) captures information relevant for predicting outcomes Y(0) and Y(1)
- Balanced: The distribution of Φ(X) is similar across treatment groups (treated vs control)
- Compact: Φ(X) has lower dimension than raw X, facilitating estimation
If we achieve both goals, we can estimate treatment effects by comparing outcomes within similar regions of the representation space, without worrying about covariate imbalance.
2.2 Notation and Setup
Recall our standard notation:
- Xi: High-dimensional covariates (customer features, behavioral data, etc.)
- Wi: Binary treatment indicator (1 = received promo, 0 = no promo)
- Yi: Observed outcome (purchase or not)
- Φ(Xi): Learned representation (output of shared layers in neural network)
We'll train neural networks with architectures designed to:
- Map X → Φ(X) via shared representation layers
- Predict Y(0) and Y(1) from Φ(X) via separate "heads"
- Enforce balance by penalizing differences in Φ(X) distributions across treatment groups
3. TARNet: Treatment-Agnostic Representation Network
3.1 Architecture
TARNet (Shalit et al., 2017) is a foundational architecture for causal effect estimation with neural networks.
Architecture Components:
- Shared Representation Network: Maps X → Φ(X) using several hidden layers
- Treatment-Specific Heads: Two separate output networks:
- μ0(Φ(X)) predicts E[Y(0) | X]
- μ1(Φ(X)) predicts E[Y(1) | X]
During training, we observe only one potential outcome per unit:
- If Wi = 1, we see Yi = Yi(1), so we train μ1 to predict Yi
- If Wi = 0, we see Yi = Yi(0), so we train μ0 to predict Yi
The loss function is:
3.2 PyTorch Implementation
import torch
import torch.nn as nn
import torch.optim as optim
class TARNet(nn.Module):
def __init__(self, input_dim, hidden_dim=128, repr_dim=64):
super(TARNet, self).__init__()
# Shared representation network
self.repr_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, repr_dim),
nn.ReLU()
)
# Treatment-specific heads
self.head_0 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
self.head_1 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, w):
# Compute shared representation
phi = self.repr_net(x)
# Compute predictions for both potential outcomes
y0_pred = self.head_0(phi).squeeze()
y1_pred = self.head_1(phi).squeeze()
# Return observed outcome based on treatment
y_pred = w * y1_pred + (1 - w) * y0_pred
return y_pred, y0_pred, y1_pred
def predict_ite(self, x):
"""Predict individual treatment effect"""
with torch.no_grad():
phi = self.repr_net(x)
y0_pred = self.head_0(phi).squeeze()
y1_pred = self.head_1(phi).squeeze()
return y1_pred - y0_pred
# Training loop
def train_tarnet(model, X_train, W_train, Y_train, epochs=100, lr=1e-3):
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
X = torch.FloatTensor(X_train)
W = torch.FloatTensor(W_train)
Y = torch.FloatTensor(Y_train)
for epoch in range(epochs):
optimizer.zero_grad()
y_pred, y0_pred, y1_pred = model(X, W)
# Factual loss (only for observed outcomes)
loss = criterion(y_pred, Y)
loss.backward()
optimizer.step()
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
return model
# Example usage
input_dim = 50 # e.g., 50 customer features
model = TARNet(input_dim=input_dim)
# Simulate data (replace with real data)
X_train = torch.randn(1000, input_dim)
W_train = torch.bernoulli(torch.ones(1000) * 0.5)
Y_train = torch.randn(1000)
trained_model = train_tarnet(model, X_train, W_train, Y_train)
# Predict ITE for new customers
X_test = torch.randn(100, input_dim)
ite = trained_model.predict_ite(X_test)
print(f"Average ITE: {ite.mean().item():.3f}")4. CFR: Counterfactual Regression
4.1 Adding a Balancing Penalty
TARNet learns predictive representations but doesn't explicitly enforce balance. Counterfactual Regression (CFR)adds a regularization term that penalizes imbalance in the learned representation.
Where disc(·) is a discrepancy measure between distributions (e.g., Wasserstein distance, MMD).
4.2 Integral Probability Metrics (IPM)
Common choices for disc(·) include:
- Wasserstein Distance: Optimal transport cost between distributions
- Maximum Mean Discrepancy (MMD): Distance in a reproducing kernel Hilbert space
For MMD with RBF kernel:
Where Φ ~ P (treated), Ψ ~ Q (control), and k(·,·) is an RBF kernel.
4.3 Implementation
def compute_mmd(phi_treated, phi_control, sigma=1.0):
"""Compute Maximum Mean Discrepancy with RBF kernel"""
def rbf_kernel(x, y, sigma):
dist = torch.cdist(x, y, p=2)
return torch.exp(-dist**2 / (2 * sigma**2))
k_tt = rbf_kernel(phi_treated, phi_treated, sigma).mean()
k_cc = rbf_kernel(phi_control, phi_control, sigma).mean()
k_tc = rbf_kernel(phi_treated, phi_control, sigma).mean()
mmd = k_tt + k_cc - 2 * k_tc
return mmd
class CFR(nn.Module):
def __init__(self, input_dim, hidden_dim=128, repr_dim=64):
super(CFR, self).__init__()
# Same architecture as TARNet
self.repr_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, repr_dim),
nn.ReLU()
)
self.head_0 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
self.head_1 = nn.Sequential(
nn.Linear(repr_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
def forward(self, x, w):
phi = self.repr_net(x)
y0_pred = self.head_0(phi).squeeze()
y1_pred = self.head_1(phi).squeeze()
y_pred = w * y1_pred + (1 - w) * y0_pred
return y_pred, y0_pred, y1_pred, phi
def train_cfr(model, X_train, W_train, Y_train, alpha=1.0, epochs=100, lr=1e-3):
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
X = torch.FloatTensor(X_train)
W = torch.FloatTensor(W_train)
Y = torch.FloatTensor(Y_train)
for epoch in range(epochs):
optimizer.zero_grad()
y_pred, y0_pred, y1_pred, phi = model(X, W)
# Factual loss
loss_factual = criterion(y_pred, Y)
# Balancing penalty (MMD)
phi_treated = phi[W == 1]
phi_control = phi[W == 0]
if len(phi_treated) > 0 and len(phi_control) > 0:
loss_balance = compute_mmd(phi_treated, phi_control)
else:
loss_balance = torch.tensor(0.0)
# Total loss
loss = loss_factual + alpha * loss_balance
loss.backward()
optimizer.step()
if (epoch + 1) % 20 == 0:
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, "
f"Factual: {loss_factual.item():.4f}, Balance: {loss_balance.item():.4f}")
return model5. Adversarial Learning for Balance
5.1 GAN Framework for Balancing
An alternative to explicit distance metrics is to use adversarial training à la GANs.
Adversarial Framework:
- Generator (Representation Network): Learns Φ(X) to fool the discriminator
- Discriminator: Tries to predict W from Φ(X)
- Outcome Predictors: Predict Y(0) and Y(1) from Φ(X)
If Φ(X) is perfectly balanced, the discriminator can't distinguish treated from control → accuracy = 50%.
The negative sign on Ldiscriminator means the representation network is incentivized to produce Φ(X) that makes the discriminator perform poorly (i.e., can't predict treatment).
5.2 Implementation
class AdversarialCausalNet(nn.Module):
def __init__(self, input_dim, hidden_dim=128, repr_dim=64):
super(AdversarialCausalNet, self).__init__()
# Representation network
self.repr_net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, repr_dim),
nn.ReLU()
)
# Outcome heads
self.head_0 = nn.Linear(repr_dim, 1)
self.head_1 = nn.Linear(repr_dim, 1)
# Discriminator (tries to predict treatment from representation)
self.discriminator = nn.Sequential(
nn.Linear(repr_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
def forward(self, x, w=None):
phi = self.repr_net(x)
y0_pred = self.head_0(phi).squeeze()
y1_pred = self.head_1(phi).squeeze()
if w is not None:
y_pred = w * y1_pred + (1 - w) * y0_pred
else:
y_pred = None
w_pred = self.discriminator(phi).squeeze()
return y_pred, y0_pred, y1_pred, w_pred, phi
def train_adversarial(model, X_train, W_train, Y_train,
lambda_adv=0.5, epochs=100, lr=1e-3):
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion_outcome = nn.MSELoss()
criterion_disc = nn.BCELoss()
X = torch.FloatTensor(X_train)
W = torch.FloatTensor(W_train)
Y = torch.FloatTensor(Y_train)
for epoch in range(epochs):
optimizer.zero_grad()
y_pred, y0_pred, y1_pred, w_pred, phi = model(X, W)
# Outcome loss
loss_outcome = criterion_outcome(y_pred, Y)
# Discriminator loss (we want discriminator to fail)
loss_disc = criterion_disc(w_pred, W)
# Total loss (negative sign encourages representations that fool discriminator)
loss = loss_outcome - lambda_adv * loss_disc
loss.backward()
optimizer.step()
if (epoch + 1) % 20 == 0:
disc_acc = ((w_pred > 0.5).float() == W).float().mean()
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, "
f"Outcome: {loss_outcome.item():.4f}, Disc Acc: {disc_acc.item():.3f}")
return model6. CEVAE: Causal Effect VAE
6.1 VAE Background
Variational Autoencoders (VAEs) learn latent representations Z by encoding X → Z and decoding Z → X, optimizing a variational lower bound (ELBO).
CEVAE (Louizos et al., 2017) extends VAEs to the causal inference setting by modeling the generative process:
W ~ p(W | Z) [treatment assignment]
Y ~ p(Y | W, Z) [outcome]
X ~ p(X | Z) [observed covariates]
6.2 CEVAE Model Structure
Model Components:
- Inference Network (Encoder): q(Z | X, W, Y) approximates posterior over latent Z
- Generative Model (Decoder):
- p(X | Z): Reconstructs covariates from latent
- p(W | Z): Models treatment propensity
- p(Y | W, Z): Models outcome
The ELBO objective is:
6.3 Implementation Sketch
class CEVAE(nn.Module):
def __init__(self, input_dim, latent_dim=20, hidden_dim=128):
super(CEVAE, self).__init__()
# Encoder q(z | x, w, y)
self.encoder = nn.Sequential(
nn.Linear(input_dim + 2, hidden_dim), # +2 for W and Y
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU()
)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
# Decoder p(x | z)
self.decoder_x = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim)
)
# p(w | z) - propensity model
self.decoder_w = nn.Sequential(
nn.Linear(latent_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Sigmoid()
)
# p(y | w, z) - outcome model
self.decoder_y = nn.Sequential(
nn.Linear(latent_dim + 1, hidden_dim // 2), # +1 for W
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1)
)
def encode(self, x, w, y):
inputs = torch.cat([x, w.unsqueeze(1), y.unsqueeze(1)], dim=1)
h = self.encoder(inputs)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, w):
x_recon = self.decoder_x(z)
w_pred = self.decoder_w(z).squeeze()
y_pred = self.decoder_y(torch.cat([z, w.unsqueeze(1)], dim=1)).squeeze()
return x_recon, w_pred, y_pred
def forward(self, x, w, y):
mu, logvar = self.encode(x, w, y)
z = self.reparameterize(mu, logvar)
x_recon, w_pred, y_pred = self.decode(z, w)
return x_recon, w_pred, y_pred, mu, logvar
def predict_ite(self, x, w, y):
"""Estimate ITE by intervening on treatment"""
with torch.no_grad():
mu, _ = self.encode(x, w, y)
# Predict under control
w0 = torch.zeros_like(w)
_, _, y0_pred = self.decode(mu, w0)
# Predict under treatment
w1 = torch.ones_like(w)
_, _, y1_pred = self.decode(mu, w1)
return y1_pred - y0_pred
# Note: Full training loop would optimize ELBO with reconstruction + KL terms
# This is a simplified sketch for illustration7. Practical Considerations
- •Hyperparameter Tuning: Balancing penalties (α, λ) require careful tuning via cross-validation
- •Network Architecture: Experiment with depth, width, activation functions, and dropout
- •Balance-Accuracy Tradeoff: Too much balancing can hurt predictive accuracy; monitor both
- •Evaluation: Use held-out data with known ground truth (simulations) or surrogate metrics
- •Scalability: Neural methods scale well to large datasets and high-dimensional X
8. Key Takeaways
- ✓Deep learning enables flexible, nonparametric causal effect estimation with high-dimensional data
- ✓TARNet learns treatment-agnostic representations with separate outcome heads
- ✓CFR adds explicit balancing via IPM (MMD, Wasserstein) to reduce selection bias
- ✓Adversarial methods use GANs to learn balanced representations
- ✓CEVAE models the full causal generative process using variational inference
- ✓These methods are powerful but require careful tuning and validation
9. Next Week Preview
Module 4, Week 2: Deep IV & Causal Discovery
We'll extend deep learning to instrumental variable settings (DeepIV) and explore causal structure learning— how to discover causal graphs from data using neural approaches, constraint-based methods, and score-based optimization.