-
Notifications
You must be signed in to change notification settings - Fork 233
Expand file tree
/
Copy pathparser.py
More file actions
204 lines (178 loc) · 7.13 KB
/
parser.py
File metadata and controls
204 lines (178 loc) · 7.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
from betterproto.lib.google.protobuf.compiler import (
CodeGeneratorRequest,
CodeGeneratorResponse,
CodeGeneratorResponseFeature,
CodeGeneratorResponseFile,
)
import itertools
import pathlib
import sys
from typing import Iterator, List, Set, Tuple, TYPE_CHECKING, Union
from .compiler import outputfile_compiler
from .models import (
EnumDefinitionCompiler,
FieldCompiler,
MapEntryCompiler,
MessageCompiler,
Options,
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
is_oneof,
)
if TYPE_CHECKING:
from google.protobuf.descriptor import Descriptor
def traverse(
proto_file: FieldDescriptorProto,
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
# Todo: Keep information about nested hierarchy
def _traverse(
path: List[int], items: List["EnumDescriptorProto"], prefix=""
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple
item.name = next_prefix = prefix + item.name
yield item, path + [i]
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
enum.name = next_prefix + enum.name
yield enum, path + [i, 4]
if item.nested_type:
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def parse_options(plugin_options: List[str]) -> Options:
options = Options()
for option in plugin_options:
if option.startswith("grpc="):
options.grpc_kind = option.split("=", 1)[1]
if option == "INCLUDE_GOOGLE":
options.include_google = True
return options
def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
response = CodeGeneratorResponse()
plugin_options = request.parameter.split(",") if request.parameter else []
response.supported_features = CodeGeneratorResponseFeature.FEATURE_PROTO3_OPTIONAL
options = parse_options(plugin_options)
request_data = PluginRequestCompiler(
plugin_request_obj=request, options=options
)
# Gather output packages
for proto_file in request.proto_file:
if proto_file.package == "google.protobuf" and options.include_google:
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue
output_package_name = proto_file.package
if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package
request_data.output_packages[output_package_name] = OutputTemplate(
parent_request=request_data, package_proto_obj=proto_file
)
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for item, path in traverse(proto_input_file):
read_protobuf_type(
source_file=proto_input_file,
item=item,
path=path,
output_package=output_package,
)
# Read Services
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(service, index, output_package)
# Generate output files
output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items():
# Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)
response.file.append(
CodeGeneratorResponseFile(
name=str(output_path),
# Render and then format the output file
content=outputfile_compiler(output_file=output_package),
)
)
# Make each output directory a package with __init__ file
init_files = {
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
} - output_paths
for init_file in init_files:
response.file.append(CodeGeneratorResponseFile(name=str(init_file)))
for output_package_name in sorted(output_paths.union(init_files)):
print(f"Writing {output_package_name}", file=sys.stderr)
return response
def read_protobuf_type(
item: DescriptorProto,
path: List[int],
source_file: "FileDescriptorProto",
output_package: OutputTemplate,
) -> None:
if isinstance(item, DescriptorProto):
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return
# Process Message
message_data = MessageCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
for index, field in enumerate(item.field):
if is_map(field, item):
MapEntryCompiler(
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
elif is_oneof(field):
OneOfFieldCompiler(
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
else:
FieldCompiler(
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
def read_protobuf_service(
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
) -> None:
service_data = ServiceCompiler(
parent=output_package, proto_obj=service, path=[6, index]
)
for j, method in enumerate(service.method):
ServiceMethodCompiler(
parent=service_data, proto_obj=method, path=[6, index, 2, j]
)