OILS / asdl / gen_python.py View on Github | oils.pub

671 lines, 427 significant
1#!/usr/bin/env python2
2"""gen_python.py: Generate Python code from an ASDL schema."""
3from __future__ import print_function
4
5from collections import defaultdict
6
7from asdl import ast
8from asdl import visitor
9from asdl.util import log
10
11from typing import List, Dict, Optional, Union, Tuple, IO, Any, cast
12
13_ = log # shut up lint
14
15_PRIMITIVES = {
16 'string': 'str',
17 'int': 'int',
18 'uint16': 'int',
19 'BigInt': 'mops.BigInt',
20 'float': 'float',
21 'bool': 'bool',
22 'any': 'Any',
23}
24
25
26def _MyPyType(typ):
27 # type: (ast.type_expr_t) -> str
28 """ASDL type to MyPy Type."""
29 if isinstance(typ, ast.ParameterizedType):
30
31 if typ.type_name == 'Dict':
32 k_type = _MyPyType(typ.children[0])
33 v_type = _MyPyType(typ.children[1])
34 return 'Dict[%s, %s]' % (k_type, v_type)
35
36 if typ.type_name == 'List':
37 return 'List[%s]' % _MyPyType(typ.children[0])
38
39 if typ.type_name == 'Optional':
40 return 'Optional[%s]' % _MyPyType(typ.children[0])
41
42 raise AssertionError()
43
44 elif isinstance(typ, ast.NamedType):
45 if typ.resolved:
46 if isinstance(typ.resolved, ast.Sum): # includes SimpleSum
47 return '%s_t' % typ.name
48 if isinstance(typ.resolved, ast.Product):
49 return typ.name
50 if isinstance(typ.resolved, ast.Use):
51 py_name, _ = ast.TypeNameHeuristic(typ.name)
52 return py_name
53 if isinstance(typ.resolved, ast.Extern):
54 r = typ.resolved
55 type_name = r.names[-1]
56 py_module = r.names[-2]
57 return '%s.%s' % (py_module, type_name)
58 raise AssertionError(typ)
59
60 return _PRIMITIVES[typ.name]
61
62 else:
63 raise AssertionError()
64
65
66def _CastedNull(mypy_type):
67 # type: (str) -> str
68 return "cast('%s', None)" % mypy_type
69
70
71def _DefaultValue(typ, mypy_type):
72 # type: (ast.type_expr_t, str) -> str
73 """Values that the static CreateNull() constructor passes.
74
75 mypy_type is used to cast None, to maintain mypy --strict for ASDL.
76
77 We circumvent the type system on CreateNull(). Then the user is
78 responsible for filling in all the fields. If they do so, we can
79 rely on it when reading fields at runtime.
80 """
81 if isinstance(typ, ast.ParameterizedType):
82 type_name = typ.type_name
83
84 if type_name == 'Optional':
85 return _CastedNull(mypy_type)
86
87 if type_name == 'List':
88 return "[] if alloc_lists else cast('%s', None)" % mypy_type
89
90 if type_name == 'Dict': # TODO: can respect alloc_dicts=True
91 return _CastedNull(mypy_type)
92
93 raise AssertionError(type_name)
94
95 if isinstance(typ, ast.NamedType):
96 type_name = typ.name
97
98 if type_name == 'id': # hard-coded HACK
99 return '-1'
100
101 if type_name == 'int':
102 return '-1'
103
104 if type_name == 'BigInt':
105 return 'mops.BigInt(-1)'
106
107 if type_name == 'bool':
108 return 'False'
109
110 if type_name == 'float':
111 return '0.0' # or should it be NaN?
112
113 if type_name == 'string':
114 return "''"
115
116 if isinstance(typ.resolved, ast.SimpleSum):
117 sum_type = typ.resolved
118 # Just make it the first variant. We could define "Undef" for
119 # each enum, but it doesn't seem worth it.
120 return '%s_e.%s' % (type_name, sum_type.types[0].name)
121
122 # CompoundSum or Product type
123 return _CastedNull(mypy_type)
124
125 else:
126 raise AssertionError()
127
128
129def _HNodeExpr(typ, var_name):
130 # type: (ast.type_expr_t, str) -> Tuple[str, bool]
131 none_guard = False
132
133 if ast.IsOptional(typ):
134 narrow = cast(ast.ParameterizedType, typ)
135 typ = narrow.children[0] # descend one level
136
137 if isinstance(typ, ast.ParameterizedType):
138 code_str = '%s.PrettyTree()' % var_name
139 none_guard = True
140
141 elif isinstance(typ, ast.NamedType):
142 type_name = typ.name
143
144 if type_name == 'bool':
145 code_str = "hnode.Leaf('T' if %s else 'F', color_e.OtherConst)" % var_name
146
147 elif type_name in ('int', 'uint16'):
148 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
149
150 elif type_name == 'BigInt':
151 code_str = 'hnode.Leaf(mops.ToStr(%s), color_e.OtherConst)' % var_name
152
153 elif type_name == 'float':
154 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
155
156 elif type_name == 'string':
157 code_str = 'NewLeaf(%s, color_e.StringConst)' % var_name
158
159 elif type_name == 'any':
160 code_str = 'NewLeaf(str(%s), color_e.External)' % var_name
161
162 elif type_name == 'id': # was meta.UserType
163 # This assumes it's Id, which is a simple SumType. TODO: Remove this.
164 code_str = 'hnode.Leaf(Id_str(%s, dot=False), color_e.UserType)' % var_name
165
166 elif typ.resolved and isinstance(typ.resolved, ast.SimpleSum):
167 code_str = 'hnode.Leaf(%s_str(%s), color_e.TypeName)' % (type_name,
168 var_name)
169
170 else:
171 code_str = '%s.PrettyTree(do_abbrev, trav=trav)' % var_name
172 none_guard = True
173
174 else:
175 raise AssertionError()
176
177 return code_str, none_guard
178
179
180class GenMyPyVisitor(visitor.AsdlVisitor):
181 """Generate Python code with MyPy type annotations."""
182
183 def __init__(
184 self,
185 f, # type: IO[bytes]
186 abbrev_mod_entries=None, # type: Optional[List[str]]
187 pretty_print_methods=True, # type: bool
188 py_init_n=False, # type: bool
189 ):
190 # type: (...) -> None
191
192 visitor.AsdlVisitor.__init__(self, f)
193 self.abbrev_mod_entries = abbrev_mod_entries or []
194 self.pretty_print_methods = pretty_print_methods
195 self.py_init_n = py_init_n
196
197 self._shared_type_tags = {} # type: Dict[str, int]
198 self._product_counter = 64 # matches asdl/gen_cpp.py
199
200 self._products = [] # type: List[Any]
201 self._base_classes = defaultdict(list) # type: Dict[str, List[str]]
202
203 self._subtypes = [] # type: List[Any]
204
205 def _EmitDict(self, name, d, depth):
206 # type: (str, Dict[int, str], int) -> None
207 self.Emit('_%s_str = {' % name, depth)
208 for k in sorted(d):
209 self.Emit('%d: %r,' % (k, d[k]), depth + 1)
210 self.Emit('}', depth)
211 self.Emit('', depth)
212
213 def VisitSimpleSum(self, sum, sum_name, depth):
214 # type: (ast.SimpleSum, str, int) -> None
215 int_to_str = {}
216 variants = []
217 for i, variant in enumerate(sum.types):
218 tag_num = i + 1
219 int_to_str[tag_num] = variant.name
220 variants.append((variant, tag_num))
221
222 add_suffix = not ('no_namespace_suffix' in sum.generate)
223 gen_integers = 'integers' in sum.generate or 'uint16' in sum.generate
224
225 if gen_integers:
226 self.Emit('%s_t = int # type alias for integer' % sum_name)
227 self.Emit('')
228
229 i_name = ('%s_i' % sum_name) if add_suffix else sum_name
230
231 self.Emit('class %s(object):' % i_name, depth)
232
233 for variant, tag_num in variants:
234 line = ' %s = %d' % (variant.name, tag_num)
235 self.Emit(line, depth)
236
237 # Help in sizing array. Note that we're 1-based.
238 line = ' %s = %d' % ('ARRAY_SIZE', len(variants) + 1)
239 self.Emit(line, depth)
240
241 else:
242 # First emit a type
243 self.Emit('class %s_t(pybase.SimpleObj):' % sum_name, depth)
244 self.Emit(' pass', depth)
245 self.Emit('', depth)
246
247 # Now emit a namespace
248 e_name = ('%s_e' % sum_name) if add_suffix else sum_name
249 self.Emit('class %s(object):' % e_name, depth)
250
251 for variant, tag_num in variants:
252 line = ' %s = %s_t(%d)' % (variant.name, sum_name, tag_num)
253 self.Emit(line, depth)
254
255 self.Emit('', depth)
256
257 self._EmitDict(sum_name, int_to_str, depth)
258
259 sum_name = ast.NameHack(sum_name)
260
261 self.Emit('def %s_str(val, dot=True):' % sum_name, depth)
262 self.Emit(' # type: (%s_t, bool) -> str' % sum_name, depth)
263 self.Emit(' v = _%s_str[val]' % sum_name, depth)
264 self.Emit(' if dot:', depth)
265 self.Emit(' return "%s.%%s" %% v' % sum_name, depth)
266 self.Emit(' else:', depth)
267 self.Emit(' return v', depth)
268 self.Emit('', depth)
269
270 def _EmitCodeForField(self, field, counter):
271 # type: (ast.Field, int) -> None
272 """Generate code that returns an hnode for a field."""
273 out_val_name = 'x%d' % counter
274
275 if ast.IsList(field.typ):
276 iter_name = 'i%d' % counter
277
278 typ = field.typ
279 if ast.IsOptional(typ): # descend one level
280 typ = cast(ast.ParameterizedType, typ).children[0]
281 item_type = cast(ast.ParameterizedType, typ).children[0]
282
283 self.Emit(' if self.%s is not None: # List' % field.name)
284 self.Emit(' %s = hnode.Array([])' % out_val_name)
285 self.Emit(' for %s in self.%s:' % (iter_name, field.name))
286 child_code_str, none_guard = _HNodeExpr(item_type, iter_name)
287
288 if none_guard: # e.g. for List[Optional[value_t]]
289 # TODO: could consolidate with asdl/runtime.py NewLeaf(), which
290 # also uses _ to mean None/nullptr
291 self.Emit(
292 ' h = (hnode.Leaf("_", color_e.OtherConst) if %s is None else %s)'
293 % (iter_name, child_code_str))
294 self.Emit(' %s.children.append(h)' % out_val_name)
295 else:
296 self.Emit(' %s.children.append(%s)' %
297 (out_val_name, child_code_str))
298
299 self.Emit(' L.append(Field(%r, %s))' %
300 (field.name, out_val_name))
301
302 elif ast.IsDict(field.typ):
303 k = 'k%d' % counter
304 v = 'v%d' % counter
305
306 typ = field.typ
307 if ast.IsOptional(typ): # descend one level
308 typ = cast(ast.ParameterizedType, typ).children[0]
309
310 k_typ = cast(ast.ParameterizedType, typ).children[0]
311 v_typ = cast(ast.ParameterizedType, typ).children[1]
312
313 k_code_str, _ = _HNodeExpr(k_typ, k)
314 v_code_str, _ = _HNodeExpr(v_typ, v)
315
316 unnamed = 'unnamed%d' % counter
317 self.Emit(' if self.%s is not None: # Dict' % field.name)
318 self.Emit(' %s = [] # type: List[hnode_t]' % unnamed)
319 self.Emit(' %s = hnode.Record("", "{", "}", [], %s)' %
320 (out_val_name, unnamed))
321 self.Emit(' for %s, %s in self.%s.iteritems():' %
322 (k, v, field.name))
323 self.Emit(' %s.append(%s)' % (unnamed, k_code_str))
324 self.Emit(' %s.append(%s)' % (unnamed, v_code_str))
325 self.Emit(' L.append(Field(%r, %s))' %
326 (field.name, out_val_name))
327
328 elif ast.IsOptional(field.typ):
329 typ = cast(ast.ParameterizedType, field.typ).children[0]
330
331 self.Emit(' if self.%s is not None: # Optional' % field.name)
332 child_code_str, _ = _HNodeExpr(typ, 'self.%s' % field.name)
333 self.Emit(' %s = %s' % (out_val_name, child_code_str))
334 self.Emit(' L.append(Field(%r, %s))' %
335 (field.name, out_val_name))
336
337 else:
338 var_name = 'self.%s' % field.name
339 code_str, obj_none_guard = _HNodeExpr(field.typ, var_name)
340 depth = self.current_depth
341 if obj_none_guard: # to satisfy MyPy type system
342 self.Emit(' assert self.%s is not None' % field.name)
343 self.Emit(' %s = %s' % (out_val_name, code_str), depth)
344
345 self.Emit(' L.append(Field(%r, %s))' % (field.name, out_val_name),
346 depth)
347
348 def _GenClassBegin(self, class_name, base_classes, tag_num):
349 # type: (str, Union[List[str], Tuple[str]], int) -> None
350 self.Emit('class %s(%s):' % (class_name, ', '.join(base_classes)))
351 self.Emit(' _type_tag = %d' % tag_num)
352
353 def _GenListSubclass(self, class_name, base_classes, tag_num, class_ns=''):
354 # type: (str, List[str], int, str) -> None
355 self._GenClassBegin(class_name, base_classes, tag_num)
356
357 # TODO: Do something nicer
358 base_class_str = [b for b in base_classes if b.startswith('List[')][0]
359
360 # Needed for c = CompoundWord() to work
361 # TODO: make it
362 # c = CompoundWord.New()
363 if 0:
364 self.Emit(' def __init__(self, other=None):')
365 self.Emit(' # type: (Optional[%s]) -> None' % base_class_str,
366 reflow=False)
367 self.Emit(' if other is not None:')
368 self.Emit(' self.extend(other)')
369 self.Emit('')
370
371 # Use our own constructor
372 self.Emit(' @staticmethod')
373 self.Emit(' def New():')
374 self.Emit(' # type: () -> %s' % class_name)
375 self.Emit(' return %s()' % class_name)
376 self.Emit('')
377
378 self.Emit(' @staticmethod')
379 self.Emit(' def Take(plain_list):')
380 self.Emit(' # type: (%s) -> %s' % (base_class_str, class_name))
381 self.Emit(' result = %s(plain_list)' % class_name)
382 self.Emit(' del plain_list[:]')
383 self.Emit(' return result')
384 self.Emit('')
385
386 if self.pretty_print_methods:
387 self._EmitPrettyPrintMethodsForList(class_name)
388
389 def _GenClass(
390 self,
391 fields, # type: List[ast.Field]
392 class_name, # type: str
393 base_classes, # type: Union[List[str], Tuple[str]]
394 tag_num, # type: int
395 class_ns='', # type: str
396 ):
397 # type: (...) -> None
398 """Generate a typed Python class.
399
400 Used for both Sum variants ("constructors") and Product types.
401
402 Args:
403 class_ns: for variants like value.Str
404 """
405 self._GenClassBegin(class_name, base_classes, tag_num)
406
407 field_names = [f.name for f in fields]
408
409 quoted_fields = repr(tuple(field_names))
410 self.Emit(' __slots__ = %s' % quoted_fields)
411 self.Emit('')
412
413 #
414 # __init__
415 #
416
417 args = [f.name for f in fields]
418 arg_types = []
419 default_vals = []
420 for f in fields:
421 mypy_type = _MyPyType(f.typ)
422 arg_types.append(mypy_type)
423
424 d_str = _DefaultValue(f.typ, mypy_type)
425 default_vals.append(d_str)
426
427 self.Emit(' def __init__(self, %s):' % ', '.join(args))
428 self.Emit(' # type: (%s) -> None' % ', '.join(arg_types),
429 reflow=False)
430
431 if not fields:
432 self.Emit(' pass') # for types like NoOp
433
434 for f in fields:
435 # don't wrap the type comment
436 self.Emit(' self.%s = %s' % (f.name, f.name), reflow=False)
437
438 self.Emit('')
439
440 # CreateNull() - another way of initializing
441 if len(fields) and not self.py_init_n:
442 self.Emit(' @staticmethod')
443 self.Emit(' def CreateNull(alloc_lists=False):')
444 self.Emit(' # type: () -> %s%s' % (class_ns, class_name))
445 self.Emit(' return %s%s(%s)' %
446 (class_ns, class_name, ', '.join(default_vals)),
447 reflow=False)
448 self.Emit('')
449
450 # PrettyTree()
451 if self.pretty_print_methods:
452 self._EmitPrettyPrintMethods(class_name, class_ns, fields)
453
454 def _EmitPrettyBegin(self):
455 # type: () -> None
456 self.Emit(' def PrettyTree(self, do_abbrev, trav=None):')
457 self.Emit(' # type: (bool, Optional[TraversalState]) -> hnode_t')
458 self.Emit(' trav = trav or TraversalState()')
459 self.Emit(' heap_id = id(self)')
460 self.Emit(' if heap_id in trav.seen:')
461 # cut off recursion
462 self.Emit(' return hnode.AlreadySeen(heap_id)')
463 self.Emit(' trav.seen[heap_id] = True')
464
465 def _EmitPrettyPrintMethodsForList(self, class_name):
466 # type: (str) -> None
467 self._EmitPrettyBegin()
468 self.Emit(' h = runtime.NewRecord(%r)' % class_name)
469 self.Emit(
470 ' h.unnamed_fields = [c.PrettyTree(do_abbrev) for c in self]')
471 self.Emit(' return h')
472 self.Emit('')
473
474 def _EmitPrettyPrintMethods(self, class_name, class_ns, fields):
475 # type: (str, str, List[ast.Field]) -> None
476 if len(fields) == 0:
477 # value__Stdin -> value.Stdin (defined at top level)
478 pretty_cls_name = class_name.replace('__', '.')
479 else:
480 # value.Str (defined inside the 'class value') namespace
481 pretty_cls_name = '%s%s' % (class_ns, class_name)
482
483 # def PrettyTree(...):
484
485 self._EmitPrettyBegin()
486 self.Emit('')
487
488 if class_ns:
489 # e.g. _command__Simple
490 assert class_ns.endswith('.')
491 abbrev_name = '_%s__%s' % (class_ns[:-1], class_name)
492 else:
493 # e.g. _Token
494 abbrev_name = '_%s' % class_name
495
496 if abbrev_name in self.abbrev_mod_entries:
497 self.Emit(' if do_abbrev:')
498 self.Emit(' p = %s(self)' % abbrev_name)
499 self.Emit(' if p:')
500 self.Emit(' return p')
501 self.Emit('')
502
503 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
504 self.Emit(' L = out_node.fields')
505 self.Emit('')
506
507 # Use the runtime type to be more like asdl/format.py
508 for local_id, field in enumerate(fields):
509 #log('%s :: %s', field_name, field_desc)
510 self.Indent()
511 self._EmitCodeForField(field, local_id)
512 self.Dedent()
513 self.Emit('')
514 self.Emit(' return out_node')
515 self.Emit('')
516
517 def VisitCompoundSum(self, sum, sum_name, depth):
518 # type: (ast.Sum, str, int) -> None
519 """Note that the following is_simple:
520
521 cflow = Break | Continue
522
523 But this is compound:
524
525 cflow = Break | Continue | Return(int val)
526
527 The generated code changes depending on which one it is.
528 """
529 #log('%d variants in %s', len(sum.types), sum_name)
530
531 # We emit THREE Python types for each meta.CompoundType:
532 #
533 # 1. enum for tag (cflow_e)
534 # 2. base class for inheritance (cflow_t)
535 # 3. namespace for classes (cflow) -- TODO: Get rid of this one.
536 #
537 # Should code use cflow_e.tag or isinstance()?
538 # isinstance() is better for MyPy I think. But tag is better for C++.
539 # int tag = static_cast<cflow>(node).tag;
540
541 int_to_str = {}
542
543 # enum for the tag
544 self.Emit('class %s_e(object):' % sum_name, depth)
545
546 for i, variant in enumerate(sum.types):
547 if variant.shared_type:
548 tag_num = self._shared_type_tags[variant.shared_type]
549 # e.g. DoubleQuoted may have base types expr_t, word_part_t
550 base_class = sum_name + '_t'
551 bases = self._base_classes[variant.shared_type]
552 if base_class in bases:
553 raise RuntimeError(
554 "Two tags in sum %r refer to product type %r" %
555 (sum_name, variant.shared_type))
556
557 else:
558 bases.append(base_class)
559 else:
560 tag_num = i + 1
561 self.Emit(' %s = %d' % (variant.name, tag_num), depth)
562 int_to_str[tag_num] = variant.name
563 self.Emit('', depth)
564
565 self._EmitDict(sum_name, int_to_str, depth)
566
567 sum_name = ast.NameHack(sum_name)
568
569 self.Emit('def %s_str(tag, dot=True):' % sum_name, depth)
570 self.Emit(' # type: (int, bool) -> str', depth)
571 self.Emit(' v = _%s_str[tag]' % sum_name, depth)
572 self.Emit(' if dot:', depth)
573 self.Emit(' return "%s.%%s" %% v' % sum_name, depth)
574 self.Emit(' else:', depth)
575 self.Emit(' return v', depth)
576 self.Emit('', depth)
577
578 # the base class, e.g. 'command_t'
579 self.Emit('class %s_t(pybase.CompoundObj):' % sum_name, depth)
580 self.Indent()
581 depth = self.current_depth
582
583 # To imitate C++ API
584 self.Emit('def tag(self):')
585 self.Emit(' # type: () -> int')
586 self.Emit(' return self._type_tag')
587
588 self.Dedent()
589 depth = self.current_depth
590
591 self.Emit('')
592
593 # Declare any zero argument singleton classes outside of the main
594 # "namespace" class.
595 for i, variant in enumerate(sum.types):
596 if variant.shared_type:
597 continue # Don't generate a class for shared types.
598 if len(variant.fields) == 0:
599 # We must use the old-style naming here, ie. command__NoOp, in order
600 # to support zero field variants as constants.
601 class_name = '%s__%s' % (sum_name, variant.name)
602 self._GenClass(variant.fields, class_name, (sum_name + '_t', ),
603 i + 1)
604
605 # Class that's just a NAMESPACE, e.g. for value.Str
606 self.Emit('class %s(object):' % sum_name, depth)
607
608 self.Indent()
609
610 for i, variant in enumerate(sum.types):
611 if variant.shared_type:
612 continue
613
614 if len(variant.fields) == 0:
615 self.Emit('%s = %s__%s()' %
616 (variant.name, sum_name, variant.name))
617 self.Emit('')
618 else:
619 # Use fully-qualified name, so we can have osh_cmd.Simple and
620 # oil_cmd.Simple.
621 fq_name = variant.name
622 self._GenClass(variant.fields,
623 fq_name, (sum_name + '_t', ),
624 i + 1,
625 class_ns=sum_name + '.')
626 self.Emit(' pass', depth) # in case every variant is first class
627
628 self.Dedent()
629 self.Emit('')
630
631 def VisitSubType(self, subtype):
632 # type: (ast.SubTypeDecl) -> None
633 self._shared_type_tags[subtype.name] = self._product_counter
634
635 # Also create these last. They may inherit from sum types that have yet
636 # to be defined.
637 self._subtypes.append((subtype, self._product_counter))
638 self._product_counter += 1
639
640 def VisitProduct(self, product, name, depth):
641 # type: (ast.Product, str, int) -> None
642 self._shared_type_tags[name] = self._product_counter
643 # Create a tuple of _GenClass args to create LAST. They may inherit from
644 # sum types that have yet to be defined.
645 self._products.append((product, name, depth, self._product_counter))
646 self._product_counter += 1
647
648 def EmitFooter(self):
649 # type: () -> None
650 # Now generate all the product types we deferred.
651 for args in self._products:
652 ast_node, name, depth, tag_num = args
653 # Figure out base classes AFTERWARD.
654 bases = self._base_classes[name]
655 if not bases:
656 bases = ['pybase.CompoundObj']
657 self._GenClass(ast_node.fields, name, bases, tag_num)
658
659 for args in self._subtypes:
660 subtype, tag_num = args
661 # Figure out base classes AFTERWARD.
662 bases = self._base_classes[subtype.name]
663 if not bases:
664 bases = ['pybase.CompoundObj']
665
666 bases.append(_MyPyType(subtype.base_class))
667
668 if ast.IsList(subtype.base_class):
669 self._GenListSubclass(subtype.name, bases, tag_num)
670 else:
671 self._GenClass([], subtype.name, bases, tag_num)