Advanced SEM and ML Mediation Models

The previous notebooks built the discovery-quality mediation design, estimated a transparent linear g-computation decomposition, and stress-tested the result across definitions and specifications. This notebook adds a more advanced modeling layer.

The goal is not to replace the earlier notebooks. The goal is to ask whether the same story survives when we use more flexible models and a more formal path-model framing.

This notebook has four advanced pieces:

The main question remains the same: does high discovery exposure increase future user value, and is that effect meaningfully routed through same-day satisfaction depth?

1. Load Libraries and Paths

This cell imports the modeling libraries used in the notebook. LightGBM and XGBoost are used for flexible nuisance models, while linear models provide the transparent baseline and SEM-style path estimates. The path setup mirrors earlier notebooks so the outputs land in the same processed and writeup folders.

from pathlib import Path
import os
import warnings

os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display
from sklearn.base import clone
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import GroupKFold
from sklearn.preprocessing import StandardScaler

from lightgbm import LGBMRegressor
from xgboost import XGBRegressor

warnings.filterwarnings("ignore", category=FutureWarning)

sns.set_theme(style="whitegrid", context="notebook")
plt.rcParams["figure.figsize"] = (11, 6)
plt.rcParams["axes.titlesize"] = 13
plt.rcParams["axes.labelsize"] = 11
pd.set_option("display.max_columns", 180)
pd.set_option("display.max_colwidth", 140)

PROJECT_ROOT = Path.cwd().resolve()
while not (PROJECT_ROOT / "data").exists() and PROJECT_ROOT.parent != PROJECT_ROOT:
    PROJECT_ROOT = PROJECT_ROOT.parent

PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
NOTEBOOK_DIR = PROJECT_ROOT / "notebooks" / "discovery_quality_mediation"
WRITEUP_DIR = NOTEBOOK_DIR / "writeup"
FIGURE_DIR = WRITEUP_DIR / "figures"
TABLE_DIR = WRITEUP_DIR / "tables"

FIGURE_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR.mkdir(parents=True, exist_ok=True)

ANALYSIS_PANEL_INPUT = PROCESSED_DIR / "kuairec_discovery_quality_mediation_analysis_panel.parquet"
EFFECT_SUMMARY_INPUT = PROCESSED_DIR / "kuairec_discovery_quality_effect_summary.csv"
ROBUSTNESS_SUMMARY_INPUT = PROCESSED_DIR / "kuairec_discovery_quality_robustness_summary.csv"

SEM_PATH_OUTPUT = PROCESSED_DIR / "kuairec_discovery_quality_sem_path_results.csv"
SEM_BOOTSTRAP_OUTPUT = PROCESSED_DIR / "kuairec_discovery_quality_sem_bootstrap.csv"
ML_EFFECTS_OUTPUT = PROCESSED_DIR / "kuairec_discovery_quality_advanced_ml_effects.csv"
ML_PERFORMANCE_OUTPUT = PROCESSED_DIR / "kuairec_discovery_quality_advanced_ml_performance.csv"
HETEROGENEITY_OUTPUT = PROCESSED_DIR / "kuairec_discovery_quality_advanced_heterogeneity.csv"
FEATURE_IMPORTANCE_OUTPUT = PROCESSED_DIR / "kuairec_discovery_quality_advanced_feature_importance.csv"
ADVANCED_SUMMARY_OUTPUT = PROCESSED_DIR / "kuairec_discovery_quality_advanced_model_summary.csv"

The thread settings keep LightGBM and XGBoost from taking over the machine. The dataset is small enough that single-threaded tree models are still fast and easier to reproduce.

2. Load Analysis Inputs

This cell loads the saved mediation analysis panel and the previous result tables. The advanced notebook should not reconstruct the panel from raw logs; it should reuse the cleaned analysis object created earlier.

analysis_panel = pd.read_parquet(ANALYSIS_PANEL_INPUT)
effect_summary = pd.read_csv(EFFECT_SUMMARY_INPUT)
robustness_summary = pd.read_csv(ROBUSTNESS_SUMMARY_INPUT)

load_summary = pd.DataFrame(
    {
        "artifact": ["analysis_panel", "effect_summary", "robustness_summary"],
        "rows": [len(analysis_panel), len(effect_summary), len(robustness_summary)],
        "columns": [analysis_panel.shape[1], effect_summary.shape[1], robustness_summary.shape[1]],
    }
)

main_effects = effect_summary.query(
    "outcome == 'Y_future_interactions' and estimand in ['gcomp_total_effect', 'natural_direct_effect', 'natural_indirect_effect']"
)[["estimand", "estimate", "ci_95_lower", "ci_95_upper", "relative_effect"]]

display(load_summary)
display(main_effects.round(4))
display(robustness_summary.round(4))
artifact rows columns
0 analysis_panel 8199 103
1 effect_summary 21 12
2 robustness_summary 4 9
estimand estimate ci_95_lower ci_95_upper relative_effect
2 gcomp_total_effect 37.2781 32.2444 42.8135 0.1152
3 natural_direct_effect 38.5858 33.3634 43.7154 0.1192
4 natural_indirect_effect -1.3077 -2.3047 -0.4290 -0.0036
robustness_family specifications total_effect_min total_effect_max total_effect_share_positive indirect_effect_min indirect_effect_max indirect_effect_share_positive median_abs_indirect_to_total_ratio
0 threshold_sensitivity 4 37.2781 40.0555 1.0 -1.6532 -1.3077 0.0 0.0402
1 mediator_sensitivity 5 37.2560 37.3502 1.0 -1.7145 0.1281 0.2 0.0359
2 model_sensitivity 5 34.8085 37.6289 1.0 -1.5581 -0.9360 0.0 0.0351
3 outcome_sensitivity 1 37.2781 37.2781 1.0 -1.3077 -1.3077 0.0 0.0351

The earlier result is the benchmark. The advanced models should be judged against this reference, not treated as automatically better just because they are more flexible.

3. Define Variables and Control Sets

This cell defines the treatment, mediator, outcome, and covariates. The covariates are the same pre-treatment history and profile features used in prior notebooks. This keeps the comparison focused on model form rather than changing the adjustment set.

TREATMENT = "A_high_discovery"
MEDIATOR = "M_satisfaction_depth"
PRIMARY_OUTCOME = "Y_future_interactions"
SECONDARY_OUTCOME = "Y_future_play_hours"
WEIGHT_COLUMN = "stabilized_ipw_capped"
GROUP_COLUMN = "user_id"
RANDOM_SEED = 42
N_SPLITS = 3
SEM_BOOTSTRAP_ITERATIONS = 100

base_numeric_covariates = [
    "calendar_day_index",
    "lag_1_active_day",
    "prior_3day_active_day",
    "lag_1_interactions",
    "prior_3day_interactions",
    "lag_1_total_play_duration_sec",
    "prior_3day_total_play_duration_sec",
    "lag_1_valid_play_share",
    "prior_3day_valid_play_share",
    "lag_1_high_satisfaction_share",
    "prior_3day_high_satisfaction_share",
    "lag_1_discovery_candidate_share",
    "prior_3day_discovery_candidate_share",
    "recent_activity_score",
    "register_days",
    "follow_user_num",
    "fans_user_num",
    "friend_user_num",
    "is_lowactive_period",
    "is_live_streamer",
    "is_video_author",
]
profile_onehot_covariates = [col for col in analysis_panel.columns if col.startswith("onehot_feat")]
numeric_covariates = [col for col in base_numeric_covariates + profile_onehot_covariates if col in analysis_panel.columns]
categorical_covariates = [
    col
    for col in [
        "user_active_degree",
        "follow_user_num_range",
        "fans_user_num_range",
        "friend_user_num_range",
        "register_days_range",
    ]
    if col in analysis_panel.columns
]

variable_summary = pd.DataFrame(
    {
        "item": ["treatment", "mediator", "primary_outcome", "secondary_outcome", "numeric_covariates", "categorical_covariates", "group_folds"],
        "value": [TREATMENT, MEDIATOR, PRIMARY_OUTCOME, SECONDARY_OUTCOME, len(numeric_covariates), len(categorical_covariates), N_SPLITS],
    }
)

display(variable_summary)
item value
0 treatment A_high_discovery
1 mediator M_satisfaction_depth
2 primary_outcome Y_future_interactions
3 secondary_outcome Y_future_play_hours
4 numeric_covariates 39
5 categorical_covariates 5
6 group_folds 3

The cross-fitting group is user_id, so the same user’s days do not appear in both train and validation folds. That is important because each user contributes repeated active days.

4. Build the Modeling Matrix

This cell prepares the adjustment matrix used by the SEM and ML models. Numeric columns are median-filled, categorical columns are one-hot encoded, and constant columns are removed.

def make_safe_column_names(columns):
    safe_names = []
    counts = {}
    for col in columns:
        safe = "".join(ch if ch.isalnum() else "_" for ch in str(col))
        safe = "_".join(part for part in safe.split("_") if part)
        if not safe:
            safe = "feature"
        if safe[0].isdigit():
            safe = f"feature_{safe}"
        counts[safe] = counts.get(safe, 0) + 1
        if counts[safe] > 1:
            safe = f"{safe}_{counts[safe]}"
        safe_names.append(safe)
    return safe_names


def build_covariate_matrix(frame, numeric_cols, categorical_cols):
    numeric = frame[numeric_cols].copy()
    for col in numeric.columns:
        numeric[col] = pd.to_numeric(numeric[col], errors="coerce")
        numeric[col] = numeric[col].fillna(numeric[col].median())

    if categorical_cols:
        categorical = frame[categorical_cols].copy()
        for col in categorical.columns:
            categorical[col] = categorical[col].astype("string").fillna("missing")
        encoded = pd.get_dummies(categorical, drop_first=True, dtype=float)
        design = pd.concat([numeric, encoded], axis=1)
    else:
        design = numeric

    design = design.loc[:, design.nunique(dropna=False) > 1]
    design.columns = make_safe_column_names(design.columns)
    return design.astype(float)

X_covariates = build_covariate_matrix(analysis_panel, numeric_covariates, categorical_covariates)
A = analysis_panel[TREATMENT].astype(float).to_numpy()
M = analysis_panel[MEDIATOR].astype(float).to_numpy()
Y = analysis_panel[PRIMARY_OUTCOME].astype(float).to_numpy()
weights = analysis_panel[WEIGHT_COLUMN].astype(float).to_numpy()
groups = analysis_panel[GROUP_COLUMN].to_numpy()

matrix_summary = pd.DataFrame(
    {
        "item": ["rows", "covariate_columns", "treatment_rate", "mediator_mean", "outcome_mean", "unique_users"],
        "value": [len(analysis_panel), X_covariates.shape[1], A.mean(), M.mean(), Y.mean(), analysis_panel[GROUP_COLUMN].nunique()],
    }
)

display(matrix_summary)
display(X_covariates.head())
item value
0 rows 8199.000000
1 covariate_columns 54.000000
2 treatment_rate 0.500061
3 mediator_mean 0.618626
4 outcome_mean 340.694475
5 unique_users 133.000000
calendar_day_index lag_1_active_day prior_3day_active_day lag_1_interactions prior_3day_interactions lag_1_total_play_duration_sec prior_3day_total_play_duration_sec lag_1_valid_play_share prior_3day_valid_play_share lag_1_high_satisfaction_share prior_3day_high_satisfaction_share lag_1_discovery_candidate_share prior_3day_discovery_candidate_share recent_activity_score register_days follow_user_num fans_user_num friend_user_num is_video_author onehot_feat0 onehot_feat1 onehot_feat2 onehot_feat3 onehot_feat4 onehot_feat6 onehot_feat7 onehot_feat8 onehot_feat9 onehot_feat10 onehot_feat11 onehot_feat12 onehot_feat13 onehot_feat14 onehot_feat15 onehot_feat17 user_active_degree_full_active user_active_degree_high_active follow_user_num_range_10_50 follow_user_num_range_100_150 follow_user_num_range_150_250 follow_user_num_range_250_500 follow_user_num_range_50_100 follow_user_num_range_0 follow_user_num_range_500 fans_user_num_range_1_10 fans_user_num_range_10_100 friend_user_num_range_1_5 friend_user_num_range_30_60 friend_user_num_range_5_30 register_days_range_31_60 register_days_range_366_730 register_days_range_61_90 register_days_range_730 register_days_range_91_180
0 0.0 0.0 0.0 0.0 0.0 0.000 0.000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 224.0 7.0 3.0 0.0 0.0 0.0 1.0 24.0 876.0 1.0 1.0 4.0 98.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
1 1.0 1.0 1.0 32.0 32.0 163.970 163.970 0.937500 0.937500 0.156250 0.156250 0.687500 0.687500 0.562990 224.0 7.0 3.0 0.0 0.0 0.0 1.0 24.0 876.0 1.0 1.0 4.0 98.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
2 2.0 1.0 2.0 20.0 52.0 130.986 294.956 0.950000 1.887500 0.350000 0.506250 0.450000 1.137500 0.639277 224.0 7.0 3.0 0.0 0.0 0.0 1.0 24.0 876.0 1.0 1.0 4.0 98.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
3 3.0 1.0 3.0 16.0 68.0 100.920 395.876 1.000000 2.887500 0.187500 0.693750 0.437500 1.575000 0.681755 224.0 7.0 3.0 0.0 0.0 0.0 1.0 24.0 876.0 1.0 1.0 4.0 98.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
4 4.0 1.0 3.0 37.0 73.0 222.720 454.626 0.891892 2.841892 0.216216 0.753716 0.432432 1.319932 0.693019 224.0 7.0 3.0 0.0 0.0 0.0 1.0 24.0 876.0 1.0 1.0 4.0 98.0 6.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0

The modeling matrix is intentionally reused across methods. SEM, Linear Regression, LightGBM, and XGBoost all see the same observed adjustment information.

5. SEM-Style Path Model

This cell estimates a simple path model:

  • M = alpha * A + X + error
  • Y = c_prime * A + beta * M + X + error
  • indirect = alpha * beta
  • total_path = c_prime + alpha * beta

This is not a full latent-variable SEM. It is a transparent path-analysis version of the mediation story, which is often enough for product analytics settings where the variables are observed scores.

def design_with_columns(covariates, **named_arrays):
    design = covariates.copy()
    for name, values in reversed(list(named_arrays.items())):
        design.insert(0, name, np.asarray(values, dtype=float))
    return design


def fit_sem_path(frame, covariates):
    a = frame[TREATMENT].astype(float).to_numpy()
    m = frame[MEDIATOR].astype(float).to_numpy()
    y = frame[PRIMARY_OUTCOME].astype(float).to_numpy()
    w = frame[WEIGHT_COLUMN].astype(float).to_numpy()

    mediator_design = design_with_columns(covariates, A=a)
    mediator_model = LinearRegression().fit(mediator_design, m, sample_weight=w)
    alpha = float(mediator_model.coef_[0])

    outcome_design = design_with_columns(covariates, A=a, M=m)
    outcome_model = LinearRegression().fit(outcome_design, y, sample_weight=w)
    c_prime = float(outcome_model.coef_[0])
    beta = float(outcome_model.coef_[1])

    total_design = design_with_columns(covariates, A=a)
    total_model = LinearRegression().fit(total_design, y, sample_weight=w)
    total_coefficient = float(total_model.coef_[0])

    return {
        "alpha_A_to_M": alpha,
        "beta_M_to_Y": beta,
        "c_prime_A_to_Y": c_prime,
        "indirect_alpha_beta": alpha * beta,
        "total_path_c_prime_plus_indirect": c_prime + alpha * beta,
        "total_regression_coefficient": total_coefficient,
        "mediator_model_r2": mediator_model.score(mediator_design, m, sample_weight=w),
        "outcome_model_r2": outcome_model.score(outcome_design, y, sample_weight=w),
        "total_model_r2": total_model.score(total_design, y, sample_weight=w),
    }

sem_path_results = pd.DataFrame([fit_sem_path(analysis_panel, X_covariates)])

display(sem_path_results.round(4))
alpha_A_to_M beta_M_to_Y c_prime_A_to_Y indirect_alpha_beta total_path_c_prime_plus_indirect total_regression_coefficient mediator_model_r2 outcome_model_r2 total_model_r2
0 0.0231 -48.4626 38.4311 -1.1172 37.3139 37.3139 0.4653 0.6642 0.6637

The SEM-style table gives a compact path decomposition. If alpha is positive but beta is negative, high discovery can increase satisfaction depth while satisfaction depth is not the channel that raises future interaction count.

6. Bootstrap the SEM-Style Paths by User

This cell bootstraps the path model by resampling users. The purpose is to quantify how stable the SEM-style path coefficients are under the repeated-user structure.

rng = np.random.default_rng(RANDOM_SEED)
unique_users = np.sort(analysis_panel[GROUP_COLUMN].unique())
indexed_panel = analysis_panel.set_index(GROUP_COLUMN, drop=False)
sem_bootstrap_rows = []

for iteration in range(SEM_BOOTSTRAP_ITERATIONS):
    sampled_users = rng.choice(unique_users, size=len(unique_users), replace=True)
    sample = indexed_panel.loc[sampled_users].reset_index(drop=True)
    sample_covariates = build_covariate_matrix(sample, numeric_covariates, categorical_covariates)
    sample_covariates = sample_covariates.reindex(columns=X_covariates.columns, fill_value=0.0)
    row = fit_sem_path(sample, sample_covariates)
    row["bootstrap_iteration"] = iteration
    sem_bootstrap_rows.append(row)

sem_bootstrap = pd.DataFrame(sem_bootstrap_rows)
sem_interval_rows = []
for column in [
    "alpha_A_to_M",
    "beta_M_to_Y",
    "c_prime_A_to_Y",
    "indirect_alpha_beta",
    "total_path_c_prime_plus_indirect",
    "total_regression_coefficient",
]:
    sem_interval_rows.append(
        {
            "path_quantity": column,
            "estimate": sem_path_results.loc[0, column],
            "bootstrap_mean": sem_bootstrap[column].mean(),
            "ci_95_lower": sem_bootstrap[column].quantile(0.025),
            "ci_95_upper": sem_bootstrap[column].quantile(0.975),
        }
    )

sem_path_summary = pd.DataFrame(sem_interval_rows)

display(sem_path_summary.round(4))
path_quantity estimate bootstrap_mean ci_95_lower ci_95_upper
0 alpha_A_to_M 0.0231 0.0235 0.0183 0.0281
1 beta_M_to_Y -48.4626 -48.4765 -74.0799 -21.9557
2 c_prime_A_to_Y 38.4311 38.7951 32.6143 44.0143
3 indirect_alpha_beta -1.1172 -1.1509 -1.8332 -0.4317
4 total_path_c_prime_plus_indirect 37.3139 37.6442 31.4660 42.8712
5 total_regression_coefficient 37.3139 37.6442 31.4660 42.8712

The bootstrapped path table is the formal counterpart to the earlier g-computation results. It lets us compare whether the simple path-product indirect effect points in the same direction as the notebook 04 indirect estimate.

7. Plot SEM Path Quantities

This cell plots the SEM-style path quantities with bootstrap intervals. The visual focuses on the direct path, indirect path, and total path because these are the quantities that correspond most closely to the earlier decomposition.

sem_plot = sem_path_summary.query(
    "path_quantity in ['c_prime_A_to_Y', 'indirect_alpha_beta', 'total_path_c_prime_plus_indirect', 'total_regression_coefficient']"
).copy()
sem_plot["label"] = sem_plot["path_quantity"].map(
    {
        "c_prime_A_to_Y": "Direct path",
        "indirect_alpha_beta": "Indirect path",
        "total_path_c_prime_plus_indirect": "Direct + indirect path",
        "total_regression_coefficient": "Total regression path",
    }
)
sem_plot["lower_error"] = sem_plot["estimate"] - sem_plot["ci_95_lower"]
sem_plot["upper_error"] = sem_plot["ci_95_upper"] - sem_plot["estimate"]

fig, ax = plt.subplots(figsize=(10, 5.5))
sns.barplot(data=sem_plot, x="estimate", y="label", color="steelblue", ax=ax)
for row_index, row in sem_plot.reset_index(drop=True).iterrows():
    ax.errorbar(
        x=row["estimate"],
        y=row_index,
        xerr=[[row["lower_error"]], [row["upper_error"]]],
        fmt="none",
        color="black",
        capsize=4,
        linewidth=1.2,
    )
ax.axvline(0, color="black", linewidth=1)
ax.set_title("SEM-Style Path Decomposition")
ax.set_xlabel("Effect on future 7-day interactions")
ax.set_ylabel("")
plt.tight_layout()
fig.savefig(FIGURE_DIR / "21_sem_path_decomposition.png", dpi=160, bbox_inches="tight")
plt.show()

The SEM plot should be read as a path-model summary, not as a new identification strategy. It is useful because it compresses the mediation story into a few path coefficients.

8. Define Cross-Fitted ML Nuisance Models

This cell defines the model families used in the cross-fitted g-computation. The Linear model is the transparent baseline. LightGBM and XGBoost allow nonlinearities and interactions in the nuisance functions.

model_factories = {
    "linear": lambda: LinearRegression(),
    "lightgbm": lambda: LGBMRegressor(
        n_estimators=80,
        learning_rate=0.05,
        num_leaves=24,
        subsample=0.85,
        colsample_bytree=0.85,
        random_state=RANDOM_SEED,
        n_jobs=1,
        verbose=-1,
    ),
    "xgboost": lambda: XGBRegressor(
        n_estimators=80,
        max_depth=3,
        learning_rate=0.05,
        subsample=0.85,
        colsample_bytree=0.85,
        objective="reg:squarederror",
        tree_method="hist",
        random_state=RANDOM_SEED,
        n_jobs=1,
        verbosity=0,
    ),
}

model_summary = pd.DataFrame(
    {
        "model_family": list(model_factories.keys()),
        "role": [
            "transparent benchmark",
            "tree-based nonlinear nuisance model",
            "boosted-tree nonlinear nuisance model",
        ],
    }
)

display(model_summary)
model_family role
0 linear transparent benchmark
1 lightgbm tree-based nonlinear nuisance model
2 xgboost boosted-tree nonlinear nuisance model

All three model families estimate the same counterfactual quantities. That keeps the comparison fair: any differences come from nuisance model flexibility, not from changing the causal estimand.

9. Implement Cross-Fitted ML G-Computation

This cell implements group cross-fitting. For each fold, models are trained on some users and evaluated on held-out users. The estimator predicts M(0), M(1), Y(0, M(0)), Y(1, M(0)), and Y(1, M(1)) out of fold, then averages the implied effects.

def treatment_design(covariates, treatment_values):
    design = covariates.copy()
    design.insert(0, TREATMENT, np.asarray(treatment_values, dtype=float))
    return design


def outcome_design(covariates, treatment_values, mediator_values):
    treatment_array = np.asarray(treatment_values, dtype=float)
    mediator_array = np.asarray(mediator_values, dtype=float)
    design = covariates.copy()
    design.insert(0, "A_x_M", treatment_array * mediator_array)
    design.insert(0, MEDIATOR, mediator_array)
    design.insert(0, TREATMENT, treatment_array)
    return design


def fit_with_optional_weight(model, X, y, sample_weight=None):
    try:
        model.fit(X, y, sample_weight=sample_weight)
    except TypeError:
        model.fit(X, y)
    return model


def regression_metrics(y_true, y_pred):
    rmse = mean_squared_error(y_true, y_pred) ** 0.5
    return {
        "rmse": rmse,
        "mae": mean_absolute_error(y_true, y_pred),
        "r2": r2_score(y_true, y_pred),
    }


def crossfit_ml_mediation(model_name, model_factory, outcome_col=PRIMARY_OUTCOME):
    n = len(analysis_panel)
    splitter = GroupKFold(n_splits=N_SPLITS)

    m_hat = np.zeros(n)
    y_hat = np.zeros(n)
    m0 = np.zeros(n)
    m1 = np.zeros(n)
    y_0_m0 = np.zeros(n)
    y_1_m0 = np.zeros(n)
    y_1_m1 = np.zeros(n)
    y_0_m1 = np.zeros(n)

    fold_rows = []
    for fold_id, (train_idx, test_idx) in enumerate(splitter.split(X_covariates, A, groups=groups), start=1):
        X_train = X_covariates.iloc[train_idx]
        X_test = X_covariates.iloc[test_idx]
        A_train = A[train_idx]
        M_train = M[train_idx]
        Y_train = analysis_panel[outcome_col].astype(float).to_numpy()[train_idx]
        W_train = weights[train_idx]

        mediator_model = model_factory()
        outcome_model = model_factory()

        mediator_model = fit_with_optional_weight(
            mediator_model,
            treatment_design(X_train, A_train),
            M_train,
            sample_weight=W_train,
        )
        train_outcome_design = outcome_design(X_train, A_train, M_train)
        outcome_model = fit_with_optional_weight(
            outcome_model,
            train_outcome_design,
            Y_train,
            sample_weight=W_train,
        )

        zeros = np.zeros(len(test_idx))
        ones = np.ones(len(test_idx))
        observed_a = A[test_idx]
        observed_m = M[test_idx]

        m_hat[test_idx] = mediator_model.predict(treatment_design(X_test, observed_a)).clip(0, 1)
        m0[test_idx] = mediator_model.predict(treatment_design(X_test, zeros)).clip(0, 1)
        m1[test_idx] = mediator_model.predict(treatment_design(X_test, ones)).clip(0, 1)

        y_hat[test_idx] = outcome_model.predict(outcome_design(X_test, observed_a, observed_m))
        y_0_m0[test_idx] = outcome_model.predict(outcome_design(X_test, zeros, m0[test_idx]))
        y_1_m0[test_idx] = outcome_model.predict(outcome_design(X_test, ones, m0[test_idx]))
        y_1_m1[test_idx] = outcome_model.predict(outcome_design(X_test, ones, m1[test_idx]))
        y_0_m1[test_idx] = outcome_model.predict(outcome_design(X_test, zeros, m1[test_idx]))

        fold_rows.append(
            {
                "model_family": model_name,
                "outcome": outcome_col,
                "fold": fold_id,
                "train_rows": len(train_idx),
                "test_rows": len(test_idx),
                "test_users": len(np.unique(groups[test_idx])),
            }
        )

    observed_y = analysis_panel[outcome_col].astype(float).to_numpy()
    mediator_metrics = regression_metrics(M, m_hat)
    outcome_metrics = regression_metrics(observed_y, y_hat)

    effects = {
        "model_family": model_name,
        "outcome": outcome_col,
        "outcome_label": "Future 7-day interactions" if outcome_col == PRIMARY_OUTCOME else "Future 7-day play hours",
        "gcomp_total_effect": float(np.mean(y_1_m1 - y_0_m0)),
        "natural_direct_effect": float(np.mean(y_1_m0 - y_0_m0)),
        "natural_indirect_effect": float(np.mean(y_1_m1 - y_1_m0)),
        "reverse_indirect_check": float(np.mean(y_0_m1 - y_0_m0)),
        "mediator_shift": float(np.mean(m1 - m0)),
        "reference_mean": float(np.mean(y_0_m0)),
        "relative_total_effect": float(np.mean(y_1_m1 - y_0_m0) / np.mean(y_0_m0)),
    }

    performance = [
        {"model_family": model_name, "outcome": outcome_col, "target": "mediator", **mediator_metrics},
        {"model_family": model_name, "outcome": outcome_col, "target": "outcome", **outcome_metrics},
    ]

    row_level = pd.DataFrame(
        {
            "user_id": analysis_panel[GROUP_COLUMN].to_numpy(),
            "event_date": analysis_panel["event_date"].to_numpy(),
            "model_family": model_name,
            "outcome": outcome_col,
            "ite_total": y_1_m1 - y_0_m0,
            "ite_direct": y_1_m0 - y_0_m0,
            "ite_indirect": y_1_m1 - y_1_m0,
            "m0_hat": m0,
            "m1_hat": m1,
            "m_hat": m_hat,
            "y_hat": y_hat,
        }
    )

    return effects, performance, pd.DataFrame(fold_rows), row_level

print("Cross-fitted ML mediation function ready.")
Cross-fitted ML mediation function ready.

The row-level output is useful for heterogeneity. Each held-out row receives an estimated total, direct, and indirect effect, which can then be summarized by user segments.

10. Run Cross-Fitted ML Mediation

This cell runs the cross-fitted estimator for Linear Regression, LightGBM, and XGBoost. It estimates the primary future-interaction outcome and a secondary future-play-hours outcome.

ml_effect_rows = []
ml_performance_rows = []
fold_summaries = []
row_level_effect_frames = []

for outcome_col in [PRIMARY_OUTCOME, SECONDARY_OUTCOME]:
    for model_name, model_factory in model_factories.items():
        effects, performance, folds, row_level = crossfit_ml_mediation(
            model_name,
            model_factory,
            outcome_col=outcome_col,
        )
        ml_effect_rows.append(effects)
        ml_performance_rows.extend(performance)
        fold_summaries.append(folds)
        row_level_effect_frames.append(row_level)

ml_effects = pd.DataFrame(ml_effect_rows)
ml_performance = pd.DataFrame(ml_performance_rows)
ml_fold_summary = pd.concat(fold_summaries, ignore_index=True)
row_level_effects = pd.concat(row_level_effect_frames, ignore_index=True)

display(ml_effects.round(4))
display(ml_performance.round(4))
model_family outcome outcome_label gcomp_total_effect natural_direct_effect natural_indirect_effect reverse_indirect_check mediator_shift reference_mean relative_total_effect
0 linear Y_future_interactions Future 7-day interactions 37.3421 38.6464 -1.3043 -0.9904 0.0231 322.8968 0.1156
1 lightgbm Y_future_interactions Future 7-day interactions 3.2069 3.2652 -0.0583 -0.0137 0.0184 339.1834 0.0095
2 xgboost Y_future_interactions Future 7-day interactions 2.6829 2.7279 -0.0450 -0.0294 0.0138 339.7969 0.0079
3 linear Y_future_play_hours Future 7-day play hours 0.0786 0.0745 0.0042 0.0021 0.0231 0.7480 0.1051
4 lightgbm Y_future_play_hours Future 7-day play hours 0.0076 0.0030 0.0046 0.0046 0.0184 0.8177 0.0093
5 xgboost Y_future_play_hours Future 7-day play hours 0.0062 0.0034 0.0028 0.0027 0.0138 0.8230 0.0075
model_family outcome target rmse mae r2
0 linear Y_future_interactions mediator 0.1015 0.0703 0.1878
1 linear Y_future_interactions outcome 96.6374 71.1050 0.7129
2 lightgbm Y_future_interactions mediator 0.0875 0.0629 0.3959
3 lightgbm Y_future_interactions outcome 66.7067 47.0941 0.8632
4 xgboost Y_future_interactions mediator 0.0870 0.0625 0.4026
5 xgboost Y_future_interactions outcome 66.0357 46.7289 0.8659
6 linear Y_future_play_hours mediator 0.1015 0.0703 0.1878
7 linear Y_future_play_hours outcome 0.3609 0.2428 0.4056
8 lightgbm Y_future_play_hours mediator 0.0875 0.0629 0.3959
9 lightgbm Y_future_play_hours outcome 0.2388 0.1672 0.7398
10 xgboost Y_future_play_hours mediator 0.0870 0.0625 0.4026
11 xgboost Y_future_play_hours outcome 0.2383 0.1694 0.7409

The model comparison is now cross-fitted. If the flexible models preserve the positive total effect and small mediated pathway, the earlier story is less likely to be an artifact of linear functional form.

11. Plot Cross-Fitted Model Comparison

This cell visualizes total, direct, and indirect effects across model families for the primary outcome. The plot uses the cross-fitted estimates from held-out users.

primary_ml_plot = ml_effects.query("outcome == @PRIMARY_OUTCOME").melt(
    id_vars=["model_family"],
    value_vars=["gcomp_total_effect", "natural_direct_effect", "natural_indirect_effect"],
    var_name="effect_type",
    value_name="estimate",
)
primary_ml_plot["effect_label"] = primary_ml_plot["effect_type"].map(
    {
        "gcomp_total_effect": "Total",
        "natural_direct_effect": "Direct",
        "natural_indirect_effect": "Indirect",
    }
)
primary_ml_plot["model_label"] = primary_ml_plot["model_family"].str.upper()

fig, ax = plt.subplots(figsize=(11, 6))
sns.barplot(data=primary_ml_plot, x="estimate", y="model_label", hue="effect_label", ax=ax)
ax.axvline(0, color="black", linewidth=1)
ax.set_title("Cross-Fitted ML Mediation Effect Comparison")
ax.set_xlabel("Effect on future 7-day interactions")
ax.set_ylabel("Model family")
plt.tight_layout()
fig.savefig(FIGURE_DIR / "21_crossfit_ml_effect_comparison.png", dpi=160, bbox_inches="tight")
plt.show()

This comparison is the main advanced-model result. Agreement across Linear, LightGBM, and XGBoost is more reassuring than any one model alone.

12. Compare Predictive Performance

This cell compares out-of-fold predictive performance for the mediator and outcome nuisance models. Stronger prediction does not automatically mean a better causal estimate, but very weak prediction can make counterfactual estimates unstable.

performance_plot = ml_performance.copy()
performance_plot["model_label"] = performance_plot["model_family"].str.upper()
performance_plot["target_label"] = performance_plot["target"].str.title()
performance_plot["outcome_label"] = performance_plot["outcome"].map(
    {PRIMARY_OUTCOME: "Future interactions", SECONDARY_OUTCOME: "Future play hours"}
)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for ax, metric in zip(axes, ["rmse", "r2"]):
    sns.barplot(
        data=performance_plot.query("outcome == @PRIMARY_OUTCOME"),
        x=metric,
        y="model_label",
        hue="target_label",
        ax=ax,
    )
    ax.set_title(f"Primary Outcome Nuisance {metric.upper()}")
    ax.set_xlabel(metric.upper())
    ax.set_ylabel("Model family")
plt.tight_layout()
fig.savefig(FIGURE_DIR / "22_crossfit_nuisance_performance.png", dpi=160, bbox_inches="tight")
plt.show()

display(ml_performance.round(4))

model_family outcome target rmse mae r2
0 linear Y_future_interactions mediator 0.1015 0.0703 0.1878
1 linear Y_future_interactions outcome 96.6374 71.1050 0.7129
2 lightgbm Y_future_interactions mediator 0.0875 0.0629 0.3959
3 lightgbm Y_future_interactions outcome 66.7067 47.0941 0.8632
4 xgboost Y_future_interactions mediator 0.0870 0.0625 0.4026
5 xgboost Y_future_interactions outcome 66.0357 46.7289 0.8659
6 linear Y_future_play_hours mediator 0.1015 0.0703 0.1878
7 linear Y_future_play_hours outcome 0.3609 0.2428 0.4056
8 lightgbm Y_future_play_hours mediator 0.0875 0.0629 0.3959
9 lightgbm Y_future_play_hours outcome 0.2388 0.1672 0.7398
10 xgboost Y_future_play_hours mediator 0.0870 0.0625 0.4026
11 xgboost Y_future_play_hours outcome 0.2383 0.1694 0.7409

The performance table helps explain model behavior. If a flexible model changes the effect estimate but does not improve held-out performance, that change deserves caution.

13. Heterogeneity from Cross-Fitted Effects

This cell uses LightGBM’s row-level cross-fitted effects to summarize heterogeneity. The groups are based on pre-treatment information: prior activity, prior satisfaction, prior discovery, and user activity degree.

heterogeneity_base = analysis_panel[[
    "user_id",
    "event_date",
    "prior_3day_interactions",
    "prior_3day_high_satisfaction_share",
    "prior_3day_discovery_candidate_share",
    "user_active_degree",
]].copy()

heterogeneity_base["prior_activity_group"] = pd.qcut(
    heterogeneity_base["prior_3day_interactions"].rank(method="first"),
    q=3,
    labels=["low prior activity", "medium prior activity", "high prior activity"],
)
heterogeneity_base["prior_satisfaction_group"] = pd.qcut(
    heterogeneity_base["prior_3day_high_satisfaction_share"].rank(method="first"),
    q=3,
    labels=["low prior satisfaction", "medium prior satisfaction", "high prior satisfaction"],
)
heterogeneity_base["prior_discovery_group"] = pd.qcut(
    heterogeneity_base["prior_3day_discovery_candidate_share"].rank(method="first"),
    q=3,
    labels=["low prior discovery", "medium prior discovery", "high prior discovery"],
)

lightgbm_row_effects = row_level_effects.query(
    "model_family == 'lightgbm' and outcome == @PRIMARY_OUTCOME"
).merge(
    heterogeneity_base,
    on=["user_id", "event_date"],
    how="left",
)

heterogeneity_rows = []
for segment_col in ["prior_activity_group", "prior_satisfaction_group", "prior_discovery_group", "user_active_degree"]:
    summary = (
        lightgbm_row_effects.groupby(segment_col, observed=True)
        .agg(
            user_days=("user_id", "size"),
            users=("user_id", "nunique"),
            total_effect=("ite_total", "mean"),
            direct_effect=("ite_direct", "mean"),
            indirect_effect=("ite_indirect", "mean"),
            mediator_shift=("m1_hat", "mean"),
        )
        .reset_index()
        .rename(columns={segment_col: "segment_value"})
    )
    summary["segment"] = segment_col
    heterogeneity_rows.append(summary)

heterogeneity_effects = pd.concat(heterogeneity_rows, ignore_index=True)
# Convert mediator_shift from average M1 to the segment-level treatment-induced shift.
for segment_col in ["prior_activity_group", "prior_satisfaction_group", "prior_discovery_group", "user_active_degree"]:
    mask = heterogeneity_effects["segment"].eq(segment_col)
    segment_shift = (
        lightgbm_row_effects.groupby(segment_col, observed=True)
        .apply(lambda g: (g["m1_hat"] - g["m0_hat"]).mean(), include_groups=False)
        .reset_index(name="true_mediator_shift")
        .rename(columns={segment_col: "segment_value"})
    )
    heterogeneity_effects.loc[mask, "mediator_shift"] = heterogeneity_effects.loc[mask].merge(
        segment_shift, on="segment_value", how="left"
    )["true_mediator_shift"].to_numpy()

display(heterogeneity_effects.round(4))
segment_value user_days users total_effect direct_effect indirect_effect mediator_shift segment
0 low prior activity 2733 133 2.1266 2.3286 -0.2021 0.0195 prior_activity_group
1 medium prior activity 2733 133 3.9690 4.0530 -0.0841 0.0177 prior_activity_group
2 high prior activity 2733 133 3.5253 3.4140 0.1112 0.0180 prior_activity_group
3 low prior satisfaction 2733 133 2.9903 2.8696 0.1208 0.0165 prior_satisfaction_group
4 medium prior satisfaction 2733 130 3.5907 3.7333 -0.1426 0.0158 prior_satisfaction_group
5 high prior satisfaction 2733 113 3.0398 3.1928 -0.1530 0.0229 prior_satisfaction_group
6 low prior discovery 2733 133 2.0592 2.1656 -0.1064 0.0184 prior_discovery_group
7 medium prior discovery 2733 133 3.3925 3.3729 0.0196 0.0166 prior_discovery_group
8 high prior discovery 2733 133 4.1691 4.2572 -0.0881 0.0201 prior_discovery_group
9 UNKNOWN 184 3 4.4181 4.5911 -0.1731 0.0241 user_active_degree
10 full_active 7133 115 3.1915 3.2485 -0.0570 0.0181 user_active_degree
11 high_active 882 15 3.0788 3.1241 -0.0453 0.0195 user_active_degree

The heterogeneity table asks where the discovery effect is largest. Because these are subgroup summaries of cross-fitted predictions, they are descriptive signals for product targeting rather than separate randomized subgroup effects.

14. Plot Heterogeneous Effects

This cell plots LightGBM total effects by segment. It focuses on total effects because those are the most stable quantities in the earlier robustness checks.

heterogeneity_plot = heterogeneity_effects.copy()
heterogeneity_plot["segment_label"] = heterogeneity_plot["segment"].str.replace("_", " ").str.title()
heterogeneity_plot["segment_value"] = heterogeneity_plot["segment_value"].astype(str)

fig, axes = plt.subplots(2, 2, figsize=(15, 10))
axes = axes.flatten()
for ax, segment in zip(axes, heterogeneity_plot["segment"].unique()):
    current = heterogeneity_plot.query("segment == @segment").copy()
    sns.barplot(data=current, x="total_effect", y="segment_value", color="steelblue", ax=ax)
    ax.axvline(0, color="black", linewidth=1)
    ax.set_title(segment.replace("_", " ").title())
    ax.set_xlabel("LightGBM total effect on future interactions")
    ax.set_ylabel("")
plt.tight_layout()
fig.savefig(FIGURE_DIR / "23_lightgbm_heterogeneous_effects.png", dpi=160, bbox_inches="tight")
plt.show()

The heterogeneity plot makes the advanced notebook more product-relevant. It shows where the model expects discovery exposure to matter more or less.

15. Train Full Tree Models for Feature Importance

This cell trains full-sample LightGBM and XGBoost nuisance models for feature importance only. These models are not used for effect estimates; the effect estimates above are cross-fitted. The goal here is model auditability.

feature_importance_rows = []
full_design_m = treatment_design(X_covariates, A)
full_design_y = outcome_design(X_covariates, A, M)

importance_model_specs = [
    ("lightgbm", LGBMRegressor(
        n_estimators=120,
        learning_rate=0.05,
        num_leaves=24,
        subsample=0.85,
        colsample_bytree=0.85,
        random_state=RANDOM_SEED,
        n_jobs=1,
        verbose=-1,
    )),
    ("xgboost", XGBRegressor(
        n_estimators=120,
        max_depth=3,
        learning_rate=0.05,
        subsample=0.85,
        colsample_bytree=0.85,
        objective="reg:squarederror",
        tree_method="hist",
        random_state=RANDOM_SEED,
        n_jobs=1,
        verbosity=0,
    )),
]

for model_family, model in importance_model_specs:
    mediator_model = clone(model)
    outcome_model = clone(model)
    fit_with_optional_weight(mediator_model, full_design_m, M, sample_weight=weights)
    fit_with_optional_weight(outcome_model, full_design_y, Y, sample_weight=weights)

    for target_name, fitted_model, design in [
        ("mediator", mediator_model, full_design_m),
        ("outcome", outcome_model, full_design_y),
    ]:
        importances = getattr(fitted_model, "feature_importances_", None)
        if importances is None:
            continue
        current = pd.DataFrame(
            {
                "model_family": model_family,
                "target": target_name,
                "feature": design.columns,
                "importance": importances,
            }
        )
        current["importance_share"] = current["importance"] / current["importance"].sum()
        feature_importance_rows.append(current)

feature_importance = pd.concat(feature_importance_rows, ignore_index=True)
top_feature_importance = (
    feature_importance.sort_values(["model_family", "target", "importance"], ascending=[True, True, False])
    .groupby(["model_family", "target"], group_keys=False)
    .head(12)
    .reset_index(drop=True)
)

display(top_feature_importance.round(4))
model_family target feature importance importance_share
0 lightgbm mediator calendar_day_index 297.0000 0.1076
1 lightgbm mediator prior_3day_high_satisfaction_share 260.0000 0.0942
2 lightgbm mediator lag_1_high_satisfaction_share 249.0000 0.0902
3 lightgbm mediator prior_3day_discovery_candidate_share 187.0000 0.0678
4 lightgbm mediator onehot_feat8 147.0000 0.0533
5 lightgbm mediator lag_1_valid_play_share 137.0000 0.0496
6 lightgbm mediator prior_3day_valid_play_share 122.0000 0.0442
7 lightgbm mediator onehot_feat3 118.0000 0.0428
8 lightgbm mediator register_days 111.0000 0.0402
9 lightgbm mediator prior_3day_total_play_duration_sec 108.0000 0.0391
10 lightgbm mediator lag_1_total_play_duration_sec 103.0000 0.0373
11 lightgbm mediator A_high_discovery 101.0000 0.0366
12 lightgbm outcome calendar_day_index 592.0000 0.2145
13 lightgbm outcome onehot_feat8 187.0000 0.0678
14 lightgbm outcome prior_3day_discovery_candidate_share 180.0000 0.0652
15 lightgbm outcome prior_3day_interactions 157.0000 0.0569
16 lightgbm outcome register_days 148.0000 0.0536
17 lightgbm outcome onehot_feat7 126.0000 0.0457
18 lightgbm outcome prior_3day_high_satisfaction_share 121.0000 0.0438
19 lightgbm outcome onehot_feat3 120.0000 0.0435
20 lightgbm outcome follow_user_num 108.0000 0.0391
21 lightgbm outcome prior_3day_total_play_duration_sec 105.0000 0.0380
22 lightgbm outcome prior_3day_valid_play_share 103.0000 0.0373
23 lightgbm outcome fans_user_num 87.0000 0.0315
24 xgboost mediator prior_3day_high_satisfaction_share 0.1697 0.1697
25 xgboost mediator lag_1_high_satisfaction_share 0.0686 0.0686
26 xgboost mediator lag_1_valid_play_share 0.0570 0.0570
27 xgboost mediator prior_3day_valid_play_share 0.0436 0.0436
28 xgboost mediator lag_1_interactions 0.0432 0.0432
29 xgboost mediator onehot_feat11 0.0416 0.0416
30 xgboost mediator fans_user_num_range_1_10 0.0381 0.0381
31 xgboost mediator user_active_degree_high_active 0.0330 0.0330
32 xgboost mediator prior_3day_active_day 0.0312 0.0312
33 xgboost mediator onehot_feat6 0.0262 0.0262
34 xgboost mediator lag_1_total_play_duration_sec 0.0261 0.0261
35 xgboost mediator A_high_discovery 0.0227 0.0227
36 xgboost outcome calendar_day_index 0.3093 0.3093
37 xgboost outcome prior_3day_active_day 0.1905 0.1905
38 xgboost outcome prior_3day_discovery_candidate_share 0.1517 0.1517
39 xgboost outcome prior_3day_interactions 0.0373 0.0373
40 xgboost outcome lag_1_interactions 0.0202 0.0202
41 xgboost outcome lag_1_discovery_candidate_share 0.0190 0.0190
42 xgboost outcome recent_activity_score 0.0169 0.0169
43 xgboost outcome user_active_degree_high_active 0.0161 0.0161
44 xgboost outcome prior_3day_total_play_duration_sec 0.0145 0.0145
45 xgboost outcome register_days_range_61_90 0.0141 0.0141
46 xgboost outcome fans_user_num 0.0136 0.0136
47 xgboost outcome onehot_feat14 0.0117 0.0117

The feature-importance audit helps reveal whether models are mostly using treatment, mediator, calendar time, prior activity, or profile signals. That matters because an advanced model can look impressive while leaning heavily on confounding proxies.

16. Plot Feature Importance

This cell plots the top features for the LightGBM nuisance models. LightGBM is used here because its importances are usually stable and easy to read.

importance_plot = top_feature_importance.query("model_family == 'lightgbm'").copy()
importance_plot["feature_label"] = importance_plot["feature"].str.replace("_", " ").str.slice(0, 45)
importance_plot["target_label"] = importance_plot["target"].str.title()

fig, axes = plt.subplots(1, 2, figsize=(16, 6))
for ax, target in zip(axes, ["mediator", "outcome"]):
    current = importance_plot.query("target == @target").sort_values("importance", ascending=True)
    sns.barplot(data=current, x="importance", y="feature_label", color="steelblue", ax=ax)
    ax.set_title(f"LightGBM {target.title()} Model Features")
    ax.set_xlabel("Feature importance")
    ax.set_ylabel("")
plt.tight_layout()
fig.savefig(FIGURE_DIR / "24_lightgbm_feature_importance.png", dpi=160, bbox_inches="tight")
plt.show()

The feature plot is an audit layer, not an explanation of causality. It tells us what the prediction functions used, which is useful for debugging and for communicating model behavior.

17. Advanced Model Summary

This cell combines the linear g-computation result from notebook 04, SEM-style paths, and cross-fitted ML estimates into one comparison table. The table is the main handoff to the final report.

linear_reference = effect_summary.query(
    "outcome == @PRIMARY_OUTCOME and estimand in ['gcomp_total_effect', 'natural_direct_effect', 'natural_indirect_effect']"
).pivot_table(index="outcome", columns="estimand", values="estimate").reset_index()
linear_reference["model_family"] = "linear_reference_from_notebook_04"
linear_reference = linear_reference.rename(
    columns={
        "gcomp_total_effect": "total_effect",
        "natural_direct_effect": "direct_effect",
        "natural_indirect_effect": "indirect_effect",
    }
)[["model_family", "total_effect", "direct_effect", "indirect_effect"]]

sem_reference = pd.DataFrame(
    [
        {
            "model_family": "sem_style_path_model",
            "total_effect": sem_path_results.loc[0, "total_path_c_prime_plus_indirect"],
            "direct_effect": sem_path_results.loc[0, "c_prime_A_to_Y"],
            "indirect_effect": sem_path_results.loc[0, "indirect_alpha_beta"],
        }
    ]
)

ml_reference = ml_effects.query("outcome == @PRIMARY_OUTCOME").rename(
    columns={
        "gcomp_total_effect": "total_effect",
        "natural_direct_effect": "direct_effect",
        "natural_indirect_effect": "indirect_effect",
    }
)[["model_family", "total_effect", "direct_effect", "indirect_effect", "mediator_shift", "relative_total_effect"]]

advanced_model_summary = pd.concat(
    [linear_reference, sem_reference, ml_reference],
    ignore_index=True,
    sort=False,
)
advanced_model_summary["indirect_to_total_ratio"] = (
    advanced_model_summary["indirect_effect"] / advanced_model_summary["total_effect"].replace(0, np.nan)
)
advanced_model_summary["total_effect_positive"] = advanced_model_summary["total_effect"] > 0
advanced_model_summary["indirect_effect_small_abs_lt_5"] = advanced_model_summary["indirect_effect"].abs() < 5

display(advanced_model_summary.round(4))
model_family total_effect direct_effect indirect_effect mediator_shift relative_total_effect indirect_to_total_ratio total_effect_positive indirect_effect_small_abs_lt_5
0 linear_reference_from_notebook_04 37.2781 38.5858 -1.3077 NaN NaN -0.0351 True True
1 sem_style_path_model 37.3139 38.4311 -1.1172 NaN NaN -0.0299 True True
2 linear 37.3421 38.6464 -1.3043 0.0231 0.1156 -0.0349 True True
3 lightgbm 3.2069 3.2652 -0.0583 0.0184 0.0095 -0.0182 True True
4 xgboost 2.6829 2.7279 -0.0450 0.0138 0.0079 -0.0168 True True

The comparison table answers the advanced-model question directly. If the total effect is positive across model families and the indirect effect remains small, the notebook 04 result is strengthened.

18. Plot Advanced Model Summary

This cell plots total, direct, and indirect effects across the reference, SEM-style, and cross-fitted ML approaches.

summary_plot = advanced_model_summary.melt(
    id_vars=["model_family"],
    value_vars=["total_effect", "direct_effect", "indirect_effect"],
    var_name="effect_type",
    value_name="estimate",
)
summary_plot["effect_label"] = summary_plot["effect_type"].map(
    {
        "total_effect": "Total",
        "direct_effect": "Direct",
        "indirect_effect": "Indirect",
    }
)
summary_plot["model_label"] = summary_plot["model_family"].str.replace("_", " ").str.title()

fig, ax = plt.subplots(figsize=(13, 7))
sns.barplot(data=summary_plot, x="estimate", y="model_label", hue="effect_label", ax=ax)
ax.axvline(0, color="black", linewidth=1)
ax.set_title("Advanced Model Comparison for Discovery Mediation")
ax.set_xlabel("Effect on future 7-day interactions")
ax.set_ylabel("Model approach")
plt.tight_layout()
fig.savefig(FIGURE_DIR / "25_advanced_model_summary.png", dpi=160, bbox_inches="tight")
plt.show()

This final figure is the high-level advanced-model takeaway. It lets a reader see whether the main story depends on the modeling approach.

19. Save Advanced Model Artifacts

This cell saves the SEM path results, bootstrap path samples, ML effect estimates, model performance, heterogeneity table, feature importance table, and advanced summary table.

sem_path_summary.to_csv(SEM_PATH_OUTPUT, index=False)
sem_bootstrap.to_csv(SEM_BOOTSTRAP_OUTPUT, index=False)
ml_effects.to_csv(ML_EFFECTS_OUTPUT, index=False)
ml_performance.to_csv(ML_PERFORMANCE_OUTPUT, index=False)
heterogeneity_effects.to_csv(HETEROGENEITY_OUTPUT, index=False)
feature_importance.to_csv(FEATURE_IMPORTANCE_OUTPUT, index=False)
advanced_model_summary.to_csv(ADVANCED_SUMMARY_OUTPUT, index=False)

sem_path_summary.to_csv(TABLE_DIR / "advanced_sem_path_results.csv", index=False)
ml_effects.to_csv(TABLE_DIR / "advanced_ml_effects.csv", index=False)
ml_performance.to_csv(TABLE_DIR / "advanced_ml_performance.csv", index=False)
heterogeneity_effects.to_csv(TABLE_DIR / "advanced_heterogeneity.csv", index=False)
top_feature_importance.to_csv(TABLE_DIR / "advanced_feature_importance_top.csv", index=False)
advanced_model_summary.to_csv(TABLE_DIR / "advanced_model_summary.csv", index=False)

saved_outputs = pd.DataFrame(
    {
        "artifact": [
            "sem_path_results",
            "sem_bootstrap",
            "ml_effects",
            "ml_performance",
            "heterogeneity_effects",
            "feature_importance",
            "advanced_model_summary",
        ],
        "path": [
            str(SEM_PATH_OUTPUT),
            str(SEM_BOOTSTRAP_OUTPUT),
            str(ML_EFFECTS_OUTPUT),
            str(ML_PERFORMANCE_OUTPUT),
            str(HETEROGENEITY_OUTPUT),
            str(FEATURE_IMPORTANCE_OUTPUT),
            str(ADVANCED_SUMMARY_OUTPUT),
        ],
    }
)

display(saved_outputs)
artifact path
0 sem_path_results /home/apex/Documents/ranking_sys/data/processed/kuairec_discovery_quality_sem_path_results.csv
1 sem_bootstrap /home/apex/Documents/ranking_sys/data/processed/kuairec_discovery_quality_sem_bootstrap.csv
2 ml_effects /home/apex/Documents/ranking_sys/data/processed/kuairec_discovery_quality_advanced_ml_effects.csv
3 ml_performance /home/apex/Documents/ranking_sys/data/processed/kuairec_discovery_quality_advanced_ml_performance.csv
4 heterogeneity_effects /home/apex/Documents/ranking_sys/data/processed/kuairec_discovery_quality_advanced_heterogeneity.csv
5 feature_importance /home/apex/Documents/ranking_sys/data/processed/kuairec_discovery_quality_advanced_feature_importance.csv
6 advanced_model_summary /home/apex/Documents/ranking_sys/data/processed/kuairec_discovery_quality_advanced_model_summary.csv

The saved tables give the final report enough material to compare simple and advanced estimators without rerunning the models.

20. Notebook Takeaways

This notebook added advanced model checks to the discovery-quality mediation workflow:

  • SEM-style path modeling gives a compact path-coefficient version of the mediation story.
  • Cross-fitted Linear, LightGBM, and XGBoost nuisance models test whether the decomposition survives nonlinear prediction functions.
  • Heterogeneity summaries show where discovery exposure appears more or less valuable.
  • Feature-importance tables audit which variables the tree-based nuisance models rely on.
  • The final comparison table tells the final report whether the main result is stable across modeling approaches.

The next notebook should be the final report and figures notebook, using the simple, robustness, and advanced-model outputs together.