1# Copyright 2002 by Jeffrey Chang. 2# All rights reserved. 3# 4# This file is part of the Biopython distribution and governed by your 5# choice of the "Biopython License Agreement" or the "BSD 3-Clause License". 6# Please see the LICENSE file that should have been included as part of this 7# package. 8"""Code for doing k-nearest-neighbors classification. 9 10k Nearest Neighbors is a supervised learning algorithm that classifies 11a new observation based the classes in its surrounding neighborhood. 12 13Glossary: 14 - distance The distance between two points in the feature space. 15 - weight The importance given to each point for classification. 16 17Classes: 18 - kNN Holds information for a nearest neighbors classifier. 19 20 21Functions: 22 - train Train a new kNN classifier. 23 - calculate Calculate the probabilities of each class, given an observation. 24 - classify Classify an observation into a class. 25 26Weighting Functions: 27 - equal_weight Every example is given a weight of 1. 28 29""" 30 31import numpy 32 33 34class kNN: 35 """Holds information necessary to do nearest neighbors classification. 36 37 Attribues: 38 - classes Set of the possible classes. 39 - xs List of the neighbors. 40 - ys List of the classes that the neighbors belong to. 41 - k Number of neighbors to look at. 42 """ 43 44 def __init__(self): 45 """Initialize the class.""" 46 self.classes = set() 47 self.xs = [] 48 self.ys = [] 49 self.k = None 50 51 52def equal_weight(x, y): 53 """Return integer one (dummy method for equally weighting).""" 54 # everything gets 1 vote 55 return 1 56 57 58def train(xs, ys, k, typecode=None): 59 """Train a k nearest neighbors classifier on a training set. 60 61 xs is a list of observations and ys is a list of the class assignments. 62 Thus, xs and ys should contain the same number of elements. k is 63 the number of neighbors that should be examined when doing the 64 classification. 65 """ 66 knn = kNN() 67 knn.classes = set(ys) 68 knn.xs = numpy.asarray(xs, typecode) 69 knn.ys = ys 70 knn.k = k 71 return knn 72 73 74def calculate(knn, x, weight_fn=None, distance_fn=None): 75 """Calculate the probability for each class. 76 77 Arguments: 78 - x is the observed data. 79 - weight_fn is an optional function that takes x and a training 80 example, and returns a weight. 81 - distance_fn is an optional function that takes two points and 82 returns the distance between them. If distance_fn is None (the 83 default), the Euclidean distance is used. 84 85 Returns a dictionary of the class to the weight given to the class. 86 """ 87 if weight_fn is None: 88 weight_fn = equal_weight 89 90 x = numpy.asarray(x) 91 92 order = [] # list of (distance, index) 93 if distance_fn: 94 for i in range(len(knn.xs)): 95 dist = distance_fn(x, knn.xs[i]) 96 order.append((dist, i)) 97 else: 98 # Default: Use a fast implementation of the Euclidean distance 99 temp = numpy.zeros(len(x)) 100 # Predefining temp allows reuse of this array, making this 101 # function about twice as fast. 102 for i in range(len(knn.xs)): 103 temp[:] = x - knn.xs[i] 104 dist = numpy.sqrt(numpy.dot(temp, temp)) 105 order.append((dist, i)) 106 order.sort() 107 108 # first 'k' are the ones I want. 109 weights = {} # class -> number of votes 110 for k in knn.classes: 111 weights[k] = 0.0 112 for dist, i in order[: knn.k]: 113 klass = knn.ys[i] 114 weights[klass] = weights[klass] + weight_fn(x, knn.xs[i]) 115 116 return weights 117 118 119def classify(knn, x, weight_fn=None, distance_fn=None): 120 """Classify an observation into a class. 121 122 If not specified, weight_fn will give all neighbors equal weight. 123 distance_fn is an optional function that takes two points and returns 124 the distance between them. If distance_fn is None (the default), 125 the Euclidean distance is used. 126 """ 127 if weight_fn is None: 128 weight_fn = equal_weight 129 130 weights = calculate(knn, x, weight_fn=weight_fn, distance_fn=distance_fn) 131 132 most_class = None 133 most_weight = None 134 for klass, weight in weights.items(): 135 if most_class is None or weight > most_weight: 136 most_class = klass 137 most_weight = weight 138 return most_class 139