Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Implementing the ID3 Decision Tree Algorithm from Scratch

Tech 1

Core Concepts

Decision trees come in several variants including CART, ID3, and C4.5. While CART relies on Gini impurity, ID3 and C4.5 both leverage information entropy for splitting criteria. This implementation focuses on the ID3 algorithm.

Information Theory Foundations:

  • $p(a_i)$: Probability of event $a_i$ occurring
  • $I(a_i) = -\log_2(p(a_i))$: Uncertainty measure for event $a_i$, known as self-information
  • $H = \sum(p(a_i) \cdot I(a_i))$: Average informasion content of source S, or Shannon entropy
  • $Gain = BaseEntropy - newEntropy$: Information gain from a split

The ID3 algorithm constructs decision trees using a top-down recursive approach. The fundamental principle is building a tree that decreases entropy most rapid, reaching zero entropy at leaf nodes where all instances share the same class label. ID3 selects splits by maximizing information gain. For a binary classification problem with $p$ positive and $n$ negative examples, the initial entropy is computed as:

$$H = -\frac{p}{p+n}\log_2\frac{p}{p+n} - \frac{n}{p+n}\log_2\frac{n}{p+n}$$

where $N$ represents the number of distinct values for a feature (e.g., {rain, sunny} has N=2).

Limitation: ID3 tends to favor attributes with more possible values, especially continuous ones, since they produce higher information gain. This leads to shallow but wide trees with poor generalization. C4.5 addresses this by using a gain ratio, which normalizes the information gain by the split information:

$$GainRatio = \frac{Gain}{SplitInfo}$$

where $SplitInfo = -\sum\frac{|S_v|}{|S|}\log_2\frac{|S_v|}{|S|}$

Algorithm Characteristics

Advantages:

  • Low computational complexity
  • Interpretable output
  • Handles missing values gracefully
  • Works with irrelevant features

Disadvantages:

  • Prone to overfitting

Applicable Data Types: Numerical and categorical

Python Implementation

Consider this dataset containing 5 marine creatures with two features: whether they can survive without surfacing, and whether they have flippers. The target is binary classification (fish vs. non-fish).

ID Survives Without Surfacing Has Flippers Fish
1 Yes Yes Yes
2 Yes Yes Yes
3 Yes No No
4 No Yes No
5 No Yes No

In array format:

dataSet = [
    [1, 1, 'yes'],  # can survive without surfacing, has flippers, is fish
    [1, 1, 'yes'],
    [0, 1, 'no'],
    [0, 1, 'no'],
    [0, 1, 'no']
]
labels = ['survives_without_air', 'has_flippers']

Computing Shannon Entropy

from math import log

def shannon_entropy(dataset):
    """Calculate the Shannon entropy of a dataset."""
    entry_count = len(dataset)
    label_frequencies = {}
    
    for record in dataset:
        category = record[-1]
        label_frequencies[category] = label_frequencies.get(category, 0) + 1
    
    entropy = 0.0
    for count in label_frequencies.values():
        probability = count / entry_count
        entropy -= probability * log(probability, 2)
    
    return entropy

Testing the entropy function:

>>> dataset, feature_names = create_dataset()
>>> shannon_entropy(dataset)
0.9709505944546686

>>> dataset[0][-1] = 'maybe'  # Add third class
>>> shannon_entropy(dataset)
1.3709505944546687

Higher entropy indicates more mixed data. Adding more class diversity increases entropy.

Dataset Creation Utility

def create_dataset():
    """Generate the sample marine creature classification dataset."""
    data = [[1, 1, 'yes'],
            [1, 1, 'yes'],
            [0, 1, 'no'],
            [0, 1, 'no'],
            [0, 1, 'no']]
    features = ['survives_without_air', 'has_flippers']
    return data, features

Splitting Datasets by Feature Value

To evaluate which feature provides the best split, we need to partition the dataset for each possible feature value and compute entropy for each partition.

def partition_by_feature(dataset, feature_index, feature_value):
    """Split dataset based on a specific feature value.
    
    Args:
        dataset: Input data to split
        feature_index: Which column to split on
        feature_value: Value to filter on
    
    Returns:
        Filtered dataset excluding the split column
    """
    filtered_records = []
    
    for record in dataset:
        if record[feature_index] == feature_value:
            reduced_record = record[:feature_index] + record[feature_index+1:]
            filtered_records.append(reduced_record)
    
    return filtered_records

Usage examples:

>>> dataset, _ = create_dataset()
>>> partition_by_feature(dataset, 0, 1)  # Filter by feature 0 = 1
[[1, 'yes'], [1, 'yes']]
>>> partition_by_feature(dataset, 0, 0)  # Filter by feature 0 = 0
[[1, 'no'], [1, 'no'], [1, 'no']]

Selecting the Optimal Split Feature

def select_best_feature(dataset):
    """Identify the feature that produces maximum information gain."""
    feature_count = len(dataset[0]) - 1  # Exclude label column
    base_entropy = shannon_entropy(dataset)
    
    best_gain = 0.0
    best_feature_index = -1
    
    for i in range(feature_count):
        feature_values = [record[i] for record in dataset]
        unique_values = set(feature_values)
        
        split_entropy = 0.0
        for value in unique_values:
            subset = partition_by_feature(dataset, i, value)
            weight = len(subset) / len(dataset)
            split_entropy += weight * shannon_entropy(subset)
        
        information_gain = base_entropy - split_entropy
        
        if information_gain > best_gain:
            best_gain = information_gain
            best_feature_index = i
    
    return best_feature_index

Prerequisites: All rows must have equal length, and the final column must contain class labels.

>>> dataset, _ = create_dataset()
>>> select_best_feature(dataset)
0  # Feature 0 is the best split

Building the Decision Tree

from collections import Counter

def majority_vote(class_labels):
    """Return the most common class label."""
    counter = Counter(class_labels)
    return counter.most_common(1)[0][0]

def build_tree(dataset, feature_labels):
    """Recursively construct a decision tree.
    
    Uses a dictionary structure to represent the tree:
    {feature_name: {feature_value: subtree_or_label, ...}}
    """
    class_labels = [record[-1] for record in dataset]
    
    # Stopping condition: all instances have the same class
    if len(set(class_labels)) == 1:
        return class_labels[0]
    
    # Stopping condition: no more features to split on
    if len(dataset[0]) == 1:
        return majority_vote(class_labels)
    
    best_idx = select_best_feature(dataset)
    best_label = feature_labels[best_idx]
    
    tree = {best_label: {}}
    del feature_labels[best_idx]  # Remove used feature
    
    feature_values = [record[best_idx] for record in dataset]
    unique_values = set(feature_values)
    
    for value in unique_values:
        sub_labels = feature_labels[:]
        subtree = build_tree(
            partition_by_feature(dataset, best_idx, value),
            sub_labels
        )
        tree[best_label][value] = subtree
    
    return tree
>>> dataset, features = create_dataset()
>>> decision_tree = build_tree(dataset, features)
>>> decision_tree
{'survives_without_air': {0: 'no', 1: 'yes'}}

When all features are exhausted but class labels remain impure, majority voting determines the leaf classification.

Classification with the Trained Tree

def predict(tree, feature_labels, test_instance):
    """Classify a new instance using the decision tree."""
    root_feature = list(tree.keys())[0]
    children = tree[root_feature]
    feature_idx = feature_labels.index(root_feature)
    
    for value in children.keys():
        if test_instance[feature_idx] == value:
            subtree = children[value]
            if isinstance(subtree, dict):
                return predict(subtree, feature_labels, test_instance)
            else:
                return subtree
>>> tree, features = build_tree(*create_dataset()), ['survives_without_air', 'has_flippers']
>>> predict(tree, features, [1, 0])
'no'
>>> predict(tree, features, [1, 1])
'yes'

Persisting the Trained Model

Decision tree construction is computationally expensive. Serialization allows reusing a trained model without rebuilding it.

import pickle

def save_tree(tree, filepath):
    """Serialize and save the decision tree to disk."""
    with open(filepath, 'wb') as f:
        pickle.dump(tree, f)

def load_tree(filepath):
    """Load a previously saved decision tree from disk."""
    with open(filepath, 'rb') as f:
        return pickle.load(f)
>>> tree, features = build_tree(*create_dataset()), ['survives_without_air', 'has_flippers']
>>> save_tree(tree, 'classifier.pkl')
>>> loaded = load_tree('classifier.pkl')
>>> loaded
{'survives_without_air': {0: 'no', 1: {'has_flippers': {0: 'no', 1: 'yes'}}}}

Visualizing Decision Trees

Using matplotlib for tree visualization helps understand the learned decision boundareis.

import matplotlib.pyplot as plt
import numpy as np

# Node styling configurations
DECISION_NODE = dict(boxstyle='sawtooth', fc='0.8')
LEAF_NODE = dict(boxstyle='round4', fc='0.8')
ARROW_STYLE = dict(arrowstyle='->')

def draw_node(text, center_position, parent_position, node_type):
    """Render a single node with connecting arrow."""
    ax.annotate(text,
                xy=parent_position,
                xycoords='axes fraction',
                xytext=center_position,
                textcoords='axes fraction',
                va='center', ha='center',
                bbox=node_type,
                arrowprops=ARROW_STYLE)

def count_leaves(tree):
    """Count terminal nodes in the tree."""
    first_key = list(tree.keys())[0]
    children = tree[first_key]
    leaf_count = 0
    
    for child in children.values():
        if isinstance(child, dict):
            leaf_count += count_leaves(child)
        else:
            leaf_count += 1
    
    return leaf_count

def measure_depth(tree):
    """Calculate maximum tree depth."""
    first_key = list(tree.keys())[0]
    children = tree[first_key]
    max_depth = 0
    
    for child in children.values():
        if isinstance(child, dict):
            depth = 1 + measure_depth(child)
        else:
            depth = 1
        max_depth = max(max_depth, depth)
    
    return max_depth

def plot_split_position(parent, child, label):
    """Place text label between parent and child nodes."""
    mid_x = (parent[0] - child[0]) / 2.0 + child[0]
    mid_y = (parent[1] - child[1]) / 2.0 + child[1]
    ax.text(mid_x, mid_y, label, fontsize=10)

def render_tree(tree, parent_pos, node_label):
    """Recursively render the decision tree."""
    leaf_count = count_leaves(tree)
    tree_depth = measure_depth(tree)
    
    root_key = list(tree.keys())[0]
    center = (x_offset + (1.0 + leaf_count) / 2.0 / total_width,
              y_offset)
    
    plot_split_position(center, parent_pos, node_label)
    draw_node(root_key, center, parent_pos, DECISION_NODE)
    
    children = tree[root_key]
    y_offset -= 1.0 / total_depth
    
    for value, subtree in children.items():
        if isinstance(subtree, dict):
            render_tree(subtree, center, str(value))
        else:
            x_offset += 1.0 / total_width
            draw_node(subtree, (x_offset, y_offset), center, LEAF_NODE)
            plot_split_position((x_offset, y_offset), center, str(value))
    
    y_offset += 1.0 / total_depth

def visualize(tree):
    """Create the complete tree visualization."""
    global ax, x_offset, y_offset, total_width, total_depth
    
    fig = plt.figure(figsize=(12, 8), facecolor='white')
    ax = plt.subplot(111, frameon=False)
    ax.set_xticks([])
    ax.set_yticks([])
    
    total_width = float(count_leaves(tree))
    total_depth = float(measure_depth(tree))
    x_offset = -0.5 / total_width
    y_offset = 1.0
    
    render_tree(tree, (0.5, 1.0), '')
    plt.tight_layout()
    plt.savefig('decision_tree.pdf', dpi=150)
    plt.show()

The visualization places leaf nodes at the bottom and decision nodes higher, with edge labels showing the feature values that lead to each branch. This representation makes the decision-making process transparent and easy to interpret.

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.