1package ldap
2
3import (
4	"bytes"
5	hexpac "encoding/hex"
6	"errors"
7	"fmt"
8	"strings"
9	"unicode/utf8"
10
11	"gopkg.in/asn1-ber.v1"
12)
13
14// Filter choices
15const (
16	FilterAnd             = 0
17	FilterOr              = 1
18	FilterNot             = 2
19	FilterEqualityMatch   = 3
20	FilterSubstrings      = 4
21	FilterGreaterOrEqual  = 5
22	FilterLessOrEqual     = 6
23	FilterPresent         = 7
24	FilterApproxMatch     = 8
25	FilterExtensibleMatch = 9
26)
27
28// FilterMap contains human readable descriptions of Filter choices
29var FilterMap = map[uint64]string{
30	FilterAnd:             "And",
31	FilterOr:              "Or",
32	FilterNot:             "Not",
33	FilterEqualityMatch:   "Equality Match",
34	FilterSubstrings:      "Substrings",
35	FilterGreaterOrEqual:  "Greater Or Equal",
36	FilterLessOrEqual:     "Less Or Equal",
37	FilterPresent:         "Present",
38	FilterApproxMatch:     "Approx Match",
39	FilterExtensibleMatch: "Extensible Match",
40}
41
42// SubstringFilter options
43const (
44	FilterSubstringsInitial = 0
45	FilterSubstringsAny     = 1
46	FilterSubstringsFinal   = 2
47)
48
49// FilterSubstringsMap contains human readable descriptions of SubstringFilter choices
50var FilterSubstringsMap = map[uint64]string{
51	FilterSubstringsInitial: "Substrings Initial",
52	FilterSubstringsAny:     "Substrings Any",
53	FilterSubstringsFinal:   "Substrings Final",
54}
55
56// MatchingRuleAssertion choices
57const (
58	MatchingRuleAssertionMatchingRule = 1
59	MatchingRuleAssertionType         = 2
60	MatchingRuleAssertionMatchValue   = 3
61	MatchingRuleAssertionDNAttributes = 4
62)
63
64// MatchingRuleAssertionMap contains human readable descriptions of MatchingRuleAssertion choices
65var MatchingRuleAssertionMap = map[uint64]string{
66	MatchingRuleAssertionMatchingRule: "Matching Rule Assertion Matching Rule",
67	MatchingRuleAssertionType:         "Matching Rule Assertion Type",
68	MatchingRuleAssertionMatchValue:   "Matching Rule Assertion Match Value",
69	MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
70}
71
72// CompileFilter converts a string representation of a filter into a BER-encoded packet
73func CompileFilter(filter string) (*ber.Packet, error) {
74	if len(filter) == 0 || filter[0] != '(' {
75		return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
76	}
77	packet, pos, err := compileFilter(filter, 1)
78	if err != nil {
79		return nil, err
80	}
81	switch {
82	case pos > len(filter):
83		return nil, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
84	case pos < len(filter):
85		return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
86	}
87	return packet, nil
88}
89
90// DecompileFilter converts a packet representation of a filter into a string representation
91func DecompileFilter(packet *ber.Packet) (ret string, err error) {
92	defer func() {
93		if r := recover(); r != nil {
94			err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
95		}
96	}()
97	ret = "("
98	err = nil
99	childStr := ""
100
101	switch packet.Tag {
102	case FilterAnd:
103		ret += "&"
104		for _, child := range packet.Children {
105			childStr, err = DecompileFilter(child)
106			if err != nil {
107				return
108			}
109			ret += childStr
110		}
111	case FilterOr:
112		ret += "|"
113		for _, child := range packet.Children {
114			childStr, err = DecompileFilter(child)
115			if err != nil {
116				return
117			}
118			ret += childStr
119		}
120	case FilterNot:
121		ret += "!"
122		childStr, err = DecompileFilter(packet.Children[0])
123		if err != nil {
124			return
125		}
126		ret += childStr
127
128	case FilterSubstrings:
129		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
130		ret += "="
131		for i, child := range packet.Children[1].Children {
132			if i == 0 && child.Tag != FilterSubstringsInitial {
133				ret += "*"
134			}
135			ret += EscapeFilter(ber.DecodeString(child.Data.Bytes()))
136			if child.Tag != FilterSubstringsFinal {
137				ret += "*"
138			}
139		}
140	case FilterEqualityMatch:
141		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
142		ret += "="
143		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
144	case FilterGreaterOrEqual:
145		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
146		ret += ">="
147		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
148	case FilterLessOrEqual:
149		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
150		ret += "<="
151		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
152	case FilterPresent:
153		ret += ber.DecodeString(packet.Data.Bytes())
154		ret += "=*"
155	case FilterApproxMatch:
156		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
157		ret += "~="
158		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
159	case FilterExtensibleMatch:
160		attr := ""
161		dnAttributes := false
162		matchingRule := ""
163		value := ""
164
165		for _, child := range packet.Children {
166			switch child.Tag {
167			case MatchingRuleAssertionMatchingRule:
168				matchingRule = ber.DecodeString(child.Data.Bytes())
169			case MatchingRuleAssertionType:
170				attr = ber.DecodeString(child.Data.Bytes())
171			case MatchingRuleAssertionMatchValue:
172				value = ber.DecodeString(child.Data.Bytes())
173			case MatchingRuleAssertionDNAttributes:
174				dnAttributes = child.Value.(bool)
175			}
176		}
177
178		if len(attr) > 0 {
179			ret += attr
180		}
181		if dnAttributes {
182			ret += ":dn"
183		}
184		if len(matchingRule) > 0 {
185			ret += ":"
186			ret += matchingRule
187		}
188		ret += ":="
189		ret += EscapeFilter(value)
190	}
191
192	ret += ")"
193	return
194}
195
196func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
197	for pos < len(filter) && filter[pos] == '(' {
198		child, newPos, err := compileFilter(filter, pos+1)
199		if err != nil {
200			return pos, err
201		}
202		pos = newPos
203		parent.AppendChild(child)
204	}
205	if pos == len(filter) {
206		return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
207	}
208
209	return pos + 1, nil
210}
211
212func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
213	var (
214		packet *ber.Packet
215		err    error
216	)
217
218	defer func() {
219		if r := recover(); r != nil {
220			err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter"))
221		}
222	}()
223	newPos := pos
224
225	currentRune, currentWidth := utf8.DecodeRuneInString(filter[newPos:])
226
227	switch currentRune {
228	case utf8.RuneError:
229		return nil, 0, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos))
230	case '(':
231		packet, newPos, err = compileFilter(filter, pos+currentWidth)
232		newPos++
233		return packet, newPos, err
234	case '&':
235		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
236		newPos, err = compileFilterSet(filter, pos+currentWidth, packet)
237		return packet, newPos, err
238	case '|':
239		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
240		newPos, err = compileFilterSet(filter, pos+currentWidth, packet)
241		return packet, newPos, err
242	case '!':
243		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
244		var child *ber.Packet
245		child, newPos, err = compileFilter(filter, pos+currentWidth)
246		packet.AppendChild(child)
247		return packet, newPos, err
248	default:
249		const (
250			stateReadingAttr                   = 0
251			stateReadingExtensibleMatchingRule = 1
252			stateReadingCondition              = 2
253		)
254
255		state := stateReadingAttr
256
257		attribute := ""
258		extensibleDNAttributes := false
259		extensibleMatchingRule := ""
260		condition := ""
261
262		for newPos < len(filter) {
263			remainingFilter := filter[newPos:]
264			currentRune, currentWidth = utf8.DecodeRuneInString(remainingFilter)
265			if currentRune == ')' {
266				break
267			}
268			if currentRune == utf8.RuneError {
269				return packet, newPos, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos))
270			}
271
272			switch state {
273			case stateReadingAttr:
274				switch {
275				// Extensible rule, with only DN-matching
276				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:="):
277					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
278					extensibleDNAttributes = true
279					state = stateReadingCondition
280					newPos += 5
281
282				// Extensible rule, with DN-matching and a matching OID
283				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:"):
284					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
285					extensibleDNAttributes = true
286					state = stateReadingExtensibleMatchingRule
287					newPos += 4
288
289				// Extensible rule, with attr only
290				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
291					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
292					state = stateReadingCondition
293					newPos += 2
294
295				// Extensible rule, with no DN attribute matching
296				case currentRune == ':':
297					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
298					state = stateReadingExtensibleMatchingRule
299					newPos++
300
301				// Equality condition
302				case currentRune == '=':
303					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
304					state = stateReadingCondition
305					newPos++
306
307				// Greater-than or equal
308				case currentRune == '>' && strings.HasPrefix(remainingFilter, ">="):
309					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
310					state = stateReadingCondition
311					newPos += 2
312
313				// Less-than or equal
314				case currentRune == '<' && strings.HasPrefix(remainingFilter, "<="):
315					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
316					state = stateReadingCondition
317					newPos += 2
318
319				// Approx
320				case currentRune == '~' && strings.HasPrefix(remainingFilter, "~="):
321					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterApproxMatch])
322					state = stateReadingCondition
323					newPos += 2
324
325				// Still reading the attribute name
326				default:
327					attribute += fmt.Sprintf("%c", currentRune)
328					newPos += currentWidth
329				}
330
331			case stateReadingExtensibleMatchingRule:
332				switch {
333
334				// Matching rule OID is done
335				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
336					state = stateReadingCondition
337					newPos += 2
338
339				// Still reading the matching rule oid
340				default:
341					extensibleMatchingRule += fmt.Sprintf("%c", currentRune)
342					newPos += currentWidth
343				}
344
345			case stateReadingCondition:
346				// append to the condition
347				condition += fmt.Sprintf("%c", currentRune)
348				newPos += currentWidth
349			}
350		}
351
352		if newPos == len(filter) {
353			err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
354			return packet, newPos, err
355		}
356		if packet == nil {
357			err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
358			return packet, newPos, err
359		}
360
361		switch {
362		case packet.Tag == FilterExtensibleMatch:
363			// MatchingRuleAssertion ::= SEQUENCE {
364			//         matchingRule    [1] MatchingRuleID OPTIONAL,
365			//         type            [2] AttributeDescription OPTIONAL,
366			//         matchValue      [3] AssertionValue,
367			//         dnAttributes    [4] BOOLEAN DEFAULT FALSE
368			// }
369
370			// Include the matching rule oid, if specified
371			if len(extensibleMatchingRule) > 0 {
372				packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
373			}
374
375			// Include the attribute, if specified
376			if len(attribute) > 0 {
377				packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType]))
378			}
379
380			// Add the value (only required child)
381			encodedString, encodeErr := escapedStringToEncodedBytes(condition)
382			if encodeErr != nil {
383				return packet, newPos, encodeErr
384			}
385			packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchValue, encodedString, MatchingRuleAssertionMap[MatchingRuleAssertionMatchValue]))
386
387			// Defaults to false, so only include in the sequence if true
388			if extensibleDNAttributes {
389				packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes]))
390			}
391
392		case packet.Tag == FilterEqualityMatch && condition == "*":
393			packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
394		case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
395			packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
396			packet.Tag = FilterSubstrings
397			packet.Description = FilterMap[uint64(packet.Tag)]
398			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
399			parts := strings.Split(condition, "*")
400			for i, part := range parts {
401				if part == "" {
402					continue
403				}
404				var tag ber.Tag
405				switch i {
406				case 0:
407					tag = FilterSubstringsInitial
408				case len(parts) - 1:
409					tag = FilterSubstringsFinal
410				default:
411					tag = FilterSubstringsAny
412				}
413				encodedString, encodeErr := escapedStringToEncodedBytes(part)
414				if encodeErr != nil {
415					return packet, newPos, encodeErr
416				}
417				seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)]))
418			}
419			packet.AppendChild(seq)
420		default:
421			encodedString, encodeErr := escapedStringToEncodedBytes(condition)
422			if encodeErr != nil {
423				return packet, newPos, encodeErr
424			}
425			packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
426			packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
427		}
428
429		newPos += currentWidth
430		return packet, newPos, err
431	}
432}
433
434// Convert from "ABC\xx\xx\xx" form to literal bytes for transport
435func escapedStringToEncodedBytes(escapedString string) (string, error) {
436	var buffer bytes.Buffer
437	i := 0
438	for i < len(escapedString) {
439		currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:])
440		if currentRune == utf8.RuneError {
441			return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i))
442		}
443
444		// Check for escaped hex characters and convert them to their literal value for transport.
445		if currentRune == '\\' {
446			// http://tools.ietf.org/search/rfc4515
447			// \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not
448			// being a member of UTF1SUBSET.
449			if i+2 > len(escapedString) {
450				return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
451			}
452			escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
453			if decodeErr != nil {
454				return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))
455			}
456			buffer.WriteByte(escByte[0])
457			i += 2 // +1 from end of loop, so 3 total for \xx.
458		} else {
459			buffer.WriteRune(currentRune)
460		}
461
462		i += currentWidth
463	}
464	return buffer.String(), nil
465}
466