1#!/usr/local/bin/python3.8
2#
3# Copyright (c) 2017 Nuxi (https://nuxi.nl/) and contributors.
4#
5# Redistribution and use in source and binary forms, with or without
6# modification, are permitted provided that the following conditions
7# are met:
8# 1. Redistributions of source code must retain the above copyright
9#    notice, this list of conditions and the following disclaimer.
10# 2. Redistributions in binary form must reproduce the above copyright
11#    notice, this list of conditions and the following disclaimer in the
12#    documentation and/or other materials provided with the distribution.
13#
14# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
15# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
18# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
20# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
21# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
22# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
23# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
24# SUCH DAMAGE.
25
26import hashlib
27import pypeg2
28import re
29import sys
30import toposort
31
32
33# TODO(ed): Fill in more language keywords.
34FORBIDDEN_WORDS = {'namespace'}
35
36
37class ScalarType:
38
39    def get_dependencies(self):
40        return set()
41
42    def print_fields(self, name, declarations):
43        print('  %s %s_;' % (self.get_storage_type(declarations), name))
44
45
46class NumericType(ScalarType):
47
48    def get_initializer(self, name, declarations):
49        return '%s_(%s)' % (name, self.get_default_value())
50
51    def get_isset_expression(self, name, declarations):
52        return '%s_ != %s' % (name, self.get_default_value())
53
54    def print_accessors(self, name, declarations):
55        print('  %s %s() const { return %s_; }' % (self.get_storage_type(declarations), name, name))
56        print('  void set_%s(%s value) { %s_ = value; }' % (name, self.get_storage_type(declarations), name))
57        print('  void clear_%s() { %s_ = %s; }' % (name, name, self.get_default_value()))
58
59    def print_accessors_repeated(self, name, declarations):
60        print('  %s %s(std::size_t index) const { return %s_[index]; }' % (self.get_storage_type(declarations), name, name))
61        print('  void set_%s(std::size_t index, %s value) { %s_[index] = value; }' % (name, self.get_storage_type(declarations), name))
62        print('  void add_%s(%s value) { %s_.push_back(value); }' % (name, self.get_storage_type(declarations), name))
63
64
65class IntegerType(NumericType):
66
67    def __init__(self, name):
68        self._name = name
69
70    def get_default_value(self):
71        return '0'
72
73    def get_storage_type(self, declarations):
74        return 'std::%s_t' % self._name
75
76    def print_building(self, name, declarations):
77        print('      values.push_back(argdata_builder->BuildInt(%s_));' % name)
78
79    def print_building_repeated(self, declarations):
80        print('        elements.push_back(argdata_builder->BuildInt(element));')
81
82    def print_parsing(self, name, declarations):
83        print('          argdata_get_int(value, &%s_);' % name)
84
85    def print_parsing_map_key(self):
86        print('          std::%s_t mapkey;' % self._name)
87        print('          if (argdata_get_int(key2, &mapkey) == 0) {')
88
89    def print_parsing_map_value(self, name, declarations):
90        print('              std::%s_t value2int;' % self._name);
91        print('              if (argdata_get_int(value2, &value2int) == 0)')
92        print('                %s_.emplace(mapkey, 0).first->second = value2int;' % name)
93
94    def print_parsing_repeated(self, name, declarations):
95        print('            std::%s_t elementint;' % self._name)
96        print('            if (argdata_get_int(element, &elementint) == 0)')
97        print('              %s_.push_back(elementint);' % name)
98
99
100class Int32Type(IntegerType):
101
102    grammar = ['int32', 'sint32', 'sfixed32']
103
104    def __init__(self):
105        super(Int32Type, self).__init__('int32')
106
107
108class UInt32Type(IntegerType):
109
110    grammar = ['uint32', 'fixed32']
111
112    def __init__(self):
113        super(UInt32Type, self).__init__('uint32')
114
115
116class Int64Type(IntegerType):
117
118    grammar = ['int64', 'sint64', 'sfixed64']
119
120    def __init__(self):
121        super(Int64Type, self).__init__('int64')
122
123
124class UInt64Type(IntegerType):
125
126    grammar = ['uint64', 'fixed64']
127
128    def __init__(self):
129        super(UInt64Type, self).__init__('uint64')
130
131
132class FloatingPointType(NumericType):
133
134    grammar = ['double', 'float']
135
136    def get_default_value(self):
137        return '0.0'
138
139    def get_storage_type(self, declarations):
140        return 'double'
141
142    def print_parsing(self, name, declarations):
143        print('            argdata_get_float(value, &%s_);' % name)
144
145
146class BooleanType(NumericType):
147
148    grammar = ['bool']
149
150    def get_default_value(self):
151        return 'false'
152
153    def get_storage_type(self, declarations):
154        return 'bool'
155
156    def print_building(self, name, declarations):
157        print('      values.push_back(&argdata_true);')
158
159    def print_parsing(self, name, declarations):
160        print('          argdata_get_bool(value, &%s_);' % name)
161
162
163class StringlikeType(ScalarType):
164
165    def get_initializer(self, name, declarations):
166        return ''
167
168    def get_isset_expression(self, name, declarations):
169        return '!%s_.empty()' % name
170
171    def get_storage_type(self, declarations):
172        return 'std::string'
173
174    def print_accessors(self, name, declarations):
175        print('  const std::string& %s() const { return %s_; }' % (name, name))
176        print('  void set_%s(std::string_view value) { %s_ = value; }' % (name, name))
177        print('  std::string* mutable_%s() { return &%s_; }' % (name, name))
178        print('  void clear_%s() { %s_.clear(); }' % (name, name))
179
180    def print_accessors_repeated(self, name, declarations):
181        print('  const std::string& %s(std::size_t index) const { return %s_[index]; }' % (name, name))
182        print('  void set_%s(std::size_t index, std::string_view value) { %s_[index] = value; }' % (name, name))
183        print('  std::string* mutable_%s(std::size_t index) { return &%s_[index]; }' % (name, name))
184        print('  void add_%s(std::string_view value) { %s_.emplace_back(value); }' % (name, name))
185        print('  std::string* add_%s() { return &%s_.emplace_back(); }' % (name, name))
186
187
188class StringType(StringlikeType):
189
190    grammar = ['string']
191
192    def print_building(self, name, declarations):
193        print('      values.push_back(argdata_builder->BuildStr(%s_));' % name)
194
195    def print_building_map_key(self):
196        print('        mapkeys.push_back(argdata_builder->BuildStr(mapentry.first));')
197
198    def print_building_map_value(self, declarations):
199        print('        mapvalues.push_back(argdata_builder->BuildStr(mapentry.second));')
200
201    def print_building_repeated(self, declarations):
202        print('        elements.push_back(argdata_builder->BuildStr(element));')
203
204    def print_parsing(self, name, declarations):
205        print('          const char* valuestr;');
206        print('          std::size_t valuelen;');
207        print('          if (argdata_get_str(value, &valuestr, &valuelen) == 0)')
208        print('            %s_ = std::string_view(valuestr, valuelen);' % name)
209
210    def print_parsing_map_key(self):
211        print('            const char* key2str;');
212        print('            std::size_t key2len;');
213        print('            if (argdata_get_str(key2, &key2str, &key2len) == 0) {')
214        print('              std::string_view mapkey(key2str, key2len);')
215
216    def print_parsing_map_value(self, name, declarations):
217        print('              const char* value2str;');
218        print('              std::size_t value2len;');
219        print('              if (argdata_get_str(value2, &value2str, &value2len) == 0)')
220        print('                %s_.emplace(mapkey, std::string()).first->second = std::string_view(value2str, value2len);' % name)
221
222    def print_parsing_repeated(self, name, declarations):
223        print('            const char* elementstr;');
224        print('            std::size_t elementlen;');
225        print('            if (argdata_get_str(element, &elementstr, &elementlen) == 0)')
226        print('              %s_.emplace_back(std::string_view(elementstr, elementlen));' % name)
227
228
229class BytesType(StringlikeType):
230
231    grammar = ['bytes']
232
233    def print_parsing(self, name, declarations):
234        print('          const void* valuestr;');
235        print('          std::size_t valuelen;');
236        print('          if (argdata_get_binary(value, &valuestr, &valuelen) == 0)')
237        print('            %s_ = std::string_view(static_cast<const char*>(valuestr), valuelen);' % name)
238
239
240class FileDescriptorType:
241
242    grammar = ['fd']
243
244    def get_dependencies(self):
245        return set()
246
247    def get_initializer(self, name, declarations):
248        return ''
249
250    def get_isset_expression(self, name, declarations):
251        return name + '_'
252
253    def get_storage_type(self, declarations):
254        return 'std::shared_ptr<arpc::FileDescriptor>'
255
256    def print_accessors(self, name, declarations):
257        print('  const std::shared_ptr<arpc::FileDescriptor>& %s() const { return %s_; }' % (name, name))
258        print('  void set_%s(const std::shared_ptr<arpc::FileDescriptor>& value) { %s_ = value; }' % (name, name))
259        print('  void clear_%s() { %s_.reset(); }' % (name, name))
260
261    def print_accessors_repeated(self, name, declarations):
262        print('  const std::shared_ptr<arpc::FileDescriptor>& %s(std::size_t index) const { return %s_[index]; }' % (name, name))
263        print('  void set_%s(std::size_t index, const std::shared_ptr<arpc::FileDescriptor>& value) { %s_[index] = value; }' % (name, name))
264        print('  void add_%s(const std::shared_ptr<arpc::FileDescriptor>& value) { %s_.push_back(value); }' % (name, name))
265
266    def print_building(self, name, declarations):
267        print('      values.push_back(argdata_builder->BuildFd(%s_));' % name)
268
269    def print_fields(self, name, declarations):
270        print('  std::shared_ptr<arpc::FileDescriptor> %s_;' % name)
271
272    def print_parsing(self, name, declarations):
273        print('          std::shared_ptr<arpc::FileDescriptor> fd = argdata_parser->ParseFileDescriptor(*value);')
274        print('          if (fd)')
275        print('            %s_ = std::move(fd);' % name)
276
277    def print_parsing_map_value(self, name, declarations):
278        print('          std::shared_ptr<arpc::FileDescriptor> fd = argdata_parser->ParseFileDescriptor(*key2);')
279        print('          if (fd)')
280        print('            %s_.emplace(mapkey, nullptr).first->second = std::move(fd);' % name)
281
282    def print_parsing_repeated(self, name, declarations):
283        print('          std::shared_ptr<arpc::FileDescriptor> fd = argdata_parser->ParseFileDescriptor(*element);')
284        print('          if (fd)')
285        print('            %s_.emplace_back(std::move(fd));' % name)
286
287
288class AnyType:
289
290    grammar = ['google.protobuf.Any']
291
292    def get_dependencies(self):
293        return set()
294
295    def get_initializer(self, name, declarations):
296        return '%s_(nullptr)' % name
297
298    def get_isset_expression(self, name, declarations):
299        return '%s_ != nullptr' % name
300
301    def print_accessors(self, name, declarations):
302        print('  bool has_%s() const { return %s_ != nullptr; }' % (name, name))
303        print('  const argdata_t* %s() const { return %s_ == nullptr ? &argdata_null : %s_; }' % (name, name, name))
304        print('  void set_%s(const argdata_t* value) { %s_ = value; }' % (name, name))
305        print('  void clear_%s() { %s_ = nullptr; }' % (name, name))
306
307    def print_building(self, name, declarations):
308        print('      values.push_back(%s_);' % name)
309
310    def print_fields(self, name, declarations):
311        print('  const argdata_t* %s_;' % name)
312
313    def print_parsing(self, name, declarations):
314        print('          %s_ = argdata_parser->ParseAnyFromMap(it);' % name)
315
316
317class ReferenceType:
318
319    grammar = pypeg2.word
320
321    def __init__(self, name):
322        self._name = name
323
324    def get_dependencies(self):
325        return {self._name}
326
327    def get_initializer(self, name, declarations):
328        return declarations[self._name].get_initializer(name)
329
330    def get_isset_expression(self, name, declarations):
331        return declarations[self._name].get_isset_expression(name)
332
333    def get_storage_type(self, declarations):
334        return self._name
335
336    def is_stream(self):
337        return False
338
339    def print_accessors(self, name, declarations):
340        declarations[self._name].print_accessors(name)
341
342    def print_accessors_repeated(self, name, declarations):
343        declarations[self._name].print_accessors_repeated(name)
344
345    def print_building(self, name, declarations):
346        declarations[self._name].print_building(name)
347
348    def print_building_repeated(self, declarations):
349        declarations[self._name].print_building_repeated()
350
351    def print_fields(self, name, declarations):
352        declarations[self._name].print_fields(name)
353
354    def print_parsing(self, name, declarations):
355        declarations[self._name].print_parsing(name)
356
357    def print_parsing_map_value(self, name, declarations):
358        declarations[self._name].print_parsing_map_value(name)
359
360    def print_parsing_repeated(self, name, declarations):
361        declarations[self._name].print_parsing_repeated(name)
362
363
364PrimitiveType = [
365    Int32Type,
366    UInt32Type,
367    Int64Type,
368    UInt64Type,
369    FloatingPointType,
370    BooleanType,
371    StringType,
372    BytesType,
373    FileDescriptorType,
374    AnyType,
375    ReferenceType,
376]
377
378
379class MapType:
380
381    grammar = 'map', '<', [
382        Int32Type,
383        UInt32Type,
384        Int64Type,
385        UInt64Type,
386        BooleanType,
387        StringType,
388    ], ',', PrimitiveType, '>'
389
390    def __init__(self, arguments):
391        self._key_type = arguments[0]
392        self._value_type = arguments[1]
393
394    def get_dependencies(self):
395        return (self._key_type.get_dependencies() |
396                self._value_type.get_dependencies())
397
398    def get_isset_expression(self, name, declarations):
399        return '!%s_.empty()' % name
400
401    def get_initializer(self, name, declarations):
402        return ''
403
404    def get_storage_type(self, declarations):
405        return 'std::map<%s, %s, std::less<>>' % (self._key_type.get_storage_type(declarations),
406                                                  self._value_type.get_storage_type(declarations))
407
408    def print_accessors(self, name, declarations):
409        print('  const %s& %s() const { return %s_; }' % (self.get_storage_type(declarations), name, name))
410        print('  %s* mutable_%s() { return &%s_; }' % (self.get_storage_type(declarations), name, name))
411
412    def print_building(self, name, declarations):
413        print('      std::vector<const argdata_t*> mapkeys;')
414        print('      std::vector<const argdata_t*> mapvalues;')
415        print('      for (const auto& mapentry : %s_) {' % name)
416        self._key_type.print_building_map_key()
417        self._value_type.print_building_map_value(declarations)
418        print('      }')
419        print('      values.push_back(argdata_builder->BuildMap(std::move(mapkeys), std::move(mapvalues)));')
420
421    def print_fields(self, name, declarations):
422        print('  %s %s_;' % (self.get_storage_type(declarations), name))
423
424    def print_parsing(self, name, declarations):
425        print('          argdata_map_iterator_t it2;')
426        print('          argdata_map_iterate(value, &it2);')
427        print('          const argdata_t* key2, *value2;')
428        print('          while (argdata_map_get(&it2, &key2, &value2)) {')
429        self._key_type.print_parsing_map_key()
430        self._value_type.print_parsing_map_value(name, declarations)
431        print('            }')
432        print('            argdata_map_next(&it2);')
433        print('          }')
434
435
436class RepeatedType:
437
438    grammar = 'repeated', PrimitiveType
439
440    def __init__(self, type):
441        self._type = type
442
443    def get_dependencies(self):
444        return self._type.get_dependencies()
445
446    def get_initializer(self, name, declarations):
447        return ''
448
449    def get_isset_expression(self, name, declarations):
450        return '!%s_.empty()' % name
451
452    def get_storage_type(self, declarations):
453        return 'std::vector<%s>' % self._type.get_storage_type(declarations)
454
455    def print_accessors(self, name, declarations):
456        print('  std::size_t %s_size() const { return %s_.size(); }' % (name, name))
457        self._type.print_accessors_repeated(name, declarations)
458        print('  void clear_%s() { %s_.clear(); }' % (name, name))
459        print('  const %s& %s() const { return %s_; }' % (self.get_storage_type(declarations), name, name))
460        print('  %s* mutable_%s() { return &%s_; }' % (self.get_storage_type(declarations), name, name))
461
462    def print_building(self, name, declarations):
463        print('      std::vector<const argdata_t*> elements;')
464        print('      for (const auto& element : %s_) {' % name)
465        self._type.print_building_repeated(declarations)
466        print('      }')
467        print('      values.push_back(argdata_builder->BuildSeq(std::move(elements)));')
468
469    def print_fields(self, name, declarations):
470        print('  %s %s_;' % (self.get_storage_type(declarations), name))
471
472    def print_parsing(self, name, declarations):
473        print('          argdata_seq_iterator_t it2;')
474        print('          argdata_seq_iterate(value, &it2);')
475        print('          const argdata_t* element;')
476        print('          while (argdata_seq_get(&it2, &element)) {')
477        self._type.print_parsing_repeated(name, declarations)
478        print('            argdata_seq_next(&it2);')
479        print('          }')
480
481
482class StreamType:
483
484    grammar = 'stream', ReferenceType
485
486    def __init__(self, type):
487        self._type = type
488
489    def get_dependencies(self):
490        return self._type.get_dependencies()
491
492    def get_storage_type(self, declarations):
493        return self._type.get_storage_type(declarations)
494
495    def is_stream(self):
496        return True
497
498
499class EnumDeclaration:
500
501    grammar = 'enum', pypeg2.word, '{', pypeg2.some(
502        pypeg2.word, '=', re.compile(r'\d+'), ';'
503    ), '}'
504
505    def __init__(self, arguments):
506        self._name = arguments[0]
507        self._constants = {}
508        self._canonical = {}
509        for i in range(1, len(arguments), 2):
510            key = arguments[i]
511            value = int(arguments[i + 1])
512            self._constants[key] = value
513            if value not in self._canonical:
514                self._canonical[value] = key
515
516    def get_dependencies(self):
517        return set()
518
519    def get_isset_expression(self, name):
520        return '%s_ != %s::%s' % (name, self._name, self._canonical[0])
521
522    def get_initializer(self, name):
523        return '%s_(%s::%s)' % (name, self._name, self._canonical[0])
524
525    def get_name(self):
526        return self._name
527
528    def print_accessors(self, name):
529        print('  %s %s() const { return %s_; }' % (self._name, name, name))
530        print('  void set_%s(%s value) { %s_ = value; }' % (name, self._name, name))
531        print('  void clear_%s() { %s_ = %s::%s; }' % (name, name, self._name, self._canonical[0]))
532
533    def print_accessors_repeated(self, name):
534        print('  %s %s(std::size_t index) const { return %s_[index]; }' % (self._name, name, name))
535        print('  void set_%s(std::size_t index, %s value) { %s_[index] = value; }' % (name, self._name, name))
536        print('  void add_%s(%s value) { return %s_.push_back(value); }' % (name, self._name, name))
537
538    def print_building(self, name):
539        print('      values.push_back(argdata_builder->BuildStr(%s_Name(%s_)));' % (self._name, name))
540
541    def print_building_repeated(self):
542        print('        elements.push_back(argdata_builder->BuildStr(%s_Name(element)));' % self._name)
543
544    def print_code(self, declarations):
545        print('enum %s {' % self._name)
546        print('  %s' % ',\n  '.join('%s = %d' % constant for constant in sorted(self._constants.items())))
547        print('};')
548        print()
549        print('namespace {')
550        print()
551        print('inline bool %s_IsValid(int value) {' % self._name)
552        print('  return %s;' % ' || '.join('value == %d' % v for v in sorted(self._canonical)))
553        print('}')
554        print()
555        print('inline const char* %s_Name(int value) {' % self._name)
556        print('  switch (value) {')
557        for value, name in sorted(self._canonical.items()):
558            print('  case %d: return "%s";' % (value, name))
559        print('  default: return "";')
560        print('  }')
561        print('}')
562        print()
563        print('inline bool %s_Parse(std::string_view name, %s* value) {' % (self._name, self._name))
564        for name in sorted(self._constants):
565            print('  if (name == "%s") { *value = %s::%s; return true; }' % (name, self._name, name))
566        print('  return false;')
567        print('}')
568        print()
569        print('const %s %s_MIN = %s::%s;' % (self._name, self._name, self._name, self._canonical[min(self._canonical)]))
570        print('const %s %s_MAX = %s::%s;' % (self._name, self._name, self._name, self._canonical[max(self._canonical)]))
571        print('const std::size_t %s_ARRAYSIZE = %d;' % (self._name, max(self._canonical) + 1))
572        print()
573        print('}  // namespace')
574
575    def print_fields(self, name):
576        print('  %s %s_;' % (self._name, name))
577
578    def print_parsing(self, name):
579        print('          const char* valuestr;')
580        print('          std::size_t valuelen;')
581        print('          if (argdata_get_str(value, &valuestr, &valuelen) == 0)')
582        print('            %s_Parse(std::string_view(valuestr, valuelen), &%s_);' % (self._name, name))
583
584    def print_parsing_map_value(self, name):
585        print('            const char* value2str;')
586        print('            std::size_t value2len;')
587        print('            if (argdata_get_str(value2, &value2str, &value2len) == 0)')
588        print('              %s_Parse(std::string_view(value2str, value2len), &%s_.emplace(mapkey, %s::%s).first->second);' % (self._name, name, self._name, self._canonical[0]))
589
590    def print_parsing_repeated(self, name):
591        print('            const char* elementstr;')
592        print('            std::size_t elementlen;')
593        print('            if (argdata_get_str(element, &elementstr, &elementlen) == 0)')
594        print('              %s_Parse(std::string_view(elementstr, elementlen), &%s_.emplace_back(%s::%s));' % (self._name, name, self._name, self._canonical[0]))
595
596
597class MessageFieldDeclaration:
598
599    grammar = [
600        MapType,
601        RepeatedType,
602        PrimitiveType,
603    ], pypeg2.word, '=', pypeg2.ignore(re.compile(r'\d+')), ';',
604
605    def __init__(self, arguments):
606        self._type = arguments[0]
607        self._name = arguments[1]
608
609    def get_name(self, sanitized):
610        if sanitized and self._name in FORBIDDEN_WORDS:
611            return self._name + '_'
612        return self._name
613
614    def get_type(self):
615        return self._type
616
617
618class MessageDeclaration:
619
620    grammar = 'message', pypeg2.word, '{', pypeg2.maybe_some(
621        MessageFieldDeclaration
622    ), '}'
623
624    def __init__(self, arguments):
625        self._name = arguments[0]
626        self._fields = arguments[1:]
627
628    def get_dependencies(self):
629        r = set()
630        for field in self._fields:
631            r |= field.get_type().get_dependencies()
632        return r
633
634    def get_isset_expression(self, name):
635        return 'has_%s_' % name
636
637    def get_initializer(self, name):
638        return 'has_%s_(false)' % name
639
640    def get_name(self):
641        return self._name
642
643    def print_accessors(self, name):
644        print('  bool has_%s() const { return has_%s_; }' % (name, name))
645        print('  const %s& %s() const { return %s_; }' % (self._name, name, name))
646        print('  %s* mutable_%s() {' % (self._name, name))
647        print('    has_%s_ = true;' % name)
648        print('    return &%s_;' % name)
649        print('  }')
650        print('  void clear_%s() {' % name)
651        print('    has_%s_ = false;' % name)
652        print('    %s_ = %s();' % (name, self._name))
653        print('  }')
654
655    def print_accessors_repeated(self, name):
656        print('  const %s& %s(std::size_t index) const { return %s_[index]; }' % (self._name, name, name))
657        print('  %s* mutable_%s(std::size_t index) { return &%s_[index]; }' % (self._name, name, name))
658        print('  %s* add_%s() { return &%s_.emplace_back(); }' % (self._name, name, name))
659
660    def print_building(self, name):
661        print('      values.push_back(%s_.Build(argdata_builder));' % name)
662
663    def print_building_repeated(self):
664        print('        elements.push_back(element.Build(argdata_builder));')
665
666    def print_code(self, declarations):
667        print('class %s final : public arpc::Message {' % self._name)
668        print(' public:')
669        initializers = list(filter(None, (
670            field.get_type().get_initializer(field.get_name(True), declarations)
671            for field in sorted(self._fields, key=lambda field: field.get_name(False)))))
672        if initializers:
673            print('  %s() : %s {}' % (self._name, ', '.join(initializers)))
674            print()
675        print('  const argdata_t* Build(arpc::ArgdataBuilder* argdata_builder) const override {')
676        if self._fields:
677            print('    std::vector<const argdata_t*> keys;')
678            print('    std::vector<const argdata_t*> values;')
679            for field in sorted(self._fields, key=lambda field: field.get_name(False)):
680                print('    if (%s) {' % (field.get_type().get_isset_expression(field.get_name(True), declarations)))
681                print('      keys.push_back(argdata_builder->BuildStr("%s"));' % field.get_name(False))
682                field.get_type().print_building(field.get_name(True), declarations)
683                print('    }')
684            print('    return argdata_builder->BuildMap(std::move(keys), std::move(values));')
685        else:
686            print('    return &argdata_null;')
687        print('  }')
688        print()
689        print('  void Clear() override {')
690        print('    *this = %s();' % self._name)
691        print('  }')
692        print()
693        print('  void Parse(const argdata_t& ad, arpc::ArgdataParser* argdata_parser) override {')
694        if self._fields:
695            print('    argdata_map_iterator_t it;')
696            print('    argdata_map_iterate(&ad, &it);')
697            print('    const argdata_t* key;')
698            print('    const argdata_t* value;')
699            print('    while (argdata_map_get(&it, &key, &value)) {')
700            print('      const char* keystr;')
701            print('      std::size_t keylen;')
702            print('      if (argdata_get_str(key, &keystr, &keylen) == 0) {')
703            print('        std::string_view keyss(keystr, keylen);')
704            prefix = ''
705            for field in sorted(self._fields, key=lambda field: field.get_name(False)):
706                print('        %sif (keyss == "%s") {' % (prefix, field.get_name(False)))
707                field.get_type().print_parsing(field.get_name(True), declarations)
708                prefix = '} else '
709            print('        }')
710            print('      }')
711            print('      argdata_map_next(&it);')
712            print('    }')
713        print('  }')
714        print()
715
716        for field in sorted(self._fields, key=lambda field: field.get_name(False)):
717            field.get_type().print_accessors(field.get_name(True), declarations)
718            print()
719
720        print(' private:')
721        for field in sorted(self._fields, key=lambda field: field.get_name(False)):
722            field.get_type().print_fields(field.get_name(True), declarations)
723
724        print('};')
725
726    def print_fields(self, name):
727        print('  bool has_%s_;' % name)
728        print('  %s %s_;' % (self._name, name))
729
730    def print_parsing(self, name):
731        print('          has_%s_ = true;' % name)
732        print('          %s_.Parse(*value, argdata_parser);' % name)
733
734    def print_parsing_map_value(self, name):
735        print('              %s_.emplace(mapkey, %s()).first->second.Parse(*value2, argdata_parser);' % (name, self._name))
736
737    def print_parsing_repeated(self, name):
738        print('            %s_.emplace_back().Parse(*element, argdata_parser);' % name)
739
740
741class ServiceRpcDeclaration:
742
743    grammar = 'rpc', pypeg2.word, '(', [
744        StreamType,
745        ReferenceType,
746    ], ')', 'returns', '(', [
747        StreamType,
748        ReferenceType,
749    ], ')', [
750        ';',
751        ('{', '}'),
752    ]
753
754    def __init__(self, arguments):
755        self._name = arguments[0]
756        self._argument_type = arguments[1]
757        self._return_type = arguments[2]
758
759    def get_dependencies(self):
760        return (self._argument_type.get_dependencies() |
761                self._return_type.get_dependencies())
762
763    def get_name(self):
764        return self._name
765
766    def print_service_blocking_client_streaming_call(self, declarations):
767        if self._argument_type.is_stream() and not self._return_type.is_stream():
768            print('    if (rpc == "%s") {' % self._name)
769            print('      arpc::ServerReader<%s> reader_object(reader);' % self._argument_type.get_storage_type(declarations))
770            print('      %s response_object;' % self._return_type.get_storage_type(declarations))
771            print('      arpc::Status status = %s(context, &reader_object, &response_object);' % self._name)
772            print('      if (status.ok())')
773            print('        *response = response_object.Build(argdata_builder);')
774            print('      return status;')
775            print('    }')
776
777    def print_service_blocking_server_streaming_call(self, declarations):
778        if not self._argument_type.is_stream() and self._return_type.is_stream():
779            print('    if (rpc == "%s") {' % self._name)
780            print('      %s request_object;' % self._argument_type.get_storage_type(declarations))
781            print('      request_object.Parse(request, argdata_parser);')
782            print('      arpc::ServerWriter<%s> writer_object(writer);' % self._return_type.get_storage_type(declarations))
783            print('      return %s(context, &request_object, &writer_object);' % self._name)
784            print('    }')
785
786    def print_service_blocking_unary_call(self, declarations):
787        if not self._argument_type.is_stream() and not self._return_type.is_stream():
788            print('    if (rpc == "%s") {' % self._name)
789            print('      %s request_object;' % self._argument_type.get_storage_type(declarations))
790            print('      request_object.Parse(request, argdata_parser);')
791            print('      %s response_object;' % self._return_type.get_storage_type(declarations))
792            print('      arpc::Status status = %s(context, &request_object, &response_object);' % self._name)
793            print('      if (status.ok())')
794            print('        *response = response_object.Build(argdata_builder);')
795            print('      return status;')
796            print('    }')
797
798    def print_service_function(self, declarations):
799        if self._argument_type.is_stream():
800            if self._return_type.is_stream():
801                print('  virtual arpc::Status %s(arpc::ServerContext* context, arpc::ServerReaderWriter<%s, %s>* stream) {' % (self._name, self._argument_type.get_storage_type(declarations), self._return_type.get_storage_type(declarations)))
802            else:
803                print('  virtual arpc::Status %s(arpc::ServerContext* context, arpc::ServerReader<%s>* reader, %s* response) {' % (self._name, self._argument_type.get_storage_type(declarations), self._return_type.get_storage_type(declarations)))
804        else:
805            if self._return_type.is_stream():
806                print('  virtual arpc::Status %s(arpc::ServerContext* context, const %s* request, arpc::ServerWriter<%s>* writer) {' % (self._name, self._argument_type.get_storage_type(declarations), self._return_type.get_storage_type(declarations)))
807            else:
808                print('  virtual arpc::Status %s(arpc::ServerContext* context, const %s* request, %s* response) {' % (self._name, self._argument_type.get_storage_type(declarations), self._return_type.get_storage_type(declarations)))
809        print('    return arpc::Status(arpc::StatusCode::UNIMPLEMENTED, "Operation not provided by this implementation");')
810        print('  }')
811
812    def print_stub_function(self, service, declarations):
813        if self._argument_type.is_stream():
814            if self._return_type.is_stream():
815                print('  std::unique_ptr<arpc::ClientReaderWriter<%s, %s>> %s(arpc::ClientContext* context) {' % (self._argument_type.get_storage_type(declarations), self._return_type.get_storage_type(declarations), self._name))
816                print('    return std::make_unique<arpc::ClientReaderWriter<%s, %s>>(channel_.get(), arpc::RpcMethod("%s", "%s"), context);' % (self._argument_type.get_storage_type(declarations), self._return_type.get_storage_type(declarations), service, self._name))
817                print('  }')
818            else:
819                print('  std::unique_ptr<arpc::ClientWriter<%s>> %s(arpc::ClientContext* context, %s* response) {' % (self._argument_type.get_storage_type(declarations), self._name, self._return_type.get_storage_type(declarations)))
820                print('    return std::make_unique<arpc::ClientWriter<%s>>(channel_.get(), arpc::RpcMethod("%s", "%s"), context, response);' % (self._argument_type.get_storage_type(declarations), service, self._name))
821                print('  }')
822        else:
823            if self._return_type.is_stream():
824                print('  std::unique_ptr<arpc::ClientReader<%s>> %s(arpc::ClientContext* context, const %s& request) {' % (self._return_type.get_storage_type(declarations), self._name, self._argument_type.get_storage_type(declarations)))
825                print('    return std::make_unique<arpc::ClientReader<%s>>(channel_.get(), arpc::RpcMethod("%s", "%s"), context, request);' % (self._return_type.get_storage_type(declarations), service, self._name))
826                print('  }')
827            else:
828                print('  arpc::Status %s(arpc::ClientContext* context, const %s& request, %s* response) {' % (self._name, self._argument_type.get_storage_type(declarations), self._return_type.get_storage_type(declarations)))
829                print('    return channel_->BlockingUnaryCall(arpc::RpcMethod("%s", "%s"), context, request, response);' % (service, self._name))
830                print('  }')
831
832
833class ServiceDeclaration:
834
835    grammar = 'service', pypeg2.word, '{', pypeg2.maybe_some(
836        ServiceRpcDeclaration
837    ), '}'
838
839    def __init__(self, arguments):
840        self._name = arguments[0]
841        self._rpcs = sorted(arguments[1:], key=lambda rpc: rpc.get_name())
842
843    def get_name(self):
844        return self._name
845
846    def get_dependencies(self):
847        r = set()
848        for rpc in self._rpcs:
849            r |= rpc.get_dependencies()
850        return r
851
852    def print_code(self, declarations):
853        print('struct %s {' % self._name)
854        print()
855        print('class Service : public arpc::Service {')
856        print(' public:')
857        print('  std::string_view GetName() override {')
858        print('    return "%s";' % self._name)
859        print('  }')
860        print()
861        print('  arpc::Status BlockingUnaryCall(std::string_view rpc, arpc::ServerContext* context, const argdata_t& request, arpc::ArgdataParser* argdata_parser, const argdata_t** response, arpc::ArgdataBuilder* argdata_builder) override {')
862        for rpc in self._rpcs:
863            rpc.print_service_blocking_unary_call(declarations)
864        print('    return arpc::Status(arpc::StatusCode::UNIMPLEMENTED, "Operation not provided by this service");')
865        print('  }')
866        print()
867
868        print('  arpc::Status BlockingClientStreamingCall(std::string_view rpc, arpc::ServerContext* context, arpc::ServerReaderImpl* reader, const argdata_t** response, arpc::ArgdataBuilder* argdata_builder) override {')
869        for rpc in self._rpcs:
870            rpc.print_service_blocking_client_streaming_call(declarations)
871        print('    return arpc::Status(arpc::StatusCode::UNIMPLEMENTED, "Operation not provided by this service");')
872        print('  }')
873        print()
874
875        print('  arpc::Status BlockingServerStreamingCall(std::string_view rpc, arpc::ServerContext* context, const argdata_t& request, arpc::ArgdataParser* argdata_parser, arpc::ServerWriterImpl* writer) override {')
876        for rpc in self._rpcs:
877            rpc.print_service_blocking_server_streaming_call(declarations)
878        print('    return arpc::Status(arpc::StatusCode::UNIMPLEMENTED, "Operation not provided by this service");')
879        print('  }')
880
881        for rpc in self._rpcs:
882            print()
883            rpc.print_service_function(declarations)
884        print('};')
885        print()
886        print('class Stub {')
887        print(' public:')
888        print('  explicit Stub(const std::shared_ptr<arpc::Channel>& channel)')
889        print('      : channel_(channel) {}')
890        print()
891        for rpc in self._rpcs:
892            rpc.print_stub_function(self._name, declarations)
893            print()
894        print(' private:')
895        print('  const std::shared_ptr<arpc::Channel> channel_;')
896        print('};')
897        print()
898        print('static std::unique_ptr<Stub> NewStub(const std::shared_ptr<arpc::Channel>& channel) {')
899        print('  return std::make_unique<Stub>(channel);')
900        print('}')
901        print()
902        print('};')
903
904
905ProtoFile = (
906    'syntax', '=', ['"proto3"', '\'proto3\''], ';',
907    'package', pypeg2.csl(pypeg2.word, separator='.'), ';',
908    pypeg2.ignore((
909        pypeg2.maybe_some(['import', 'option'], pypeg2.restline),
910    )),
911    pypeg2.maybe_some([
912        EnumDeclaration,
913        MessageDeclaration,
914        ServiceDeclaration,
915    ])
916)
917
918
919input_str = sys.stdin.read()
920input_sha256 = hashlib.sha256(input_str.encode('UTF-8')).hexdigest()
921declarations = pypeg2.parse(input_str, ProtoFile, comment=pypeg2.comment_cpp)
922package = []
923while isinstance(declarations[0], str):
924    package.append(declarations[0])
925    declarations = declarations[1:]
926declarations = {declaration.get_name(): declaration
927                for declaration in declarations[1:]}
928
929def sort_declarations_by_dependencies(declarations):
930    return toposort.toposort_flatten(
931        {declaration.get_name(): declaration.get_dependencies()
932         for declaration in declarations},
933        sort=True)
934
935print('#ifndef APROTOC_%s' % input_sha256)
936print('#define APROTOC_%s' % input_sha256)
937print()
938print('#include <cstdint>')
939print('#include <map>')
940print('#include <memory>')
941print('#include <string>')
942print('#include <string_view>')
943print('#include <vector>')
944print()
945print('#include <argdata.h>')
946print('#include <arpc++/arpc++.h>')
947print()
948for component in package:
949    print('namespace %s {' % component)
950print()
951
952for declaration in sort_declarations_by_dependencies(declarations.values()):
953    declarations[declaration].print_code(declarations)
954    print()
955
956for component in reversed(package):
957    print('}  // namespace', component)
958print()
959
960print('#endif')
961