Skip to content

XGBoost Algorithm in Machine Learning with example- Working Code

XGBoost Algorithm in Machine Learning with example- Working Code

 

Financial Distress Prediction with XGBoost Algorithm, SHAP, and LIME: A Python Guide

Predicting financial distress is a critical task for businesses, investors, and analysts. Leveraging machine learning, especially advanced algorithms like XGBoost, can significantly enhance the accuracy and interpretability of such predictions. In this article, we explore a robust Python pipeline for financial distress prediction, complete with model training, evaluation, and explainability using SHAP and LIME.

Overview of the Pipeline

The provided Python script is designed to train and evaluate models that predict financial distress events using tabular data. It supports both logistic regression and XGBoost classifiers, ensuring a strong baseline and a powerful gradient boosting model. The script also integrates SHAP (SHapley Additive exPlanations) and LIME (Local Interpretable Model-agnostic Explanations) for model interpretability, making it ideal for regulated industries where explainability is crucial.

Key Features of Xgboost Algorithm Pipeline implementation

  1. Flexible Data Input:** The script accepts a CSV file containing entity-period rows, with customizable columns for entity ID, time, and target variable.
  2. Temporal Train-Test Split:** It uses a temporal split to ensure that the model is evaluated on future data, mimicking real-world forecasting scenarios.
  3. Class Imbalance Handling:** The XGBoost model automatically adjusts for class imbalance using the `scale_pos_weight` parameter.
  4. Comprehensive Evaluation:** The script generates ROC, Precision-Recall, and calibration curves, along with confusion matrices and classification reports.
  5. Model Explainability:** SHAP and LIME are used to provide both global and local explanations for model predictions.

How the Code Works

1. Argument Parsing: The script uses `argparse` to accept command-line arguments for data path, entity column, time column, target column, and test size.
2. Data Loading and Preprocessing:** Data is loaded and preprocessed using custom utility functions. Numeric features are selected, and preprocessing pipelines are built for consistent data transformation.
3. Model Training:** Both logistic regression and XGBoost models are trained. XGBoost is configured with early stopping and hyperparameters optimized for tabular data.
4. Evaluation:** The models are evaluated on the test set, and various performance metrics and plots are saved to the `artifacts` directory.
5. Explainability:** SHAP is used to generate global feature importance and dependence plots, while LIME provides local explanations for individual predictions.
6. Artifact Saving:** All models, preprocessors, metrics, and explanation plots are saved for future reference or deployment.

Why Use XGBoost for Financial Distress Prediction?

XGBoost is renowned for its performance on structured data and its ability to handle missing values, outliers, and class imbalance. Its tree-based approach captures complex nonlinear relationships, making it ideal for financial datasets where such patterns are common.

Enhancing Trust with SHAP and LIME

Model explainability is essential in finance. SHAP provides a global view of feature importance and how each feature impacts predictions, while LIME offers case-by-case explanations. This dual approach ensures that stakeholders can trust and understand the model’s decisions.

To use this script, simply run:

python train.py –data your_data.csv –entity_col company_id –time_col date –target distress

Ensure you have the required Python packages installed: `xgboost`, `scikit-learn`, `shap`, `lime`, `matplotlib`, and `joblib`.

This Python pipeline offers a comprehensive solution for financial distress prediction, combining the predictive power of XGBoost with the transparency of SHAP and LIME. By following this approach, organizations can make data-driven decisions with confidence and clarity.

Code


import os, argparse, warnings, json
warnings.filterwarnings("ignore", category=UserWarning)
import numpy as np, pandas as pd, joblib, matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, average_precision_score, brier_score_loss, roc_curve, precision_recall_curve, confusion_matrix, classification_report
from sklearn.calibration import calibration_curve
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from xgboost import XGBClassifier
import shap
from lime.lime_tabular import LimeTabularExplainer


from utils import build_preprocessor, temporal_split, save_json, load_data


def pick_feature_columns(df, entity_col, time_col, target):
    ignore = {entity_col, time_col, target}
    # numeric columns only
    feats = [c for c in df.columns if c not in ignore and pd.api.types.is_numeric_dtype(df[c])]
    return feats


def evaluate_and_plot(y_true, y_prob, outdir, tag="test"):
    os.makedirs(outdir, exist_ok=True)
    # Metrics
    roc = roc_auc_score(y_true, y_prob)
    pr  = average_precision_score(y_true, y_prob)
    brier = brier_score_loss(y_true, y_prob)


    # ROC curve
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr, label=f"ROC AUC={roc:.3f}")
    plt.plot([0,1], [0,1], linestyle="--")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate"); plt.title(f"ROC Curve ({tag})")
    plt.legend(loc="lower right")
    plt.tight_layout(); plt.savefig(os.path.join(outdir, f"roc_curve_{tag}.png")); plt.close()


    # PR curve
    prec, rec, _ = precision_recall_curve(y_true, y_prob)
    plt.figure()
    plt.plot(rec, prec, label=f"AP={pr:.3f}")
    plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"Precision‑Recall Curve ({tag})")
    plt.legend(loc="lower left")
    plt.tight_layout(); plt.savefig(os.path.join(outdir, f"pr_curve_{tag}.png")); plt.close()


    # Calibration curve
    prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=10, strategy="quantile")
    plt.figure()
    plt.plot(prob_pred, prob_true, marker="o")
    plt.plot([0,1],[0,1], linestyle="--")
    plt.xlabel("Mean predicted probability"); plt.ylabel("Fraction of positives")
    plt.title(f"Calibration plot ({tag})")
    plt.tight_layout(); plt.savefig(os.path.join(outdir, f"calibration_curve_{tag}.png")); plt.close()


    # Confusion matrix at threshold chosen by top decile
    thresh = np.quantile(y_prob, 0.9)
    y_pred = (y_prob >= thresh).astype(int)
    cm = confusion_matrix(y_true, y_pred)
    plt.figure()
    plt.imshow(cm, interpolation="nearest")
    plt.title(f"Confusion Matrix ({tag}) @TopDecile")
    plt.colorbar()
    tick_marks = np.arange(2)
    plt.xticks(tick_marks, ["No Distress","Distress"])
    plt.yticks(tick_marks, ["No Distress","Distress"])
    for i in range(2):
        for j in range(2):
            plt.text(j, i, cm[i, j], ha="center", va="center")
    plt.ylabel("True label"); plt.xlabel("Predicted label")
    plt.tight_layout(); plt.savefig(os.path.join(outdir, f"confusion_matrix_{tag}.png")); plt.close()


    # Classification report @0.5
    y_pred_05 = (y_prob >= 0.5).astype(int)
    report = classification_report(y_true, y_pred_05, output_dict=True)
    return {"roc_auc": roc, "average_precision": pr, "brier": brier, "threshold_top_decile": float(thresh), "report_at_0.5": report}


def main(args):
    artifacts_dir = os.path.join("artifacts")
    os.makedirs(artifacts_dir, exist_ok=True)


    df = load_data(args.data, args.entity_col, args.time_col, args.target)


    # Feature columns
    feature_cols = pick_feature_columns(df, args.entity_col, args.time_col, args.target)


    # Temporal split (holdout is the latest chunk of time)
    train_df, test_df, cutoff = temporal_split(df, args.time_col, test_size=args.test_size)
    print(f"[INFO] Temporal cutoff = {cutoff}  | Train rows={len(train_df)}  Test rows={len(test_df)}")


    X_train_raw = train_df[feature_cols].copy()
    y_train = train_df[args.target].astype(int).values


    X_test_raw = test_df[feature_cols].copy()
    y_test = test_df[args.target].astype(int).values


    # Preprocessor
    preprocessor = build_preprocessor(feature_cols)


    # Fit preprocessor on TRAIN only
    X_train = preprocessor.fit_transform(X_train_raw)
    X_test  = preprocessor.transform(X_test_raw)


    # Class imbalance (for XGB, use scale_pos_weight)
    pos = y_train.sum()
    neg = len(y_train) - pos
    spw = max((neg / max(pos, 1)), 1.0)


    # Models: Logistic Regression baseline & XGBoost (primary)
    logreg = LogisticRegression(max_iter=2000, class_weight="balanced", n_jobs=None)
    logreg.fit(X_train, y_train)
    y_test_prob_lr = logreg.predict_proba(X_test)[:,1]


    xgb = XGBClassifier(
        n_estimators=500,
        max_depth=4,
        learning_rate=0.05,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=2.0,
        min_child_weight=1.0,
        objective="binary:logistic",
        eval_metric="aucpr",
        tree_method="hist",
        scale_pos_weight=spw,
        random_state=42
    )
    # Keep a validation split from train for early stopping
    X_tr, X_val, y_tr, y_val = train_test_split(X_train, y_train, test_size=0.2, stratify=y_train, random_state=42)
#    xgb.fit(X_tr, y_tr, eval_set=[(X_val, y_val)], verbose=False, early_stopping_rounds=50)
    xgb.fit(X_tr, y_tr, eval_set=[(X_val, y_val)],  verbose=False)


    y_test_prob_xgb = xgb.predict_proba(X_test)[:,1]


    # Evaluate
    metrics_lr  = evaluate_and_plot(y_test, y_test_prob_lr, artifacts_dir, tag="test_lr")
    metrics_xgb = evaluate_and_plot(y_test, y_test_prob_xgb, artifacts_dir, tag="test_xgb")


    # Save artifacts
    joblib.dump(preprocessor, os.path.join(artifacts_dir, "preprocessor.joblib"))
    joblib.dump(logreg, os.path.join(artifacts_dir, "model_logreg.pkl"))
    joblib.dump(xgb, os.path.join(artifacts_dir, "model.pkl"))
    xgb.save_model(os.path.join(artifacts_dir, "model_xgb.json"))


    with open(os.path.join(artifacts_dir, "metrics.json"), "w") as f:
        json.dump({"logreg": metrics_lr, "xgb": metrics_xgb, "cutoff": str(cutoff), "features": feature_cols}, f, indent=2)


    # -------------------------
    # Global SHAP (TreeExplainer on XGB)
    # -------------------------
    print("[INFO] Computing SHAP values on a sample...")
    X_bg = X_train[:200]  # background for SHAP plots
    explainer = shap.TreeExplainer(xgb)
    # sample to speed up
    sample_idx = np.random.RandomState(42).choice(len(X_test), size=min(200, len(X_test)), replace=False)
    X_sample = X_test[sample_idx]
    shap_values = explainer.shap_values(X_sample)


    # Plot summary
    shap.summary_plot(shap_values, X_sample, feature_names=feature_cols, show=False)
    plt.tight_layout()
    plt.savefig(os.path.join(artifacts_dir, "shap_summary.png"), dpi=144, bbox_inches="tight")
    plt.close()


    # Dependence for the most important feature
    # Compute mean |shap|
    mean_abs = np.mean(np.abs(shap_values), axis=0)
    top_idx = int(np.argmax(mean_abs))
    top_feat = feature_cols[top_idx]
    shap.dependence_plot(top_idx, shap_values, X_sample, feature_names=feature_cols, show=False)
    plt.tight_layout()
    plt.savefig(os.path.join(artifacts_dir, f"shap_dependence_{top_feat}.png"), dpi=144, bbox_inches="tight")
    plt.close()


    # -------------------------
    # LIME local explanations (few examples)
    # -------------------------
    print("[INFO] Generating LIME explanations for a few test rows...")
    lime_explainer = LimeTabularExplainer(
        training_data=X_train,
        feature_names=feature_cols,
        class_names=["NoDistress","Distress"],
        mode="classification",
        discretize_continuous=True,
        sample_around_instance=True,
        random_state=42
    )
    for i, idx in enumerate(sample_idx[:3]):
        exp = lime_explainer.explain_instance(
            X_sample[i],
            xgb.predict_proba,
            num_features=min(10, len(feature_cols))
        )
        out_html = os.path.join(artifacts_dir, f"lime_local_explanation_{i}.html")
        exp.save_to_file(out_html)


    print("[DONE] Artifacts saved in ./artifacts")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--data", type=str, required=True, help="Path to CSV with entity-period rows")
    p.add_argument("--entity_col", type=str, default="company_id")
    p.add_argument("--time_col", type=str, default="date")
    p.add_argument("--target", type=str, default="distress")
    p.add_argument("--test_size", type=float, default=0.2)
    args = p.parse_args()
    main(args)

Leave a Reply

Your email address will not be published. Required fields are marked *