1
2template_class = \
3"""//
4// Copyright (C) 2020 Signal Messenger, LLC.
5// All rights reserved.
6//
7// SPDX-License-Identifier: GPL-3.0-only
8//
9// Generated by zkgroup/codegen/codegen.py - do not edit
10
11%(imports)s
12
13public class %(class_name)s : ByteArray {
14
15  public static let SIZE: Int = %(size)s
16%(static_methods)s%(constructors)s
17%(methods)s%(serialize_method)s}
18"""
19
20template_constructor = \
21"""
22  public init(contents: %(constructor_contents_type)s) %(constructor_exception_decl)s {
23    try super.init(newContents: %(constructor_contents)s, expectedLength: %(class_name)s.SIZE%(runtime_error_bool)s)
24%(check_valid_contents)s
25  }
26"""
27
28template_constructor_for_string_contents = \
29"""
30  public init(%(constructor_contents_type)s contents) %(constructor_exception_decl)s {
31    try super.init(newContents: %(constructor_contents)s, expectedLength: %(class_name)s.SIZE%(runtime_error_bool)s)
32%(check_valid_contents)s
33  }
34"""
35
36serialize_method_string = \
37"""
38  public func serialize() throws -> String {
39    do {
40      return new String(contents.clone(), "UTF-8")
41    } catch UnsupportedEncodingException e) {
42      throw ZkGroupException.AssertionError
43    }
44  }
45
46"""
47
48serialize_method_binary = \
49"""
50  public func serialize() -> [UInt8] {
51    return contents
52  }
53
54"""
55
56template_wrapping_class = \
57"""// Generated by zkgroup/codegen/codegen.py - do not edit
58
59%(imports)s
60
61public class %(class_name)s {
62
63  let %(wrapped_class_var)s: %(wrapped_class_type)s
64%(static_methods)s
65  public init(%(wrapped_class_var)s: %(wrapped_class_type)s) {
66    self.%(wrapped_class_var)s = %(wrapped_class_var)s
67  }
68%(methods)s
69}
70"""
71
72template_check_valid_contents_constructor = \
73"""
74
75    let ffi_return = FFI_%(class_name_camel)s_checkValidContents(self.contents, UInt32(self.contents.count))
76
77    if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
78      throw ZkGroupException.InvalidInput
79    }
80
81    if (ffi_return != Native.FFI_RETURN_OK) {
82      throw ZkGroupException.ZkGroupError
83    }"""
84
85template_check_valid_contents_constructor_runtime_error = \
86"""
87
88    let ffi_return = FFI_%(class_name_camel)s_checkValidContents(self.contents, UInt32(self.contents.count))
89
90    if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
91      throw ZkGroupException.IllegalArgument
92    }
93
94    if (ffi_return != Native.FFI_RETURN_OK) {
95      throw ZkGroupException.ZkGroupError
96    }"""
97
98template_static_method = \
99"""
100  %(access)s static func %(method_name)s(%(param_decls)s) %(exception_decl)s -> %(return_name)s {
101    var newContents: [UInt8] = Array(repeating: 0, count: %(return_name)s.SIZE)%(get_rand)s
102
103    let ffi_return = FFI_%(jni_method_name)s(%(param_args)s&newContents, UInt32(newContents.count))%(exception_check)s
104
105    if (ffi_return != Native.FFI_RETURN_OK) {
106      throw ZkGroupException.ZkGroupError
107    }
108
109    do {
110      return try %(return_name)s(contents: newContents)
111    } catch ZkGroupException.Invalid {
112      throw ZkGroupException.AssertionError
113    }
114  }
115"""
116
117template_static_method_retval_runtime_error_on_serialize = \
118"""
119  public static func %(method_name)s(%(param_decls)s) %(exception_decl)s -> %(return_name)s {
120    var newContents: [UInt8] = Array(repeating: 0, count: %(return_name)s.SIZE)%(get_rand)s
121
122    let ffi_return = FFI_%(jni_method_name)s(%(param_args)s&newContents, UInt32(newContents.count))%(exception_check)s
123
124    if (ffi_return != Native.FFI_RETURN_OK) {
125      throw ZkGroupException.ZkGroupError
126    }
127
128    do {
129      return try %(return_name)s(contents: newContents)
130    } catch ZkGroupException.IllegalArgument {
131      throw ZkGroupException.AssertionError
132    }
133  }
134"""
135
136template_static_method_rand_wrapper = \
137"""
138  public static func %(method_name)s(%(param_decls)s) %(exception_decl)s -> %(return_name)s {
139    var randomness: [UInt8] = Array(repeating: 0, count: Int(32))
140    let result = SecRandomCopyBytes(kSecRandomDefault, randomness.count, &randomness)
141    guard result == errSecSuccess else {
142      throw ZkGroupException.AssertionError
143    }
144
145    return try %(full_method_name)s(%(param_args)s)
146  }
147"""
148
149template_method = \
150"""
151  public func %(method_name)s(%(param_decls)s) %(exception_decl)s -> %(return_name)s {
152    var newContents: [UInt8] = Array(repeating: 0, count: %(return_len)s)
153
154    let ffi_return = FFI_%(jni_method_name)s(%(contents)s, %(param_args)s&newContents, UInt32(newContents.count))%(exception_check)s
155
156    if (ffi_return != Native.FFI_RETURN_OK) {
157      throw ZkGroupException.ZkGroupError
158    }
159
160    do {
161      return try %(return_name)s(contents: newContents)
162    } catch ZkGroupException.InvalidInput {
163      throw ZkGroupException.AssertionError
164    }
165
166  }
167"""
168
169template_method_retval_runtime_error_on_serialize = \
170"""
171  public func %(method_name)s(%(param_decls)s) %(exception_decl)s -> %(return_name)s {
172    var newContents: [UInt8] = Array(repeating: 0, count: Int(%(return_len)s))%(get_rand)s
173
174    let ffi_return = FFI_%(jni_method_name)s(%(contents)s, %(param_args)s&newContents, UInt32(newContents.count))%(exception_check)s
175
176    if (ffi_return != Native.FFI_RETURN_OK) {
177      throw ZkGroupException.ZkGroupError
178    }
179
180    return try %(return_name)s(contents: newContents)
181  }
182"""
183
184template_method_bool = \
185"""
186  public func %(method_name)s(%(param_decls)s) %(exception_decl)s{
187    let ffi_return = FFI_%(jni_method_name)s(%(contents)s, %(param_args)s)%(exception_check)s
188
189    if (ffi_return != Native.FFI_RETURN_OK) {
190      throw ZkGroupException.ZkGroupError
191    }
192  }
193"""
194
195template_method_uuid = \
196"""
197  public func %(method_name)s(%(param_decls)s) %(exception_decl)s -> %(return_name)s {
198    var newContents: [UInt8] = Array(repeating: 0, count: Int(%(return_len)s))
199
200    let ffi_return = FFI_%(jni_method_name)s(%(contents)s, %(param_args)s&newContents, UInt32(newContents.count))%(exception_check)s
201
202    if (ffi_return != Native.FFI_RETURN_OK) {
203      throw ZkGroupException.ZkGroupError
204    }
205
206    return UUIDUtil.deserialize(newContents)
207  }
208"""
209
210template_method_bytearray = \
211"""
212  public func %(method_name)s(%(param_decls)s) %(exception_decl)s -> [UInt8] {
213    var newContents: [UInt8] = Array(repeating: 0, count: Int(%(return_len)s))
214
215    let ffi_return = FFI_%(jni_method_name)s(%(contents)s, %(param_args)s&newContents, UInt32(newContents.count))%(exception_check)s
216
217    if (ffi_return != Native.FFI_RETURN_OK) {
218      throw ZkGroupException.ZkGroupError
219    }
220
221    return newContents
222  }
223"""
224
225template_method_int = \
226"""
227  public func %(method_name)s(%(param_decls)s) %(exception_decl)s -> UInt32 {
228    var newContents: [UInt8] = Array(repeating: 0, count: Int(4))
229
230    let ffi_return = FFI_%(jni_method_name)s(%(contents)s, %(param_args)s&newContents, UInt32(newContents.count))%(exception_check)s
231
232    if (ffi_return != Native.FFI_RETURN_OK) {
233      throw ZkGroupException.ZkGroupError
234     }
235
236    let data = Data(bytes: newContents)
237    let value = UInt32(bigEndian: data.withUnsafeBytes { $0.pointee })
238    return value
239  }
240"""
241
242template_method_rand_wrapper = \
243"""
244  public func %(method_name)s(%(param_decls)s) %(exception_decl)s -> %(return_name)s {
245    var randomness: [UInt8] = Array(repeating: 0, count: Int(32))
246    let result = SecRandomCopyBytes(kSecRandomDefault, randomness.count, &randomness)
247    guard result == errSecSuccess else {
248      throw ZkGroupException.AssertionError
249    }
250
251    return try %(full_method_name)s(%(param_args)s)
252  }
253"""
254
255def add_import(import_strings, class_dir_dict, my_dir_name, class_name):
256    dir_name = class_dir_dict[class_name.snake()].snake()
257    if len(dir_name)==0 and len(my_dir_name.snake()) == 0:
258        return
259    elif dir_name == my_dir_name.snake():
260        return
261    """
262    if dir_name:
263        import_strings.append("import org.signal.zkgroup.%s.%s;" % (dir_name, class_name.camel()))
264    else:
265        import_strings.append("import org.signal.zkgroup.%s;" % (class_name.camel()))
266    """
267
268def get_decls(params, import_strings, class_dir_dict, my_dir_name):
269    s = ""
270    for param in params:
271        if param[1].snake() == "randomness":
272            s += "randomness: [UInt8], "
273            #SWIFT import_strings.append("import java.security.SecureRandom;")
274        elif param[0] == "class":
275            s += param[1].lower_camel() + ": " + param[1].camel() + ", "
276            #SWIFT add_import(import_strings, class_dir_dict, my_dir_name, param[1])
277        elif param[0] == "byte[]":
278            s += param[1].lower_camel() + ": [UInt8], "
279            #SWIFT add_import(import_strings, class_dir_dict, my_dir_name, param[1])
280        elif param[0] == "int":
281            s += param[1].lower_camel() + ": UInt32, "
282            #SWIFT add_import(import_strings, class_dir_dict, my_dir_name, param[1])
283        elif param[0] == "UUID":
284            s += param[1].lower_camel() + ": ZKGUuid, "
285            #SWIFT add_import(import_strings, class_dir_dict, my_dir_name, param[1])
286        else:
287            s += param[1].lower_camel() + ": " + param[0] + ", "
288    if len(s) != 0:
289        s = s[:-2]
290    return s
291
292def get_rand_wrapper_decls(params):
293    s = ""
294    for param in params:
295        if param[1].snake() != "randomness":
296            if param[0] == "class":
297                s += param[1].lower_camel() + ": " + param[1].camel() + ", "
298            elif param[0] == "byte[]":
299                s += param[1].lower_camel() + ": [UInt8], "
300            elif param[0] == "int":
301                s += param[1].lower_camel() + ": UInt32, "
302            elif param[0] == "UUID":
303                s += param[1].lower_camel() + ": ZKGUuid, "
304            else:
305                s += param[1].lower_camel() + ": " + param[0] + ", "
306    if len(s) != 0:
307        s = s[:-2]
308    return s
309
310
311def get_args(params, import_strings, commaAtEnd):
312    s = ""
313    for param in params:
314        if param[0] == "byte[]" or param[0] == "int":
315            term = param[1].lower_camel()
316        # SWIFT elif param[0] == "UUID":
317        # SWIFT    term = "UUIDUtil.serialize(" + param[1].lower_camel() + ")"
318        elif param[1].snake() == "randomness":
319            term = "randomness"
320        else:
321            term = param[1].lower_camel() + ".getInternalContentsForFFI()"
322
323        if param[0] != "int":
324            s += term + ", UInt32(" + term + ".count), "
325        else:
326            s += term + ", "
327
328    if len(s) != 0 and not commaAtEnd:
329        s = s[:-2]
330    return s
331
332def get_jni_arg_decls(params, selfBool, commaAtEndBool):
333    s = ""
334    if selfBool:
335        s += "byte[] self, "
336    counter = 0
337    for param in params:
338        if param[0] == "byte[]":
339            s += "byte[] %s, " % param[1].lower_camel()
340        elif param[0] == "int":
341            s += "int %s, " % param[1].lower_camel()
342        # elif param[0] == "UUID":
343        #    s += "byte[] %s, " % param[1].lower_camel()
344        elif param[1].snake() == "randomness":
345            s += "byte[] %s, " % param[1].lower_camel()
346        else:
347            s += "byte[] %s, " % param[1].lower_camel()
348        counter += 1
349
350    if len(s) != 0 and not commaAtEndBool:
351        s = s[:-2]
352
353    if commaAtEndBool:
354        s += "byte[] output"
355
356    return s
357
358def get_rand_wrapper_args(params, commaAtEnd):
359    s = ""
360    for param in params:
361        if param[0] == "byte[]":
362            s += param[1].lower_camel() + ": " + param[1].lower_camel() + ", "
363        else:
364            s += param[1].lower_camel() + ": " + param[1].lower_camel() + ", "
365    if len(s) != 0 and not commaAtEnd:
366        s = s[:-2]
367    return s
368
369def print_class(c, runtime_error_on_serialize_dict, class_dir_dict):
370    static_methods_string = ""
371    if len(c.methods) == 0 and len(c.static_methods) == 0:
372        import_strings = []
373    else:
374        import_strings = [\
375"import libzkgroup",
376"import Foundation"
377            ]
378
379    my_dir_name = c.dir_name
380
381    if c.wrap_class == None:
382        contents = "self.contents"
383    else:
384        contents = c.wrap_class.lower_camel() + ".getInternalContentsForFFI()"
385        add_import(import_strings, class_dir_dict, my_dir_name, c.wrap_class)
386    contents += ", UInt32(%s.count)" % contents
387
388    for method in c.static_methods:
389
390        exception_decl = "throws "
391        exception_check =""
392        if len(method.params) > 1 or (len(method.params) == 1 and not method.method_name.snake().endswith("_deterministic")):
393            if method.runtime_error == False:
394                if method.verification == False:
395                    #if my_dir_name.snake() != "":
396                    #    import_strings.append("import org.signal.zkgroup.VerificationFailedException;")
397                    exception_decl = "throws "
398                    exception_check ="""\n    if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
399      throw ZkGroupException.VerificationFailed
400    }"""
401                else:
402                    exception_decl = "throws VerificationFailed "
403                    exception_check ="""\n    if (ffi_return == Native.FFI_RETURN_VERIFICATION_FAILED) {
404      throw ZkGroupException.VerificationFailed
405    }"""
406
407        access = "public"
408        method_name = method.method_name.lower_camel()
409        get_rand = ""
410        if method.method_name.snake().endswith("_deterministic"):
411            method_name = method.method_name.lower_camel()[:-len("Deterministic")]
412            param_args = get_rand_wrapper_args(method.params, False)
413            get_rand = ""
414            static_methods_string += template_static_method_rand_wrapper % {
415                    "method_name": method_name,
416                    "return_name": method.return_name.camel(),
417                    "full_method_name": method_name,
418                    "param_decls": get_rand_wrapper_decls(method.params),
419                    "param_args": param_args,
420                    "access": access,
421                    "exception_decl": exception_decl,
422                    "exception_check": exception_check,
423                    }
424        param_args = get_args(method.params, import_strings, True)
425        if c.wrap_class == None:
426            jni_method_name = c.class_name.camel() + "_" + method.method_name.lower_camel()
427        else:
428            jni_method_name = c.wrap_class.camel() + "_" + method.method_name.lower_camel()
429        if runtime_error_on_serialize_dict[method.return_name.snake()]:
430            template = template_static_method_retval_runtime_error_on_serialize
431        else:
432            template = template_static_method
433        static_methods_string += template % {
434                "method_name": method_name,
435                "return_name": method.return_name.camel(),
436                "return_len": c.class_len,
437                "param_decls": get_decls(method.params, import_strings, class_dir_dict, my_dir_name),
438                "param_args": param_args,
439                "jni_method_name": jni_method_name,
440                "access": access,
441                "exception_decl": exception_decl,
442                "exception_check": exception_check,
443                "get_rand": get_rand,
444                }
445
446    methods_string = ""
447    for method in c.methods:
448
449        if method.method_name.snake() == "check_valid_contents":
450            continue
451
452        exception_decl = "throws "
453        exception_check =""
454        if len(method.params) != 0:
455            if method.runtime_error == False:
456                if method.verification == False:
457                    #if my_dir_name.snake() != "":
458                    #    import_strings.append("import org.signal.zkgroup.VerificationFailedException;")
459                    exception_decl = "throws "
460                    exception_check ="""\n    if (ffi_return == Native.FFI_RETURN_INPUT_ERROR) {
461      throw ZkGroupException.VerificationFailed
462    }"""
463                else:
464                    exception_decl = "throws "
465                    exception_check ="""\n    if (ffi_return == Native.FtFI_RETURN_VERIFICATION_FAILED) {
466      throw ZkGroupException.VerificationFailed
467    }"""
468
469        access = "public"
470        method_name = method.method_name.lower_camel()
471        get_rand = ""
472        if method.method_name.snake().endswith("_deterministic"):
473            method_name = method.method_name.lower_camel()[:-len("Deterministic")]
474            param_args = get_rand_wrapper_args(method.params, False)
475            get_rand = """\n    byte[] random      = byte[Native.RANDOM_LENGTH];
476
477    secureRandom.nextBytes(random);"""
478            methods_string += template_method_rand_wrapper % {
479                    "contents": contents,
480                    "method_name": method_name,
481                    "return_name": method.return_name.camel(),
482                    "full_method_name": method_name,
483                    "param_decls": get_rand_wrapper_decls(method.params),
484                    "param_args": param_args,
485                    "access": access,
486                    "exception_decl": exception_decl,
487                    "exception_check": exception_check,
488                    }
489
490        if c.wrap_class == None:
491            jni_method_name = c.class_name.camel() + "_" + method.method_name.lower_camel()
492        else:
493            jni_method_name = c.wrap_class.camel()  + "_" + method.method_name.lower_camel()
494
495        return_name = None
496        return_len = None
497
498        if method.return_type == "UUID":
499            return_name = "ZKGUuid"
500            return_len = "ZKGUuid.SIZE"
501
502        if method.return_type == "boolean":
503            template = template_method_bool
504            param_args = get_args(method.params, import_strings, False)
505        elif method.return_type == "int":
506            template = template_method_int
507            param_args = get_args(method.params, import_strings, False)
508        elif method.return_type == "byte[]": # copied from UUID?
509            template = template_method_bytearray
510            param_args = get_args(method.params, import_strings, True)
511            return_len = method.params[0][1].lower_camel()  # hardcode to first arg
512            if method.return_size_increment >= 0:
513                return_len += ".count+%d" % method.return_size_increment
514            if method.return_size_increment < 0:
515                return_len += ".count+%d" % method.return_size_increment
516        else:
517            add_import(import_strings, class_dir_dict, my_dir_name, method.return_name)
518            if runtime_error_on_serialize_dict[method.return_name.snake()]:
519                template = template_method_retval_runtime_error_on_serialize
520            else:
521                template = template_method
522            param_args = get_args(method.params, import_strings, True)
523        if return_name == None:
524            return_name = method.return_name.camel()
525        if return_len == None:
526            return_len = method.return_name.camel() + ".SIZE"
527
528
529        methods_string += template % {
530                "contents": contents,
531                "method_name": method_name,
532                "return_name": return_name,
533                "return_len": return_len,
534                "param_decls": get_decls(method.params, import_strings, class_dir_dict, my_dir_name),
535                "param_args": param_args,
536                "jni_method_name": jni_method_name,
537                "access": access,
538                "exception_decl": exception_decl,
539                "exception_check": exception_check,
540                "get_rand": get_rand,
541                }
542
543    if c.dir_name.snake() != "":
544        dir_section = "." + c.dir_name.snake()
545    else:
546        dir_section = ""
547
548    constructor_exception_decl = "throws " # overwritten in needed
549    runtime_error_bool = ""
550    if c.check_valid_contents:
551        if c.runtime_error_on_serialize:
552            constructor_exception_decl = "throws " # overwritten in needed
553            runtime_error_bool = ", unrecoverable: true"
554            check_valid_contents = template_check_valid_contents_constructor_runtime_error % {
555                    "class_name_camel": c.class_name.camel(),
556                    }
557            jni_method_name = c.class_name.lower_camel() + "CheckValidContentsFFI"
558        else:
559            constructor_exception_decl = "throws " # overwritten in needed
560            check_valid_contents = template_check_valid_contents_constructor % {
561                    "class_name_camel": c.class_name.camel(),
562                    }
563            jni_method_name = c.class_name.lower_camel() + "CheckValidContentsFFI"
564    else:
565        check_valid_contents = ""
566
567    if c.no_serialize:
568        constructor_access = "private"
569    else:
570        constructor_access = "public"
571
572    import_strings = list(set(import_strings))
573    import_strings.sort()
574
575    # constructors
576    constructors_string = ""
577    constructor_contents = "contents"
578    constructor_contents_type = "[UInt8]"
579    constructors_string += template_constructor % {
580        "class_name": c.class_name.camel(),
581        "constructor_contents": constructor_contents,
582        "constructor_contents_type": constructor_contents_type,
583        "constructor_access": constructor_access,
584        "constructor_exception_decl": constructor_exception_decl,
585        "runtime_error_bool": runtime_error_bool,
586        "check_valid_contents": check_valid_contents,
587        }
588
589    # if c.string_contents == False:
590    if True:
591        serialize_method = serialize_method_binary
592    else:
593        constructor_contents = 'contents.getBytes("UTF-8")'
594        constructor_contents_type = "String"
595        serialize_method = serialize_method_string
596        constructors_string += template_constructor_for_string_contents % {
597            "class_name": c.class_name.camel(),
598            "constructor_contents": constructor_contents,
599            "constructor_contents_type": constructor_contents_type,
600            "constructor_access": constructor_access,
601            "constructor_exception_decl": constructor_exception_decl + ", UnsupportedEncodingException",
602            "runtime_error_bool": runtime_error_bool,
603            "check_valid_contents": check_valid_contents,
604            }
605    constructors_string = constructors_string[:-1]
606
607    if c.wrap_class != None:
608        class_string = template_wrapping_class % {
609                "imports": "\n".join(import_strings),
610                "wrapped_class_type": c.wrap_class.camel(),
611                "wrapped_class_var": c.wrap_class.lower_camel(),
612                "dir_section": dir_section,
613                "class_name": c.class_name.camel(),
614                "size": c.class_len_int,
615                "constructors": constructors_string,
616                "static_methods": static_methods_string,
617                "methods": methods_string,
618                "serialize_method": serialize_method
619                }
620    else:
621        class_string = template_class % {
622                "imports": "\n".join(import_strings),
623                "dir_section": dir_section,
624                "class_name": c.class_name.camel(),
625                "size": c.class_len_int,
626                "constructors": constructors_string,
627                "static_methods": static_methods_string,
628                "methods": methods_string,
629                "serialize_method": serialize_method
630                }
631    return class_string
632
633def produce_output(classes):
634
635    runtime_error_on_serialize_dict = {}
636    class_dir_dict = {}
637    for c in classes:
638        runtime_error_on_serialize_dict[c.class_name.snake()] = c.runtime_error_on_serialize
639        class_dir_dict[c.class_name.snake()] = c.dir_name
640
641    for c in classes:
642        if c.no_class:
643            continue
644        f = open("swift/%s.swift" % c.class_name.camel(), "w")
645        f.write(print_class(c, runtime_error_on_serialize_dict, class_dir_dict))
646        f.close()
647