Implementing K-Nearest Neighbors from Scratch in Python
Understanding the KNN Algorithm
The K-Nearest Neighbors (KNN) algorithm is a non-parametric method used for classification and regression. In the context of classification, it operates on a simple principle: similar data points tend to belong to similar categories.
The process involves a training dataset where every entry has a known label. When a new, unlabeled data point arrives, the algorithm calculates the distance (typically Euclidean) between this new point and every point in the training set. It then identifies the k closest neighbors. The new point is assigned the class that is most frequent among these k neighbors. Usually, k is an integer value, often kept below 20 to avoid over-generalization.
Example: Genre Classification
Consider a scenario where we need to classify a movie as either a "Romance" or "Action" based on two features: the number of fight scenes and kiss scenes.
After computing the Euclidean distances between the unknown movie and the training set, we get the following results:
If we set k=3, the three nearest neighbors are all Romance movies ("He's Not Really into Dudes", "Beautiful Woman", and "California Man"). Consequently, the algorithm classifeis the unknown movie as a Romance.
Standard Workflow
- Data Collection: Gather the dataset using any available method.
- Data Preparation: Structure the data into a numerical format suitable for distance calculation.
- Exploratory Analysis: Visualize or analyze the data distribution.
- Algorithm Training: Note that KNN does not have a training phase; it simply stores the dataset.
- Testing: Evaluate the model by calculating the error rate on a test set.
- Deployment: Use the model to predict categories for new input vectors.
Python Implementation
Below is a custom implementation of the KNN algorithm using Python and NumPy. We avoid using high-level ML libraries to demonstrate the underlying math.
import numpy as np
import operator
def init_dataset():
"""Creates a mock dataset for testing."""
features = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
categories = ['Group_A', 'Group_A', 'Group_B', 'Group_B']
return features, categories
def knn_predict(input_vector, training_set, labels, k):
"""
Classifies a data point using the KNN algorithm.
:param input_vector: The new data point to classify.
:param training_set: The training data (NumPy array).
:param labels: List of labels corresponding to the training data.
:param k: The number of neighbors to consider.
:return: The predicted label.
"""
# 1. Calculate Euclidean distances
num_samples = training_set.shape[0]
# Create a matrix of input_vector replicated num_samples times
diff_matrix = np.tile(input_vector, (num_samples, 1)) - training_set
squared_diff = diff_matrix ** 2
squared_dist = squared_diff.sum(axis=1)
distances = squared_dist ** 0.5
# 2. Get indices of sorted distances (ascending)
sorted_indices = distances.argsort()
# 3. Count votes from the k nearest neighbors
vote_counter = {}
for i in range(k):
nearest_label = labels[sorted_indices[i]]
vote_counter[nearest_label] = vote_counter.get(nearest_label, 0) + 1
# 4. Sort the votes to find the winner
sorted_votes = sorted(vote_counter.items(), key=operator.itemgetter(1), reverse=True)
return sorted_votes[0][0]
Execution Example
Running the script in a Python environment yields the following output:
>>> import my_knn_module
>>> data, tags = my_knn_module.init_dataset()
>>> print(data)
[[1. 1.1]
[1. 1. ]
[0. 0. ]
[0. 0.1]]
>>> my_knn_module.knn_predict([0, 0], data, tags, 3)
'Group_B'
Pros and Cons
Advantages:
- High Accuracy: Effective for many simple classifciation tasks.
- Robustness: Not sensitive to outliers because the classification is based on a majority vote of neighbors.
- No Assumptions: Makes no underlying assumptions about the data distribution.
Disadvantages:
- Computasionally Expensive: Must calculate distance to every single sample in the dataset for every prediction.
- Memory Intensive: Requires storing the entire training dataset in memory.
- No Insight: Does not provide information about the underlying structure or features of the data.
Data Compatibility: Works best with Numerical data (continuous values) and Nominal data (categorical values with no order).