import warnings
from collections import defaultdict
import graphviz
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.display import display
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor, export_text, plot_tree
warnings.filterwarnings("ignore")
sns.set_theme(style="whitegrid", context="notebook")
pd.set_option("display.float_format", "{:.3f}".format)04. Causal Forests
Causal forests are one of the most important tree-based methods for estimating heterogeneous treatment effects. They extend the random forest idea from prediction to causal effect estimation. Each tree leaf contributes treatment-control contrasts among locally similar units.
This notebook builds an educational causal forest from scratch. Production work should still use libraries such as grf, EconML, or causalml. The scratch implementation exposes the mechanics of adaptive neighborhoods, honest splitting, leaf-level treatment effects, forest averaging, and validation by ranked groups.
Learning Goals
By the end of this notebook, you should be able to:
- Explain why a regression forest is not automatically a causal forest.
- Describe causal trees, honesty, adaptive neighborhoods, and forest weights.
- Build a simplified honest causal forest for a randomized experiment.
- Compare a single causal tree, transformed-outcome random forest, T-learner, and honest causal forest.
- Interpret forest predictions, tree-to-tree stability, and feature importance carefully.
- Translate a causal forest score into decile validation and treatment targeting curves.
Dataset and Experiment Setup
This notebook simulates accounts in a business-to-business customer-success setting. The unit of analysis is an account, and the covariates capture account scale, usage depth, support exposure, operational maturity, and other pre-treatment differences. The treatment is a success playbook or targeted support intervention. The outcome is a post-intervention business result such as retained value, renewal probability, or product adoption.
The purpose of the data is to create a realistic CATE problem where the intervention works better for some account profiles than others. A causal forest is useful only if it can find stable heterogeneity without chasing noise. The simulation therefore includes nonlinear response, baseline differences, treatment-effect variation, and nuisance variation. Because the true effect is known, the notebook can evaluate heterogeneity estimates directly. In real deployments, the same diagnostics would rely on out-of-sample validation, honest splits, calibration checks, and prospective experiments.
# Define reusable helpers so the later analytical cells stay focused on the causal argument.
def sigmoid(x):
"""Compute sigmoid for the causal forests example.
Inputs:
- `x`: logit-scale numeric input converted into a probability.
Returns: a NumPy array or scalar of logistic probabilities."""
return 1 / (1 + np.exp(-x))
def make_dag(edges, title=None, node_colors=None, rankdir="LR"):
"""Construct the dag object.
Inputs:
- `edges`: directed or undirected graph edges represented as node pairs.
- `title`: plot or graph title displayed above the figure.
- `node_colors`: optional mapping from graph nodes to fill colors.
- `rankdir`: Graphviz layout direction, such as left-to-right or top-to-bottom.
Returns: a Graphviz diagram for the causal structure in the section."""
graph = graphviz.Digraph(format="svg")
graph.attr(rankdir=rankdir, bgcolor="white", pad="0.2")
graph.attr("node", shape="box", style="rounded,filled", color="#334155", fontname="Helvetica")
graph.attr("edge", color="#475569", arrowsize="0.8")
node_colors = node_colors or {}
nodes = sorted({node for edge in edges for node in edge[:2]})
for node in nodes:
graph.node(node, fillcolor=node_colors.get(node, "#eef2ff"))
for edge in edges:
if len(edge) == 2:
graph.edge(edge[0], edge[1])
else:
graph.edge(edge[0], edge[1], label=edge[2])
if title:
graph.attr(label=title, labelloc="t", fontsize="18", fontname="Helvetica-Bold")
return graph
def simulate_success_playbook_experiment(n=9_000, seed=5404):
"""Simulate the success playbook experiment data-generating process.
Inputs:
- `n`: number of simulated units or records to generate for the teaching example.
- `seed`: integer random seed that makes the simulation or model split reproducible.
Returns: a pandas DataFrame with generated covariates, treatment or policy variables, outcomes, and teaching diagnostics for this scenario."""
rng = np.random.default_rng(seed)
risk_score = rng.beta(2.2, 2.7, n)
engagement_score = rng.beta(2.5, 2.0, n)
log_account_value = rng.normal(4.25, 0.60, n)
account_value = np.exp(log_account_value)
tenure_months = rng.gamma(2.1, 10.0, n)
enterprise_plan = rng.binomial(1, sigmoid(-0.30 + 0.65 * (log_account_value - 4.25) + 0.45 * risk_score), n)
integration_count = rng.poisson(np.clip(1.2 + 4.0 * engagement_score + 1.2 * enterprise_plan, 0.2, None))
support_tickets = rng.poisson(np.clip(0.5 + 3.2 * risk_score - 0.8 * engagement_score + 0.5 * enterprise_plan, 0.1, None))
treatment = rng.binomial(1, 0.5, n)
mu0 = (
80
+ 0.12 * account_value
+ 9.0 * engagement_score
- 12.0 * risk_score
+ 4.0 * enterprise_plan
+ 0.08 * tenure_months
+ 1.3 * integration_count
- 1.7 * support_tickets
+ 2.0 * np.sin(2 * engagement_score)
)
true_cate = (
-3.0
+ 18.0 * risk_score * (1 - engagement_score)
+ 4.0 * enterprise_plan
+ 3.0 * (support_tickets >= 4).astype(float)
- 4.0 * (engagement_score > 0.82).astype(float)
+ 2.0 * np.log1p(integration_count)
- 2.0 * (tenure_months < 5).astype(float)
)
y0 = mu0 + rng.normal(0, 6.5, n)
y1 = y0 + true_cate
outcome = np.where(treatment == 1, y1, y0)
return pd.DataFrame(
{
"account_id": np.arange(1, n + 1),
"risk_score": risk_score,
"engagement_score": engagement_score,
"log_account_value": log_account_value,
"account_value": account_value,
"tenure_months": tenure_months,
"enterprise_plan": enterprise_plan,
"integration_count": integration_count,
"support_tickets": support_tickets,
"treatment": treatment,
"outcome": outcome,
"mu0": mu0,
"true_cate": true_cate,
}
)
def difference_in_means(frame, outcome="outcome", treatment="treatment"):
"""Compute difference in means for the causal forests example.
Inputs:
- `frame`: pandas DataFrame containing the rows for this function and the causal variables named by the other arguments.
- `outcome`: name of the observed outcome column whose causal contrast is being estimated.
- `treatment`: name of the treatment-assignment column, usually coded as treated versus control.
Returns: a treatment-control difference with its uncertainty summary."""
treated = frame.loc[frame[treatment] == 1, outcome]
control = frame.loc[frame[treatment] == 0, outcome]
if len(treated) < 2 or len(control) < 2:
return np.nan, np.nan
estimate = treated.mean() - control.mean()
se = np.sqrt(treated.var(ddof=1) / len(treated) + control.var(ddof=1) / len(control))
return estimate, se
def transformed_outcome(frame, treatment_prob=0.5):
"""Compute the transformed outcome target for randomized CATE learning.
Inputs:
- `frame`: pandas DataFrame containing the rows for this function and the causal variables named by the other arguments.
- `treatment_prob`: known treatment probability used to construct transformed outcomes or honest splitting targets.
Returns: a NumPy array with inverse-propensity-scaled outcomes for treated and control units."""
return frame["outcome"] * (frame["treatment"] - treatment_prob) / (treatment_prob * (1 - treatment_prob))
def top_fraction_mask(frame, score_col, fraction=0.20, largest=True):
"""Compute top fraction mask for the causal forests example.
Inputs:
- `frame`: pandas DataFrame containing the rows for this function and the causal variables named by the other arguments.
- `score_col`: name of the model-score column used for ranking, targeting, or validation.
- `fraction`: fraction of units selected by a top-ranked targeting rule.
- `largest`: whether larger scores should be treated as better when selecting top-ranked units.
Returns: a Boolean mask or treatment rule selecting the highest-scoring share of units."""
n_select = max(1, int(np.floor(fraction * len(frame))))
selected_index = frame[score_col].sort_values(ascending=not largest).head(n_select).index
mask = pd.Series(False, index=frame.index)
mask.loc[selected_index] = True
return mask
def fit_honest_tree(sample, feature_cols, seed=1, max_depth=4, min_samples_leaf=120, treatment_prob=0.5):
"""Fit the honest tree model or estimator.
Inputs:
- `sample`: honest estimation or splitting sample used to fit a tree-based causal model.
- `feature_cols`: list of pre-treatment covariate column names used as model features or effect modifiers.
- `seed`: integer random seed that makes the simulation or model split reproducible.
- `max_depth`: maximum depth of a tree in the honest-tree or forest estimator.
- `min_samples_leaf`: minimum number of rows allowed in each tree leaf.
- `treatment_prob`: known treatment probability used to construct transformed outcomes or honest splitting targets.
Returns: fitted model objects or a structured result used by the estimator in this notebook."""
discovery, estimation = train_test_split(
sample,
test_size=0.50,
random_state=seed,
stratify=sample["treatment"],
)
tree = DecisionTreeRegressor(
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
random_state=seed,
)
tree.fit(discovery[feature_cols], transformed_outcome(discovery, treatment_prob=treatment_prob))
estimation = estimation.copy()
estimation["leaf_id"] = tree.apply(estimation[feature_cols])
leaf_effects = {}
leaf_se = {}
leaf_counts = {}
leaf_indices = {}
for leaf_id, leaf_frame in estimation.groupby("leaf_id"):
effect, se = difference_in_means(leaf_frame)
leaf_effects[leaf_id] = effect
leaf_se[leaf_id] = se
leaf_counts[leaf_id] = len(leaf_frame)
leaf_indices[leaf_id] = leaf_frame.index.to_numpy()
global_effect, _ = difference_in_means(estimation)
return {
"tree": tree,
"leaf_effects": leaf_effects,
"leaf_se": leaf_se,
"leaf_counts": leaf_counts,
"leaf_indices": leaf_indices,
"discovery_index": discovery.index.to_numpy(),
"estimation_index": estimation.index.to_numpy(),
"global_effect": global_effect,
}
def predict_honest_tree(tree_bundle, X):
"""Generate honest tree predictions.
Inputs:
- `tree_bundle`: dictionary or object containing the fitted honest tree and its estimation metadata.
- `X`: feature matrix used for prediction or Bayesian regression.
Returns: predicted effects, rewards, or outcomes for the rows passed into the function."""
leaves = tree_bundle["tree"].apply(X)
fallback = tree_bundle["global_effect"]
predictions = []
for leaf in leaves:
value = tree_bundle["leaf_effects"].get(leaf, fallback)
if pd.isna(value):
value = fallback
predictions.append(value)
return np.asarray(predictions)
def honest_leaf_table(tree_bundle, estimation_frame):
"""Compute honest leaf table for the causal forests example.
Inputs:
- `tree_bundle`: dictionary or object containing the fitted honest tree and its estimation metadata.
- `estimation_frame`: DataFrame reserved for estimating leaf-level treatment effects after tree structure is learned.
Returns: a pandas table used for reporting, plotting, or display."""
rows = []
for leaf_id, effect in tree_bundle["leaf_effects"].items():
idx = tree_bundle["leaf_indices"][leaf_id]
leaf_frame = estimation_frame.loc[idx]
rows.append(
{
"leaf_id": leaf_id,
"n": len(leaf_frame),
"treatment_share": leaf_frame["treatment"].mean(),
"estimated_effect": effect,
"std_error": tree_bundle["leaf_se"].get(leaf_id, np.nan),
"ci_low": effect - 1.96 * tree_bundle["leaf_se"].get(leaf_id, np.nan),
"ci_high": effect + 1.96 * tree_bundle["leaf_se"].get(leaf_id, np.nan),
"true_cate_mean": leaf_frame["true_cate"].mean(),
"mean_risk": leaf_frame["risk_score"].mean(),
"mean_engagement": leaf_frame["engagement_score"].mean(),
}
)
return pd.DataFrame(rows).sort_values("estimated_effect").reset_index(drop=True)
def fit_honest_forest(train, feature_cols, n_trees=90, sample_fraction=0.65, seed=100, max_depth=5, min_samples_leaf=80):
"""Fit the honest forest model or estimator.
Inputs:
- `train`: training DataFrame used to learn preprocessing, nuisance, or effect-model quantities.
- `feature_cols`: list of pre-treatment covariate column names used as model features or effect modifiers.
- `n_trees`: number of honest trees to fit in the forest-style estimator.
- `sample_fraction`: fraction of rows sampled for each tree in the honest forest.
- `seed`: integer random seed that makes the simulation or model split reproducible.
- `max_depth`: maximum depth of a tree in the honest-tree or forest estimator.
- `min_samples_leaf`: minimum number of rows allowed in each tree leaf.
Returns: fitted model objects or a structured result used by the estimator in this notebook."""
rng = np.random.default_rng(seed)
bundles = []
n_sample = int(sample_fraction * len(train))
for tree_id in range(n_trees):
sample_index = rng.choice(train.index.to_numpy(), size=n_sample, replace=False)
sample = train.loc[sample_index]
bundle = fit_honest_tree(
sample,
feature_cols,
seed=seed + tree_id,
max_depth=max_depth,
min_samples_leaf=min_samples_leaf,
)
bundles.append(bundle)
split_importance = np.vstack([bundle["tree"].feature_importances_ for bundle in bundles]).mean(axis=0)
return {"bundles": bundles, "feature_cols": feature_cols, "split_importance": split_importance}
def predict_honest_forest(forest, X):
"""Generate honest forest predictions.
Inputs:
- `forest`: fitted honest forest or ensemble used for CATE prediction.
- `X`: feature matrix used for prediction or Bayesian regression.
Returns: predicted effects, rewards, or outcomes for the rows passed into the function."""
tree_predictions = np.column_stack([predict_honest_tree(bundle, X) for bundle in forest["bundles"]])
return pd.DataFrame(
{
"cate_hat": np.nanmean(tree_predictions, axis=1),
"tree_sd": np.nanstd(tree_predictions, axis=1, ddof=1),
"tree_q10": np.nanquantile(tree_predictions, 0.10, axis=1),
"tree_q90": np.nanquantile(tree_predictions, 0.90, axis=1),
},
index=X.index,
)
def forest_neighbor_weights(forest, x_row):
"""Compute forest proximity weights for one target row.
Inputs:
- `forest`: fitted honest forest or ensemble used for CATE prediction.
- `x_row`: single feature row whose forest-neighbor weights are being inspected.
Returns: a NumPy vector showing how often each training row shares leaves with the target row."""
weights = defaultdict(float)
n_trees = len(forest["bundles"])
feature_cols = forest["feature_cols"]
for bundle in forest["bundles"]:
leaf_id = bundle["tree"].apply(x_row[feature_cols])[0]
neighbor_index = bundle["leaf_indices"].get(leaf_id, np.array([]))
if len(neighbor_index) == 0:
continue
per_neighbor_weight = 1 / (n_trees * len(neighbor_index))
for idx in neighbor_index:
weights[idx] += per_neighbor_weight
return pd.Series(weights).sort_values(ascending=False)
def model_metrics(frame, score_col, label, contact_fraction=0.20):
"""Compute model metrics for the causal forests example.
Inputs:
- `frame`: pandas DataFrame containing the rows for this function and the causal variables named by the other arguments.
- `score_col`: name of the model-score column used for ranking, targeting, or validation.
- `label`: human-readable label attached to the model, policy, or diagnostic result.
- `contact_fraction`: share of units contacted under a targeting policy.
Returns: a model object or model-based diagnostic used by the notebook."""
selected = top_fraction_mask(frame, score_col, fraction=contact_fraction)
targeted = frame.loc[selected]
return {
"model": label,
"cate_rmse": np.sqrt(mean_squared_error(frame["true_cate"], frame[score_col])),
"cate_correlation": np.corrcoef(frame["true_cate"], frame[score_col])[0, 1],
"mean_predicted_cate": frame[score_col].mean(),
"true_cate_top20": targeted["true_cate"].mean(),
"negative_cate_share_top20": (targeted["true_cate"] < 0).mean(),
}
def decile_validation_table(frame, score_col):
"""Compute decile validation table for the causal forests example.
Inputs:
- `frame`: pandas DataFrame containing the rows for this function and the causal variables named by the other arguments.
- `score_col`: name of the model-score column used for ranking, targeting, or validation.
Returns: a pandas table used for reporting, plotting, or display."""
work = frame.sort_values(score_col, ascending=False).reset_index(drop=True).copy()
work["score_decile"] = pd.qcut(
np.arange(len(work)),
q=10,
labels=[f"D{i} highest" if i == 1 else f"D{i}" for i in range(1, 11)],
)
rows = []
for decile, group in work.groupby("score_decile", observed=True):
effect, se = difference_in_means(group)
rows.append(
{
"decile": str(decile),
"n": len(group),
"treatment_share": group["treatment"].mean(),
"observed_lift": effect,
"ci_low": effect - 1.96 * se,
"ci_high": effect + 1.96 * se,
"predicted_cate_mean": group[score_col].mean(),
"true_cate_mean": group["true_cate"].mean(),
"tree_sd_mean": group["forest_tree_sd"].mean() if "forest_tree_sd" in group else np.nan,
}
)
return pd.DataFrame(rows)1. Regression Forests Versus Causal Forests
A random forest for prediction estimates an outcome surface:
\[ \hat{m}(x) \approx E[Y \mid X=x] \]
A causal forest estimates a treatment-effect surface:
\[ \hat{\tau}(x) \approx E[Y(1)-Y(0) \mid X=x] \]
Those are different targets. A feature that strongly predicts revenue may not modify the treatment effect. A feature that barely predicts revenue may be crucial for treatment responsiveness.
Breiman (2001) introduced random forests as ensembles of randomized trees for prediction. Athey and Imbens (2016) adapted tree partitioning to heterogeneous causal effects using honest estimation. Wager and Athey (2018) developed causal forests for heterogeneous treatment-effect estimation. Athey, Tibshirani, and Wager (2019) generalized this idea by viewing forests as adaptive local weighting schemes for solving local moment equations.
make_dag(
edges=[
("Features", "TreeSplits"),
("TreeSplits", "LocalNeighborhoods"),
("Treatment", "LeafEffectEstimates"),
("Outcome", "LeafEffectEstimates"),
("LocalNeighborhoods", "LeafEffectEstimates"),
("LeafEffectEstimates", "ForestAverage"),
("ForestAverage", "CATEPrediction"),
("CATEPrediction", "TargetingPolicy"),
],
title="A causal forest averages many local treatment-control contrasts",
node_colors={
"Features": "#dbeafe",
"TreeSplits": "#fef3c7",
"LocalNeighborhoods": "#dcfce7",
"Treatment": "#fee2e2",
"Outcome": "#f1f5f9",
"LeafEffectEstimates": "#cffafe",
"ForestAverage": "#e0e7ff",
"CATEPrediction": "#ede9fe",
"TargetingPolicy": "#fce7f3",
},
)The key forest intuition is adaptive nearest neighbors. For a target account with features \(x\), each tree places that account into a leaf. The estimation-sample accounts in the same leaf become local neighbors. The forest prediction averages treatment-effect estimates across many such neighborhoods.
2. Honesty
A causal tree can overstate heterogeneity if the same data are used both to find splits and to estimate leaf effects. Honesty addresses this by splitting the sample into two roles:
- Discovery sample: choose the tree structure.
- Estimation sample: estimate treatment effects within the discovered leaves.
This separation is important because treatment effects are never observed at the individual level. If we repeatedly search for subgroups and estimate them on the same data, we can mistake noise for heterogeneity.
In this educational implementation, each tree will:
- Draw a random subsample of the training data.
- Split that subsample into discovery and estimation halves.
- Fit a tree on a transformed outcome in the discovery half.
- Estimate treatment-control differences in the estimation half leaves.
- Predict a new account by dropping it through the tree and using the effect estimate in its leaf.
Mathematical Core: Honest Local Treatment Effects
A causal forest can be viewed as an adaptive local estimator. For a target account with features \(x\), the forest gives each training row a weight \(\alpha_i(x)\) based on how often row \(i\) falls in the same leaves as \(x\). The local treatment effect is then estimated as a weighted treated-control contrast:
\[ \hat{\tau}(x)= \frac{\sum_{i} \alpha_i(x) W_i Y_i}{\sum_{i} \alpha_i(x) W_i} - \frac{\sum_{i} \alpha_i(x)(1-W_i)Y_i}{\sum_{i} \alpha_i(x)(1-W_i)}. \]
Honesty separates two tasks. The discovery sample chooses splits that create useful neighborhoods. The estimation sample computes the leaf treatment effects. This matters because the individual effect \(Y_i(1)-Y_i(0)\) is never observed. If the same rows are used to search for subgroups and estimate subgroup effects, the largest-looking leaves can be partly search noise.
In randomized experiments with treatment probability \(e\), the transformed outcome:
\[ Y_i^* = Y_i\frac{W_i-e}{e(1-e)} \]
has conditional mean \(E[Y_i^*\mid X_i=x]=\tau(x)\). That identity explains why transformed-outcome trees can find candidate splits, while honest leaf estimates keep the final causal estimate tied to treatment-control comparisons.
3. Running Example: Success Playbook Experiment
A SaaS company tests a high-touch success playbook. The playbook includes success-manager outreach, workflow review, and a tailored adoption plan. The outcome is next-quarter account value.
The experiment is randomized with treatment probability 0.5. That lets us focus on the causal forest mechanics without simultaneously solving confounding. Later, in production observational settings, causal forests are often combined with nuisance models for outcome and propensity adjustment.
# Generate the teaching data and keep the causal quantities needed for diagnostics.
df = simulate_success_playbook_experiment()
feature_cols = [
"risk_score",
"engagement_score",
"log_account_value",
"tenure_months",
"enterprise_plan",
"integration_count",
"support_tickets",
]
summary = pd.DataFrame(
{
"quantity": [
"Accounts",
"Treatment share",
"Outcome mean",
"True ATE",
"True CATE standard deviation",
"Share with negative true CATE",
],
"value": [
len(df),
df["treatment"].mean(),
df["outcome"].mean(),
df["true_cate"].mean(),
df["true_cate"].std(),
(df["true_cate"] < 0).mean(),
],
}
)
display(summary.round(3))
display(df.head())| quantity | value | |
|---|---|---|
| 0 | Accounts | 9000.000 |
| 1 | Treatment share | 0.497 |
| 2 | Outcome mean | 99.703 |
| 3 | True ATE | 5.203 |
| 4 | True CATE standard deviation | 4.212 |
| 5 | Share with negative true CATE | 0.087 |
| account_id | risk_score | engagement_score | log_account_value | account_value | tenure_months | enterprise_plan | integration_count | support_tickets | treatment | outcome | mu0 | true_cate | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0.350 | 0.391 | 3.914 | 50.109 | 44.304 | 1 | 4 | 0 | 0 | 99.505 | 99.486 | 8.057 |
| 1 | 2 | 0.717 | 0.395 | 4.886 | 132.472 | 11.013 | 1 | 1 | 6 | 1 | 92.603 | 88.245 | 13.198 |
| 2 | 3 | 0.259 | 0.200 | 5.742 | 311.680 | 58.479 | 1 | 3 | 3 | 0 | 114.522 | 124.341 | 7.509 |
| 3 | 4 | 0.790 | 0.849 | 5.076 | 160.122 | 19.194 | 0 | 6 | 4 | 0 | 100.677 | 101.893 | 2.044 |
| 4 | 5 | 0.299 | 0.322 | 4.510 | 90.959 | 32.640 | 1 | 4 | 2 | 0 | 100.047 | 99.845 | 7.863 |
This result clarifies how far the method can carry the claim. Honest splitting and forest averaging are there to keep local effect estimates from chasing noise.
# Build the visualization for the estimates or diagnostics computed above.
fig, axes = plt.subplots(1, 3, figsize=(14, 4.2))
sns.histplot(df["true_cate"], bins=40, kde=True, color="#0f766e", ax=axes[0])
axes[0].axvline(df["true_cate"].mean(), color="#16a34a", linestyle="--", label="ATE")
axes[0].axvline(0, color="#334155", linestyle=":")
axes[0].set_title("True CATE distribution")
axes[0].set_xlabel("True treatment effect")
axes[0].legend()
sample_plot = df.sample(3000, random_state=12)
sns.scatterplot(
data=sample_plot,
x="risk_score",
y="true_cate",
hue="enterprise_plan",
alpha=0.45,
edgecolor=None,
palette=["#2563eb", "#dc2626"],
ax=axes[1],
)
axes[1].set_title("Risk and plan type modify effects")
axes[1].set_xlabel("Risk score")
axes[1].set_ylabel("True CATE")
axes[1].legend(title="Enterprise")
sns.scatterplot(
data=sample_plot,
x="engagement_score",
y="true_cate",
alpha=0.35,
edgecolor=None,
color="#7c3aed",
ax=axes[2],
)
axes[2].set_title("Very engaged accounts have less room to improve")
axes[2].set_xlabel("Engagement score")
axes[2].set_ylabel("True CATE")
plt.tight_layout()
plt.show()
The treatment effect is nonlinear. High-risk, low-engagement accounts tend to benefit most. Very engaged accounts have less room to improve and can have low or negative incremental value.
honest splitting and averaging help stabilize local effect estimates that would be too fragile in a single tree.
4. Baseline ATE and Train-Test Split
The average effect is easy to estimate because treatment is randomized. The harder task is estimating how the effect changes across accounts.
ate_hat, ate_se = difference_in_means(df)
ate_table = pd.DataFrame(
{
"quantity": ["Estimated ATE", "95% CI lower", "95% CI upper", "True ATE"],
"value": [ate_hat, ate_hat - 1.96 * ate_se, ate_hat + 1.96 * ate_se, df["true_cate"].mean()],
}
)
display(ate_table.round(3))
train_df, test_df = train_test_split(df, test_size=0.35, random_state=41, stratify=df["treatment"])
split_table = pd.DataFrame(
{
"sample": ["Train", "Test"],
"n": [len(train_df), len(test_df)],
"treatment_share": [train_df["treatment"].mean(), test_df["treatment"].mean()],
"true_ate": [train_df["true_cate"].mean(), test_df["true_cate"].mean()],
}
)
display(split_table.round(3))| quantity | value | |
|---|---|---|
| 0 | Estimated ATE | 5.385 |
| 1 | 95% CI lower | 4.903 |
| 2 | 95% CI upper | 5.866 |
| 3 | True ATE | 5.203 |
| sample | n | treatment_share | true_ate | |
|---|---|---|---|---|
| 0 | Train | 5850 | 0.497 | 5.212 |
| 1 | Test | 3150 | 0.497 | 5.186 |
5. A Single Honest Causal Tree
For randomized treatment with probability \(e=0.5\), a transformed outcome is:
\[ Y_i^* = Y_i\frac{W_i-e}{e(1-e)} \]
Its conditional expectation equals the CATE:
\[ E[Y_i^* \mid X_i=x] = \tau(x) \]
We use this transformed outcome only to choose splits in the discovery sample. The final leaf effects are estimated with treatment-control differences in the separate estimation sample.
single_tree = fit_honest_tree(
train_df,
feature_cols,
seed=705,
max_depth=4,
min_samples_leaf=160,
)
print(export_text(single_tree["tree"], feature_names=feature_cols, decimals=2))
fig, ax = plt.subplots(figsize=(18, 8))
plot_tree(
single_tree["tree"],
feature_names=feature_cols,
filled=True,
rounded=True,
impurity=False,
ax=ax,
)
ax.set_title("Discovery tree fit on transformed outcomes")
plt.tight_layout()
plt.show()|--- tenure_months <= 7.59
| |--- tenure_months <= 5.42
| | |--- value: [-4.19]
| |--- tenure_months > 5.42
| | |--- value: [-46.88]
|--- tenure_months > 7.59
| |--- tenure_months <= 24.62
| | |--- tenure_months <= 18.02
| | | |--- tenure_months <= 15.99
| | | | |--- value: [10.32]
| | | |--- tenure_months > 15.99
| | | | |--- value: [-19.62]
| | |--- tenure_months > 18.02
| | | |--- tenure_months <= 20.29
| | | | |--- value: [70.82]
| | | |--- tenure_months > 20.29
| | | | |--- value: [17.52]
| |--- tenure_months > 24.62
| | |--- log_account_value <= 4.75
| | | |--- log_account_value <= 4.17
| | | | |--- value: [-13.75]
| | | |--- log_account_value > 4.17
| | | | |--- value: [28.73]
| | |--- log_account_value > 4.75
| | | |--- value: [-30.82]

For a decision maker, the practical meaning matters most. The forest is useful when it finds stable neighborhoods with enough treated and control support.
# Build the visualization for the estimates or diagnostics computed above.
estimation_frame = train_df.loc[single_tree["estimation_index"]].copy()
leaf_table = honest_leaf_table(single_tree, estimation_frame)
display(leaf_table.round(3))
fig, ax = plt.subplots(figsize=(9.5, 5))
y = np.arange(len(leaf_table))
ax.errorbar(
leaf_table["estimated_effect"],
y,
xerr=np.vstack([
leaf_table["estimated_effect"] - leaf_table["ci_low"],
leaf_table["ci_high"] - leaf_table["estimated_effect"],
]),
fmt="o",
color="#2563eb",
ecolor="#94a3b8",
capsize=4,
label="Honest leaf estimate",
)
ax.scatter(leaf_table["true_cate_mean"], y, color="#dc2626", marker="D", label="True leaf CATE mean")
ax.axvline(0, color="#334155", linestyle=":")
ax.set_yticks(y)
ax.set_yticklabels([f"Leaf {leaf}" for leaf in leaf_table["leaf_id"]])
ax.set_xlabel("Treatment effect")
ax.set_title("Honest effects in discovered leaves")
ax.legend(loc="best")
plt.tight_layout()
plt.show()| leaf_id | n | treatment_share | estimated_effect | std_error | ci_low | ci_high | true_cate_mean | mean_risk | mean_engagement | |
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 2 | 252 | 0.484 | 2.650 | 1.430 | -0.153 | 5.453 | 3.119 | 0.417 | 0.539 |
| 1 | 11 | 338 | 0.497 | 4.755 | 1.193 | 2.416 | 7.094 | 5.773 | 0.469 | 0.554 |
| 2 | 16 | 182 | 0.549 | 4.850 | 1.743 | 1.433 | 8.266 | 5.954 | 0.444 | 0.543 |
| 3 | 3 | 179 | 0.508 | 5.081 | 1.727 | 1.696 | 8.466 | 5.466 | 0.439 | 0.562 |
| 4 | 14 | 415 | 0.501 | 5.884 | 0.896 | 4.127 | 7.641 | 4.799 | 0.457 | 0.564 |
| 5 | 15 | 325 | 0.468 | 6.178 | 1.009 | 4.200 | 8.157 | 5.440 | 0.431 | 0.552 |
| 6 | 8 | 182 | 0.478 | 6.357 | 1.642 | 3.139 | 9.576 | 5.368 | 0.457 | 0.557 |
| 7 | 7 | 872 | 0.495 | 7.462 | 0.777 | 5.940 | 8.985 | 5.340 | 0.442 | 0.558 |
| 8 | 10 | 180 | 0.528 | 7.575 | 1.818 | 4.011 | 11.139 | 5.435 | 0.419 | 0.500 |

A single tree is interpretable, but it is unstable. Small changes in the sample can change the splits. A forest reduces this instability by averaging many randomized honest trees.
honest splitting and averaging help stabilize local effect estimates that would be too fragile in a single tree.
6. Honest Causal Forest
The forest below repeats the honest-tree procedure many times. Each tree sees a random subsample, chooses splits using a discovery half, and estimates leaf effects using an estimation half.
For a target account \(x\), tree \(b\) gives a local estimate \(\hat{\tau}_b(x)\). The forest prediction averages over trees:
\[ \hat{\tau}_{forest}(x) = \frac{1}{B}\sum_{b=1}^B \hat{\tau}_b(x) \]
This is an educational implementation. Production causal forests use more sophisticated splitting criteria, nuisance adjustments, variance estimation, and optimized code.
# Organize the calculations for this section into readable intermediate steps.
honest_forest = fit_honest_forest(
train_df,
feature_cols,
n_trees=90,
sample_fraction=0.65,
seed=900,
max_depth=5,
min_samples_leaf=80,
)
forest_predictions = predict_honest_forest(honest_forest, test_df[feature_cols])
test_scored = test_df.copy()
test_scored["cate_honest_tree"] = predict_honest_tree(single_tree, test_scored[feature_cols])
test_scored["cate_honest_forest"] = forest_predictions["cate_hat"]
test_scored["forest_tree_sd"] = forest_predictions["tree_sd"]
test_scored["forest_tree_q10"] = forest_predictions["tree_q10"]
test_scored["forest_tree_q90"] = forest_predictions["tree_q90"]
display(
test_scored[
[
"account_id",
"risk_score",
"engagement_score",
"enterprise_plan",
"true_cate",
"cate_honest_forest",
"forest_tree_sd",
"forest_tree_q10",
"forest_tree_q90",
]
].head(10).round(3)
)| account_id | risk_score | engagement_score | enterprise_plan | true_cate | cate_honest_forest | forest_tree_sd | forest_tree_q10 | forest_tree_q90 | |
|---|---|---|---|---|---|---|---|---|---|
| 6979 | 6980 | 0.110 | 0.664 | 1 | 4.882 | 4.080 | 2.359 | 0.783 | 6.970 |
| 5929 | 5930 | 0.572 | 0.200 | 0 | 8.007 | 5.985 | 2.193 | 3.248 | 8.523 |
| 1820 | 1821 | 0.880 | 0.476 | 0 | 9.690 | 7.880 | 2.896 | 4.442 | 11.701 |
| 2376 | 2377 | 0.726 | 0.320 | 0 | 7.273 | 6.762 | 3.337 | 2.539 | 11.347 |
| 5836 | 5837 | 0.534 | 0.827 | 1 | 2.552 | 4.899 | 2.098 | 2.151 | 7.518 |
| 848 | 849 | 0.148 | 0.657 | 1 | 4.685 | 5.254 | 2.645 | 1.671 | 8.595 |
| 1141 | 1142 | 0.167 | 0.811 | 1 | 4.341 | 4.279 | 2.629 | 0.782 | 6.898 |
| 3155 | 3156 | 0.526 | 0.816 | 1 | 6.903 | 5.166 | 2.003 | 2.527 | 7.445 |
| 8237 | 8238 | 0.227 | 0.640 | 1 | 2.669 | 4.889 | 2.689 | 1.846 | 8.455 |
| 3395 | 3396 | 0.321 | 0.456 | 1 | 6.915 | 6.233 | 2.259 | 3.083 | 9.048 |
The tree-to-tree standard deviation is a stability diagnostic, not a formal standard error. Trees are correlated because they are trained on overlapping data, so treating this as an independent sampling standard error would be too optimistic. Still, high tree-to-tree disagreement is useful for flagging uncertain predictions.
7. Forests as Adaptive Neighborhoods
Generalized random forests can be understood as local weighting methods. For a target account, each tree defines a local neighborhood: the estimation-sample accounts that land in the same leaf. The forest averages across many such neighborhoods.
The next cell picks one high predicted-CATE account and shows the training accounts that most often share leaves with it.
target_index = test_scored.sort_values("cate_honest_forest", ascending=False).index[0]
target_row = test_scored.loc[[target_index], feature_cols]
neighbor_weights = forest_neighbor_weights(honest_forest, target_row)
neighbor_table = (
train_df.loc[neighbor_weights.head(12).index, feature_cols + ["treatment", "outcome", "true_cate"]]
.assign(forest_weight=neighbor_weights.head(12).values)
.sort_values("forest_weight", ascending=False)
)
target_summary = test_scored.loc[
[target_index],
feature_cols + ["true_cate", "cate_honest_forest", "forest_tree_sd"],
]
print("Target account")
display(target_summary.round(3))
print("Most heavily weighted forest neighbors from the training estimation leaves")
display(neighbor_table.round(3))Target account
| risk_score | engagement_score | log_account_value | tenure_months | enterprise_plan | integration_count | support_tickets | true_cate | cate_honest_forest | forest_tree_sd | |
|---|---|---|---|---|---|---|---|---|---|---|
| 1995 | 0.865 | 0.195 | 5.400 | 11.112 | 1 | 6 | 3 | 17.423 | 9.034 | 2.419 |
Most heavily weighted forest neighbors from the training estimation leaves
| risk_score | engagement_score | log_account_value | tenure_months | enterprise_plan | integration_count | support_tickets | treatment | outcome | true_cate | forest_weight | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 6495 | 0.586 | 0.069 | 5.576 | 15.223 | 1 | 3 | 2 | 1 | 135.614 | 13.588 | 0.003 |
| 2876 | 0.876 | 0.334 | 5.668 | 18.327 | 1 | 1 | 2 | 1 | 120.388 | 12.890 | 0.003 |
| 1862 | 0.673 | 0.116 | 5.322 | 23.136 | 0 | 3 | 5 | 0 | 113.707 | 13.475 | 0.002 |
| 2720 | 0.846 | 0.240 | 5.242 | 14.409 | 1 | 4 | 5 | 0 | 97.487 | 18.794 | 0.002 |
| 4817 | 0.911 | 0.257 | 5.093 | 13.012 | 1 | 6 | 5 | 0 | 99.521 | 20.078 | 0.002 |
| 5533 | 0.803 | 0.405 | 5.294 | 18.607 | 1 | 5 | 4 | 0 | 104.245 | 16.186 | 0.002 |
| 89 | 0.732 | 0.737 | 5.335 | 11.284 | 1 | 6 | 3 | 0 | 107.216 | 8.352 | 0.002 |
| 8936 | 0.539 | 0.234 | 5.525 | 14.725 | 1 | 5 | 4 | 0 | 117.446 | 15.017 | 0.002 |
| 1817 | 0.760 | 0.174 | 6.047 | 7.963 | 1 | 2 | 3 | 1 | 132.099 | 14.494 | 0.002 |
| 7339 | 0.778 | 0.254 | 4.844 | 9.406 | 1 | 7 | 3 | 0 | 94.896 | 15.603 | 0.002 |
| 4471 | 0.868 | 0.616 | 5.456 | 46.977 | 0 | 4 | 1 | 1 | 118.272 | 6.215 | 0.002 |
| 3940 | 0.786 | 0.766 | 5.382 | 16.045 | 1 | 5 | 3 | 0 | 113.123 | 7.892 | 0.002 |
The neighbors are not chosen by Euclidean distance alone. They are chosen by repeated tree splits that were useful for finding treatment-effect variation. That is why forests are often described as adaptive nearest-neighbor methods.
8. Compare Against Other Forest-Style Baselines
We compare four estimators:
- Single honest causal tree: interpretable but high variance.
- Transformed-outcome random forest: fits a random forest directly to \(Y^*\), but does not use honesty for effect estimation.
- T-learner random forest: fits separate outcome forests for treated and control accounts.
- Honest causal forest: averages many honest local treatment-effect estimates.
In real data, we do not know the true CATE. Here we use the simulation truth to understand behavior.
# Transformed-outcome random forest.
to_rf = RandomForestRegressor(
n_estimators=220,
min_samples_leaf=45,
random_state=1201,
n_jobs=-1,
)
to_rf.fit(train_df[feature_cols], transformed_outcome(train_df))
test_scored["cate_transformed_rf"] = to_rf.predict(test_scored[feature_cols])
# T-learner random forest.
t_rf_1 = RandomForestRegressor(
n_estimators=220,
min_samples_leaf=45,
random_state=1301,
n_jobs=-1,
)
t_rf_0 = RandomForestRegressor(
n_estimators=220,
min_samples_leaf=45,
random_state=1302,
n_jobs=-1,
)
t_rf_1.fit(train_df.loc[train_df["treatment"] == 1, feature_cols], train_df.loc[train_df["treatment"] == 1, "outcome"])
t_rf_0.fit(train_df.loc[train_df["treatment"] == 0, feature_cols], train_df.loc[train_df["treatment"] == 0, "outcome"])
test_scored["cate_t_forest"] = t_rf_1.predict(test_scored[feature_cols]) - t_rf_0.predict(test_scored[feature_cols])
# Oracle score for simulation benchmarking.
test_scored["oracle_true_cate"] = test_scored["true_cate"]
score_map = {
"Single honest tree": "cate_honest_tree",
"Transformed-outcome RF": "cate_transformed_rf",
"T-learner RF": "cate_t_forest",
"Honest causal forest": "cate_honest_forest",
"Oracle true CATE": "oracle_true_cate",
}
metrics = pd.DataFrame(
[model_metrics(test_scored, score_col, label) for label, score_col in score_map.items()]
).sort_values("cate_rmse")
display(metrics.round(3))| model | cate_rmse | cate_correlation | mean_predicted_cate | true_cate_top20 | negative_cate_share_top20 | |
|---|---|---|---|---|---|---|
| 4 | Oracle true CATE | 0.000 | 1.000 | 5.186 | 11.042 | 0.000 |
| 3 | Honest causal forest | 3.215 | 0.832 | 5.543 | 10.031 | 0.000 |
| 2 | T-learner RF | 3.217 | 0.745 | 5.266 | 9.427 | 0.008 |
| 0 | Single honest tree | 4.350 | 0.082 | 5.998 | 5.390 | 0.078 |
| 1 | Transformed-outcome RF | 19.885 | 0.225 | 4.792 | 6.296 | 0.037 |
Ask what this result would change in the analysis. Local heterogeneity is persuasive only when the nearby comparison is credible.
fig, axes = plt.subplots(1, 2, figsize=(13, 4.8))
plot_metrics = metrics.loc[metrics["model"] != "Oracle true CATE"].copy()
sns.barplot(data=plot_metrics, y="model", x="cate_rmse", color="#38bdf8", ax=axes[0])
axes[0].set_title("CATE RMSE against simulation truth")
axes[0].set_xlabel("RMSE")
axes[0].set_ylabel("")
sns.barplot(data=plot_metrics, y="model", x="true_cate_top20", color="#22c55e", ax=axes[1])
axes[1].axvline(test_scored["true_cate"].mean(), color="#334155", linestyle=":", label="Test ATE")
axes[1].set_title("True effect among top 20% targeted")
axes[1].set_xlabel("Mean true CATE in top 20%")
axes[1].set_ylabel("")
axes[1].legend(loc="lower right")
plt.tight_layout()
plt.show()
The honest causal forest improves substantially over a single tree. The T-learner is competitive because the data-generating process has smooth outcome surfaces and both treatment arms are large. The transformed-outcome random forest is useful as a simple baseline, but its target is noisy and it does not separate split discovery from effect estimation.
# Build the visualization for the estimates or diagnostics computed above.
plot_sample = test_scored.sample(2600, random_state=16)
fig, axes = plt.subplots(1, 3, figsize=(15, 4.6), sharey=True)
for ax, (label, score_col) in zip(
axes,
[
("Single honest tree", "cate_honest_tree"),
("T-learner RF", "cate_t_forest"),
("Honest causal forest", "cate_honest_forest"),
],
):
sns.scatterplot(
data=plot_sample,
x=score_col,
y="true_cate",
alpha=0.35,
edgecolor=None,
color="#2563eb",
ax=ax,
)
line_min = min(plot_sample[score_col].min(), plot_sample["true_cate"].min())
line_max = max(plot_sample[score_col].max(), plot_sample["true_cate"].max())
ax.plot([line_min, line_max], [line_min, line_max], color="#334155", linestyle="--", linewidth=1)
ax.set_title(label)
ax.set_xlabel("Predicted CATE")
ax.set_ylabel("True CATE")
plt.tight_layout()
plt.show()
The single tree produces coarse predictions because every account in a leaf receives the same effect estimate. The forest smooths over many such partitions and gives a more continuous ranking.
honest splitting and averaging help stabilize local effect estimates that would be too fragile in a single tree.
9. Decile Validation
In a real randomized experiment, individual treatment effects are unobserved, but we can validate a CATE score by sorting accounts into buckets and estimating treatment-control differences within each bucket.
A useful score should produce higher observed lift in top predicted-CATE buckets.
# Build the visualization for the estimates or diagnostics computed above.
decile_table = decile_validation_table(test_scored, "cate_honest_forest")
display(decile_table.round(3))
fig, ax = plt.subplots(figsize=(10, 5.2))
ax.plot(
decile_table["decile"],
decile_table["observed_lift"],
marker="o",
color="#2563eb",
label="Observed treatment-control lift",
)
ax.fill_between(
np.arange(len(decile_table)),
decile_table["ci_low"],
decile_table["ci_high"],
color="#bfdbfe",
alpha=0.45,
label="Observed 95% CI",
)
ax.plot(
decile_table["decile"],
decile_table["predicted_cate_mean"],
marker="o",
color="#7c3aed",
label="Predicted CATE mean",
)
ax.plot(
decile_table["decile"],
decile_table["true_cate_mean"],
marker="D",
color="#dc2626",
label="True CATE mean",
)
ax.axhline(0, color="#334155", linestyle=":")
ax.set_title("Decile validation for the honest causal forest")
ax.set_xlabel("Predicted CATE decile")
ax.set_ylabel("Treatment effect")
ax.tick_params(axis="x", rotation=35)
ax.legend(loc="best")
plt.tight_layout()
plt.show()| decile | n | treatment_share | observed_lift | ci_low | ci_high | predicted_cate_mean | true_cate_mean | tree_sd_mean | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | D1 highest | 315 | 0.508 | 11.636 | 8.683 | 14.589 | 7.791 | 11.336 | 2.506 |
| 1 | D2 | 315 | 0.540 | 7.544 | 4.497 | 10.591 | 6.880 | 8.726 | 2.362 |
| 2 | D3 | 315 | 0.441 | 5.406 | 2.520 | 8.292 | 6.351 | 7.515 | 2.185 |
| 3 | D4 | 315 | 0.470 | 6.278 | 3.850 | 8.705 | 5.980 | 6.585 | 2.108 |
| 4 | D5 | 315 | 0.451 | 8.082 | 5.500 | 10.663 | 5.672 | 5.776 | 2.059 |
| 5 | D6 | 315 | 0.505 | 3.567 | 1.046 | 6.087 | 5.358 | 4.565 | 2.059 |
| 6 | D7 | 315 | 0.489 | 4.161 | 1.732 | 6.591 | 5.043 | 3.647 | 2.138 |
| 7 | D8 | 315 | 0.552 | 2.501 | 0.148 | 4.854 | 4.679 | 2.688 | 2.208 |
| 8 | D9 | 315 | 0.530 | 1.490 | -0.825 | 3.805 | 4.250 | 1.859 | 2.322 |
| 9 | D10 | 315 | 0.489 | -1.938 | -4.110 | 0.234 | 3.426 | -0.839 | 2.324 |

The top predicted deciles show higher lift than the bottom deciles. This is the validation view you can show to a business team. It avoids pretending that individual effects are directly observed.
10. Stability and Uncertainty Diagnostics
The forest gives a distribution of tree predictions for each account. This diagnostic is different from a formal confidence interval. It is still useful for model diagnostics.
If two accounts have similar predicted CATE but one has much higher tree-to-tree disagreement, the high-disagreement account should be treated as less stable.
# Build the visualization for the estimates or diagnostics computed above.
fig, axes = plt.subplots(1, 2, figsize=(12, 4.6))
sns.scatterplot(
data=test_scored.sample(2600, random_state=21),
x="cate_honest_forest",
y="forest_tree_sd",
hue="risk_score",
palette="viridis",
alpha=0.55,
edgecolor=None,
ax=axes[0],
)
axes[0].set_title("Tree-to-tree disagreement")
axes[0].set_xlabel("Predicted CATE")
axes[0].set_ylabel("Tree prediction SD")
axes[0].legend(title="Risk", loc="best")
sns.lineplot(
data=decile_table,
x="decile",
y="tree_sd_mean",
marker="o",
color="#0f766e",
ax=axes[1],
)
axes[1].set_title("Average tree disagreement by score decile")
axes[1].set_xlabel("Predicted CATE decile")
axes[1].set_ylabel("Mean tree prediction SD")
axes[1].tick_params(axis="x", rotation=35)
plt.tight_layout()
plt.show()
High uncertainty can come from sparse regions, mixed leaves, weak overlap, or genuinely complex local heterogeneity. In production, uncertainty diagnostics should be combined with randomized holdouts and policy guardrails.
honest splitting and averaging help stabilize local effect estimates that would be too fragile in a single tree.
11. Feature Importance for Heterogeneity
Feature importance in causal forests is not the same as feature importance in outcome prediction. A feature is important for the forest if it helps find splits with different treatment effects.
We will look at two educational summaries:
- Average split importance across the discovery trees.
- Simulation-only permutation importance: how much CATE RMSE worsens when a feature is shuffled in the test set.
The second measure uses the true CATE and is not available in real projects.
# Generate the teaching data and keep the causal quantities needed for diagnostics.
split_importance = pd.DataFrame(
{
"feature": feature_cols,
"average_split_importance": honest_forest["split_importance"],
}
)
baseline_rmse = np.sqrt(mean_squared_error(test_scored["true_cate"], test_scored["cate_honest_forest"]))
rng = np.random.default_rng(606)
perm_rows = []
for feature in feature_cols:
permuted = test_scored[feature_cols].copy()
permuted[feature] = rng.permutation(permuted[feature].to_numpy())
permuted_pred = predict_honest_forest(honest_forest, permuted)["cate_hat"]
permuted_rmse = np.sqrt(mean_squared_error(test_scored["true_cate"], permuted_pred))
perm_rows.append({"feature": feature, "rmse_increase_when_permuted": permuted_rmse - baseline_rmse})
importance_table = split_importance.merge(pd.DataFrame(perm_rows), on="feature").sort_values(
"rmse_increase_when_permuted",
ascending=False,
)
display(importance_table.round(4))
fig, axes = plt.subplots(1, 2, figsize=(13, 4.8))
sns.barplot(
data=importance_table.sort_values("average_split_importance", ascending=True),
y="feature",
x="average_split_importance",
color="#38bdf8",
ax=axes[0],
)
axes[0].set_title("Average split importance")
axes[0].set_xlabel("Importance")
axes[0].set_ylabel("")
sns.barplot(
data=importance_table.sort_values("rmse_increase_when_permuted", ascending=True),
y="feature",
x="rmse_increase_when_permuted",
color="#22c55e",
ax=axes[1],
)
axes[1].set_title("Simulation-only CATE permutation importance")
axes[1].set_xlabel("RMSE increase")
axes[1].set_ylabel("")
plt.tight_layout()
plt.show()| feature | average_split_importance | rmse_increase_when_permuted | |
|---|---|---|---|
| 0 | risk_score | 0.246 | 0.466 |
| 1 | engagement_score | 0.186 | 0.414 |
| 4 | enterprise_plan | 0.027 | 0.150 |
| 5 | integration_count | 0.065 | 0.068 |
| 2 | log_account_value | 0.240 | 0.059 |
| 6 | support_tickets | 0.031 | 0.051 |
| 3 | tenure_months | 0.203 | 0.020 |

Feature importance should be treated as a screening tool, not as proof of a causal mechanism. The forest can tell us which pre-treatment features help predict treatment-effect variation. Domain knowledge and experimental follow-up are still needed to explain why.
12. Targeting Curves
Now convert CATE scores into a targeting decision. Suppose the company has limited customer-success capacity. If it can treat only the top \(k\) percent of accounts, how much effect does each scoring rule capture?
# Fit the models for this section and assemble the estimates used in the discussion.
fractions = np.linspace(0.05, 1.00, 20)
curve_rows = []
policy_score_map = {
"Single honest tree": "cate_honest_tree",
"Transformed-outcome RF": "cate_transformed_rf",
"T-learner RF": "cate_t_forest",
"Honest causal forest": "cate_honest_forest",
"Oracle true CATE": "oracle_true_cate",
}
for label, score_col in policy_score_map.items():
for fraction in fractions:
selected = top_fraction_mask(test_scored, score_col, fraction=fraction)
targeted = test_scored.loc[selected]
curve_rows.append(
{
"model": label,
"treatment_share": fraction,
"mean_true_cate_targeted": targeted["true_cate"].mean(),
"incremental_value_per_account": fraction * targeted["true_cate"].mean(),
"negative_cate_share_targeted": (targeted["true_cate"] < 0).mean(),
}
)
policy_curves = pd.DataFrame(curve_rows)
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
sns.lineplot(
data=policy_curves,
x="treatment_share",
y="mean_true_cate_targeted",
hue="model",
marker="o",
ax=axes[0],
)
axes[0].axhline(test_scored["true_cate"].mean(), color="#334155", linestyle=":", label="Treat-all average")
axes[0].set_title("Mean effect among targeted accounts")
axes[0].set_xlabel("Share treated")
axes[0].set_ylabel("Mean true CATE targeted")
sns.lineplot(
data=policy_curves,
x="treatment_share",
y="incremental_value_per_account",
hue="model",
marker="o",
ax=axes[1],
)
axes[1].set_title("Incremental value per eligible account")
axes[1].set_xlabel("Share treated")
axes[1].set_ylabel("Incremental value")
axes[1].legend_.remove()
plt.tight_layout()
plt.show()
The causal forest is useful when treatment capacity is limited because it ranks accounts by estimated incremental value. If the company can treat everyone, the ranking matters less. If the company can treat only 10% to 30%, the ranking can be the difference between a high-value program and wasted capacity.
13. Practical Notes for Production
In production, use a mature implementation when possible. The simplified implementation here omits many details that matter for inference and reliability.
Practical checks:
- Verify treatment assignment or adjust for confounding with propensity and outcome nuisance models.
- Use only pre-treatment features.
- Check overlap before trusting local effects.
- Validate ranked groups with randomized holdout traffic.
- Report bucket-level effects, not only individual CATE scores.
- Monitor model drift and treatment saturation over time.
- Treat feature importance as heterogeneity screening, not mechanistic proof.
- Keep an exploration policy so future data still contain treatment and control variation across the score range.
Causal forests are powerful because they search for heterogeneity flexibly. That same flexibility makes validation essential.
readout = pd.DataFrame(
{
"question": [
"What did the forest learn?",
"Who should be targeted first?",
"What should stakeholders see?",
"What is the main caution?",
"What should happen after deployment?",
],
"answer": [
"Treatment effects are highest for accounts with risk and room for engagement improvement.",
"Prioritize the top predicted-CATE deciles, subject to capacity and guardrails.",
"Show decile lift, targeting curves, and uncertainty/stability diagnostics.",
"Do not interpret individual CATEs as observed facts; validate by groups.",
"Keep randomized holdouts and monitor score calibration over time.",
],
}
)
display(readout)| question | answer | |
|---|---|---|
| 0 | What did the forest learn? | Treatment effects are highest for accounts wit... |
| 1 | Who should be targeted first? | Prioritize the top predicted-CATE deciles, sub... |
| 2 | What should stakeholders see? | Show decile lift, targeting curves, and uncert... |
| 3 | What is the main caution? | Do not interpret individual CATEs as observed ... |
| 4 | What should happen after deployment? | Keep randomized holdouts and monitor score cal... |
Key Takeaways
- A causal forest estimates treatment effects, not just outcomes.
- Honest trees separate split discovery from treatment-effect estimation.
- A forest averages many local treatment-control contrasts, reducing single-tree instability.
- Forest predictions can be understood as adaptive neighborhood estimates.
- Tree-to-tree dispersion is a useful stability diagnostic but not a formal confidence interval.
- Decile validation and policy curves are more useful for business decisions than individual CATE claims.
- Mature causal-forest implementations add optimized splitting, nuisance adjustment, and inference tools beyond this educational version.
References
Athey, S., & Imbens, G. W. (2016). Recursive partitioning for heterogeneous causal effects. Proceedings of the National Academy of Sciences, 113(27), 7353-7360. https://doi.org/10.1073/pnas.1510489113
Athey, S., Tibshirani, J., & Wager, S. (2019). Generalized random forests. The Annals of Statistics, 47(2), 1148-1178. https://doi.org/10.1214/18-AOS1709
Breiman, L. (2001). Random forests. Machine Learning, 45(1), 5-32. https://doi.org/10.1023/A:1010933404324
Nie, X., & Wager, S. (2020). Quasi-oracle estimation of heterogeneous treatment effects. Biometrika, 108(2), 299-319. https://doi.org/10.1093/biomet/asaa076
Wager, S., & Athey, S. (2018). Estimation and inference of heterogeneous treatment effects using random forests. Journal of the American Statistical Association, 113(523), 1228-1242. https://doi.org/10.1080/01621459.2017.1319839