forked from BachiLi/loma_public
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathasdl_gen.py
367 lines (317 loc) · 13.1 KB
/
asdl_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
"""
This is the main parsing and class metaprogramming module in ASDL-ADT.
Code written by Gilbert Bernstein, Jonathan Ragan-Kelley, and Alex Reinking.
Available under MIT open source-license. Please see ASDL-ADT project
on Github for more info.
"""
import asdl
import textwrap
import inspect
import os
from typing import Any, Mapping, Optional, Collection, List
from collections import ChainMap
from yapf.yapflib.yapf_api import FormatCode
# --------------------------------------------------------------------------- #
# Main Interface API
def ADT(
asdl_str : str,
*,
header : Optional[str] = None,
ext_types : Optional[Mapping[str, str]] = None,
checks : Optional[Mapping[str, str]] = None,
no_checks : bool = False,
memoize : Optional[Collection[str]] = None,
filename : Optional[str] = None
):
"""
Function that converts an ASDL grammar into a Python Module.
The generated module will contain one class for every ASDL type
declared in the input grammar, and one (sub-)class for every
constructor in each of those types. These constructors will
type-check objects on __new__ to ensure conformity with the
given grammar.
You should `import module_name` after running the ADT command.
Some additional features are:
- The new module will be written to a file -- either `filename.py` or
by default `module_name.py`
Arguments:
=================
- `header` is a string that will get pre-pended to the generated file
- `ext_types` is a mapping from undefined ASDL type names to Python
types used to back those types. Note that the following types are
built in: 'bool', 'float', 'int', 'object' (i.e. any),
'string' (i.e. str)
- `checks` is a mapping from (undefined) ASDL type names to Python
functions used to type check those types instead of a default
`isinstance(...)` check
- `no_checks` will turn off the dynamic type-checking feature, which
may improve performance. We recommend against using this option.
- `memoize` is a collection of ASDL class names for which you would
like to generate a memoized constructor
- `filename` allows for specifying an alternate filename (see above)
ASDL Syntax
=================
The grammar of ASDL follows this BNF::
module ::= "module" Id "{" [definitions] "}"
definitions ::= { TypeId "=" type }+
type ::= product | sum
product ::= fields
fields ::= "()" | "(" { field, "," }* field ")"
field ::= TypeId ["?" | "*"]? Id
sum ::= constructor { "|" constructor }*
["attributes" fields]?
constructor ::= ConstructorId [fields]
note: *Id is a valid posix name string
Example
=================
::
ADT(\"\"\" module PolyMod {
expr = Var ( id name )
| Const ( float val )
| Sum ( expr* terms )
| Prod ( float coeff, expr* terms )
attributes( string? tag )
}\"\"\",
header=\"\"\"
from my_library import is_valid_id
\"\"\",
ext_types = { 'id' : 'str' },
checks = { 'id' : 'is_valid_id' },
memoize=['Var'])
import PolyMod
"""
asdl_ast = asdl.ASDLParser().parse(asdl_str)
assert isinstance(asdl_ast, asdl.Module)
header = header or ""
ext_types = ext_types or dict()
checks = checks or dict()
memoize = memoize or set()
mod_str = _BuildClasses(asdl_ast, header,
ext_types,
checks,
not no_checks,
memoize).file()
mod_str = ("\"\"\"\n"+textwrap.dedent(
"""\
ASDL Module generated by asdl_adt
Original ASDL description:
""")
+ textwrap.dedent(asdl_str)
+ "\n\"\"\"\n"
+ textwrap.dedent(mod_str))
# determine the filename to dump to
caller_frame = inspect.stack()[1]
if not filename:
caller_dir = os.path.dirname(caller_frame.filename)
try:
os.mkdir(os.path.join(caller_dir, '_asdl'))
except FileExistsError:
pass
filename = os.path.join(caller_dir, '_asdl',
f"{asdl_ast.name}.py")
with open(filename, "w") as new_file:
new_file.write(mod_str)
# --------------------------------------------------------------------------- #
# Implementation
class _BuildClasses:
""" Constructs an output file to dump """
_builtin_types = {
"bool": "bool",
"float": "float",
"int": "int",
"object": "object",
"string": "str",
}
_builtin_checks = {
"object": "(lambda x: x is not None)",
}
_default_header = """
from __future__ import annotations
import attrs as _attrs
from typing import Tuple as _Tuple
from typing import Optional as _Optional
def _list_to_tuple(xs):
return tuple(xs) if isinstance(xs, list) else xs
"""
def __init__(self, mod : asdl.Module,
header : str,
ext_types : Mapping[str, str],
checks : Mapping[str, str],
do_checks : bool,
memoize : Collection[str]):
self._header = textwrap.dedent(self._default_header)
if header:
self._header += textwrap.dedent(header) + "\n\n"
self._types = ChainMap({}, ext_types, self._builtin_types)
self._checks = ChainMap({}, checks, self._builtin_checks)
self._do_checks = do_checks
self._memoize = memoize
self._classes = []
self._mod_name = None
self.register_module(mod)
self.build_module(mod)
def file(self):
""" Return the results of building """
final_str = FormatCode(self._header +
"\n\n".join(self._classes) +
"\n",
style_config="pep8")[0]
return final_str
def register_module(self, mod : asdl.Module):
""" This is the first pass over the definition. It sets up
anything that might be needed by the second pass """
self._mod_name = mod.name
# register all classes
for dfn in mod.dfns:
if isinstance(dfn.value, asdl.Product):
self.register_class(dfn.name)
elif isinstance(dfn.value, asdl.Sum):
self.register_sum(dfn.value, dfn.name)
else:
assert False, f"bad case: {type(dfn)}"
def register_class(self, name : str):
self._types[name] = name
def register_sum(self, sum_node : asdl.Sum, name : str):
self.register_class(name)
for t in sum_node.types:
self.register_class(t.name)
def build_module(self, mod : asdl.Module):
""" Go through all the top-level type definitions and
dispatch to the product or sum builders """
for dfn in mod.dfns:
if isinstance(dfn.value, asdl.Product):
self.build_product(dfn.value, dfn.name)
elif isinstance(dfn.value, asdl.Sum):
self.build_sum(dfn.value, dfn.name)
else:
assert False, f"bad case: {type(dfn)}"
def build_product(self, prod : asdl.Product, name : str):
""" Create a full-featured class for each Product """
self.build_class(name, prod.fields)
def build_sum(self, sum_node : asdl.Sum, name : str):
""" Create an abstract base class for each Sum type.
Then create sub-classes for each Constructor """
self._classes.append(
f"class {name}:\n"
f" def __init__(self, *args, **kwargs):\n"
f" assert False,\"Cannot instantiate {name}\"")
attributes = sum_node.attributes
for t in sum_node.types:
self.build_class(t.name, t.fields + attributes, super_cls = name)
def build_class(self, name : str, fields : List[asdl.Field],
super_cls : Optional[str] = None):
""" Build a fully-featured class """
# optional inheritance from a super-class
super_str = f"({super_cls})" if super_cls else ""
# Define the class name and set up `lines` as an accumulator
lines = [f"@_attrs.define(frozen=True)",
f"class {name}{super_str}:"]
if len(fields) == 0:
#lines.append(" pass")
pass
else:
lines += self.build_field_decls(fields)
lines.append(" ")
if name in self._memoize:
lines += [" _memo_cache = dict()",
" "]
lines += self.build_new_fn(fields, name)
lines += [" "]
lines += self.build_init_fn(fields, name)
self._classes.append("\n".join(lines))
def build_field_decls(self, fields : List[asdl.Field]):
lines = []
for f in fields:
if f.name in self._types:
raise TypeError(f"Cannot use '{f.name}' as a field name; "
f"it is already being used as a type name")
seq = ""
if f.seq:
seq = " = _attrs.field(converter=_list_to_tuple)"
l = f" {f.name} : {self.field_type(f)}{seq}"
if f.opt:
l += f" = None"
lines.append(l)
return lines
def build_new_fn(self, fields : List[asdl.Field], clsname):
fnames = []
fnames_init = []
for f in fields:
n = f.name
fnames.append(n)
if f.opt:
n += ' = None'
fnames_init.append(n)
argstr = ', '.join(['cls']+fnames_init)
fargs = ', '.join(fnames)
lines = [f" def __new__({argstr}):"]
# do conversion from list to tuple as needed for key
if clsname in self._memoize:
for f in fields:
lines += self.build_seq_coercion(f.name, f, pre=" ")
if len(lines) > 1:
lines += [" "]
if clsname in self._memoize:
lines += [f" _key_ = tuple([{fargs}])",
f" if _key_ in cls._memo_cache:",
f" result = cls._memo_cache[_key_]",
f" else:",
f" result = super().__new__(cls)",
f" cls._memo_cache[_key_] = result",
f" return result"]
else:
lines += [f" return super().__new__(cls)"]
return lines
def build_init_fn(self, fields : List[asdl.Field], clsname):
if self._do_checks:
lines = [f" def __attrs_post_init__(self):"]
for k,f in enumerate(fields):
lines += self.build_field_check(clsname, f, k, pre=" ")
if len(lines) == 1:
lines += [" pass"]
else:
lines = []
return lines
def build_seq_coercion(self, val : str, field : asdl.Field, pre=""):
if not field.seq:
return []
else:
return [f"{pre}if isinstance({val}, list):",
f"{pre} {val} = tuple({val})"]
def build_field_check(self, clsname : str,
field : asdl.Field, k : int, pre=""):
fname = f"self.{field.name}"
ftype = field.type + ("*" if field.seq else
"?" if field.opt else "")
errmsg = (f"\"{clsname}(...) argument {k+1}: \" + \"invalid instance "
f"of '{ftype} {field.name}'\"")
def check(x):
if field.type in self._checks:
return f"{self._checks[field.type]}({x})"
else:
btyp = self.field_type(field,base=True)
return f"isinstance({x}, {btyp})"
if field.seq:
check_expr = (f"not (isinstance({fname}, (tuple,list)) and "
f"all({check('x')} for x in {fname}))")
elif field.opt:
check_expr = f"not ({fname} is None or {check(fname)})"
else:
check_expr = f"not {check(fname)}"
return [f"{pre}if {check_expr}:",
f"{pre} raise TypeError({errmsg})"]
def field_type(self, field : asdl.Field, base=False):
# safely look up the concrete type name for the annotation
if field.type not in self._types:
raise TypeError(f"Unrecognized type '{field.type}'; did you "
f"mean to include it in ext_types?")
typ = self._types[field.type]
if base:
return typ
# handle sequence and optional qualifiers
assert not (field.seq and field.opt), "cannot qualify as both * and ?"
if field.seq:
typ = f"_Tuple[{typ}]"
elif field.opt:
typ = f"_Optional[{typ}]"
return typ