DoubleML Tutorial 11: Sample Splitting, Cross-Fitting, And Repeated Cross-Fitting
This notebook is about the machinery that makes double machine learning more than “fit a flexible model and regress residuals.” The key idea is cross-fitting: nuisance models are trained on one part of the data and used to predict nuisance quantities on held-out rows. Those held-out nuisance predictions are then used inside an orthogonal score for the treatment effect.
Why does this matter?
Flexible learners can overfit. If a nuisance model predicts the same rows it was trained on, the residuals can become artificially small or distorted. In ordinary prediction work, that is a validation problem. In causal estimation, it is worse: distorted residuals can distort the orthogonal score itself.
A PLR score can be written informally as residualizing both outcome and treatment:
and then estimating the treatment effect from the relationship between (_i) and (_i). Cross-fitting asks for ((X_i)) and ((X_i)) to be predictions from models that did not train on row (i).
This notebook covers five practical topics:
What K-fold cross-fitting is doing.
How fold count changes training/validation trade-offs.
How repeated cross-fitting reduces dependence on one random split.
How to set external splits for grouped data.
Why in-sample nuisance residualization can fail with very flexible learners.
Expected runtime: about 30-60 seconds on a laptop. The notebook fits several DoubleML models and a few manual cross-fitting models, but all learners are configured to stay tutorial-friendly.
Setup
The setup cell follows the same pattern as the earlier DoubleML tutorials. It creates output folders, sets a local Matplotlib cache, imports DoubleML and sklearn tools, and fixes global plotting defaults.
from pathlib import Pathimport osimport timeimport warningsPROJECT_ROOT = Path.cwd().resolve()if PROJECT_ROOT.name =="doubleml": PROJECT_ROOT = PROJECT_ROOT.parents[2]OUTPUT_DIR = PROJECT_ROOT /"notebooks"/"tutorials"/"doubleml"/"outputs"DATASET_DIR = OUTPUT_DIR /"datasets"FIGURE_DIR = OUTPUT_DIR /"figures"TABLE_DIR = OUTPUT_DIR /"tables"REPORT_DIR = OUTPUT_DIR /"reports"MATPLOTLIB_CACHE_DIR = OUTPUT_DIR /"matplotlib_cache"for directory in [DATASET_DIR, FIGURE_DIR, TABLE_DIR, REPORT_DIR, MATPLOTLIB_CACHE_DIR]: directory.mkdir(parents=True, exist_ok=True)os.environ.setdefault("MPLCONFIGDIR", str(MATPLOTLIB_CACHE_DIR))warnings.filterwarnings("ignore", category=FutureWarning)warnings.filterwarnings("ignore", message="X does not have valid feature names.*")warnings.filterwarnings("ignore", message="IProgress not found.*")import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport seaborn as snsfrom IPython.display import Markdown, displayimport doubleml as dmlfrom doubleml import DoubleMLData, DoubleMLPLRfrom sklearn.base import clonefrom sklearn.ensemble import ExtraTreesRegressor, HistGradientBoostingRegressorfrom sklearn.linear_model import LinearRegressionfrom sklearn.metrics import mean_absolute_error, mean_squared_error, r2_scorefrom sklearn.model_selection import GroupKFold, KFoldNOTEBOOK_PREFIX ="11"RANDOM_SEED =42TRUE_THETA =1.00sns.set_theme(style="whitegrid", context="notebook")plt.rcParams.update({"figure.dpi": 120, "savefig.dpi": 160})print(f"DoubleML version: {dml.__version__}")print(f"Output directory: {OUTPUT_DIR}")
The setup output confirms the package version and the artifact location. The notebook uses only local synthetic data, so every result should be reproducible from the cells below.
Helper Functions
These helpers handle repeated notebook chores: saving tables, building DoubleMLData, computing nuisance metrics, creating sample splits, and fitting a PLR model with a specified split design.
The most important helper is fit_plr_with_splits(). It lets us compare DoubleML models while controlling whether the sample splitting is drawn internally or provided externally.
The helpers make the split structure explicit. A DoubleML repeated split is represented as a list of repetitions, where each repetition is itself a list of train/test index pairs.
Split Vocabulary
This table defines the terms used throughout the notebook. The distinction between sample splitting, cross-fitting, and repeated cross-fitting is subtle but important.
split_vocabulary = pd.DataFrame( [ {"term": "Sample splitting","meaning": "Partition rows into training and held-out pieces for nuisance prediction.","why it matters": "Prevents nuisance predictions for a row from being trained on that same row.", }, {"term": "K-fold cross-fitting","meaning": "Split data into K folds; train nuisances on K-1 folds and predict the held-out fold, repeated for every fold.","why it matters": "Every row gets an out-of-fold nuisance prediction while still using most data for training.", }, {"term": "Repeated cross-fitting","meaning": "Draw multiple K-fold splits and aggregate the resulting DoubleML estimates.","why it matters": "Reduces dependence on one random partition of the sample.", }, {"term": "External sample splits","meaning": "User-supplied train/test fold indices passed into DoubleML.","why it matters": "Needed for grouped, temporal, clustered, or otherwise constrained validation designs.", }, {"term": "In-sample residualization","meaning": "Fit nuisances and predict on the same rows used for fitting.","why it matters": "Can badly distort residuals with flexible learners; useful here only as a cautionary comparison.", }, ])save_table(split_vocabulary, f"{NOTEBOOK_PREFIX}_split_vocabulary.csv")display(split_vocabulary)
term
meaning
why it matters
0
Sample splitting
Partition rows into training and held-out piec...
Prevents nuisance predictions for a row from b...
1
K-fold cross-fitting
Split data into K folds; train nuisances on K-...
Every row gets an out-of-fold nuisance predict...
2
Repeated cross-fitting
Draw multiple K-fold splits and aggregate the ...
Reduces dependence on one random partition of ...
3
External sample splits
User-supplied train/test fold indices passed i...
Needed for grouped, temporal, clustered, or ot...
4
In-sample residualization
Fit nuisances and predict on the same rows use...
Can badly distort residuals with flexible lear...
The vocabulary separates concepts that often get compressed into one phrase. DoubleML automates a lot of this, but the analyst still needs to know what split design is being used.
Synthetic Panel-Like PLR Data
We simulate a small panel-like dataset with repeated observations per user. The user identifier is excluded from the model controls, but a noisy user-level proxy is included. This lets us demonstrate both ordinary random folds and group-aware external folds.
The true treatment effect is TRUE_THETA = 1.00. The treatment and outcome both depend on nonlinear functions of observed controls, so cross-fitting with a nonlinear learner is useful.
The first rows show a dataset with repeated users and known oracle nuisance functions. The oracle columns help us evaluate the simulation but are excluded from the DoubleML controls.
Data Audit
The audit records the sample size, number of groups, missingness, and the strength of confounding. A split tutorial still needs a design audit: cross-fitting helps with overfitting, not with a wrong treatment definition or missing confounders.
field_dictionary = pd.DataFrame( [ {"column": "unit_id", "role": "identifier", "description": "Synthetic row identifier; excluded from modeling."}, {"column": "user_id", "role": "group identifier", "description": "Repeated-observation group id; used only for external GroupKFold splits."}, {"column": "x00-x09", "role": "observed controls", "description": "Numeric pre-treatment controls."}, {"column": "user_effect_proxy", "role": "observed control", "description": "Noisy pre-treatment proxy for user-level heterogeneity."}, {"column": "user_effect", "role": "oracle only", "description": "Latent user effect used only in simulation diagnostics."}, {"column": "true_m", "role": "oracle only", "description": "True treatment nuisance E[D|X] used only for simulation diagnostics."}, {"column": "true_g", "role": "oracle only", "description": "True outcome nuisance component g0(X) used only for simulation diagnostics."}, {"column": "treatment", "role": "treatment", "description": "Continuous treatment D."}, {"column": "outcome", "role": "outcome", "description": "Continuous outcome Y."}, ])data_audit = pd.DataFrame( {"n_rows": [len(plr_df)],"n_users": [plr_df["user_id"].nunique()],"observations_per_user": [observations_per_user],"model_controls": [len(model_x_cols(plr_df))],"missing_cells": [int(plr_df.isna().sum().sum())],"true_theta": [TRUE_THETA],"corr_treatment_true_m": [plr_df["treatment"].corr(plr_df["true_m"])],"corr_treatment_true_g": [plr_df["treatment"].corr(plr_df["true_g"])], })save_table(field_dictionary, f"{NOTEBOOK_PREFIX}_field_dictionary.csv")save_table(data_audit, f"{NOTEBOOK_PREFIX}_data_audit.csv")display(field_dictionary)display(data_audit)
column
role
description
0
unit_id
identifier
Synthetic row identifier; excluded from modeling.
1
user_id
group identifier
Repeated-observation group id; used only for e...
2
x00-x09
observed controls
Numeric pre-treatment controls.
3
user_effect_proxy
observed control
Noisy pre-treatment proxy for user-level heter...
4
user_effect
oracle only
Latent user effect used only in simulation dia...
5
true_m
oracle only
True treatment nuisance E[D|X] used only for s...
6
true_g
oracle only
True outcome nuisance component g0(X) used onl...
7
treatment
treatment
Continuous treatment D.
8
outcome
outcome
Continuous outcome Y.
n_rows
n_users
observations_per_user
model_controls
missing_cells
true_theta
corr_treatment_true_m
corr_treatment_true_g
0
960
240
4
11
0
1.0
0.649431
0.1303
The treatment is related to both the treatment nuisance and outcome-relevant control structure. That is the confounding pattern the PLR score is designed to address.
Cross-Fitting Design Diagram
The diagram below shows why a row’s nuisance prediction is out-of-fold. A fold’s held-out rows are predicted by a model trained on the other folds. After all folds are predicted, DoubleML has one out-of-fold nuisance prediction per row.
The workflow is simple but powerful. We are not just validating a predictive model; we are constructing the nuisance predictions that enter the final causal score.
Visualizing Fold Assignments
This cell creates a 5-fold split and displays the fold assignment for the first rows. The heatmap is a compact way to see that every row belongs to exactly one held-out fold in each repetition.
The fold sizes are balanced, and the heatmap shows one held-out fold label per row. DoubleML uses this structure to fit nuisance models and assemble out-of-fold predictions.
Choosing The Number Of Folds
The number of folds controls a trade-off. More folds mean each nuisance model trains on more rows, but there are more nuisance fits. Fewer folds run faster, but each nuisance model trains on less data.
This section compares 2, 3, 5, and 8 folds using the same learner family. The goal is to understand sensitivity, not to find a universally best fold count.
The estimates move modestly across fold counts. That movement is part of split sensitivity. In real work, large swings across reasonable fold counts would be a reason to slow down and inspect nuisance quality, overlap, and sample size.
Fold Count Plot
The next plot shows point estimates, confidence intervals, and nuisance RMSE side by side. This keeps the fold-count choice tied to both causal uncertainty and nuisance prediction quality.
The RMSE panel shows that nuisance quality changes with fold count, but not always monotonically. Fold choice is a practical design choice, so reporting it is part of reproducibility.
Repeated Cross-Fitting
Repeated cross-fitting draws multiple K-fold partitions and aggregates the resulting estimates. This helps because a single random partition can be lucky or unlucky, especially in smaller samples or when learners are unstable.
DoubleML stores the per-repetition estimates in all_coef and all_se. The aggregate estimate is available as coef and se.
The detail table shows the estimates from individual repetitions. The aggregate row becomes less dependent on one particular split as the number of repetitions grows.
Repeated Split Distribution
A plot makes repeated cross-fitting easier to understand. Each point below is one repetition-specific estimate; the dashed vertical line is the known true effect.
fig, ax = plt.subplots(figsize=(10, 5.5))sns.stripplot( data=repetition_detail, x="theta_hat_rep", y="n_rep_setting", orient="h", size=8, jitter=0.16, color="#2563eb", ax=ax,)ax.axvline(TRUE_THETA, color="#dc2626", linestyle="--", linewidth=1.5, label="True effect")ax.set_title("Repetition-Specific Estimates From Repeated Cross-Fitting")ax.set_xlabel("Repetition-specific treatment effect estimate")ax.set_ylabel("Configured number of repetitions")ax.legend(loc="best")plt.tight_layout()fig.savefig(FIGURE_DIR /f"{NOTEBOOK_PREFIX}_repeated_cross_fitting_distribution.png", bbox_inches="tight")plt.show()
The individual points vary because each repetition uses a different random partition. Repeated cross-fitting is useful when that variation is nontrivial and the extra runtime is acceptable.
External Sample Splits
Sometimes random K-fold splitting is not the right design. If observations are grouped by user, household, region, device, school, or time period, the held-out fold may need to respect that structure.
This section compares ordinary random K-fold splits with GroupKFold splits by user_id. The model controls still exclude user_id; the group is used only to define fold boundaries.
Random K-fold splits can place the same user in both training and held-out folds. GroupKFold prevents group overlap, which is often the right validation logic when rows from the same group are closely related.
Random Versus Group-Aware Splits
Now we fit DoubleML with random K-fold splits and with external GroupKFold splits. The comparison shows how to pass a custom split list into DoubleML with set_sample_splitting().
The group-aware estimate is similar in this synthetic data, but the split audit still matters. In real grouped data, preventing group leakage can be more important than a small change in point estimates.
Group Overlap Plot
This plot visualizes the split audit. The goal of GroupKFold is not to change the estimate mechanically; the goal is to enforce a design rule that no user appears in both training and held-out rows for a fold.
fig, ax = plt.subplots(figsize=(10, 5))sns.barplot(data=split_group_audit, x="fold", y="overlap_groups", hue="split_type", ax=ax, palette=["#2563eb", "#059669"])ax.set_title("Group Overlap Between Train And Held-Out Rows")ax.set_xlabel("Fold")ax.set_ylabel("Number of overlapping user groups")ax.legend(title="Split type")plt.tight_layout()fig.savefig(FIGURE_DIR /f"{NOTEBOOK_PREFIX}_group_overlap_audit.png", bbox_inches="tight")plt.show()
The GroupKFold bars are zero because each user is held out as a whole unit. That is the central reason to use external splits: they let the analyst encode validation constraints that the default random splitter cannot know.
Manual No-Cross-Fitting Caution
DoubleML is designed around out-of-fold nuisance predictions. To see why, we now construct a manual cautionary example with an extremely flexible ExtraTreesRegressor.
The in-sample version fits nuisance models on all rows and predicts those same rows. The cross-fitted version trains on K-1 folds and predicts the held-out fold. Both then run the same residual-on-residual regression.
This is not meant to replace DoubleML’s implementation. It is a visual caution about why in-sample residualization is dangerous.
The in-sample nuisance RMSE is almost zero because the flexible trees can interpolate the training rows. That looks excellent as prediction output, but it destroys the residual structure needed for causal estimation. The cross-fitted residuals are noisier in a healthy way: they are honest held-out residuals.
No-Cross-Fitting Caution Plot
The plot below puts the cautionary comparison on two panels: effect estimate and treatment-residual standard deviation. The in-sample residual standard deviation can become tiny when the learner memorizes the training data.
This is the clearest practical reason for cross-fitting: training-set predictions can be too good to be useful. The causal score needs held-out nuisance predictions, not memorized training residuals.
Split Reporting Checklist
A reproducible DoubleML analysis should document its split design. This checklist can be copied into applied notebooks and reports.
split_reporting_checklist = pd.DataFrame( [ {"item": "State number of folds", "why": "Fold count changes training size and number of nuisance fits."}, {"item": "State number of repetitions", "why": "Repeated cross-fitting affects split stability and runtime."}, {"item": "Record random seeds", "why": "Random sample splits should be reproducible."}, {"item": "Explain external split constraints", "why": "Grouped, clustered, or temporal data often require custom split logic."}, {"item": "Report nuisance prediction diagnostics", "why": "Split design should be assessed with held-out nuisance performance."}, {"item": "Show split sensitivity when material", "why": "Large changes across reasonable splits weaken confidence in one estimate."}, {"item": "Avoid in-sample nuisance residualization", "why": "Flexible learners can memorize rows and distort the orthogonal score."}, ])save_table(split_reporting_checklist, f"{NOTEBOOK_PREFIX}_split_reporting_checklist.csv")display(split_reporting_checklist)
item
why
0
State number of folds
Fold count changes training size and number of...
1
State number of repetitions
Repeated cross-fitting affects split stability...
2
Record random seeds
Random sample splits should be reproducible.
3
Explain external split constraints
Grouped, clustered, or temporal data often req...
4
Report nuisance prediction diagnostics
Split design should be assessed with held-out ...
5
Show split sensitivity when material
Large changes across reasonable splits weaken ...
6
Avoid in-sample nuisance residualization
Flexible learners can memorize rows and distor...
The checklist makes split design part of the causal audit trail. It also helps separate honest split sensitivity from casual rerunning until a pleasing estimate appears.
Report Template And Artifact Manifest
The final cell writes a short split-design report template and an artifact manifest. The template is intentionally concise, but it includes the details that should be visible in a serious DoubleML write-up.
report_text =f"""# Sample Splitting And Cross-Fitting Report Template## Model- DoubleML model class:- Treatment:- Outcome:- Control set:- Primary learner:## Split Design- Number of folds:- Number of repetitions:- Random seed(s):- Internal or external splits:- Group, cluster, or time constraints:## Diagnostics- Outcome nuisance RMSE:- Treatment nuisance RMSE:- Split sensitivity across fold counts:- Repetition-specific estimate range:- Group overlap audit, if applicable:## Final Estimate- Point estimate:- Standard error:- Confidence interval:- Split-related caveats:""".strip()report_path = REPORT_DIR /f"{NOTEBOOK_PREFIX}_sample_splitting_report_template.md"report_path.write_text(report_text)artifact_manifest = pd.DataFrame( [ {"artifact": "synthetic PLR data", "path": str(DATASET_DIR /f"{NOTEBOOK_PREFIX}_synthetic_panel_like_plr_data.csv")}, {"artifact": "fold count comparison", "path": str(TABLE_DIR /f"{NOTEBOOK_PREFIX}_fold_count_comparison.csv")}, {"artifact": "repeated cross-fitting summary", "path": str(TABLE_DIR /f"{NOTEBOOK_PREFIX}_repeated_cross_fitting_summary.csv")}, {"artifact": "external split audit", "path": str(TABLE_DIR /f"{NOTEBOOK_PREFIX}_external_split_group_audit.csv")}, {"artifact": "manual no-cross-fit caution", "path": str(TABLE_DIR /f"{NOTEBOOK_PREFIX}_manual_no_cross_fit_caution.csv")}, {"artifact": "report template", "path": str(report_path)}, {"artifact": "cross-fitting workflow figure", "path": str(FIGURE_DIR /f"{NOTEBOOK_PREFIX}_cross_fitting_workflow.png")}, {"artifact": "fold assignment heatmap", "path": str(FIGURE_DIR /f"{NOTEBOOK_PREFIX}_fold_assignment_heatmap.png")}, ])save_table(artifact_manifest, f"{NOTEBOOK_PREFIX}_artifact_manifest.csv")display(Markdown(f"Report template written to `{report_path}`"))display(artifact_manifest)
Report template written to /home/apex/Documents/ranking_sys/notebooks/tutorials/doubleml/outputs/reports/11_sample_splitting_report_template.md
artifact
path
0
synthetic PLR data
/home/apex/Documents/ranking_sys/notebooks/tut...
1
fold count comparison
/home/apex/Documents/ranking_sys/notebooks/tut...
2
repeated cross-fitting summary
/home/apex/Documents/ranking_sys/notebooks/tut...
3
external split audit
/home/apex/Documents/ranking_sys/notebooks/tut...
4
manual no-cross-fit caution
/home/apex/Documents/ranking_sys/notebooks/tut...
5
report template
/home/apex/Documents/ranking_sys/notebooks/tut...
6
cross-fitting workflow figure
/home/apex/Documents/ranking_sys/notebooks/tut...
7
fold assignment heatmap
/home/apex/Documents/ranking_sys/notebooks/tut...
The notebook now has a full split-design workflow: theory, API mechanics, fold-count sensitivity, repeated cross-fitting, external group splits, and a concrete warning against in-sample residualization.
What Comes Next
The next natural topic is inference: standard errors, confidence intervals, bootstrap options, joint inference, and how to communicate uncertainty from DoubleML estimates.