1// Copyright 2014 Richard Lehane. All rights reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package patterns describes the Pattern interface.
16// Standard patterns are also defined in this package: Sequence (as well as BMH and reverse BMH Sequence), Choice, List and Not.
17package patterns
18
19import (
20	"bytes"
21	"encoding/hex"
22	"errors"
23	"fmt"
24	"strconv"
25	"unicode/utf8"
26
27	"github.com/richardlehane/siegfried/internal/persist"
28)
29
30func init() {
31	Register(sequenceLoader, loadSequence)
32	Register(choiceLoader, loadChoice)
33	Register(listLoader, loadList)
34	Register(notLoader, loadNot)
35	Register(bmhLoader, loadBMH)
36	Register(rbmhLoader, loadRBMH)
37	Register(maskLoader, loadMask)
38	Register(anyMaskLoader, loadAnyMask)
39}
40
41// Stringify returns a string version of a byte slice.
42// If all bytes are UTF8, an ASCII string is returned
43// Otherwise a hex string is returned.
44func Stringify(b []byte) string {
45	if utf8.Valid(b) {
46		return strconv.QuoteToASCII(string(b))
47	}
48	return hex.EncodeToString(b)
49}
50
51// Patterns are the smallest building blocks of a format signature.
52// Exact byte sequence matches are a type of pattern, as are byte ranges, non-sequence matches etc.
53// You can define custom patterns (e.g. for W3C date type) by implementing this interface.
54type Pattern interface {
55	Test([]byte) ([]int, int)  // For a positive match, returns slice of lengths of the match and bytes to advance for a subsequent test. For a negative match, returns nil or empty slice and the bytes to advance for subsequent test (or 0 if the length of the pattern is longer than the length of the slice).
56	TestR([]byte) ([]int, int) // Same as Test but for testing in reverse (from the right-most position of the byte slice).
57	Equals(Pattern) bool       // Test equality with another pattern
58	Length() (int, int)        // Minimum and maximum lengths of the pattern
59	NumSequences() int         // Number of simple sequences represented by a pattern. Return 0 if the pattern cannot be represented by a defined number of simple sequence (e.g. for an indirect offset pattern) or, if in your opinion, the number of sequences is unreasonably large.
60	Sequences() []Sequence     // Convert the pattern to a slice of sequences. Return an empty slice if the pattern cannot be represented by a defined number of simple sequences.
61	String() string
62	Save(*persist.LoadSaver) // encode the pattern into bytes for saving in a persist file
63}
64
65// Loader loads a Pattern.
66type Loader func(*persist.LoadSaver) Pattern
67
68const (
69	sequenceLoader byte = iota
70	choiceLoader
71	listLoader
72	notLoader
73	bmhLoader
74	rbmhLoader
75	maskLoader
76	anyMaskLoader
77)
78
79var loaders = [32]Loader{}
80
81// Register a new Loader (provide an id higher than 16).
82func Register(id byte, l Loader) {
83	loaders[int(id)] = l
84}
85
86// Load loads the Pattern, choosing the correct Loader by the leading id byte.
87func Load(ls *persist.LoadSaver) Pattern {
88	id := ls.LoadByte()
89	l := loaders[int(id)]
90	if l == nil {
91		if ls.Err == nil {
92			ls.Err = errors.New("bad pattern loader")
93		}
94		return nil
95	}
96	return l(ls)
97}
98
99// Index reports the offset of one pattern within another (or -1 if not contained)
100func Index(a, b Pattern) int {
101	if a.Equals(b) {
102		return 0
103	}
104	seq1, ok := a.(Sequence)
105	seq2, ok2 := b.(Sequence)
106	if ok && ok2 {
107		return bytes.Index(seq1, seq2)
108	}
109	return -1
110}
111
112// Sequence is a matching sequence of bytes.
113type Sequence []byte
114
115// Test bytes against the pattern.
116func (s Sequence) Test(b []byte) ([]int, int) {
117	if len(b) < len(s) {
118		return nil, 0
119	}
120	if bytes.Equal(s, b[:len(s)]) {
121		return []int{len(s)}, 1
122	}
123	return nil, 1
124}
125
126// Test bytes against the pattern in reverse.
127func (s Sequence) TestR(b []byte) ([]int, int) {
128	if len(b) < len(s) {
129		return nil, 0
130	}
131	if bytes.Equal(s, b[len(b)-len(s):]) {
132		return []int{len(s)}, 1
133	}
134	return nil, 1
135}
136
137// Equals reports whether a pattern is identical to another pattern.
138func (s Sequence) Equals(pat Pattern) bool {
139	seq2, ok := pat.(Sequence)
140	if ok {
141		return bytes.Equal(s, seq2)
142	}
143	return false
144}
145
146// Length returns a minimum and maximum length for the pattern.
147func (s Sequence) Length() (int, int) {
148	return len(s), len(s)
149}
150
151// NumSequences reports how many plain sequences are needed to represent this pattern.
152func (s Sequence) NumSequences() int {
153	return 1
154}
155
156// Sequences converts the pattern into a slice of plain sequences.
157func (s Sequence) Sequences() []Sequence {
158	return []Sequence{s}
159}
160
161func (s Sequence) String() string {
162	return "seq " + Stringify(s)
163}
164
165// The Reverse method is unique to this pattern. It is used for the EOF byte sequence set
166func (s Sequence) Reverse() Sequence {
167	p := make(Sequence, len(s))
168	for i, j := 0, len(s)-1; j > -1; i, j = i+1, j-1 {
169		p[i] = s[j]
170	}
171	return p
172}
173
174// Save persists the pattern.
175func (s Sequence) Save(ls *persist.LoadSaver) {
176	ls.SaveByte(sequenceLoader)
177	ls.SaveBytes(s)
178}
179
180func loadSequence(ls *persist.LoadSaver) Pattern {
181	return Sequence(ls.LoadBytes())
182}
183
184// Choice is a slice of patterns, any of which can test successfully for the pattern to succeed. For advance, returns shortest
185type Choice []Pattern
186
187func (c Choice) test(b []byte, f func(Pattern, []byte) ([]int, int)) ([]int, int) {
188	var r, res []int
189	var tl, fl, adv int // trueLen and falseLen
190	for _, pat := range c {
191		res, adv = f(pat, b)
192		if len(res) > 0 {
193			r = append(r, res...)
194			if tl == 0 || (adv > 0 && adv < tl) {
195				tl = adv
196			}
197		} else if fl == 0 || (adv > 0 && adv < fl) {
198			fl = adv
199		}
200	}
201	if len(r) > 0 {
202		return r, tl
203	}
204	return nil, fl
205}
206
207// Test bytes against the pattern.
208func (c Choice) Test(b []byte) ([]int, int) {
209	return c.test(b, Pattern.Test)
210}
211
212// Test bytes against the pattern in reverse.
213func (c Choice) TestR(b []byte) ([]int, int) {
214	return c.test(b, Pattern.TestR)
215}
216
217// Equals reports whether a pattern is identical to another pattern.
218func (c Choice) Equals(pat Pattern) bool {
219	c2, ok := pat.(Choice)
220	if ok {
221		if len(c) == len(c2) {
222			for _, p := range c {
223				ident := false
224				for _, p2 := range c2 {
225					if p.Equals(p2) {
226						ident = true
227					}
228				}
229				if !ident {
230					return false
231				}
232			}
233			return true
234		}
235	}
236	return false
237}
238
239// Length returns a minimum and maximum length for the pattern.
240func (c Choice) Length() (int, int) {
241	var min, max int
242	if len(c) > 0 {
243		min, max = c[0].Length()
244	}
245	for _, pat := range c {
246		min2, max2 := pat.Length()
247		if min2 < min {
248			min = min2
249		}
250		if max2 > max {
251			max = max2
252		}
253	}
254	return min, max
255}
256
257// NumSequences reports how many plain sequences are needed to represent this pattern.
258func (c Choice) NumSequences() int {
259	var s int
260	for _, pat := range c {
261		num := pat.NumSequences()
262		if num == 0 { // if any of the patterns can't be converted to sequences, don't return any
263			return 0
264		}
265		s += num
266	}
267	return s
268}
269
270// Sequences converts the pattern into a slice of plain sequences.
271func (c Choice) Sequences() []Sequence {
272	num := c.NumSequences()
273	seqs := make([]Sequence, 0, num)
274	for _, pat := range c {
275		seqs = append(seqs, pat.Sequences()...)
276	}
277	return seqs
278}
279
280func (c Choice) String() string {
281	s := "c["
282	for i, pat := range c {
283		s += pat.String()
284		if i < len(c)-1 {
285			s += ","
286		}
287	}
288	return s + "]"
289}
290
291// Save persists the pattern.
292func (c Choice) Save(ls *persist.LoadSaver) {
293	ls.SaveByte(choiceLoader)
294	ls.SaveSmallInt(len(c))
295	for _, pat := range c {
296		pat.Save(ls)
297	}
298}
299
300func loadChoice(ls *persist.LoadSaver) Pattern {
301	l := ls.LoadSmallInt()
302	choices := make(Choice, l)
303	for i := range choices {
304		choices[i] = Load(ls)
305	}
306	return choices
307}
308
309// List is a slice of patterns, all of which must test true sequentially in order for the pattern to succeed.
310type List []Pattern
311
312// Test bytes against the pattern.
313func (l List) Test(b []byte) ([]int, int) {
314	if len(l) < 1 {
315		return nil, 0
316	}
317	totals := []int{0}
318	for _, pat := range l {
319		nts := make([]int, 0, len(totals))
320		for _, t := range totals {
321			les, _ := pat.Test(b[t:])
322			for _, le := range les {
323				nts = append(nts, t+le)
324			}
325		}
326		if len(nts) < 1 {
327			return nil, 1
328		}
329		totals = nts
330	}
331	return totals, 1
332}
333
334// Test bytes against the pattern in reverse.
335func (l List) TestR(b []byte) ([]int, int) {
336	if len(l) < 1 {
337		return nil, 0
338	}
339	totals := []int{0}
340	for i := len(l) - 1; i >= 0; i-- {
341		nts := make([]int, 0, len(totals))
342		for _, t := range totals {
343			les, _ := l[i].TestR(b[:len(b)-t])
344			for _, le := range les {
345				nts = append(nts, t+le)
346			}
347		}
348		if len(nts) < 1 {
349			return nil, 1
350		}
351		totals = nts
352	}
353	return totals, 1
354}
355
356// Equals reports whether a pattern is identical to another pattern.
357func (l List) Equals(pat Pattern) bool {
358	l2, ok := pat.(List)
359	if ok {
360		if len(l) == len(l2) {
361			for i, p := range l {
362				if !p.Equals(l2[i]) {
363					return false
364				}
365			}
366		}
367	}
368	return true
369}
370
371// Length returns a minimum and maximum length for the pattern.
372func (l List) Length() (int, int) {
373	var min, max int
374	for _, pat := range l {
375		pmin, pmax := pat.Length()
376		min += pmin
377		max += pmax
378	}
379	return min, max
380}
381
382// NumSequences reports how many plain sequences are needed to represent this pattern.
383func (l List) NumSequences() int {
384	s := 1
385	for _, pat := range l {
386		num := pat.NumSequences()
387		if num == 0 { // if any of the patterns can't be converted to sequences, don't return any
388			return 0
389		}
390		s *= num
391	}
392	return s
393}
394
395// Sequences converts the pattern into a slice of plain sequences.
396func (l List) Sequences() []Sequence {
397	total := l.NumSequences()
398	seqs := make([]Sequence, total)
399	for _, pat := range l {
400		num := pat.NumSequences()
401		times := total / num
402		idx := 0
403		for _, seq := range pat.Sequences() {
404			for i := 0; i < times; i++ {
405				seqs[idx] = append(seqs[idx], seq...)
406				idx++
407			}
408		}
409	}
410	return seqs
411}
412
413func (l List) String() string {
414	s := "l["
415	for i, pat := range l {
416		s += pat.String()
417		if i < len(l)-1 {
418			s += ","
419		}
420	}
421	return s + "]"
422}
423
424// Save persists the pattern.
425func (l List) Save(ls *persist.LoadSaver) {
426	ls.SaveByte(listLoader)
427	ls.SaveSmallInt(len(l))
428	for _, pat := range l {
429		pat.Save(ls)
430	}
431}
432
433func loadList(ls *persist.LoadSaver) Pattern {
434	le := ls.LoadSmallInt()
435	list := make(List, le)
436	for i := range list {
437		list[i] = Load(ls)
438	}
439	return list
440}
441
442// Not contains a pattern and reports the opposite of that pattern's result when testing.
443type Not struct{ Pattern }
444
445// Test bytes against the pattern.
446func (n Not) Test(b []byte) ([]int, int) {
447	min, _ := n.Pattern.Length()
448	if len(b) < min {
449		return nil, 0
450	}
451	ok, _ := n.Pattern.Test(b)
452	if len(ok) < 1 {
453		return []int{min}, 1
454	}
455	return nil, 1
456}
457
458// Test bytes against the pattern in reverse.
459func (n Not) TestR(b []byte) ([]int, int) {
460	min, _ := n.Pattern.Length()
461	if len(b) < min {
462		return nil, 0
463	}
464	ok, _ := n.Pattern.TestR(b)
465	if len(ok) < 1 {
466		return []int{min}, 1
467	}
468	return nil, 1
469}
470
471// Equals reports whether a pattern is identical to another pattern.
472func (n Not) Equals(pat Pattern) bool {
473	n2, ok := pat.(Not)
474	if ok {
475		return n.Pattern.Equals(n2.Pattern)
476	}
477	return false
478}
479
480// Length returns a minimum and maximum length for the pattern.
481func (n Not) Length() (int, int) {
482	min, _ := n.Pattern.Length()
483	return min, min
484}
485
486// NumSequences reports how many plain sequences are needed to represent this pattern.
487func (n Not) NumSequences() int {
488	_, max := n.Pattern.Length()
489	if max > 1 {
490		return 0
491	}
492	num := n.Pattern.NumSequences()
493	if num == 0 {
494		return 0
495	}
496	return 256 - num
497}
498
499// Sequences converts the pattern into a slice of plain sequences.
500func (n Not) Sequences() []Sequence {
501	num := n.NumSequences()
502	if num < 1 {
503		return nil
504	}
505	seqs := make([]Sequence, 0, num)
506	pseqs := n.Pattern.Sequences()
507	allBytes := make([]Sequence, 256)
508	for i := 0; i < 256; i++ {
509		allBytes[i] = Sequence{byte(i)}
510	}
511	for _, v := range allBytes {
512		eq := false
513		for _, w := range pseqs {
514			if v.Equals(w) {
515				eq = true
516				break
517			}
518		}
519		if eq {
520			continue
521		}
522		seqs = append(seqs, v)
523	}
524	return seqs
525}
526
527func (n Not) String() string {
528	return "not[" + n.Pattern.String() + "]"
529}
530
531// Save persists the pattern.
532func (n Not) Save(ls *persist.LoadSaver) {
533	ls.SaveByte(notLoader)
534	n.Pattern.Save(ls)
535}
536
537func loadNot(ls *persist.LoadSaver) Pattern {
538	return Not{Load(ls)}
539}
540
541type Mask byte
542
543func (m Mask) Test(b []byte) ([]int, int) {
544	if len(b) == 0 {
545		return nil, 0
546	}
547	if byte(m)&b[0] == byte(m) {
548		return []int{1}, 1
549	}
550	return nil, 1
551}
552
553func (m Mask) TestR(b []byte) ([]int, int) {
554	if len(b) == 0 {
555		return nil, 0
556	}
557	if byte(m)&b[len(b)-1] == byte(m) {
558		return []int{1}, 1
559	}
560	return nil, 1
561}
562
563func (m Mask) Equals(pat Pattern) bool {
564	msk, ok := pat.(Mask)
565	if ok {
566		if m == msk {
567			return true
568		}
569	}
570	return false
571}
572
573func (m Mask) Length() (int, int) {
574	return 1, 1
575}
576
577func countBits(b byte) int {
578	var count uint
579	for b > 0 {
580		b &= b - 1
581		count++
582	}
583	return 256 / (1 << count)
584}
585
586func allBytes() []byte {
587	all := make([]byte, 256)
588	for i := range all {
589		all[i] = byte(i)
590	}
591	return all
592}
593
594func (m Mask) NumSequences() int {
595	return countBits(byte(m))
596}
597
598func (m Mask) Sequences() []Sequence {
599	seqs := make([]Sequence, 0, m.NumSequences())
600	for _, b := range allBytes() {
601		if byte(m)&b == byte(m) {
602			seqs = append(seqs, Sequence{b})
603		}
604	}
605	return seqs
606}
607
608func (m Mask) String() string {
609	return fmt.Sprintf("m %#x", byte(m))
610}
611
612func (m Mask) Save(ls *persist.LoadSaver) {
613	ls.SaveByte(maskLoader)
614	ls.SaveByte(byte(m))
615}
616
617func loadMask(ls *persist.LoadSaver) Pattern {
618	return Mask(ls.LoadByte())
619}
620
621type AnyMask byte
622
623func (am AnyMask) Test(b []byte) ([]int, int) {
624	if len(b) == 0 {
625		return nil, 0
626	}
627	if byte(am)&b[0] != 0 {
628		return []int{1}, 1
629	}
630	return nil, 1
631}
632
633func (am AnyMask) TestR(b []byte) ([]int, int) {
634	if len(b) == 0 {
635		return nil, 0
636	}
637	if byte(am)&b[len(b)-1] != 0 {
638		return []int{1}, 1
639	}
640	return nil, 1
641}
642
643func (am AnyMask) Equals(pat Pattern) bool {
644	amsk, ok := pat.(AnyMask)
645	if ok {
646		if am == amsk {
647			return true
648		}
649	}
650	return false
651}
652
653func (am AnyMask) Length() (int, int) {
654	return 1, 1
655}
656
657func (am AnyMask) NumSequences() int {
658	return 256 - countBits(byte(am))
659}
660
661func (am AnyMask) Sequences() []Sequence {
662	seqs := make([]Sequence, 0, am.NumSequences())
663	for _, b := range allBytes() {
664		if byte(am)&b != 0 {
665			seqs = append(seqs, Sequence{b})
666		}
667	}
668	return seqs
669}
670
671func (am AnyMask) String() string {
672	return fmt.Sprintf("am %#x", byte(am))
673}
674
675func (am AnyMask) Save(ls *persist.LoadSaver) {
676	ls.SaveByte(anyMaskLoader)
677	ls.SaveByte(byte(am))
678}
679
680func loadAnyMask(ls *persist.LoadSaver) Pattern {
681	return AnyMask(ls.LoadByte())
682}
683