1//  Copyright (c) 2018 Couchbase, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// 		http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package levenshtein
16
17import (
18	"fmt"
19	"math"
20)
21
22const SinkState = uint32(0)
23
24type DFA struct {
25	transitions [][256]uint32
26	distances   []Distance
27	initState   int
28	ed          uint8
29}
30
31/// Returns the initial state
32func (d *DFA) initialState() int {
33	return d.initState
34}
35
36/// Returns the Levenshtein distance associated to the
37/// current state.
38func (d *DFA) distance(stateId int) Distance {
39	return d.distances[stateId]
40}
41
42/// Returns the number of states in the `DFA`.
43func (d *DFA) numStates() int {
44	return len(d.transitions)
45}
46
47/// Returns the destination state reached after consuming a given byte.
48func (d *DFA) transition(fromState int, b uint8) int {
49	return int(d.transitions[fromState][b])
50}
51
52func (d *DFA) eval(bytes []uint8) Distance {
53	state := d.initialState()
54
55	for _, b := range bytes {
56		state = d.transition(state, b)
57	}
58
59	return d.distance(state)
60}
61
62func (d *DFA) Start() int {
63	return int(d.initialState())
64}
65
66func (d *DFA) IsMatch(state int) bool {
67	if _, ok := d.distance(state).(Exact); ok {
68		return true
69	}
70	return false
71}
72
73func (d *DFA) CanMatch(state int) bool {
74	return state > 0 && state < d.numStates()
75}
76
77func (d *DFA) Accept(state int, b byte) int {
78	return int(d.transition(state, b))
79}
80
81// WillAlwaysMatch returns if the specified state will always end in a
82// matching state.
83func (d *DFA) WillAlwaysMatch(state int) bool {
84	return false
85}
86
87func fill(dest []uint32, val uint32) {
88	for i := range dest {
89		dest[i] = val
90	}
91}
92
93func fillTransitions(dest *[256]uint32, val uint32) {
94	for i := range dest {
95		dest[i] = val
96	}
97}
98
99type Utf8DFAStateBuilder struct {
100	dfaBuilder       *Utf8DFABuilder
101	stateID          uint32
102	defaultSuccessor []uint32
103}
104
105func (sb *Utf8DFAStateBuilder) addTransitionID(fromStateID uint32, b uint8,
106	toStateID uint32) {
107	sb.dfaBuilder.transitions[fromStateID][b] = toStateID
108}
109
110func (sb *Utf8DFAStateBuilder) addTransition(in rune, toStateID uint32) {
111	fromStateID := sb.stateID
112	chars := []byte(string(in))
113	lastByte := chars[len(chars)-1]
114
115	for i, ch := range chars[:len(chars)-1] {
116		remNumBytes := len(chars) - i - 1
117		defaultSuccessor := sb.defaultSuccessor[remNumBytes]
118		intermediateStateID := sb.dfaBuilder.transitions[fromStateID][ch]
119
120		if intermediateStateID == defaultSuccessor {
121			intermediateStateID = sb.dfaBuilder.allocate()
122			fillTransitions(&sb.dfaBuilder.transitions[intermediateStateID],
123				sb.defaultSuccessor[remNumBytes-1])
124		}
125
126		sb.addTransitionID(fromStateID, ch, intermediateStateID)
127		fromStateID = intermediateStateID
128	}
129
130	toStateIDDecoded := sb.dfaBuilder.getOrAllocate(original(toStateID))
131	sb.addTransitionID(fromStateID, lastByte, toStateIDDecoded)
132}
133
134type Utf8StateId uint32
135
136func original(stateId uint32) Utf8StateId {
137	return predecessor(stateId, 0)
138}
139
140func predecessor(stateId uint32, numSteps uint8) Utf8StateId {
141	return Utf8StateId(stateId*4 + uint32(numSteps))
142}
143
144// Utf8DFABuilder makes it possible to define a DFA
145// that takes unicode character, and build a `DFA`
146// that operates on utf-8 encoded
147type Utf8DFABuilder struct {
148	index        []uint32
149	distances    []Distance
150	transitions  [][256]uint32
151	initialState uint32
152	numStates    uint32
153	maxNumStates uint32
154}
155
156func withMaxStates(maxStates uint32) *Utf8DFABuilder {
157	rv := &Utf8DFABuilder{
158		index:        make([]uint32, maxStates*2+100),
159		distances:    make([]Distance, 0, maxStates),
160		transitions:  make([][256]uint32, 0, maxStates),
161		maxNumStates: maxStates,
162	}
163
164	for i := range rv.index {
165		rv.index[i] = math.MaxUint32
166	}
167
168	return rv
169}
170
171func (dfab *Utf8DFABuilder) allocate() uint32 {
172	newState := dfab.numStates
173	dfab.numStates++
174
175	dfab.distances = append(dfab.distances, Atleast{d: 255})
176	dfab.transitions = append(dfab.transitions, [256]uint32{})
177
178	return newState
179}
180
181func (dfab *Utf8DFABuilder) getOrAllocate(state Utf8StateId) uint32 {
182	if int(state) >= cap(dfab.index) {
183		cloneIndex := make([]uint32, int(state)*2)
184		copy(cloneIndex, dfab.index)
185		dfab.index = cloneIndex
186	}
187	if dfab.index[state] != math.MaxUint32 {
188		return dfab.index[state]
189	}
190
191	nstate := dfab.allocate()
192	dfab.index[state] = nstate
193
194	return nstate
195}
196
197func (dfab *Utf8DFABuilder) setInitialState(iState uint32) {
198	decodedID := dfab.getOrAllocate(original(iState))
199	dfab.initialState = decodedID
200}
201
202func (dfab *Utf8DFABuilder) build(ed uint8) *DFA {
203	return &DFA{
204		transitions: dfab.transitions,
205		distances:   dfab.distances,
206		initState:   int(dfab.initialState),
207		ed:          ed,
208	}
209}
210
211func (dfab *Utf8DFABuilder) addState(state, default_suc_orig uint32,
212	distance Distance) (*Utf8DFAStateBuilder, error) {
213	if state > dfab.maxNumStates {
214		return nil, fmt.Errorf("State id is larger than maxNumStates")
215	}
216
217	stateID := dfab.getOrAllocate(original(state))
218	dfab.distances[stateID] = distance
219
220	defaultSuccID := dfab.getOrAllocate(original(default_suc_orig))
221	// creates a chain of states of predecessors of `default_suc_orig`.
222	// Accepting k-bytes (whatever the bytes are) from `predecessor_states[k-1]`
223	// leads to the `default_suc_orig` state.
224	predecessorStates := []uint32{defaultSuccID,
225		defaultSuccID,
226		defaultSuccID,
227		defaultSuccID}
228
229	for numBytes := uint8(1); numBytes < 4; numBytes++ {
230		predecessorState := predecessor(default_suc_orig, numBytes)
231		predecessorStateID := dfab.getOrAllocate(predecessorState)
232		predecessorStates[numBytes] = predecessorStateID
233		succ := predecessorStates[numBytes-1]
234		fillTransitions(&dfab.transitions[predecessorStateID], succ)
235	}
236
237	// 1-byte encoded chars.
238	fill(dfab.transitions[stateID][0:192], predecessorStates[0])
239	// 2-bytes encoded chars.
240	fill(dfab.transitions[stateID][192:224], predecessorStates[1])
241	// 3-bytes encoded chars.
242	fill(dfab.transitions[stateID][224:240], predecessorStates[2])
243	// 4-bytes encoded chars.
244	fill(dfab.transitions[stateID][240:256], predecessorStates[3])
245
246	return &Utf8DFAStateBuilder{
247		dfaBuilder:       dfab,
248		stateID:          stateID,
249		defaultSuccessor: predecessorStates}, nil
250}
251