software-engineering-and-programming
Building Decision Trees from Scratch: a Beginner’s Coding Tutorial
Table of Contents
Decision trees are one of the most intuitive and widely used machine learning algorithms for both classification and regression. They work by splitting data into branches based on feature values, mimicking the way humans make decisions. While libraries like scikit‑learn make building decision trees trivial, implementing one from scratch is an excellent way for beginners to grasp the algorithm’s inner workings. This tutorial will guide you through the theory and code, so you can build your own decision tree from the ground up.
What Is a Decision Tree?
A decision tree is a flowchart‑like structure where each internal node represents a test on a feature (e.g., “Is age > 30?”), each branch represents the outcome of that test, and each leaf node holds a class label or continuous value. The goal is to create a model that predicts a target variable by learning simple decision rules inferred from the data features. Decision trees are popular because they are easy to interpret and require little data preprocessing (no scaling or normalisation).
The tree is built recursively: starting from the root, the algorithm selects the best feature and split point that separates the data most cleanly. This process is repeated on each subset until a stopping condition is met. For more background, Wikipedia’s entry on decision tree learning provides a solid overview.
Core Concepts You Must Understand
Nodes, Branches, and Leaves
The root node contains the entire training dataset. Internal nodes test a feature and split the data into two or more child nodes. Branches are the connections that represent the outcome of a test. Leaf nodes (terminal nodes) output the final prediction – the most common class in classification or the mean value in regression.
Splitting Criteria
To build a tree, you need a way to measure the quality of a potential split. The most common criteria are:
- Gini impurity – used in classification to measure how often a randomly chosen element would be incorrectly labelled if it were randomly labelled according to the distribution of classes in the subset. Lower Gini is better.
- Entropy – measures the amount of disorder or uncertainty in a set. The goal is to minimise entropy after the split (information gain).
- Variance reduction – used for regression trees. It calculates the reduction in variance (or mean squared error) achieved by the split.
The algorithm evaluates every possible split on every feature and picks the one that yields the greatest reduction in impurity (or gain in information).
Information Gain and Gain Ratio
Information gain is the difference between the impurity of the parent node and the weighted sum of child impurities. While simple, it tends to favour features with many values. The gain ratio (used in C4.5) normalises this. For this tutorial we will stick with standard information gain using Gini impurity, which is the default in CART (Classification and Regression Trees).
Building a Decision Tree Step by Step
1. Prepare Your Data
You need a dataset with features and target labels. For simplicity, use a binary classification dataset with numeric features. For example:
- Features: Age, Income
- Target: Approved (1) or Not Approved (0)
Clean the data: handle missing values, remove duplicates, and ensure numeric types. Decision trees can handle mixed data types but we’ll stick to numeric for the implementation.
2. Define a Splitting Criterion Function
We will implement Gini impurity. The Gini index for a set of items is:
Gini = 1 – Σ (p_i)^2
where p_i is the proportion of items in class i. For a binary split, the overall Gini is the weighted average of the child nodes.
3. Implement the Split Evaluation
For each feature, sort the unique values. Test every possible threshold (midpoint between consecutive sorted values). For each candidate threshold, split the data into left and right groups, compute the Gini, and track the best split.
4. Build the Tree Recursively
Create a function that takes a subset of data and a current depth. It checks stopping conditions (e.g., maximum depth reached, minimum samples per node, or no information gain). If a condition is met, create a leaf node with the majority class. Otherwise, find the best split and create an internal node, then recursively call the function on the left and right splits.
5. Make Predictions
Once the tree is built, prediction is straightforward: start at the root, follow the branches by evaluating the feature tests on the new sample, and return the value of the leaf you land on.
Full Implementation in Python
Below is a complete, minimal implementation of a decision tree for classification using Gini impurity. This code is meant for learning – it is not optimised for large datasets.
import numpy as np
from collections import Counter
class DecisionTree:
def __init__(self, max_depth=None, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.tree = None
def fit(self, X, y):
dataset = np.column_stack((X, y))
self.tree = self._grow_tree(dataset)
def _grow_tree(self, dataset, depth=0):
X, y = dataset[:, :-1], dataset[:, -1]
n_samples, n_features = X.shape
n_labels = len(np.unique(y))
# Stopping conditions
if (n_labels == 1 or depth == self.max_depth or n_samples < self.min_samples_split):
leaf_value = Counter(y).most_common(1)[0][0]
return {'leaf': True, 'value': leaf_value}
best_feature, best_threshold = self._best_split(dataset, n_features)
if best_feature is None:
leaf_value = Counter(y).most_common(1)[0][0]
return {'leaf': True, 'value': leaf_value}
left_idx, right_idx = self._split(dataset[:, best_feature], best_threshold)
left_subtree = self._grow_tree(dataset[left_idx], depth+1)
right_subtree = self._grow_tree(dataset[right_idx], depth+1)
return {'leaf': False,
'feature': best_feature,
'threshold': best_threshold,
'left': left_subtree,
'right': right_subtree}
def _best_split(self, dataset, n_features):
best_gini = float('inf')
best_feature, best_threshold = None, None
for feature in range(n_features):
thresholds = np.unique(dataset[:, feature])
for i in range(len(thresholds)-1):
thresh = (thresholds[i] + thresholds[i+1]) / 2
left_idx, right_idx = self._split(dataset[:, feature], thresh)
if len(left_idx) == 0 or len(right_idx) == 0:
continue
gini = self._gini_gain(dataset, left_idx, right_idx)
if gini < best_gini:
best_gini = gini
best_feature = feature
best_threshold = thresh
return best_feature, best_threshold
def _split(self, values, threshold):
left_idx = np.where(values <= threshold)[0]
right_idx = np.where(values > threshold)[0]
return left_idx, right_idx
def _gini_gain(self, dataset, left_idx, right_idx):
total = len(left_idx) + len(right_idx)
gini_left = self._gini(dataset[left_idx, -1])
gini_right = self._gini(dataset[right_idx, -1])
return (len(left_idx)/total) * gini_left + (len(right_idx)/total) * gini_right
def _gini(self, labels):
_, counts = np.unique(labels, return_counts=True)
p = counts / np.sum(counts)
return 1 - np.sum(p**2)
def predict(self, X):
return np.array([self._predict_row(x, self.tree) for x in X])
def _predict_row(self, x, node):
if node['leaf']:
return node['value']
if x[node['feature']] <= node['threshold']:
return self._predict_row(x, node['left'])
else:
return self._predict_row(x, node['right'])
Testing the Tree
Use a simple dataset like the classic iris dataset (two features for binary classification). The scikit‑learn Iris dataset works well. Compare your tree’s accuracy with scikit‑learn’s DecisionTreeClassifier to verify correctness.
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
data = load_iris()
X = data.data[:100] # take only first two classes (binary)
y = data.target[:100]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
tree = DecisionTree(max_depth=3)
tree.fit(X_train, y_train)
preds = tree.predict(X_test)
accuracy = np.mean(preds == y_test)
print(f'Accuracy: {accuracy:.2f}')
Advanced Techniques to Improve Your Tree
Pruning to Avoid Overfitting
A fully grown tree can memorise noise in the training data. Pruning removes branches that have little predictive power. Common methods are pre‑pruning (stopping growth early via max_depth or min_samples_split) and post‑pruning (growing the full tree then removing branches using a validation set or cost‑complexity pruning). Our implementation already supports pre‑pruning.
Handling Continuous and Categorical Features
For continuous features, we used midpoints between sorted values as thresholds. For categorical features (e.g., “Color = red/green/blue”), each category can become a separate branch (multi‑way split) or you can binary‑encode them. Most modern implementations (like scikit‑learn) use binary splits even for categorical features by evaluating all subsets.
Dealing with Missing Values
Real‑world data often has missing values. A simple approach is to assign missing values to the most frequent branch among training samples that have the feature. C4.5 uses a probabilistic method. Since this is a beginner tutorial, we assume the data is complete.
Comparing with Libraries and Further Reading
While building from scratch is educational, production systems use libraries such as scikit‑learn which provide optimised C implementations. You can learn more from the official scikit‑learn decision trees documentation. For deeper theory, the book “The Elements of Statistical Learning” by Hastie, Tibshirani, and Friedman is an authoritative resource. Another excellent reference is the original CART book by Breiman et al.
Conclusion
Building a decision tree from scratch demystifies one of the most fundamental algorithms in machine learning. You have learned how a simple recursive splitting procedure can produce a powerful model. By writing the code yourself, you gain a deeper understanding of impurity measures, split selection, and the trade‑offs between bias and variance. As a next step, try adding regression support, pruning, or handling categorical features. The skills you develop here will serve you well as you move on to more complex ensemble methods like random forests and gradient boosting.