ML Algorithms
Supervised Learning
06 / 13

Decision Trees

A flowchart-like model that recursively splits data on feature values to make predictions — highly interpretable and the building block for ensembles.

Intuition

A decision tree asks a sequence of yes/no questions about the input features. Each internal node tests one feature against a threshold, each branch is an answer, and each leaf holds a prediction (a class label for classification, a numeric value for regression). To predict, you start at the root and walk down until you hit a leaf.

Trees are non-parametric (no fixed functional form), non-linear, and require almost no preprocessing — no scaling, no one-hot encoding for ordinal splits, and they handle mixed feature types natively.

How a Tree is Built (Recursive Greedy Splitting)

  • Start with all training samples at the root.
  • For every feature and every candidate threshold, compute the impurity reduction of splitting there.
  • Pick the (feature, threshold) pair with the largest reduction. Split the node.
  • Recurse on each child until a stopping rule fires (max depth, min samples, pure node, no useful split).
  • Assign each leaf a prediction: majority class (classification) or mean target (regression).
Greedy, not optimal
Finding the globally optimal tree is NP-hard. CART and friends use a greedylocal-best split at each node, which is fast but can miss better global structures. This is partly why ensembles like Random Forest and Gradient Boosting outperform single trees.

Splitting Criteria — Classification

Gini Impurity

Gini(t) = 1 − Σᵢ pᵢ²

Probability that a randomly drawn sample from node t would be misclassified if labeled by the node's class distribution. Ranges from 0 (pure) to 1 − 1/K (uniform across K classes). Default in scikit-learn — slightly faster than entropy because it avoids the log.

Entropy & Information Gain

Entropy(t) = − Σᵢ pᵢ log₂(pᵢ)
IG = Entropy(parent) − Σ (Nₖ / N) · Entropy(childₖ)

Entropy measures disorder in bits. Information Gain is the expected reduction in entropy after the split, weighted by child sizes. Used by ID3 and C4.5.

Gain Ratio (C4.5)

GainRatio = IG / SplitInfo, SplitInfo = − Σ (Nₖ/N) log₂(Nₖ/N)

Penalizes splits with many small branches (e.g. splitting on an ID column) — fixes a known bias of plain Information Gain.

Splitting Criteria — Regression

MSE(t) = (1/N) Σ (yᵢ − ȳ_t)²

Regression trees minimize variance within child nodes. The split that maximally reduces weighted MSE (or MAE / Friedman MSE) is chosen. Each leaf predicts the mean of its samples.

CART vs ID3 vs C4.5

  • ID3 — Information Gain, categorical features only, no pruning.
  • C4.5 — Gain Ratio, handles continuous features and missing values, post-pruning.
  • CART — Gini (or MSE for regression), strictly binary splits, cost-complexity pruning. This is what scikit-learn implements.

Stopping Criteria & Pre-Pruning

  • max_depth — hard cap on tree depth.
  • min_samples_split — minimum samples required to consider splitting a node.
  • min_samples_leaf — minimum samples that must remain in each leaf.
  • min_impurity_decrease — only split if impurity drops by at least this amount.
  • max_leaf_nodes — best-first growth capped at K leaves.
  • max_features — random subset of features per split (used by Random Forest).

Post-Pruning: Cost-Complexity (CCP)

R_α(T) = R(T) + α · |leaves(T)|

Grow a deep tree, then collapse subtrees that don't justify their complexity. Largerα ⇒ smaller tree. Choose α via cross-validation (ccp_alpha in sklearn).

from sklearn.tree import DecisionTreeClassifier

clf = DecisionTreeClassifier(random_state=0)
path = clf.cost_complexity_pruning_path(X_train, y_train)
# path.ccp_alphas, path.impurities — sweep alphas, pick best by CV

Python Implementation

from sklearn.tree import DecisionTreeClassifier, plot_tree, export_text
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

X, y = load_iris(return_X_y=True)
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.2, random_state=42)

tree = DecisionTreeClassifier(
    criterion="gini",
    max_depth=3,
    min_samples_leaf=5,
    ccp_alpha=0.01,
    random_state=42,
)
tree.fit(X_tr, y_tr)
print("Train:", tree.score(X_tr, y_tr), "Test:", tree.score(X_te, y_te))

# Human-readable rules
print(export_text(tree, feature_names=load_iris().feature_names))

# Visualize
plt.figure(figsize=(12, 6))
plot_tree(tree, filled=True, feature_names=load_iris().feature_names)
plt.show()

Regression Tree Example

from sklearn.tree import DecisionTreeRegressor
import numpy as np

rng = np.random.RandomState(0)
X = np.sort(5 * rng.rand(80, 1), axis=0)
y = np.sin(X).ravel() + 0.1 * rng.randn(80)

reg = DecisionTreeRegressor(max_depth=4).fit(X, y)
# Predictions are piecewise-constant: each leaf outputs the mean y of its samples

Feature Importance

sklearn computes Mean Decrease in Impurity (MDI): for each feature, sum the impurity reduction over all nodes that split on it, weighted by samples reaching the node. Beware: MDI is biased toward high-cardinality features. Prefer permutation importance for honest rankings.

from sklearn.inspection import permutation_importance
r = permutation_importance(tree, X_te, y_te, n_repeats=20, random_state=0)
for i in r.importances_mean.argsort()[::-1]:
    print(f"{load_iris().feature_names[i]:25s} {r.importances_mean[i]:.3f}")

Handling Categorical Features & Missing Values

  • Sklearn's tree only does numeric, binary splits — encode categoricals as integers (ordinal) for tree-based models; one-hot is unnecessary and often harmful (deep, narrow splits).
  • For missing values, sklearn ≥ 1.3 supports missing values natively in HistGradientBoosting and trees by routing NaN samples to the side that minimizes impurity.
  • C4.5-style surrogate splits are not in sklearn but exist in R's rpart.

Strengths & Weaknesses

Strengths

  • Easy to interpret and visualize — you can literally read the rules.
  • No feature scaling required; handles non-linear interactions natively.
  • Robust to outliers in features (splits care only about ordering).
  • Fast inference: O(depth) per prediction.

Weaknesses

  • High variance — small data changes can produce very different trees.
  • Axis-aligned splits struggle with diagonal decision boundaries.
  • Greedy training can miss globally better trees.
  • Predictions are piecewise-constant — regression trees can't extrapolate beyond training range.
  • Biased toward features with many possible split points.

From Single Tree to Ensembles

  • Bagging / Random Forest — average many trees on bootstrapped samples with random feature subsets to cut variance.
  • Boosting (GBM, XGBoost, LightGBM, CatBoost) — sequentially fit shallow trees on residuals to cut bias.
  • Extra Trees — randomize split thresholds for even more variance reduction.
Practical defaults
Start with max_depth=4–8, min_samples_leaf=10–50, and tune ccp_alpha by CV. If accuracy matters more than interpretability, skip straight to a Random Forest or gradient-boosted trees.