1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package fuzzy implements a fuzzy matching algorithm.
6package fuzzy
7
8import (
9	"bytes"
10	"fmt"
11)
12
13const (
14	// MaxInputSize is the maximum size of the input scored against the fuzzy matcher. Longer inputs
15	// will be truncated to this size.
16	MaxInputSize = 127
17	// MaxPatternSize is the maximum size of the pattern used to construct the fuzzy matcher. Longer
18	// inputs are truncated to this size.
19	MaxPatternSize = 63
20)
21
22type scoreVal int
23
24func (s scoreVal) val() int {
25	return int(s) >> 1
26}
27
28func (s scoreVal) prevK() int {
29	return int(s) & 1
30}
31
32func score(val int, prevK int /*0 or 1*/) scoreVal {
33	return scoreVal(val<<1 + prevK)
34}
35
36// Matcher implements a fuzzy matching algorithm for scoring candidates against a pattern.
37// The matcher does not support parallel usage.
38type Matcher struct {
39	pattern       string
40	patternLower  []byte // lower-case version of the pattern
41	patternShort  []byte // first characters of the pattern
42	caseSensitive bool   // set if the pattern is mix-cased
43
44	patternRoles []RuneRole // the role of each character in the pattern
45	roles        []RuneRole // the role of each character in the tested string
46
47	scores [MaxInputSize + 1][MaxPatternSize + 1][2]scoreVal
48
49	scoreScale float32
50
51	lastCandidateLen     int // in bytes
52	lastCandidateMatched bool
53
54	// Here we save the last candidate in lower-case. This is basically a byte slice we reuse for
55	// performance reasons, so the slice is not reallocated for every candidate.
56	lowerBuf [MaxInputSize]byte
57	rolesBuf [MaxInputSize]RuneRole
58}
59
60func (m *Matcher) bestK(i, j int) int {
61	if m.scores[i][j][0].val() < m.scores[i][j][1].val() {
62		return 1
63	}
64	return 0
65}
66
67// NewMatcher returns a new fuzzy matcher for scoring candidates against the provided pattern.
68func NewMatcher(pattern string) *Matcher {
69	if len(pattern) > MaxPatternSize {
70		pattern = pattern[:MaxPatternSize]
71	}
72
73	m := &Matcher{
74		pattern:      pattern,
75		patternLower: ToLower(pattern, nil),
76	}
77
78	for i, c := range m.patternLower {
79		if pattern[i] != c {
80			m.caseSensitive = true
81			break
82		}
83	}
84
85	if len(pattern) > 3 {
86		m.patternShort = m.patternLower[:3]
87	} else {
88		m.patternShort = m.patternLower
89	}
90
91	m.patternRoles = RuneRoles(pattern, nil)
92
93	if len(pattern) > 0 {
94		maxCharScore := 4
95		m.scoreScale = 1 / float32(maxCharScore*len(pattern))
96	}
97
98	return m
99}
100
101// Score returns the score returned by matching the candidate to the pattern.
102// This is not designed for parallel use. Multiple candidates must be scored sequentially.
103// Returns a score between 0 and 1 (0 - no match, 1 - perfect match).
104func (m *Matcher) Score(candidate string) float32 {
105	if len(candidate) > MaxInputSize {
106		candidate = candidate[:MaxInputSize]
107	}
108	lower := ToLower(candidate, m.lowerBuf[:])
109	m.lastCandidateLen = len(candidate)
110
111	if len(m.pattern) == 0 {
112		// Empty patterns perfectly match candidates.
113		return 1
114	}
115
116	if m.match(candidate, lower) {
117		sc := m.computeScore(candidate, lower)
118		if sc > minScore/2 && !m.poorMatch() {
119			m.lastCandidateMatched = true
120			if len(m.pattern) == len(candidate) {
121				// Perfect match.
122				return 1
123			}
124
125			if sc < 0 {
126				sc = 0
127			}
128			normalizedScore := float32(sc) * m.scoreScale
129			if normalizedScore > 1 {
130				normalizedScore = 1
131			}
132
133			return normalizedScore
134		}
135	}
136
137	m.lastCandidateMatched = false
138	return 0
139}
140
141const minScore = -10000
142
143// MatchedRanges returns matches ranges for the last scored string as a flattened array of
144// [begin, end) byte offset pairs.
145func (m *Matcher) MatchedRanges() []int {
146	if len(m.pattern) == 0 || !m.lastCandidateMatched {
147		return nil
148	}
149	i, j := m.lastCandidateLen, len(m.pattern)
150	if m.scores[i][j][0].val() < minScore/2 && m.scores[i][j][1].val() < minScore/2 {
151		return nil
152	}
153
154	var ret []int
155	k := m.bestK(i, j)
156	for i > 0 {
157		take := (k == 1)
158		k = m.scores[i][j][k].prevK()
159		if take {
160			if len(ret) == 0 || ret[len(ret)-1] != i {
161				ret = append(ret, i)
162				ret = append(ret, i-1)
163			} else {
164				ret[len(ret)-1] = i - 1
165			}
166			j--
167		}
168		i--
169	}
170	// Reverse slice.
171	for i := 0; i < len(ret)/2; i++ {
172		ret[i], ret[len(ret)-1-i] = ret[len(ret)-1-i], ret[i]
173	}
174	return ret
175}
176
177func (m *Matcher) match(candidate string, candidateLower []byte) bool {
178	i, j := 0, 0
179	for ; i < len(candidateLower) && j < len(m.patternLower); i++ {
180		if candidateLower[i] == m.patternLower[j] {
181			j++
182		}
183	}
184	if j != len(m.patternLower) {
185		return false
186	}
187
188	// The input passes the simple test against pattern, so it is time to classify its characters.
189	// Character roles are used below to find the last segment.
190	m.roles = RuneRoles(candidate, m.rolesBuf[:])
191
192	return true
193}
194
195func (m *Matcher) computeScore(candidate string, candidateLower []byte) int {
196	pattLen, candLen := len(m.pattern), len(candidate)
197
198	for j := 0; j <= len(m.pattern); j++ {
199		m.scores[0][j][0] = minScore << 1
200		m.scores[0][j][1] = minScore << 1
201	}
202	m.scores[0][0][0] = score(0, 0) // Start with 0.
203
204	segmentsLeft, lastSegStart := 1, 0
205	for i := 0; i < candLen; i++ {
206		if m.roles[i] == RSep {
207			segmentsLeft++
208			lastSegStart = i + 1
209		}
210	}
211
212	// A per-character bonus for a consecutive match.
213	consecutiveBonus := 2
214	wordIdx := 0 // Word count within segment.
215	for i := 1; i <= candLen; i++ {
216
217		role := m.roles[i-1]
218		isHead := role == RHead
219
220		if isHead {
221			wordIdx++
222		} else if role == RSep && segmentsLeft > 1 {
223			wordIdx = 0
224			segmentsLeft--
225		}
226
227		var skipPenalty int
228		if i == 1 || (i-1) == lastSegStart {
229			// Skipping the start of first or last segment.
230			skipPenalty++
231		}
232
233		for j := 0; j <= pattLen; j++ {
234			// By default, we don't have a match. Fill in the skip data.
235			m.scores[i][j][1] = minScore << 1
236
237			// Compute the skip score.
238			k := 0
239			if m.scores[i-1][j][0].val() < m.scores[i-1][j][1].val() {
240				k = 1
241			}
242
243			skipScore := m.scores[i-1][j][k].val()
244			// Do not penalize missing characters after the last matched segment.
245			if j != pattLen {
246				skipScore -= skipPenalty
247			}
248			m.scores[i][j][0] = score(skipScore, k)
249
250			if j == 0 || candidateLower[i-1] != m.patternLower[j-1] {
251				// Not a match.
252				continue
253			}
254			pRole := m.patternRoles[j-1]
255
256			if role == RTail && pRole == RHead {
257				if j > 1 {
258					// Not a match: a head in the pattern matches a tail character in the candidate.
259					continue
260				}
261				// Special treatment for the first character of the pattern. We allow
262				// matches in the middle of a word if they are long enough, at least
263				// min(3, pattern.length) characters.
264				if !bytes.HasPrefix(candidateLower[i-1:], m.patternShort) {
265					continue
266				}
267			}
268
269			// Compute the char score.
270			var charScore int
271			// Bonus 1: the char is in the candidate's last segment.
272			if segmentsLeft <= 1 {
273				charScore++
274			}
275			// Bonus 2: Case match or a Head in the pattern aligns with one in the word.
276			// Single-case patterns lack segmentation signals and we assume any character
277			// can be a head of a segment.
278			if candidate[i-1] == m.pattern[j-1] || role == RHead && (!m.caseSensitive || pRole == RHead) {
279				charScore++
280			}
281
282			// Penalty 1: pattern char is Head, candidate char is Tail.
283			if role == RTail && pRole == RHead {
284				charScore--
285			}
286			// Penalty 2: first pattern character matched in the middle of a word.
287			if j == 1 && role == RTail {
288				charScore -= 4
289			}
290
291			// Third dimension encodes whether there is a gap between the previous match and the current
292			// one.
293			for k := 0; k < 2; k++ {
294				sc := m.scores[i-1][j-1][k].val() + charScore
295
296				isConsecutive := k == 1 || i-1 == 0 || i-1 == lastSegStart
297				if isConsecutive {
298					// Bonus 3: a consecutive match. First character match also gets a bonus to
299					// ensure prefix final match score normalizes to 1.0.
300					// Logically, this is a part of charScore, but we have to compute it here because it
301					// only applies for consecutive matches (k == 1).
302					sc += consecutiveBonus
303				}
304				if k == 0 {
305					// Penalty 3: Matching inside a segment (and previous char wasn't matched). Penalize for the lack
306					// of alignment.
307					if role == RTail || role == RUCTail {
308						sc -= 3
309					}
310				}
311
312				if sc > m.scores[i][j][1].val() {
313					m.scores[i][j][1] = score(sc, k)
314				}
315			}
316		}
317	}
318
319	result := m.scores[len(candidate)][len(m.pattern)][m.bestK(len(candidate), len(m.pattern))].val()
320
321	return result
322}
323
324// ScoreTable returns the score table computed for the provided candidate. Used only for debugging.
325func (m *Matcher) ScoreTable(candidate string) string {
326	var buf bytes.Buffer
327
328	var line1, line2, separator bytes.Buffer
329	line1.WriteString("\t")
330	line2.WriteString("\t")
331	for j := 0; j < len(m.pattern); j++ {
332		line1.WriteString(fmt.Sprintf("%c\t\t", m.pattern[j]))
333		separator.WriteString("----------------")
334	}
335
336	buf.WriteString(line1.String())
337	buf.WriteString("\n")
338	buf.WriteString(separator.String())
339	buf.WriteString("\n")
340
341	for i := 1; i <= len(candidate); i++ {
342		line1.Reset()
343		line2.Reset()
344
345		line1.WriteString(fmt.Sprintf("%c\t", candidate[i-1]))
346		line2.WriteString("\t")
347
348		for j := 1; j <= len(m.pattern); j++ {
349			line1.WriteString(fmt.Sprintf("M%6d(%c)\t", m.scores[i][j][0].val(), dir(m.scores[i][j][0].prevK())))
350			line2.WriteString(fmt.Sprintf("H%6d(%c)\t", m.scores[i][j][1].val(), dir(m.scores[i][j][1].prevK())))
351		}
352		buf.WriteString(line1.String())
353		buf.WriteString("\n")
354		buf.WriteString(line2.String())
355		buf.WriteString("\n")
356		buf.WriteString(separator.String())
357		buf.WriteString("\n")
358	}
359
360	return buf.String()
361}
362
363func dir(prevK int) rune {
364	if prevK == 0 {
365		return 'M'
366	}
367	return 'H'
368}
369
370func (m *Matcher) poorMatch() bool {
371	if len(m.pattern) < 2 {
372		return false
373	}
374
375	i, j := m.lastCandidateLen, len(m.pattern)
376	k := m.bestK(i, j)
377
378	var counter, len int
379	for i > 0 {
380		take := (k == 1)
381		k = m.scores[i][j][k].prevK()
382		if take {
383			len++
384			if k == 0 && len < 3 && m.roles[i-1] == RTail {
385				// Short match in the middle of a word
386				counter++
387				if counter > 1 {
388					return true
389				}
390			}
391			j--
392		} else {
393			len = 0
394		}
395		i--
396	}
397	return false
398}
399