1# Copyright 2002 by Jeffrey Chang.
2# All rights reserved.
3#
4# This file is part of the Biopython distribution and governed by your
5# choice of the "Biopython License Agreement" or the "BSD 3-Clause License".
6# Please see the LICENSE file that should have been included as part of this
7# package.
8"""Code for doing k-nearest-neighbors classification.
9
10k Nearest Neighbors is a supervised learning algorithm that classifies
11a new observation based the classes in its surrounding neighborhood.
12
13Glossary:
14 - distance   The distance between two points in the feature space.
15 - weight     The importance given to each point for classification.
16
17Classes:
18 - kNN           Holds information for a nearest neighbors classifier.
19
20
21Functions:
22 - train        Train a new kNN classifier.
23 - calculate    Calculate the probabilities of each class, given an observation.
24 - classify     Classify an observation into a class.
25
26Weighting Functions:
27 - equal_weight    Every example is given a weight of 1.
28
29"""
30
31import numpy
32
33
34class kNN:
35    """Holds information necessary to do nearest neighbors classification.
36
37    Attribues:
38     - classes  Set of the possible classes.
39     - xs       List of the neighbors.
40     - ys       List of the classes that the neighbors belong to.
41     - k        Number of neighbors to look at.
42    """
43
44    def __init__(self):
45        """Initialize the class."""
46        self.classes = set()
47        self.xs = []
48        self.ys = []
49        self.k = None
50
51
52def equal_weight(x, y):
53    """Return integer one (dummy method for equally weighting)."""
54    # everything gets 1 vote
55    return 1
56
57
58def train(xs, ys, k, typecode=None):
59    """Train a k nearest neighbors classifier on a training set.
60
61    xs is a list of observations and ys is a list of the class assignments.
62    Thus, xs and ys should contain the same number of elements.  k is
63    the number of neighbors that should be examined when doing the
64    classification.
65    """
66    knn = kNN()
67    knn.classes = set(ys)
68    knn.xs = numpy.asarray(xs, typecode)
69    knn.ys = ys
70    knn.k = k
71    return knn
72
73
74def calculate(knn, x, weight_fn=None, distance_fn=None):
75    """Calculate the probability for each class.
76
77    Arguments:
78     - x is the observed data.
79     - weight_fn is an optional function that takes x and a training
80       example, and returns a weight.
81     - distance_fn is an optional function that takes two points and
82       returns the distance between them.  If distance_fn is None (the
83       default), the Euclidean distance is used.
84
85    Returns a dictionary of the class to the weight given to the class.
86    """
87    if weight_fn is None:
88        weight_fn = equal_weight
89
90    x = numpy.asarray(x)
91
92    order = []  # list of (distance, index)
93    if distance_fn:
94        for i in range(len(knn.xs)):
95            dist = distance_fn(x, knn.xs[i])
96            order.append((dist, i))
97    else:
98        # Default: Use a fast implementation of the Euclidean distance
99        temp = numpy.zeros(len(x))
100        # Predefining temp allows reuse of this array, making this
101        # function about twice as fast.
102        for i in range(len(knn.xs)):
103            temp[:] = x - knn.xs[i]
104            dist = numpy.sqrt(numpy.dot(temp, temp))
105            order.append((dist, i))
106    order.sort()
107
108    # first 'k' are the ones I want.
109    weights = {}  # class -> number of votes
110    for k in knn.classes:
111        weights[k] = 0.0
112    for dist, i in order[: knn.k]:
113        klass = knn.ys[i]
114        weights[klass] = weights[klass] + weight_fn(x, knn.xs[i])
115
116    return weights
117
118
119def classify(knn, x, weight_fn=None, distance_fn=None):
120    """Classify an observation into a class.
121
122    If not specified, weight_fn will give all neighbors equal weight.
123    distance_fn is an optional function that takes two points and returns
124    the distance between them.  If distance_fn is None (the default),
125    the Euclidean distance is used.
126    """
127    if weight_fn is None:
128        weight_fn = equal_weight
129
130    weights = calculate(knn, x, weight_fn=weight_fn, distance_fn=distance_fn)
131
132    most_class = None
133    most_weight = None
134    for klass, weight in weights.items():
135        if most_class is None or weight > most_weight:
136            most_class = klass
137            most_weight = weight
138    return most_class
139