1# $Id$
2#
3#  Copyright (C) 2003 Rational Discovery LLC
4#      All Rights Reserved
5#
6""" Define the class _KNNRegressionModel_, used to represent a k-nearest neighbhors
7regression model
8
9    Inherits from _KNNModel_
10"""
11
12from rdkit.ML.KNN import KNNModel
13
14
15class KNNRegressionModel(KNNModel.KNNModel):
16  """ This is used to represent a k-nearest neighbor classifier
17
18  """
19
20  def __init__(self, k, attrs, dfunc, radius=None):
21    self._setup(k, attrs, dfunc, radius)
22
23    self._badExamples = []  # list of examples incorrectly classified
24
25  def type(self):
26    return "Regression Model"
27
28  def SetBadExamples(self, examples):
29    self._badExamples = examples
30
31  def GetBadExamples(self):
32    return self._badExamples
33
34  def NameModel(self, varNames):
35    self.SetName(self.type())
36
37  def PredictExample(self, example, appendExamples=0, weightedAverage=0, neighborList=None):
38    """ Generates a prediction for an example by looking at its closest neighbors
39
40    **Arguments**
41
42      - examples: the example to be classified
43
44      - appendExamples: if this is nonzero then the example will be stored on this model
45
46      - weightedAverage: if provided, the neighbors' contributions to the value will be
47                         weighed by their reciprocal square distance
48
49      - neighborList: if provided, will be used to return the list of neighbors
50
51    **Returns**
52
53      - the classification of _example_
54
55    """
56    if appendExamples:
57      self._examples.append(example)
58
59    # first find the k-closest examples in the training set
60    knnLst = self.GetNeighbors(example)
61
62    accum = 0.0
63    denom = 0.0
64    for knn in knnLst:
65      if knn[1] is None:
66        continue
67      if weightedAverage:
68        dist = knn[0]
69        if dist == 0.0:
70          w = 1.
71        else:
72          w = 1. / dist
73      else:
74        w = 1.0
75      accum += w * knn[1][-1]
76      denom += w
77    if denom:
78      accum /= denom
79    if neighborList is not None:
80      neighborList.extend(knnLst)
81    return accum
82