1#
2# This file is part of pyasn1 software.
3#
4# Copyright (c) 2005-2017, Ilya Etingof <etingof@gmail.com>
5# License: http://pyasn1.sf.net/license.html
6#
7from pyasn1.type import tag, univ, char, useful
8from pyasn1.codec.ber import eoo
9from pyasn1.compat.octets import int2oct, oct2int, ints2octs, null, str2octs
10from pyasn1.compat.integer import to_bytes
11from pyasn1 import debug, error
12
13__all__ = ['encode']
14
15
16class AbstractItemEncoder(object):
17    supportIndefLenMode = 1
18
19    # An outcome of otherwise legit call `encodeFun(eoo.endOfOctets)`
20    eooIntegerSubstrate = (0, 0)
21    eooOctetsSubstrate = ints2octs(eooIntegerSubstrate)
22
23    # noinspection PyMethodMayBeStatic
24    def encodeTag(self, singleTag, isConstructed):
25        tagClass, tagFormat, tagId = singleTag
26        encodedTag = tagClass | tagFormat
27        if isConstructed:
28            encodedTag |= tag.tagFormatConstructed
29        if tagId < 31:
30            return encodedTag | tagId,
31        else:
32            substrate = tagId & 0x7f,
33            tagId >>= 7
34            while tagId:
35                substrate = (0x80 | (tagId & 0x7f),) + substrate
36                tagId >>= 7
37            return (encodedTag | 0x1F,) + substrate
38
39    def encodeLength(self, length, defMode):
40        if not defMode and self.supportIndefLenMode:
41            return (0x80,)
42        if length < 0x80:
43            return length,
44        else:
45            substrate = ()
46            while length:
47                substrate = (length & 0xff,) + substrate
48                length >>= 8
49            substrateLen = len(substrate)
50            if substrateLen > 126:
51                raise error.PyAsn1Error('Length octets overflow (%d)' % substrateLen)
52            return (0x80 | substrateLen,) + substrate
53
54    def encodeValue(self, value, encodeFun, **options):
55        raise error.PyAsn1Error('Not implemented')
56
57    def encode(self, value, encodeFun, **options):
58
59        tagSet = value.tagSet
60
61        # untagged item?
62        if not tagSet:
63            substrate, isConstructed, isOctets = self.encodeValue(
64                value, encodeFun, **options
65            )
66            return substrate
67
68        defMode = options.get('defMode', True)
69
70        for idx, singleTag in enumerate(tagSet.superTags):
71
72            defModeOverride = defMode
73
74            # base tag?
75            if not idx:
76                substrate, isConstructed, isOctets = self.encodeValue(
77                    value, encodeFun, **options
78                )
79
80                if options.get('ifNotEmpty', False) and not substrate:
81                    return substrate
82
83                # primitive form implies definite mode
84                if not isConstructed:
85                    defModeOverride = True
86
87            header = self.encodeTag(singleTag, isConstructed)
88            header += self.encodeLength(len(substrate), defModeOverride)
89
90            if isOctets:
91                substrate = ints2octs(header) + substrate
92
93                if not defModeOverride:
94                    substrate += self.eooOctetsSubstrate
95
96            else:
97                substrate = header + substrate
98
99                if not defModeOverride:
100                    substrate += self.eooIntegerSubstrate
101
102        if not isOctets:
103            substrate = ints2octs(substrate)
104
105        return substrate
106
107
108class EndOfOctetsEncoder(AbstractItemEncoder):
109    def encodeValue(self, value, encodeFun, **options):
110        return null, False, True
111
112
113class BooleanEncoder(AbstractItemEncoder):
114    supportIndefLenMode = False
115
116    def encodeValue(self, value, encodeFun, **options):
117        return value and (1,) or (0,), False, False
118
119
120class IntegerEncoder(AbstractItemEncoder):
121    supportIndefLenMode = False
122    supportCompactZero = False
123
124    def encodeValue(self, value, encodeFun, **options):
125        if value == 0:
126            # de-facto way to encode zero
127            if self.supportCompactZero:
128                return (), False, False
129            else:
130                return (0,), False, False
131
132        return to_bytes(int(value), signed=True), False, True
133
134
135class BitStringEncoder(AbstractItemEncoder):
136    def encodeValue(self, value, encodeFun, **options):
137        valueLength = len(value)
138        if valueLength % 8:
139            alignedValue = value << (8 - valueLength % 8)
140        else:
141            alignedValue = value
142
143        maxChunkSize = options.get('maxChunkSize', 0)
144        if not maxChunkSize or len(alignedValue) <= maxChunkSize * 8:
145            substrate = alignedValue.asOctets()
146            return int2oct(len(substrate) * 8 - valueLength) + substrate, False, True
147
148        # strip off explicit tags
149        alignedValue = alignedValue.clone(
150            tagSet=tag.TagSet(value.tagSet.baseTag, value.tagSet.baseTag)
151        )
152
153        stop = 0
154        substrate = null
155        while stop < valueLength:
156            start = stop
157            stop = min(start + maxChunkSize * 8, valueLength)
158            substrate += encodeFun(alignedValue[start:stop], **options)
159
160        return substrate, True, True
161
162
163class OctetStringEncoder(AbstractItemEncoder):
164    def encodeValue(self, value, encodeFun, **options):
165        maxChunkSize = options.get('maxChunkSize', 0)
166        if not maxChunkSize or len(value) <= maxChunkSize:
167            return value.asOctets(), False, True
168
169        else:
170            # will strip off explicit tags
171            baseTagSet = tag.TagSet(value.tagSet.baseTag, value.tagSet.baseTag)
172
173            pos = 0
174            substrate = null
175            while True:
176                chunk = value.clone(value[pos:pos + maxChunkSize],
177                                    tagSet=baseTagSet)
178                if not chunk:
179                    break
180                substrate += encodeFun(chunk, **options)
181                pos += maxChunkSize
182
183            return substrate, True, True
184
185
186class NullEncoder(AbstractItemEncoder):
187    supportIndefLenMode = False
188
189    def encodeValue(self, value, encodeFun, **options):
190        return null, False, True
191
192
193class ObjectIdentifierEncoder(AbstractItemEncoder):
194    supportIndefLenMode = False
195
196    def encodeValue(self, value, encodeFun, **options):
197        oid = value.asTuple()
198
199        # Build the first pair
200        try:
201            first = oid[0]
202            second = oid[1]
203
204        except IndexError:
205            raise error.PyAsn1Error('Short OID %s' % (value,))
206
207        if 0 <= second <= 39:
208            if first == 1:
209                oid = (second + 40,) + oid[2:]
210            elif first == 0:
211                oid = (second,) + oid[2:]
212            elif first == 2:
213                oid = (second + 80,) + oid[2:]
214            else:
215                raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,))
216        elif first == 2:
217            oid = (second + 80,) + oid[2:]
218        else:
219            raise error.PyAsn1Error('Impossible first/second arcs at %s' % (value,))
220
221        octets = ()
222
223        # Cycle through subIds
224        for subOid in oid:
225            if 0 <= subOid <= 127:
226                # Optimize for the common case
227                octets += (subOid,)
228            elif subOid > 127:
229                # Pack large Sub-Object IDs
230                res = (subOid & 0x7f,)
231                subOid >>= 7
232                while subOid:
233                    res = (0x80 | (subOid & 0x7f),) + res
234                    subOid >>= 7
235                # Add packed Sub-Object ID to resulted Object ID
236                octets += res
237            else:
238                raise error.PyAsn1Error('Negative OID arc %s at %s' % (subOid, value))
239
240        return octets, False, False
241
242
243class RealEncoder(AbstractItemEncoder):
244    supportIndefLenMode = 0
245    binEncBase = 2  # set to None to choose encoding base automatically
246
247    @staticmethod
248    def _dropFloatingPoint(m, encbase, e):
249        ms, es = 1, 1
250        if m < 0:
251            ms = -1  # mantissa sign
252        if e < 0:
253            es = -1  # exponenta sign
254        m *= ms
255        if encbase == 8:
256            m *= 2 ** (abs(e) % 3 * es)
257            e = abs(e) // 3 * es
258        elif encbase == 16:
259            m *= 2 ** (abs(e) % 4 * es)
260            e = abs(e) // 4 * es
261
262        while True:
263            if int(m) != m:
264                m *= encbase
265                e -= 1
266                continue
267            break
268        return ms, int(m), encbase, e
269
270    def _chooseEncBase(self, value):
271        m, b, e = value
272        encBase = [2, 8, 16]
273        if value.binEncBase in encBase:
274            return self._dropFloatingPoint(m, value.binEncBase, e)
275        elif self.binEncBase in encBase:
276            return self._dropFloatingPoint(m, self.binEncBase, e)
277        # auto choosing base 2/8/16
278        mantissa = [m, m, m]
279        exponenta = [e, e, e]
280        sign = 1
281        encbase = 2
282        e = float('inf')
283        for i in range(3):
284            (sign,
285             mantissa[i],
286             encBase[i],
287             exponenta[i]) = self._dropFloatingPoint(mantissa[i], encBase[i], exponenta[i])
288            if abs(exponenta[i]) < abs(e) or (abs(exponenta[i]) == abs(e) and mantissa[i] < m):
289                e = exponenta[i]
290                m = int(mantissa[i])
291                encbase = encBase[i]
292        return sign, m, encbase, e
293
294    def encodeValue(self, value, encodeFun, **options):
295        if value.isPlusInf:
296            return (0x40,), False, False
297        if value.isMinusInf:
298            return (0x41,), False, False
299        m, b, e = value
300        if not m:
301            return null, False, True
302        if b == 10:
303            return str2octs('\x03%dE%s%d' % (m, e == 0 and '+' or '', e)), False, True
304        elif b == 2:
305            fo = 0x80  # binary encoding
306            ms, m, encbase, e = self._chooseEncBase(value)
307            if ms < 0:  # mantissa sign
308                fo |= 0x40  # sign bit
309            # exponenta & mantissa normalization
310            if encbase == 2:
311                while m & 0x1 == 0:
312                    m >>= 1
313                    e += 1
314            elif encbase == 8:
315                while m & 0x7 == 0:
316                    m >>= 3
317                    e += 1
318                fo |= 0x10
319            else:  # encbase = 16
320                while m & 0xf == 0:
321                    m >>= 4
322                    e += 1
323                fo |= 0x20
324            sf = 0  # scale factor
325            while m & 0x1 == 0:
326                m >>= 1
327                sf += 1
328            if sf > 3:
329                raise error.PyAsn1Error('Scale factor overflow')  # bug if raised
330            fo |= sf << 2
331            eo = null
332            if e == 0 or e == -1:
333                eo = int2oct(e & 0xff)
334            else:
335                while e not in (0, -1):
336                    eo = int2oct(e & 0xff) + eo
337                    e >>= 8
338                if e == 0 and eo and oct2int(eo[0]) & 0x80:
339                    eo = int2oct(0) + eo
340                if e == -1 and eo and not (oct2int(eo[0]) & 0x80):
341                    eo = int2oct(0xff) + eo
342            n = len(eo)
343            if n > 0xff:
344                raise error.PyAsn1Error('Real exponent overflow')
345            if n == 1:
346                pass
347            elif n == 2:
348                fo |= 1
349            elif n == 3:
350                fo |= 2
351            else:
352                fo |= 3
353                eo = int2oct(n & 0xff) + eo
354            po = null
355            while m:
356                po = int2oct(m & 0xff) + po
357                m >>= 8
358            substrate = int2oct(fo) + eo + po
359            return substrate, False, True
360        else:
361            raise error.PyAsn1Error('Prohibited Real base %s' % b)
362
363
364class SequenceEncoder(AbstractItemEncoder):
365    def encodeValue(self, value, encodeFun, **options):
366        value.verifySizeSpec()
367
368        namedTypes = value.componentType
369        substrate = null
370
371        idx = len(value)
372        while idx > 0:
373            idx -= 1
374            if namedTypes:
375                if namedTypes[idx].isOptional and not value[idx].isValue:
376                    continue
377                if namedTypes[idx].isDefaulted and value[idx] == namedTypes[idx].asn1Object:
378                    continue
379            substrate = encodeFun(value[idx], **options) + substrate
380
381        return substrate, True, True
382
383
384class SequenceOfEncoder(AbstractItemEncoder):
385    def encodeValue(self, value, encodeFun, **options):
386        value.verifySizeSpec()
387        substrate = null
388        idx = len(value)
389        while idx > 0:
390            idx -= 1
391            substrate = encodeFun(value[idx], **options) + substrate
392        return substrate, True, True
393
394
395class ChoiceEncoder(AbstractItemEncoder):
396    def encodeValue(self, value, encodeFun, **options):
397        return encodeFun(value.getComponent(), **options), True, True
398
399
400class AnyEncoder(OctetStringEncoder):
401    def encodeValue(self, value, encodeFun, **options):
402        return value.asOctets(), not options.get('defMode', True), True
403
404
405tagMap = {
406    eoo.endOfOctets.tagSet: EndOfOctetsEncoder(),
407    univ.Boolean.tagSet: BooleanEncoder(),
408    univ.Integer.tagSet: IntegerEncoder(),
409    univ.BitString.tagSet: BitStringEncoder(),
410    univ.OctetString.tagSet: OctetStringEncoder(),
411    univ.Null.tagSet: NullEncoder(),
412    univ.ObjectIdentifier.tagSet: ObjectIdentifierEncoder(),
413    univ.Enumerated.tagSet: IntegerEncoder(),
414    univ.Real.tagSet: RealEncoder(),
415    # Sequence & Set have same tags as SequenceOf & SetOf
416    univ.SequenceOf.tagSet: SequenceOfEncoder(),
417    univ.SetOf.tagSet: SequenceOfEncoder(),
418    univ.Choice.tagSet: ChoiceEncoder(),
419    # character string types
420    char.UTF8String.tagSet: OctetStringEncoder(),
421    char.NumericString.tagSet: OctetStringEncoder(),
422    char.PrintableString.tagSet: OctetStringEncoder(),
423    char.TeletexString.tagSet: OctetStringEncoder(),
424    char.VideotexString.tagSet: OctetStringEncoder(),
425    char.IA5String.tagSet: OctetStringEncoder(),
426    char.GraphicString.tagSet: OctetStringEncoder(),
427    char.VisibleString.tagSet: OctetStringEncoder(),
428    char.GeneralString.tagSet: OctetStringEncoder(),
429    char.UniversalString.tagSet: OctetStringEncoder(),
430    char.BMPString.tagSet: OctetStringEncoder(),
431    # useful types
432    useful.ObjectDescriptor.tagSet: OctetStringEncoder(),
433    useful.GeneralizedTime.tagSet: OctetStringEncoder(),
434    useful.UTCTime.tagSet: OctetStringEncoder()
435}
436
437# Put in ambiguous & non-ambiguous types for faster codec lookup
438typeMap = {
439    univ.Boolean.typeId: BooleanEncoder(),
440    univ.Integer.typeId: IntegerEncoder(),
441    univ.BitString.typeId: BitStringEncoder(),
442    univ.OctetString.typeId: OctetStringEncoder(),
443    univ.Null.typeId: NullEncoder(),
444    univ.ObjectIdentifier.typeId: ObjectIdentifierEncoder(),
445    univ.Enumerated.typeId: IntegerEncoder(),
446    univ.Real.typeId: RealEncoder(),
447    # Sequence & Set have same tags as SequenceOf & SetOf
448    univ.Set.typeId: SequenceEncoder(),
449    univ.SetOf.typeId: SequenceOfEncoder(),
450    univ.Sequence.typeId: SequenceEncoder(),
451    univ.SequenceOf.typeId: SequenceOfEncoder(),
452    univ.Choice.typeId: ChoiceEncoder(),
453    univ.Any.typeId: AnyEncoder(),
454    # character string types
455    char.UTF8String.typeId: OctetStringEncoder(),
456    char.NumericString.typeId: OctetStringEncoder(),
457    char.PrintableString.typeId: OctetStringEncoder(),
458    char.TeletexString.typeId: OctetStringEncoder(),
459    char.VideotexString.typeId: OctetStringEncoder(),
460    char.IA5String.typeId: OctetStringEncoder(),
461    char.GraphicString.typeId: OctetStringEncoder(),
462    char.VisibleString.typeId: OctetStringEncoder(),
463    char.GeneralString.typeId: OctetStringEncoder(),
464    char.UniversalString.typeId: OctetStringEncoder(),
465    char.BMPString.typeId: OctetStringEncoder(),
466    # useful types
467    useful.ObjectDescriptor.typeId: OctetStringEncoder(),
468    useful.GeneralizedTime.typeId: OctetStringEncoder(),
469    useful.UTCTime.typeId: OctetStringEncoder()
470}
471
472
473class Encoder(object):
474    fixedDefLengthMode = None
475    fixedChunkSize = None
476
477    # noinspection PyDefaultArgument
478    def __init__(self, tagMap, typeMap={}):
479        self.__tagMap = tagMap
480        self.__typeMap = typeMap
481
482    def __call__(self, value, **options):
483
484        if debug.logger & debug.flagEncoder:
485            logger = debug.logger
486        else:
487            logger = None
488
489        if logger:
490            logger('encoder called in %sdef mode, chunk size %s for type %s, value:\n%s' % (not options.get('defMode', True) and 'in' or '', options.get('maxChunkSize', 0), value.prettyPrintType(), value.prettyPrint()))
491
492        if self.fixedDefLengthMode is not None:
493            options.update(defMode=self.fixedDefLengthMode)
494
495        if self.fixedChunkSize is not None:
496            options.update(maxChunkSize=self.fixedChunkSize)
497
498        tagSet = value.tagSet
499
500        try:
501            concreteEncoder = self.__typeMap[value.typeId]
502
503        except KeyError:
504            # use base type for codec lookup to recover untagged types
505            baseTagSet = tag.TagSet(value.tagSet.baseTag, value.tagSet.baseTag)
506
507            try:
508                concreteEncoder = self.__tagMap[baseTagSet]
509
510            except KeyError:
511                raise error.PyAsn1Error('No encoder for %s' % (value,))
512
513        if logger:
514            logger('using value codec %s chosen by %s' % (concreteEncoder.__class__.__name__, tagSet))
515
516        substrate = concreteEncoder.encode(value, self, **options)
517
518        if logger:
519            logger('codec %s built %s octets of substrate: %s\nencoder completed' % (concreteEncoder, len(substrate), debug.hexdump(substrate)))
520
521        return substrate
522
523#: Turns ASN.1 object into BER octet stream.
524#:
525#: Takes any ASN.1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
526#: walks all its components recursively and produces a BER octet stream.
527#:
528#: Parameters
529#: ----------
530#  value: any pyasn1 object (e.g. :py:class:`~pyasn1.type.base.PyAsn1Item` derivative)
531#:     A pyasn1 object to encode
532#:
533#: defMode: :py:class:`bool`
534#:     If `False`, produces indefinite length encoding
535#:
536#: maxChunkSize: :py:class:`int`
537#:     Maximum chunk size in chunked encoding mode (0 denotes unlimited chunk size)
538#:
539#: Returns
540#: -------
541#: : :py:class:`bytes` (Python 3) or :py:class:`str` (Python 2)
542#:     Given ASN.1 object encoded into BER octetstream
543#:
544#: Raises
545#: ------
546#: : :py:class:`pyasn1.error.PyAsn1Error`
547#:     On encoding errors
548encode = Encoder(tagMap, typeMap)
549