1#!/usr/bin/env python
2# Copyright 2016 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""protoc plugin to create C++ reader/writer for JSON-encoded protobufs
7
8The reader/writer use Chrome's base::Values.
9"""
10
11import os
12import sys
13
14from util import plugin_protos, types, writer
15
16
17class CppConverterWriter(writer.CodeWriter):
18  def WriteProtoFile(self, proto_file, output_dir):
19    err = proto_file.CheckSupported()
20    if err:
21      self.AddError(err)
22      return
23
24    self.WriteCStyleHeader()
25
26    self.Output('#include "{output_dir}{generated_pb_h}"',
27                output_dir=output_dir + '/' if output_dir else '',
28                generated_pb_h=proto_file.CppBaseHeader())
29    self.Output('')
30
31    # import is not supported
32    assert [] == proto_file.GetDependencies()
33
34    self.Output('// base dependencies')
35    self.Output('#include "base/values.h"')
36    self.Output('')
37    self.Output('#include <memory>')
38    self.Output('#include <string>')
39    self.Output('#include <utility>')
40    self.Output('')
41
42    namespaces = proto_file.ProtoNamespaces() + ['json']
43    for name in namespaces:
44      self.Output('namespace {name} {{', name=name)
45      self.IncreaseIndent()
46
47    for message in proto_file.GetMessages():
48      self.WriteMessage(message)
49
50    # Nothing to do for enums
51
52    for name in namespaces:
53      self.DecreaseIndent()
54      self.Output('}}')
55
56  def WriteMessage(self, message):
57    self.Output('class {class_name} {{',
58                class_name=message.CppConverterClassName())
59    self.Output(' public:')
60    with self.AddIndent():
61      for nested_class in message.GetMessages():
62        self.WriteMessage(nested_class)
63
64      generated_class_name = message.QualifiedTypes().cpp_base
65      # Nothing to write for enums.
66
67      self.Output(
68          'static bool ReadFromValue(const base::Value* json, {generated_class_name}* message) {{\n'
69          '  const base::DictionaryValue* dict;\n'
70          '  if (!json->GetAsDictionary(&dict)) goto error;\n'
71          '',
72          generated_class_name=generated_class_name)
73
74      with self.AddIndent():
75        for field_proto in message.GetFields():
76          self.WriteFieldRead(field_proto)
77
78      self.Output(
79          '  return true;\n'
80          '\n'
81          'error:\n'
82          '  return false;\n'
83          '}}\n'
84          '\n'
85          'static std::unique_ptr<base::DictionaryValue> WriteToValue(const {generated_class_name}& message) {{\n'
86          '  std::unique_ptr<base::DictionaryValue> dict(new base::DictionaryValue());\n'
87          '',
88          generated_class_name=generated_class_name)
89
90      with self.AddIndent():
91        for field_proto in message.GetFields():
92          self.FieldWriteToValue(field_proto)
93
94      self.Output(
95          '  return dict;\n'
96          '',
97          generated_class_name=generated_class_name)
98      self.Output('}}')
99
100    self.Output('}};')
101    self.Output('')
102
103  def FieldWriteToValue(self, field):
104    if field.IsRepeated():
105      self.Output('{{')
106    else:
107      self.Output('if (message.has_{field_name}()) {{\n', field_name=field.name)
108
109    with self.AddIndent():
110      if field.IsRepeated():
111        self.RepeatedMemberFieldWriteToValue(field)
112      else:
113        self.OptionalMemberFieldWriteToValue(field)
114
115    self.Output('}}')
116
117  def RepeatedMemberFieldWriteToValue(self, field):
118    prologue = (
119        'auto field_list = std::make_unique<base::ListValue>();\n'
120        'for (int i = 0; i < message.{field_name}_size(); ++i) {{\n'
121    )
122
123    if field.IsClassType():
124      middle = (
125          'std::unique_ptr<base::Value> inner_message_value = \n'
126          '    {inner_class_converter}::WriteToValue(message.{field_name}(i));\n'
127          'field_list->Append(std::move(inner_message_value));\n'
128      )
129    else:
130      middle = (
131          'field_list->Append{value_type}(message.{field_name}(i));\n'
132      )
133
134    epilogue = (
135        '\n}}\n'
136        'dict->Set("{field_number}", std::move(field_list));'
137    )
138    self.Output(
139        prologue + Indented(middle) + epilogue,
140        field_number=field.JavascriptIndex(),
141        field_name=field.name,
142        value_type=field.CppValueType() if not field.IsClassType() else None,
143        inner_class_converter=field.CppConverterType()
144    )
145
146  def OptionalMemberFieldWriteToValue(self, field):
147    if field.IsClassType():
148      body = (
149          'std::unique_ptr<base::Value> inner_message_value = \n'
150          '    {inner_class_converter}::WriteToValue(message.{field_name}());\n'
151          'dict->Set("{field_number}", std::move(inner_message_value));\n'
152      )
153    else:
154      body = (
155          'dict->Set{value_type}("{field_number}", message.{field_name}());\n'
156      )
157
158    self.Output(
159        body,
160        field_number=field.JavascriptIndex(),
161        field_name=field.name,
162        value_type=field.CppValueType() if not field.IsClassType() else None,
163        inner_class_converter=field.CppConverterType(),
164    )
165
166  def WriteFieldRead(self, field):
167    self.Output('if (dict->HasKey("{field_number}")) {{',
168                field_number=field.JavascriptIndex())
169
170    with self.AddIndent():
171      if field.IsRepeated():
172        self.RepeatedMemberFieldRead(field)
173      else:
174        self.OptionalMemberFieldRead(field)
175
176    self.Output('}}')
177
178  def RepeatedMemberFieldRead(self, field):
179    prologue = (
180        'const base::ListValue* field_list;\n'
181        'if (!dict->GetList("{field_number}", &field_list)) {{\n'
182        '  goto error;\n'
183        '}}\n'
184        'for (size_t i = 0; i < field_list->GetSize(); ++i) {{\n'
185    )
186
187    if field.IsClassType():
188      middle = (
189          'const base::Value* inner_message_value;\n'
190          'if (!field_list->Get(i, &inner_message_value)) {{\n'
191          '  goto error;\n'
192          '}}\n'
193          'if (!{inner_class_parser}::ReadFromValue(inner_message_value, message->add_{field_name}())) {{\n'
194          '  goto error;\n'
195          '}}\n'
196      )
197    else:
198      middle = (
199          '{cpp_type} field_value;\n'
200          'if (!field_list->Get{value_type}(i, &field_value)) {{\n'
201          '  goto error;\n'
202          '}}\n'
203          'message->add_{field_name}(field_value);\n'
204      )
205
206    self.Output(
207        prologue + Indented(middle) + '\n}}',
208        field_number=field.JavascriptIndex(),
209        field_name=field.name,
210        cpp_type=field.CppPrimitiveType() if not field.IsClassType() else None,
211        value_type=field.CppValueType() if not field.IsClassType() else None,
212        inner_class_parser=field.CppConverterType()
213    )
214
215  def OptionalMemberFieldRead(self, field):
216    if field.IsClassType():
217      self.Output(
218          'const base::Value* inner_message_value;\n'
219          'if (!dict->Get("{field_number}", &inner_message_value)) {{\n'
220          '  goto error;\n'
221          '}}\n'
222          'if (!{inner_class_parser}::ReadFromValue(inner_message_value, message->mutable_{field_name}())) {{\n'
223          '  goto error;\n'
224          '}}\n'
225          '',
226          field_number=field.JavascriptIndex(),
227          field_name=field.name,
228          inner_class_parser=field.CppConverterType()
229      )
230    else:
231      self.Output(
232          '{cpp_type} field_value;\n'
233          'if (!dict->Get{value_type}("{field_number}", &field_value)) {{\n'
234          '  goto error;\n'
235          '}}\n'
236          'message->set_{field_name}(field_value);\n'
237          '',
238          field_number=field.JavascriptIndex(),
239          field_name=field.name,
240          cpp_type=field.CppPrimitiveType(),
241          value_type=field.CppValueType()
242      )
243
244
245def Indented(s, indent=2):
246  return '\n'.join((' ' * indent) + p for p in s.rstrip('\n').split('\n'))
247
248
249def SetBinaryStdio():
250  import platform
251  if platform.system() == 'Windows':
252    import msvcrt
253    msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
254    msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
255
256
257def ReadRequestFromStdin():
258  data = sys.stdin.read()
259  return plugin_protos.PluginRequestFromString(data)
260
261
262def main():
263  SetBinaryStdio()
264  request = ReadRequestFromStdin()
265  response = plugin_protos.PluginResponse()
266
267  output_dir = request.GetArgs().get('output_dir', '')
268
269  for proto_file in request.GetAllFiles():
270    types.RegisterProtoFile(proto_file)
271
272    cppwriter = CppConverterWriter()
273    cppwriter.WriteProtoFile(proto_file, output_dir)
274
275    converter_filename = proto_file.CppConverterFilename()
276    if output_dir:
277      converter_filename = os.path.join(output_dir,
278                                        os.path.split(converter_filename)[1])
279
280    response.AddFileWithContent(converter_filename, cppwriter.GetValue())
281    if cppwriter.GetErrors():
282      response.AddError('\n'.join(cppwriter.GetErrors()))
283
284  response.WriteToStdout()
285
286
287if __name__ == '__main__':
288  main()
289