← back to series

Module 4, Week 1: Neural Networks for Causal Effects

Article 8 of 1322 min read

📊 Running Example: Promotional Discount Campaign

Continuing with our e-commerce example: Does offering a 20% discount promo code increase customer purchases?

In this module, we'll explore how deep learning can learn rich representations that balance treatment groups while capturing complex patterns—enabling flexible, nonparametric causal estimation with neural networks.

1. Introduction: Why Deep Learning for Causality?

So far, we've covered classical methods (matching, regression, IV) and tree-based causal ML (causal forests, meta-learners). But what if:

  • Your covariates are high-dimensional (images, text, complex behavioral data)?
  • Treatment effects are highly nonlinear in the covariates?
  • You want to explicitly balance treatment groups in a learned representation?

Deep learning offers a powerful framework for learning flexible, data-driven representations that can capture complex patterns while enforcing balance between treatment groups.

Core Insight:

Neural networks can learn shared representations of covariates that are informative for outcomes but balanced across treatment groups, reducing selection bias while maintaining predictive power.

2. Representation Learning for Causal Inference

2.1 Key Idea: Balanced Representations

The goal is to learn a low-dimensional representation Φ(X) of high-dimensional covariates X such that:

  1. Predictive: Φ(X) captures information relevant for predicting outcomes Y(0) and Y(1)
  2. Balanced: The distribution of Φ(X) is similar across treatment groups (treated vs control)
  3. Compact: Φ(X) has lower dimension than raw X, facilitating estimation

If we achieve both goals, we can estimate treatment effects by comparing outcomes within similar regions of the representation space, without worrying about covariate imbalance.

2.2 Notation and Setup

Recall our standard notation:

  • Xi: High-dimensional covariates (customer features, behavioral data, etc.)
  • Wi: Binary treatment indicator (1 = received promo, 0 = no promo)
  • Yi: Observed outcome (purchase or not)
  • Φ(Xi): Learned representation (output of shared layers in neural network)

We'll train neural networks with architectures designed to:

  • Map X → Φ(X) via shared representation layers
  • Predict Y(0) and Y(1) from Φ(X) via separate "heads"
  • Enforce balance by penalizing differences in Φ(X) distributions across treatment groups

3. TARNet: Treatment-Agnostic Representation Network

3.1 Architecture

TARNet (Shalit et al., 2017) is a foundational architecture for causal effect estimation with neural networks.

Architecture Components:

  1. Shared Representation Network: Maps X → Φ(X) using several hidden layers
  2. Treatment-Specific Heads: Two separate output networks:
    • μ0(Φ(X)) predicts E[Y(0) | X]
    • μ1(Φ(X)) predicts E[Y(1) | X]

During training, we observe only one potential outcome per unit:

  • If Wi = 1, we see Yi = Yi(1), so we train μ1 to predict Yi
  • If Wi = 0, we see Yi = Yi(0), so we train μ0 to predict Yi

The loss function is:

L = Σi [Wi · (Yi - μ1(Φ(Xi)))2 + (1 - Wi) · (Yi - μ0(Φ(Xi)))2]

3.2 PyTorch Implementation

import torch
import torch.nn as nn
import torch.optim as optim

class TARNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, repr_dim=64):
        super(TARNet, self).__init__()

        # Shared representation network
        self.repr_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, repr_dim),
            nn.ReLU()
        )

        # Treatment-specific heads
        self.head_0 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.head_1 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, x, w):
        # Compute shared representation
        phi = self.repr_net(x)

        # Compute predictions for both potential outcomes
        y0_pred = self.head_0(phi).squeeze()
        y1_pred = self.head_1(phi).squeeze()

        # Return observed outcome based on treatment
        y_pred = w * y1_pred + (1 - w) * y0_pred

        return y_pred, y0_pred, y1_pred

    def predict_ite(self, x):
        """Predict individual treatment effect"""
        with torch.no_grad():
            phi = self.repr_net(x)
            y0_pred = self.head_0(phi).squeeze()
            y1_pred = self.head_1(phi).squeeze()
            return y1_pred - y0_pred

# Training loop
def train_tarnet(model, X_train, W_train, Y_train, epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    X = torch.FloatTensor(X_train)
    W = torch.FloatTensor(W_train)
    Y = torch.FloatTensor(Y_train)

    for epoch in range(epochs):
        optimizer.zero_grad()

        y_pred, y0_pred, y1_pred = model(X, W)

        # Factual loss (only for observed outcomes)
        loss = criterion(y_pred, Y)

        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

    return model

# Example usage
input_dim = 50  # e.g., 50 customer features
model = TARNet(input_dim=input_dim)

# Simulate data (replace with real data)
X_train = torch.randn(1000, input_dim)
W_train = torch.bernoulli(torch.ones(1000) * 0.5)
Y_train = torch.randn(1000)

trained_model = train_tarnet(model, X_train, W_train, Y_train)

# Predict ITE for new customers
X_test = torch.randn(100, input_dim)
ite = trained_model.predict_ite(X_test)
print(f"Average ITE: {ite.mean().item():.3f}")

4. CFR: Counterfactual Regression

4.1 Adding a Balancing Penalty

TARNet learns predictive representations but doesn't explicitly enforce balance. Counterfactual Regression (CFR)adds a regularization term that penalizes imbalance in the learned representation.

LCFR = Lfactual + α · disc(Φ(Xtreated), Φ(Xcontrol))

Where disc(·) is a discrepancy measure between distributions (e.g., Wasserstein distance, MMD).

4.2 Integral Probability Metrics (IPM)

Common choices for disc(·) include:

  • Wasserstein Distance: Optimal transport cost between distributions
  • Maximum Mean Discrepancy (MMD): Distance in a reproducing kernel Hilbert space

For MMD with RBF kernel:

MMD2(P, Q) = E[k(Φ, Φ')] - 2E[k(Φ, Ψ)] + E[k(Ψ, Ψ')]

Where Φ ~ P (treated), Ψ ~ Q (control), and k(·,·) is an RBF kernel.

4.3 Implementation

def compute_mmd(phi_treated, phi_control, sigma=1.0):
    """Compute Maximum Mean Discrepancy with RBF kernel"""
    def rbf_kernel(x, y, sigma):
        dist = torch.cdist(x, y, p=2)
        return torch.exp(-dist**2 / (2 * sigma**2))

    k_tt = rbf_kernel(phi_treated, phi_treated, sigma).mean()
    k_cc = rbf_kernel(phi_control, phi_control, sigma).mean()
    k_tc = rbf_kernel(phi_treated, phi_control, sigma).mean()

    mmd = k_tt + k_cc - 2 * k_tc
    return mmd

class CFR(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, repr_dim=64):
        super(CFR, self).__init__()

        # Same architecture as TARNet
        self.repr_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, repr_dim),
            nn.ReLU()
        )

        self.head_0 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

        self.head_1 = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, x, w):
        phi = self.repr_net(x)
        y0_pred = self.head_0(phi).squeeze()
        y1_pred = self.head_1(phi).squeeze()
        y_pred = w * y1_pred + (1 - w) * y0_pred
        return y_pred, y0_pred, y1_pred, phi

def train_cfr(model, X_train, W_train, Y_train, alpha=1.0, epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    X = torch.FloatTensor(X_train)
    W = torch.FloatTensor(W_train)
    Y = torch.FloatTensor(Y_train)

    for epoch in range(epochs):
        optimizer.zero_grad()

        y_pred, y0_pred, y1_pred, phi = model(X, W)

        # Factual loss
        loss_factual = criterion(y_pred, Y)

        # Balancing penalty (MMD)
        phi_treated = phi[W == 1]
        phi_control = phi[W == 0]

        if len(phi_treated) > 0 and len(phi_control) > 0:
            loss_balance = compute_mmd(phi_treated, phi_control)
        else:
            loss_balance = torch.tensor(0.0)

        # Total loss
        loss = loss_factual + alpha * loss_balance

        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, "
                  f"Factual: {loss_factual.item():.4f}, Balance: {loss_balance.item():.4f}")

    return model

5. Adversarial Learning for Balance

5.1 GAN Framework for Balancing

An alternative to explicit distance metrics is to use adversarial training à la GANs.

Adversarial Framework:

  1. Generator (Representation Network): Learns Φ(X) to fool the discriminator
  2. Discriminator: Tries to predict W from Φ(X)
  3. Outcome Predictors: Predict Y(0) and Y(1) from Φ(X)

If Φ(X) is perfectly balanced, the discriminator can't distinguish treated from control → accuracy = 50%.

L = Loutcome - λ · Ldiscriminator

The negative sign on Ldiscriminator means the representation network is incentivized to produce Φ(X) that makes the discriminator perform poorly (i.e., can't predict treatment).

5.2 Implementation

class AdversarialCausalNet(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, repr_dim=64):
        super(AdversarialCausalNet, self).__init__()

        # Representation network
        self.repr_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, repr_dim),
            nn.ReLU()
        )

        # Outcome heads
        self.head_0 = nn.Linear(repr_dim, 1)
        self.head_1 = nn.Linear(repr_dim, 1)

        # Discriminator (tries to predict treatment from representation)
        self.discriminator = nn.Sequential(
            nn.Linear(repr_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

    def forward(self, x, w=None):
        phi = self.repr_net(x)
        y0_pred = self.head_0(phi).squeeze()
        y1_pred = self.head_1(phi).squeeze()

        if w is not None:
            y_pred = w * y1_pred + (1 - w) * y0_pred
        else:
            y_pred = None

        w_pred = self.discriminator(phi).squeeze()

        return y_pred, y0_pred, y1_pred, w_pred, phi

def train_adversarial(model, X_train, W_train, Y_train,
                      lambda_adv=0.5, epochs=100, lr=1e-3):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion_outcome = nn.MSELoss()
    criterion_disc = nn.BCELoss()

    X = torch.FloatTensor(X_train)
    W = torch.FloatTensor(W_train)
    Y = torch.FloatTensor(Y_train)

    for epoch in range(epochs):
        optimizer.zero_grad()

        y_pred, y0_pred, y1_pred, w_pred, phi = model(X, W)

        # Outcome loss
        loss_outcome = criterion_outcome(y_pred, Y)

        # Discriminator loss (we want discriminator to fail)
        loss_disc = criterion_disc(w_pred, W)

        # Total loss (negative sign encourages representations that fool discriminator)
        loss = loss_outcome - lambda_adv * loss_disc

        loss.backward()
        optimizer.step()

        if (epoch + 1) % 20 == 0:
            disc_acc = ((w_pred > 0.5).float() == W).float().mean()
            print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}, "
                  f"Outcome: {loss_outcome.item():.4f}, Disc Acc: {disc_acc.item():.3f}")

    return model

6. CEVAE: Causal Effect VAE

6.1 VAE Background

Variational Autoencoders (VAEs) learn latent representations Z by encoding X → Z and decoding Z → X, optimizing a variational lower bound (ELBO).

CEVAE (Louizos et al., 2017) extends VAEs to the causal inference setting by modeling the generative process:

Z ~ p(Z) [latent confounders]
W ~ p(W | Z) [treatment assignment]
Y ~ p(Y | W, Z) [outcome]
X ~ p(X | Z) [observed covariates]

6.2 CEVAE Model Structure

Model Components:

  • Inference Network (Encoder): q(Z | X, W, Y) approximates posterior over latent Z
  • Generative Model (Decoder):
    • p(X | Z): Reconstructs covariates from latent
    • p(W | Z): Models treatment propensity
    • p(Y | W, Z): Models outcome

The ELBO objective is:

L = Eq(Z|X,W,Y)[log p(X|Z) + log p(W|Z) + log p(Y|W,Z)] - KL(q(Z|X,W,Y) || p(Z))

6.3 Implementation Sketch

class CEVAE(nn.Module):
    def __init__(self, input_dim, latent_dim=20, hidden_dim=128):
        super(CEVAE, self).__init__()

        # Encoder q(z | x, w, y)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + 2, hidden_dim),  # +2 for W and Y
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

        # Decoder p(x | z)
        self.decoder_x = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

        # p(w | z) - propensity model
        self.decoder_w = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )

        # p(y | w, z) - outcome model
        self.decoder_y = nn.Sequential(
            nn.Linear(latent_dim + 1, hidden_dim // 2),  # +1 for W
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def encode(self, x, w, y):
        inputs = torch.cat([x, w.unsqueeze(1), y.unsqueeze(1)], dim=1)
        h = self.encoder(inputs)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, w):
        x_recon = self.decoder_x(z)
        w_pred = self.decoder_w(z).squeeze()
        y_pred = self.decoder_y(torch.cat([z, w.unsqueeze(1)], dim=1)).squeeze()
        return x_recon, w_pred, y_pred

    def forward(self, x, w, y):
        mu, logvar = self.encode(x, w, y)
        z = self.reparameterize(mu, logvar)
        x_recon, w_pred, y_pred = self.decode(z, w)
        return x_recon, w_pred, y_pred, mu, logvar

    def predict_ite(self, x, w, y):
        """Estimate ITE by intervening on treatment"""
        with torch.no_grad():
            mu, _ = self.encode(x, w, y)

            # Predict under control
            w0 = torch.zeros_like(w)
            _, _, y0_pred = self.decode(mu, w0)

            # Predict under treatment
            w1 = torch.ones_like(w)
            _, _, y1_pred = self.decode(mu, w1)

            return y1_pred - y0_pred

# Note: Full training loop would optimize ELBO with reconstruction + KL terms
# This is a simplified sketch for illustration

7. Practical Considerations

  • Hyperparameter Tuning: Balancing penalties (α, λ) require careful tuning via cross-validation
  • Network Architecture: Experiment with depth, width, activation functions, and dropout
  • Balance-Accuracy Tradeoff: Too much balancing can hurt predictive accuracy; monitor both
  • Evaluation: Use held-out data with known ground truth (simulations) or surrogate metrics
  • Scalability: Neural methods scale well to large datasets and high-dimensional X

8. Key Takeaways

  • Deep learning enables flexible, nonparametric causal effect estimation with high-dimensional data
  • TARNet learns treatment-agnostic representations with separate outcome heads
  • CFR adds explicit balancing via IPM (MMD, Wasserstein) to reduce selection bias
  • Adversarial methods use GANs to learn balanced representations
  • CEVAE models the full causal generative process using variational inference
  • These methods are powerful but require careful tuning and validation

9. Next Week Preview

Module 4, Week 2: Deep IV & Causal Discovery

We'll extend deep learning to instrumental variable settings (DeepIV) and explore causal structure learning— how to discover causal graphs from data using neural approaches, constraint-based methods, and score-based optimization.