1package ber
2
3import (
4	"bytes"
5	"errors"
6	"fmt"
7	"io"
8	"math"
9	"os"
10	"reflect"
11	"time"
12	"unicode/utf8"
13)
14
15// MaxPacketLengthBytes specifies the maximum allowed packet size when calling ReadPacket or DecodePacket. Set to 0 for
16// no limit.
17var MaxPacketLengthBytes int64 = math.MaxInt32
18
19type Packet struct {
20	Identifier
21	Value       interface{}
22	ByteValue   []byte
23	Data        *bytes.Buffer
24	Children    []*Packet
25	Description string
26}
27
28type Identifier struct {
29	ClassType Class
30	TagType   Type
31	Tag       Tag
32}
33
34type Tag uint64
35
36const (
37	TagEOC              Tag = 0x00
38	TagBoolean          Tag = 0x01
39	TagInteger          Tag = 0x02
40	TagBitString        Tag = 0x03
41	TagOctetString      Tag = 0x04
42	TagNULL             Tag = 0x05
43	TagObjectIdentifier Tag = 0x06
44	TagObjectDescriptor Tag = 0x07
45	TagExternal         Tag = 0x08
46	TagRealFloat        Tag = 0x09
47	TagEnumerated       Tag = 0x0a
48	TagEmbeddedPDV      Tag = 0x0b
49	TagUTF8String       Tag = 0x0c
50	TagRelativeOID      Tag = 0x0d
51	TagSequence         Tag = 0x10
52	TagSet              Tag = 0x11
53	TagNumericString    Tag = 0x12
54	TagPrintableString  Tag = 0x13
55	TagT61String        Tag = 0x14
56	TagVideotexString   Tag = 0x15
57	TagIA5String        Tag = 0x16
58	TagUTCTime          Tag = 0x17
59	TagGeneralizedTime  Tag = 0x18
60	TagGraphicString    Tag = 0x19
61	TagVisibleString    Tag = 0x1a
62	TagGeneralString    Tag = 0x1b
63	TagUniversalString  Tag = 0x1c
64	TagCharacterString  Tag = 0x1d
65	TagBMPString        Tag = 0x1e
66	TagBitmask          Tag = 0x1f // xxx11111b
67
68	// HighTag indicates the start of a high-tag byte sequence
69	HighTag Tag = 0x1f // xxx11111b
70	// HighTagContinueBitmask indicates the high-tag byte sequence should continue
71	HighTagContinueBitmask Tag = 0x80 // 10000000b
72	// HighTagValueBitmask obtains the tag value from a high-tag byte sequence byte
73	HighTagValueBitmask Tag = 0x7f // 01111111b
74)
75
76const (
77	// LengthLongFormBitmask is the mask to apply to the length byte to see if a long-form byte sequence is used
78	LengthLongFormBitmask = 0x80
79	// LengthValueBitmask is the mask to apply to the length byte to get the number of bytes in the long-form byte sequence
80	LengthValueBitmask = 0x7f
81
82	// LengthIndefinite is returned from readLength to indicate an indefinite length
83	LengthIndefinite = -1
84)
85
86var tagMap = map[Tag]string{
87	TagEOC:              "EOC (End-of-Content)",
88	TagBoolean:          "Boolean",
89	TagInteger:          "Integer",
90	TagBitString:        "Bit String",
91	TagOctetString:      "Octet String",
92	TagNULL:             "NULL",
93	TagObjectIdentifier: "Object Identifier",
94	TagObjectDescriptor: "Object Descriptor",
95	TagExternal:         "External",
96	TagRealFloat:        "Real (float)",
97	TagEnumerated:       "Enumerated",
98	TagEmbeddedPDV:      "Embedded PDV",
99	TagUTF8String:       "UTF8 String",
100	TagRelativeOID:      "Relative-OID",
101	TagSequence:         "Sequence and Sequence of",
102	TagSet:              "Set and Set OF",
103	TagNumericString:    "Numeric String",
104	TagPrintableString:  "Printable String",
105	TagT61String:        "T61 String",
106	TagVideotexString:   "Videotex String",
107	TagIA5String:        "IA5 String",
108	TagUTCTime:          "UTC Time",
109	TagGeneralizedTime:  "Generalized Time",
110	TagGraphicString:    "Graphic String",
111	TagVisibleString:    "Visible String",
112	TagGeneralString:    "General String",
113	TagUniversalString:  "Universal String",
114	TagCharacterString:  "Character String",
115	TagBMPString:        "BMP String",
116}
117
118type Class uint8
119
120const (
121	ClassUniversal   Class = 0   // 00xxxxxxb
122	ClassApplication Class = 64  // 01xxxxxxb
123	ClassContext     Class = 128 // 10xxxxxxb
124	ClassPrivate     Class = 192 // 11xxxxxxb
125	ClassBitmask     Class = 192 // 11xxxxxxb
126)
127
128var ClassMap = map[Class]string{
129	ClassUniversal:   "Universal",
130	ClassApplication: "Application",
131	ClassContext:     "Context",
132	ClassPrivate:     "Private",
133}
134
135type Type uint8
136
137const (
138	TypePrimitive   Type = 0  // xx0xxxxxb
139	TypeConstructed Type = 32 // xx1xxxxxb
140	TypeBitmask     Type = 32 // xx1xxxxxb
141)
142
143var TypeMap = map[Type]string{
144	TypePrimitive:   "Primitive",
145	TypeConstructed: "Constructed",
146}
147
148var Debug = false
149
150func PrintBytes(out io.Writer, buf []byte, indent string) {
151	dataLines := make([]string, (len(buf)/30)+1)
152	numLines := make([]string, (len(buf)/30)+1)
153
154	for i, b := range buf {
155		dataLines[i/30] += fmt.Sprintf("%02x ", b)
156		numLines[i/30] += fmt.Sprintf("%02d ", (i+1)%100)
157	}
158
159	for i := 0; i < len(dataLines); i++ {
160		_, _ = out.Write([]byte(indent + dataLines[i] + "\n"))
161		_, _ = out.Write([]byte(indent + numLines[i] + "\n\n"))
162	}
163}
164
165func WritePacket(out io.Writer, p *Packet) {
166	printPacket(out, p, 0, false)
167}
168
169func PrintPacket(p *Packet) {
170	printPacket(os.Stdout, p, 0, false)
171}
172
173func printPacket(out io.Writer, p *Packet, indent int, printBytes bool) {
174	indentStr := ""
175
176	for len(indentStr) != indent {
177		indentStr += " "
178	}
179
180	classStr := ClassMap[p.ClassType]
181
182	tagTypeStr := TypeMap[p.TagType]
183
184	tagStr := fmt.Sprintf("0x%02X", p.Tag)
185
186	if p.ClassType == ClassUniversal {
187		tagStr = tagMap[p.Tag]
188	}
189
190	value := fmt.Sprint(p.Value)
191	description := ""
192
193	if p.Description != "" {
194		description = p.Description + ": "
195	}
196
197	_, _ = fmt.Fprintf(out, "%s%s(%s, %s, %s) Len=%d %q\n", indentStr, description, classStr, tagTypeStr, tagStr, p.Data.Len(), value)
198
199	if printBytes {
200		PrintBytes(out, p.Bytes(), indentStr)
201	}
202
203	for _, child := range p.Children {
204		printPacket(out, child, indent+1, printBytes)
205	}
206}
207
208// ReadPacket reads a single Packet from the reader.
209func ReadPacket(reader io.Reader) (*Packet, error) {
210	p, _, err := readPacket(reader)
211	if err != nil {
212		return nil, err
213	}
214	return p, nil
215}
216
217func DecodeString(data []byte) string {
218	return string(data)
219}
220
221func ParseInt64(bytes []byte) (ret int64, err error) {
222	if len(bytes) > 8 {
223		// We'll overflow an int64 in this case.
224		err = fmt.Errorf("integer too large")
225		return
226	}
227	for bytesRead := 0; bytesRead < len(bytes); bytesRead++ {
228		ret <<= 8
229		ret |= int64(bytes[bytesRead])
230	}
231
232	// Shift up and down in order to sign extend the result.
233	ret <<= 64 - uint8(len(bytes))*8
234	ret >>= 64 - uint8(len(bytes))*8
235	return
236}
237
238func encodeInteger(i int64) []byte {
239	n := int64Length(i)
240	out := make([]byte, n)
241
242	var j int
243	for ; n > 0; n-- {
244		out[j] = byte(i >> uint((n-1)*8))
245		j++
246	}
247
248	return out
249}
250
251func int64Length(i int64) (numBytes int) {
252	numBytes = 1
253
254	for i > 127 {
255		numBytes++
256		i >>= 8
257	}
258
259	for i < -128 {
260		numBytes++
261		i >>= 8
262	}
263
264	return
265}
266
267// DecodePacket decodes the given bytes into a single Packet
268// If a decode error is encountered, nil is returned.
269func DecodePacket(data []byte) *Packet {
270	p, _, _ := readPacket(bytes.NewBuffer(data))
271
272	return p
273}
274
275// DecodePacketErr decodes the given bytes into a single Packet
276// If a decode error is encountered, nil is returned.
277func DecodePacketErr(data []byte) (*Packet, error) {
278	p, _, err := readPacket(bytes.NewBuffer(data))
279	if err != nil {
280		return nil, err
281	}
282	return p, nil
283}
284
285// readPacket reads a single Packet from the reader, returning the number of bytes read.
286func readPacket(reader io.Reader) (*Packet, int, error) {
287	identifier, length, read, err := readHeader(reader)
288	if err != nil {
289		return nil, read, err
290	}
291
292	p := &Packet{
293		Identifier: identifier,
294	}
295
296	p.Data = new(bytes.Buffer)
297	p.Children = make([]*Packet, 0, 2)
298	p.Value = nil
299
300	if p.TagType == TypeConstructed {
301		// TODO: if universal, ensure tag type is allowed to be constructed
302
303		// Track how much content we've read
304		contentRead := 0
305		for {
306			if length != LengthIndefinite {
307				// End if we've read what we've been told to
308				if contentRead == length {
309					break
310				}
311				// Detect if a packet boundary didn't fall on the expected length
312				if contentRead > length {
313					return nil, read, fmt.Errorf("expected to read %d bytes, read %d", length, contentRead)
314				}
315			}
316
317			// Read the next packet
318			child, r, err := readPacket(reader)
319			if err != nil {
320				return nil, read, err
321			}
322			contentRead += r
323			read += r
324
325			// Test is this is the EOC marker for our packet
326			if isEOCPacket(child) {
327				if length == LengthIndefinite {
328					break
329				}
330				return nil, read, errors.New("eoc child not allowed with definite length")
331			}
332
333			// Append and continue
334			p.AppendChild(child)
335		}
336		return p, read, nil
337	}
338
339	if length == LengthIndefinite {
340		return nil, read, errors.New("indefinite length used with primitive type")
341	}
342
343	// Read definite-length content
344	if MaxPacketLengthBytes > 0 && int64(length) > MaxPacketLengthBytes {
345		return nil, read, fmt.Errorf("length %d greater than maximum %d", length, MaxPacketLengthBytes)
346	}
347	content := make([]byte, length)
348	if length > 0 {
349		_, err := io.ReadFull(reader, content)
350		if err != nil {
351			if err == io.EOF {
352				return nil, read, io.ErrUnexpectedEOF
353			}
354			return nil, read, err
355		}
356		read += length
357	}
358
359	if p.ClassType == ClassUniversal {
360		p.Data.Write(content)
361		p.ByteValue = content
362
363		switch p.Tag {
364		case TagEOC:
365		case TagBoolean:
366			val, _ := ParseInt64(content)
367
368			p.Value = val != 0
369		case TagInteger:
370			p.Value, _ = ParseInt64(content)
371		case TagBitString:
372		case TagOctetString:
373			// the actual string encoding is not known here
374			// (e.g. for LDAP content is already an UTF8-encoded
375			// string). Return the data without further processing
376			p.Value = DecodeString(content)
377		case TagNULL:
378		case TagObjectIdentifier:
379		case TagObjectDescriptor:
380		case TagExternal:
381		case TagRealFloat:
382			p.Value, err = ParseReal(content)
383		case TagEnumerated:
384			p.Value, _ = ParseInt64(content)
385		case TagEmbeddedPDV:
386		case TagUTF8String:
387			val := DecodeString(content)
388			if !utf8.Valid([]byte(val)) {
389				err = errors.New("invalid UTF-8 string")
390			} else {
391				p.Value = val
392			}
393		case TagRelativeOID:
394		case TagSequence:
395		case TagSet:
396		case TagNumericString:
397		case TagPrintableString:
398			val := DecodeString(content)
399			if err = isPrintableString(val); err == nil {
400				p.Value = val
401			}
402		case TagT61String:
403		case TagVideotexString:
404		case TagIA5String:
405			val := DecodeString(content)
406			for i, c := range val {
407				if c >= 0x7F {
408					err = fmt.Errorf("invalid character for IA5String at pos %d: %c", i, c)
409					break
410				}
411			}
412			if err == nil {
413				p.Value = val
414			}
415		case TagUTCTime:
416		case TagGeneralizedTime:
417			p.Value, err = ParseGeneralizedTime(content)
418		case TagGraphicString:
419		case TagVisibleString:
420		case TagGeneralString:
421		case TagUniversalString:
422		case TagCharacterString:
423		case TagBMPString:
424		}
425	} else {
426		p.Data.Write(content)
427	}
428
429	return p, read, err
430}
431
432func isPrintableString(val string) error {
433	for i, c := range val {
434		switch {
435		case c >= 'a' && c <= 'z':
436		case c >= 'A' && c <= 'Z':
437		case c >= '0' && c <= '9':
438		default:
439			switch c {
440			case '\'', '(', ')', '+', ',', '-', '.', '=', '/', ':', '?', ' ':
441			default:
442				return fmt.Errorf("invalid character in position %d", i)
443			}
444		}
445	}
446	return nil
447}
448
449func (p *Packet) Bytes() []byte {
450	var out bytes.Buffer
451
452	out.Write(encodeIdentifier(p.Identifier))
453	out.Write(encodeLength(p.Data.Len()))
454	out.Write(p.Data.Bytes())
455
456	return out.Bytes()
457}
458
459func (p *Packet) AppendChild(child *Packet) {
460	p.Data.Write(child.Bytes())
461	p.Children = append(p.Children, child)
462}
463
464func Encode(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
465	p := new(Packet)
466
467	p.ClassType = classType
468	p.TagType = tagType
469	p.Tag = tag
470	p.Data = new(bytes.Buffer)
471
472	p.Children = make([]*Packet, 0, 2)
473
474	p.Value = value
475	p.Description = description
476
477	if value != nil {
478		v := reflect.ValueOf(value)
479
480		if classType == ClassUniversal {
481			switch tag {
482			case TagOctetString:
483				sv, ok := v.Interface().(string)
484
485				if ok {
486					p.Data.Write([]byte(sv))
487				}
488			case TagEnumerated:
489				bv, ok := v.Interface().([]byte)
490				if ok {
491					p.Data.Write(bv)
492				}
493			case TagEmbeddedPDV:
494				bv, ok := v.Interface().([]byte)
495				if ok {
496					p.Data.Write(bv)
497				}
498			}
499		} else if classType == ClassContext {
500			switch tag {
501			case TagEnumerated:
502				bv, ok := v.Interface().([]byte)
503				if ok {
504					p.Data.Write(bv)
505				}
506			case TagEmbeddedPDV:
507				bv, ok := v.Interface().([]byte)
508				if ok {
509					p.Data.Write(bv)
510				}
511			}
512		}
513	}
514	return p
515}
516
517func NewSequence(description string) *Packet {
518	return Encode(ClassUniversal, TypeConstructed, TagSequence, nil, description)
519}
520
521func NewBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
522	intValue := int64(0)
523
524	if value {
525		intValue = 1
526	}
527
528	p := Encode(classType, tagType, tag, nil, description)
529
530	p.Value = value
531	p.Data.Write(encodeInteger(intValue))
532
533	return p
534}
535
536// NewLDAPBoolean returns a RFC 4511-compliant Boolean packet.
537func NewLDAPBoolean(classType Class, tagType Type, tag Tag, value bool, description string) *Packet {
538	intValue := int64(0)
539
540	if value {
541		intValue = 255
542	}
543
544	p := Encode(classType, tagType, tag, nil, description)
545
546	p.Value = value
547	p.Data.Write(encodeInteger(intValue))
548
549	return p
550}
551
552func NewInteger(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
553	p := Encode(classType, tagType, tag, nil, description)
554
555	p.Value = value
556	switch v := value.(type) {
557	case int:
558		p.Data.Write(encodeInteger(int64(v)))
559	case uint:
560		p.Data.Write(encodeInteger(int64(v)))
561	case int64:
562		p.Data.Write(encodeInteger(v))
563	case uint64:
564		// TODO : check range or add encodeUInt...
565		p.Data.Write(encodeInteger(int64(v)))
566	case int32:
567		p.Data.Write(encodeInteger(int64(v)))
568	case uint32:
569		p.Data.Write(encodeInteger(int64(v)))
570	case int16:
571		p.Data.Write(encodeInteger(int64(v)))
572	case uint16:
573		p.Data.Write(encodeInteger(int64(v)))
574	case int8:
575		p.Data.Write(encodeInteger(int64(v)))
576	case uint8:
577		p.Data.Write(encodeInteger(int64(v)))
578	default:
579		// TODO : add support for big.Int ?
580		panic(fmt.Sprintf("Invalid type %T, expected {u|}int{64|32|16|8}", v))
581	}
582
583	return p
584}
585
586func NewString(classType Class, tagType Type, tag Tag, value, description string) *Packet {
587	p := Encode(classType, tagType, tag, nil, description)
588
589	p.Value = value
590	p.Data.Write([]byte(value))
591
592	return p
593}
594
595func NewGeneralizedTime(classType Class, tagType Type, tag Tag, value time.Time, description string) *Packet {
596	p := Encode(classType, tagType, tag, nil, description)
597	var s string
598	if value.Nanosecond() != 0 {
599		s = value.Format(`20060102150405.000000000Z`)
600	} else {
601		s = value.Format(`20060102150405Z`)
602	}
603	p.Value = s
604	p.Data.Write([]byte(s))
605	return p
606}
607
608func NewReal(classType Class, tagType Type, tag Tag, value interface{}, description string) *Packet {
609	p := Encode(classType, tagType, tag, nil, description)
610
611	switch v := value.(type) {
612	case float64:
613		p.Data.Write(encodeFloat(v))
614	case float32:
615		p.Data.Write(encodeFloat(float64(v)))
616	default:
617		panic(fmt.Sprintf("Invalid type %T, expected float{64|32}", v))
618	}
619	return p
620}
621