1#! /usr/bin/env python3
2
3__author__ = "Berry van Halderen"
4__date__ = "$$"
5
6import os
7import sys
8import getopt
9import yaml
10import re
11import pkcs11
12import base64
13import binascii
14import xml.dom.minidom
15
16''''
17Limitations:
18- Only one PKCS#11 token for exporting public keys supported at any one time
19- Hardware PKCS#11 token should be specified as first repository,  SoftHSM based
20  token as second repository.
21- Only migrates RSA public keys
22- Token PIN must be specified in the configuration file
23'''
24
25tokenmodule   = "/usr/local/lib/softhsm/libsofthsm2.so"
26tokenlabel    = "SoftHSM"
27tokenpin      = "1234"
28signconffname = None
29
30def main():
31    try:
32        if sys.argv[1] == "export":
33            readconf(sys.argv[2], sys.argv[3], 0)
34            lib = pkcs11.lib(tokenmodule)
35            token = lib.get_token(token_label=tokenlabel)
36            session = token.open(user_pin=tokenpin, rw=False)
37            ( signconf, keys ) = readsignconf(signconffname)
38            exportkeys(session, keys)
39            session.close()
40            patchsignconf(signconf, keys)
41            writesignconf(signconf, signconffname, "pseudo")
42        elif sys.argv[1] == "import":
43            readconf(sys.argv[2], sys.argv[3], 1)
44            lib = pkcs11.lib(tokenmodule)
45            token = lib.get_token(token_label=tokenlabel)
46            session = token.open(user_pin=tokenpin, rw=True)
47            ( signconf, keys ) = readsignconf(signconffname, "pseudo")
48            importkeys(session, keys)
49            newsignconf = mergesignconf(signconf, keys, signconffname)
50            writesignconf(newsignconf, signconffname, "new")
51            session.close()
52    except pkcs11.exceptions.NoSuchToken:
53        print("Unable to access token", file=sys.stderr)
54        sys.exit(1)
55
56class KeyNotFound(Exception):
57    message = None
58
59    def __init__(self, message):
60        self.message = message
61
62def readconf(conffname, zonename, repoindex=None):
63    global tokenmodule
64    global tokenlabel
65    global tokenpin
66    global signconffname
67
68    confdoc = xml.dom.minidom.parse(conffname)
69    count = 0
70    if repoindex != None:
71        for reponode in getxpath(confdoc, ['Configuration', 'RepositoryList']).getElementsByTagName('Repository'):
72            if count == repoindex:
73                tokenmodule = getxpath(reponode, ['Module', None])
74                tokenlabel  = getxpath(reponode, ['TokenLabel', None])
75                tokenpin    = getxpath(reponode, ['PIN', None])
76            count = count + 1
77    zonelistfile = getxpath(confdoc, ['Configuration', 'Enforcer', 'WorkingDirectory', None], "") + "/" + "zones.xml"
78    if not os.path.exists(zonelistfile):
79        zonelistfile = getxpath(confdoc, ['Configuration', 'Common', 'ZoneListFile', None])
80    zonelistdoc = xml.dom.minidom.parse(zonelistfile)
81    signconffname = None
82    for zonenode in getxpath(zonelistdoc, ['ZoneList']).getElementsByTagName('Zone'):
83        if zonenode.getAttribute('name') == zonename:
84            signconffname = getxpath(zonenode, ['SignerConfiguration', None])
85
86def importkey(session, keyname, modulus, exponent):
87    attrs = { }
88    attrs[pkcs11.constants.Attribute.ID] = binascii.a2b_hex(keyname)
89    for handle in session.get_objects(attrs):
90        if isinstance(handle, pkcs11.PublicKey):
91            print("Found public key")
92            return False
93        elif isinstance(handle, pkcs11.PrivateKey):
94            print("Found private key")
95            return False
96    flags  = pkcs11.constants.MechanismFlag.SIGN | pkcs11.constants.MechanismFlag.VERIFY
97    flags |= pkcs11.constants.MechanismFlag.WRAP | pkcs11.constants.MechanismFlag.ENCRYPT
98    flags |= pkcs11.constants.MechanismFlag.UNWRAP | pkcs11.constants.MechanismFlag.DECRYPT
99    flags |= pkcs11.constants.MechanismFlag.HW | pkcs11.constants.MechanismFlag.DIGEST
100    template  = { pkcs11.constants.Attribute.TOKEN: True, pkcs11.constants.Attribute.PRIVATE: False }
101    template[pkcs11.constants.Attribute.LABEL]  = keyname
102    template[pkcs11.constants.Attribute.ID]     = binascii.a2b_hex(keyname)
103    template[pkcs11.constants.Attribute.CLASS]          = pkcs11.ObjectClass.PUBLIC_KEY
104    template[pkcs11.constants.Attribute.KEY_TYPE]       = pkcs11.KeyType.RSA
105    template[pkcs11.constants.Attribute.TOKEN]          = True
106    template[pkcs11.constants.Attribute.PRIVATE]        = True
107    template[pkcs11.constants.Attribute.ENCRYPT]        = True
108    template[pkcs11.constants.Attribute.VERIFY]         = True
109    template[pkcs11.constants.Attribute.VERIFY_RECOVER] = True
110    template[pkcs11.constants.Attribute.WRAP]           = True
111    template[pkcs11.constants.Attribute.MODULUS]         = modulus
112    template[pkcs11.constants.Attribute.PUBLIC_EXPONENT] = exponent
113    key = session.create_object(template)
114    template  = { pkcs11.constants.Attribute.TOKEN: True, pkcs11.constants.Attribute.PRIVATE: False }
115    template[pkcs11.constants.Attribute.LABEL]  = keyname
116    template[pkcs11.constants.Attribute.ID]     = binascii.a2b_hex(keyname)
117    template[pkcs11.constants.Attribute.CLASS]          = pkcs11.ObjectClass.PRIVATE_KEY
118    template[pkcs11.constants.Attribute.KEY_TYPE]       = pkcs11.KeyType.RSA
119    template[pkcs11.constants.Attribute.TOKEN]          = True
120    template[pkcs11.constants.Attribute.PRIVATE]        = True
121    template[pkcs11.constants.Attribute.DECRYPT]        = True
122    template[pkcs11.constants.Attribute.SIGN]         = True
123    template[pkcs11.constants.Attribute.SIGN_RECOVER] = True
124    template[pkcs11.constants.Attribute.UNWRAP]         = True
125    template[pkcs11.constants.Attribute.MODULUS]          = modulus
126    template[pkcs11.constants.Attribute.PRIVATE_EXPONENT] = exponent
127    key = session.create_object(template)
128    return True
129
130def importkeys(session, keys):
131    for keyname in keys.keys():
132        if 'keydata' in keys[keyname].keys():
133            keydata = base64.b64decode(keys[keyname]['keydata'])
134            ( modulus, exponent ) = decomposekeydata(keydata)
135            imported = importkey(session, keyname, modulus, exponent)
136            if imported:
137                print("imported key " + keyname)
138
139def exportkey(session, keyname):
140    attrs = { }
141    attrs[pkcs11.constants.Attribute.ID] = binascii.a2b_hex(keyname)
142    modulus = False
143    exponent = False
144    for handle in session.get_objects(attrs):
145        if isinstance(handle, pkcs11.PublicKey):
146            modulus = handle[pkcs11.constants.Attribute.MODULUS]
147            exponent = handle[pkcs11.constants.Attribute.PUBLIC_EXPONENT]
148            return ( modulus, exponent )
149        elif isinstance(handle, pkcs11.PrivateKey):
150            modulus = handle[pkcs11.constants.Attribute.MODULUS]
151            exponent = handle[pkcs11.constants.Attribute.PUBLIC_EXPONENT]
152    if modulus == False:
153        raise KeyNotFound(keyname)
154    return ( modulus, exponent )
155
156def exportkeys(session, keys):
157    for keyname in keys.keys():
158        try:
159            ( modulus, exponent ) = exportkey(session, keyname)
160            keys[keyname] = { 'modulus': modulus,
161                              'exponent': exponent,
162                              'keydata': composekeydata(modulus, exponent) }
163        except KeyNotFound:
164            print("key "+keyname+" not found")
165            pass
166
167def composekeydata(modulus, exponent):
168    modulus_skip = 0
169    while modulus_skip < len(modulus) and modulus[modulus_skip] == 0:
170        ++modulus_skip
171    exponent_skip = 0
172    while exponent_skip < len(exponent) and exponent[exponent_skip] == 0:
173        ++exponent_skip
174    if len(exponent) - exponent_skip > 65535:
175        raise Burned("len exponent longer than allowed ("+len(exponent)+")")
176    elif len(exponent) - exponent_skip > 255:
177        buffer = bytearray()
178        buffer.append(0)
179        buffer.append((len(exponent) - exponent_skip) >> 8)
180        buffer.append((len(exponent) - exponent_skip) & 0xff)
181        buffer.extend(exponent[exponent_skip:])
182        buffer.extend(modulus[modulus_skip:])
183    else:
184        buffer = bytearray()
185        buffer.append(len(exponent) - exponent_skip)
186        buffer.extend(exponent[exponent_skip:])
187        buffer.extend(modulus[modulus_skip:])
188    return buffer
189
190def decomposekeydata(buffer):
191    if buffer[0] == 0:
192        exponent_len = buffer[1] << 8 | buffer[2]
193        exponent = buffer[3:exponent_len+3]
194        modulus = buffer[exponent_len+3:]
195    else:
196        exponent_len = buffer[0]
197        exponent = buffer[1:exponent_len+1]
198        modulus = buffer[exponent_len+1:]
199    return ( modulus, exponent )
200
201def processkeys(keys):
202    for keyname in keys:
203        key = keys[keyname]
204        if 'modulus' in key:
205            key['keydata'] = composekeydata(key['modulus'], key['exponent'])
206
207def readsignconf(signconf, prefix=None):
208    signconfkeys = { }
209    if prefix == None:
210        fname = signconf
211    else:
212        fname = os.path.join(os.path.dirname(signconf), prefix + "-" + os.path.basename(signconf))
213    doc = xml.dom.minidom.parse(fname)
214    for keynode in doc.getElementsByTagName('Key'):
215        keyname = getxpath(keynode, ['Locator', None])
216        signconfkeys[keyname] = { }
217        signconfkeys[keyname]['keynode'] = keynode
218        keydata = getxpath(keynode, ['PublicKeyData', None])
219        signconfkeys[keyname]['keydata'] = keydata
220    return ( doc, signconfkeys )
221
222def getxpath(node, path, defaultValue=None):
223    for p in path:
224        next = None
225        for n in node.childNodes:
226            if p == None:
227                return n.data
228            elif n.localName == p:
229                next = n
230                break
231        if next == None:
232            return defaultValue
233        else:
234            node = next
235    return node
236
237def mergesignconf(signconf, keys, fname, prefix=None):
238    if prefix != None:
239        fname = os.path.join(os.path.dirname(signconf), prefix + "-" + os.path.basename(signconf))
240    doc = xml.dom.minidom.parse(fname)
241    keysnode = getxpath(doc, ['SignerConfiguration', 'Zone', 'Keys'])
242    for key in keys:
243        keynode = keys[key]['keynode']
244        locator = getxpath(keynode, ["Locator"])
245        if locator in keys:
246            keysnode.removeChild(keynode)
247    for key in keys:
248        mergekeyrolenum = getxpath(keys[key]['keynode'], ['Flags', None])
249        mergekeyalgonum = getxpath(keys[key]['keynode'], ['Algorithm', None])
250        mergekeylocator = getxpath(keys[key]['keynode'], ['Locator', None])
251        #mergekeyksk     = getxpath(keys[key]['keynode'], ['KSK'])
252        #mergekeyzsk     = getxpath(keys[key]['keynode'], ['ZSK'])
253        mergekeypublish = getxpath(keys[key]['keynode'], ['Publish'])
254        keynode = doc.createElement("Key")
255        keynode.appendChild(doc.createTextNode("\n\t\t\t\t"))
256        node = doc.createElement("Flags")
257        node.appendChild(doc.createTextNode(str(mergekeyrolenum)))
258        keynode.appendChild(node)
259        keynode.appendChild(doc.createTextNode("\n\t\t\t\t"))
260        node = doc.createElement("Algorithm")
261        node.appendChild(doc.createTextNode(str(mergekeyalgonum)))
262        keynode.appendChild(doc.createTextNode("\n\t\t\t\t"))
263        keynode.appendChild(node)
264        node = doc.createElement("Locator")
265        node.appendChild(doc.createTextNode(str(mergekeylocator)))
266        keynode.appendChild(node)
267        #if mergekeyksk != None:
268        #    node = doc.createElement("KSK")
269        #    keynode.appendChild(node)
270        #    keynode.appendChild(doc.createTextNode("\n              "))
271        #if mergekeyzsk != None:
272        #    node = doc.createElement("ZSK")
273        #    keynode.appendChild(node)
274        #    keynode.appendChild(doc.createTextNode("\n\t\t\t\t"))
275        if mergekeypublish != None:
276            node = doc.createElement("Publish")
277            keynode.appendChild(doc.createTextNode("\n\t\t\t\t"))
278            keynode.appendChild(node)
279        keynode.appendChild(doc.createTextNode("\n\t\t\t"))
280        keysnode.appendChild(doc.createTextNode("\t"))
281        keysnode.appendChild(keynode)
282        keysnode.appendChild(doc.createTextNode("\n\t\t"))
283    return doc
284
285def writesignconf(doc, signconf, prefix=None):
286    if prefix == None:
287        fname = signconf
288    else:
289        fname = os.path.join(os.path.dirname(signconf), prefix + "-" + os.path.basename(signconf))
290    with open(fname, "w") as f:
291        print(doc.toprettyxml(newl="",indent=""), file=f)
292
293def patchsignconf(doc, signconfkeys):
294    for keys in doc.getElementsByTagName('Keys'):
295        for key in keys.getElementsByTagName('Key'):
296            for locator in key.getElementsByTagName('Locator'):
297                keyname = locator.childNodes[0].data
298                if keyname in signconfkeys and 'keydata' in signconfkeys[keyname]:
299                    keydata = base64.b64encode(signconfkeys[keyname]['keydata'])
300                    # del signconfkeys[keyname]
301                    node = doc.createElement("PublicKeyData")
302                    node.appendChild(doc.createTextNode(keydata.decode('ascii')))
303                    key.appendChild(doc.createTextNode("	"))
304                    key.appendChild(node)
305                    key.appendChild(doc.createTextNode("\n			"))
306                else:
307                    keys.removeChild(key)
308                break
309
310'''
311Main program, In principe this module could be used from another program in
312which case no action is taken unless a method is explicitly called.
313'''
314if __name__ == "__main__":
315    result = main()
316    if result != 0:
317        sys.exit(result)
318