Implementing the ID3 Decision Tree Algorithm from Scratch
Core Concepts
Decision trees come in several variants including CART, ID3, and C4.5. While CART relies on Gini impurity, ID3 and C4.5 both leverage information entropy for splitting criteria. This implementation focuses on the ID3 algorithm.
Information Theory Foundations:
- $p(a_i)$: Probability of event $a_i$ occurring
- $I(a_i) = -\log_2(p(a_i))$: Uncertainty measure for event $a_i$, known as self-information
- $H = \sum(p(a_i) \cdot I(a_i))$: Average informasion content of source S, or Shannon entropy
- $Gain = BaseEntropy - newEntropy$: Information gain from a split
The ID3 algorithm constructs decision trees using a top-down recursive approach. The fundamental principle is building a tree that decreases entropy most rapid, reaching zero entropy at leaf nodes where all instances share the same class label. ID3 selects splits by maximizing information gain. For a binary classification problem with $p$ positive and $n$ negative examples, the initial entropy is computed as:
$$H = -\frac{p}{p+n}\log_2\frac{p}{p+n} - \frac{n}{p+n}\log_2\frac{n}{p+n}$$
where $N$ represents the number of distinct values for a feature (e.g., {rain, sunny} has N=2).
Limitation: ID3 tends to favor attributes with more possible values, especially continuous ones, since they produce higher information gain. This leads to shallow but wide trees with poor generalization. C4.5 addresses this by using a gain ratio, which normalizes the information gain by the split information:
$$GainRatio = \frac{Gain}{SplitInfo}$$
where $SplitInfo = -\sum\frac{|S_v|}{|S|}\log_2\frac{|S_v|}{|S|}$
Algorithm Characteristics
Advantages:
- Low computational complexity
- Interpretable output
- Handles missing values gracefully
- Works with irrelevant features
Disadvantages:
- Prone to overfitting
Applicable Data Types: Numerical and categorical
Python Implementation
Consider this dataset containing 5 marine creatures with two features: whether they can survive without surfacing, and whether they have flippers. The target is binary classification (fish vs. non-fish).
| ID | Survives Without Surfacing | Has Flippers | Fish |
|---|---|---|---|
| 1 | Yes | Yes | Yes |
| 2 | Yes | Yes | Yes |
| 3 | Yes | No | No |
| 4 | No | Yes | No |
| 5 | No | Yes | No |
In array format:
dataSet = [
[1, 1, 'yes'], # can survive without surfacing, has flippers, is fish
[1, 1, 'yes'],
[0, 1, 'no'],
[0, 1, 'no'],
[0, 1, 'no']
]
labels = ['survives_without_air', 'has_flippers']
Computing Shannon Entropy
from math import log
def shannon_entropy(dataset):
"""Calculate the Shannon entropy of a dataset."""
entry_count = len(dataset)
label_frequencies = {}
for record in dataset:
category = record[-1]
label_frequencies[category] = label_frequencies.get(category, 0) + 1
entropy = 0.0
for count in label_frequencies.values():
probability = count / entry_count
entropy -= probability * log(probability, 2)
return entropy
Testing the entropy function:
>>> dataset, feature_names = create_dataset()
>>> shannon_entropy(dataset)
0.9709505944546686
>>> dataset[0][-1] = 'maybe' # Add third class
>>> shannon_entropy(dataset)
1.3709505944546687
Higher entropy indicates more mixed data. Adding more class diversity increases entropy.
Dataset Creation Utility
def create_dataset():
"""Generate the sample marine creature classification dataset."""
data = [[1, 1, 'yes'],
[1, 1, 'yes'],
[0, 1, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
features = ['survives_without_air', 'has_flippers']
return data, features
Splitting Datasets by Feature Value
To evaluate which feature provides the best split, we need to partition the dataset for each possible feature value and compute entropy for each partition.
def partition_by_feature(dataset, feature_index, feature_value):
"""Split dataset based on a specific feature value.
Args:
dataset: Input data to split
feature_index: Which column to split on
feature_value: Value to filter on
Returns:
Filtered dataset excluding the split column
"""
filtered_records = []
for record in dataset:
if record[feature_index] == feature_value:
reduced_record = record[:feature_index] + record[feature_index+1:]
filtered_records.append(reduced_record)
return filtered_records
Usage examples:
>>> dataset, _ = create_dataset()
>>> partition_by_feature(dataset, 0, 1) # Filter by feature 0 = 1
[[1, 'yes'], [1, 'yes']]
>>> partition_by_feature(dataset, 0, 0) # Filter by feature 0 = 0
[[1, 'no'], [1, 'no'], [1, 'no']]
Selecting the Optimal Split Feature
def select_best_feature(dataset):
"""Identify the feature that produces maximum information gain."""
feature_count = len(dataset[0]) - 1 # Exclude label column
base_entropy = shannon_entropy(dataset)
best_gain = 0.0
best_feature_index = -1
for i in range(feature_count):
feature_values = [record[i] for record in dataset]
unique_values = set(feature_values)
split_entropy = 0.0
for value in unique_values:
subset = partition_by_feature(dataset, i, value)
weight = len(subset) / len(dataset)
split_entropy += weight * shannon_entropy(subset)
information_gain = base_entropy - split_entropy
if information_gain > best_gain:
best_gain = information_gain
best_feature_index = i
return best_feature_index
Prerequisites: All rows must have equal length, and the final column must contain class labels.
>>> dataset, _ = create_dataset()
>>> select_best_feature(dataset)
0 # Feature 0 is the best split
Building the Decision Tree
from collections import Counter
def majority_vote(class_labels):
"""Return the most common class label."""
counter = Counter(class_labels)
return counter.most_common(1)[0][0]
def build_tree(dataset, feature_labels):
"""Recursively construct a decision tree.
Uses a dictionary structure to represent the tree:
{feature_name: {feature_value: subtree_or_label, ...}}
"""
class_labels = [record[-1] for record in dataset]
# Stopping condition: all instances have the same class
if len(set(class_labels)) == 1:
return class_labels[0]
# Stopping condition: no more features to split on
if len(dataset[0]) == 1:
return majority_vote(class_labels)
best_idx = select_best_feature(dataset)
best_label = feature_labels[best_idx]
tree = {best_label: {}}
del feature_labels[best_idx] # Remove used feature
feature_values = [record[best_idx] for record in dataset]
unique_values = set(feature_values)
for value in unique_values:
sub_labels = feature_labels[:]
subtree = build_tree(
partition_by_feature(dataset, best_idx, value),
sub_labels
)
tree[best_label][value] = subtree
return tree
>>> dataset, features = create_dataset()
>>> decision_tree = build_tree(dataset, features)
>>> decision_tree
{'survives_without_air': {0: 'no', 1: 'yes'}}
When all features are exhausted but class labels remain impure, majority voting determines the leaf classification.
Classification with the Trained Tree
def predict(tree, feature_labels, test_instance):
"""Classify a new instance using the decision tree."""
root_feature = list(tree.keys())[0]
children = tree[root_feature]
feature_idx = feature_labels.index(root_feature)
for value in children.keys():
if test_instance[feature_idx] == value:
subtree = children[value]
if isinstance(subtree, dict):
return predict(subtree, feature_labels, test_instance)
else:
return subtree
>>> tree, features = build_tree(*create_dataset()), ['survives_without_air', 'has_flippers']
>>> predict(tree, features, [1, 0])
'no'
>>> predict(tree, features, [1, 1])
'yes'
Persisting the Trained Model
Decision tree construction is computationally expensive. Serialization allows reusing a trained model without rebuilding it.
import pickle
def save_tree(tree, filepath):
"""Serialize and save the decision tree to disk."""
with open(filepath, 'wb') as f:
pickle.dump(tree, f)
def load_tree(filepath):
"""Load a previously saved decision tree from disk."""
with open(filepath, 'rb') as f:
return pickle.load(f)
>>> tree, features = build_tree(*create_dataset()), ['survives_without_air', 'has_flippers']
>>> save_tree(tree, 'classifier.pkl')
>>> loaded = load_tree('classifier.pkl')
>>> loaded
{'survives_without_air': {0: 'no', 1: {'has_flippers': {0: 'no', 1: 'yes'}}}}
Visualizing Decision Trees
Using matplotlib for tree visualization helps understand the learned decision boundareis.
import matplotlib.pyplot as plt
import numpy as np
# Node styling configurations
DECISION_NODE = dict(boxstyle='sawtooth', fc='0.8')
LEAF_NODE = dict(boxstyle='round4', fc='0.8')
ARROW_STYLE = dict(arrowstyle='->')
def draw_node(text, center_position, parent_position, node_type):
"""Render a single node with connecting arrow."""
ax.annotate(text,
xy=parent_position,
xycoords='axes fraction',
xytext=center_position,
textcoords='axes fraction',
va='center', ha='center',
bbox=node_type,
arrowprops=ARROW_STYLE)
def count_leaves(tree):
"""Count terminal nodes in the tree."""
first_key = list(tree.keys())[0]
children = tree[first_key]
leaf_count = 0
for child in children.values():
if isinstance(child, dict):
leaf_count += count_leaves(child)
else:
leaf_count += 1
return leaf_count
def measure_depth(tree):
"""Calculate maximum tree depth."""
first_key = list(tree.keys())[0]
children = tree[first_key]
max_depth = 0
for child in children.values():
if isinstance(child, dict):
depth = 1 + measure_depth(child)
else:
depth = 1
max_depth = max(max_depth, depth)
return max_depth
def plot_split_position(parent, child, label):
"""Place text label between parent and child nodes."""
mid_x = (parent[0] - child[0]) / 2.0 + child[0]
mid_y = (parent[1] - child[1]) / 2.0 + child[1]
ax.text(mid_x, mid_y, label, fontsize=10)
def render_tree(tree, parent_pos, node_label):
"""Recursively render the decision tree."""
leaf_count = count_leaves(tree)
tree_depth = measure_depth(tree)
root_key = list(tree.keys())[0]
center = (x_offset + (1.0 + leaf_count) / 2.0 / total_width,
y_offset)
plot_split_position(center, parent_pos, node_label)
draw_node(root_key, center, parent_pos, DECISION_NODE)
children = tree[root_key]
y_offset -= 1.0 / total_depth
for value, subtree in children.items():
if isinstance(subtree, dict):
render_tree(subtree, center, str(value))
else:
x_offset += 1.0 / total_width
draw_node(subtree, (x_offset, y_offset), center, LEAF_NODE)
plot_split_position((x_offset, y_offset), center, str(value))
y_offset += 1.0 / total_depth
def visualize(tree):
"""Create the complete tree visualization."""
global ax, x_offset, y_offset, total_width, total_depth
fig = plt.figure(figsize=(12, 8), facecolor='white')
ax = plt.subplot(111, frameon=False)
ax.set_xticks([])
ax.set_yticks([])
total_width = float(count_leaves(tree))
total_depth = float(measure_depth(tree))
x_offset = -0.5 / total_width
y_offset = 1.0
render_tree(tree, (0.5, 1.0), '')
plt.tight_layout()
plt.savefig('decision_tree.pdf', dpi=150)
plt.show()
The visualization places leaf nodes at the bottom and decision nodes higher, with edge labels showing the feature values that lead to each branch. This representation makes the decision-making process transparent and easy to interpret.