OILS / asdl / gen_python.py View on Github | oilshell.org

649 lines, 419 significant
1#!/usr/bin/env python2
2"""gen_python.py: Generate Python code from an ASDL schema."""
3from __future__ import print_function
4
5from collections import defaultdict
6
7from asdl import ast
8from asdl import visitor
9from asdl.util import log
10
11_ = log # shut up lint
12
13_PRIMITIVES = {
14 'string': 'str',
15 'int': 'int',
16 'uint16': 'int',
17 'BigInt': 'mops.BigInt',
18 'float': 'float',
19 'bool': 'bool',
20 'any': 'Any',
21 # TODO: frontend/syntax.asdl should properly import id enum instead of
22 # hard-coding it here.
23 'id': 'Id_t',
24}
25
26
27def _MyPyType(typ):
28 """ASDL type to MyPy Type."""
29 if isinstance(typ, ast.ParameterizedType):
30
31 if typ.type_name == 'Dict':
32 k_type = _MyPyType(typ.children[0])
33 v_type = _MyPyType(typ.children[1])
34 return 'Dict[%s, %s]' % (k_type, v_type)
35
36 if typ.type_name == 'List':
37 return 'List[%s]' % _MyPyType(typ.children[0])
38
39 if typ.type_name == 'Optional':
40 return 'Optional[%s]' % _MyPyType(typ.children[0])
41
42 elif isinstance(typ, ast.NamedType):
43 if typ.resolved:
44 if isinstance(typ.resolved, ast.Sum): # includes SimpleSum
45 return '%s_t' % typ.name
46 if isinstance(typ.resolved, ast.Product):
47 return typ.name
48 if isinstance(typ.resolved, ast.Use):
49 return ast.TypeNameHeuristic(typ.name)
50
51 # 'id' falls through here
52 return _PRIMITIVES[typ.name]
53
54 else:
55 raise AssertionError()
56
57
58def _DefaultValue(typ, mypy_type):
59 """Values that the static CreateNull() constructor passes.
60
61 mypy_type is used to cast None, to maintain mypy --strict for ASDL.
62
63 We circumvent the type system on CreateNull(). Then the user is
64 responsible for filling in all the fields. If they do so, we can
65 rely on it when reading fields at runtime.
66 """
67 if isinstance(typ, ast.ParameterizedType):
68 type_name = typ.type_name
69
70 if type_name == 'Optional':
71 return "cast('%s', None)" % mypy_type
72
73 if type_name == 'List':
74 return "[] if alloc_lists else cast('%s', None)" % mypy_type
75
76 if type_name == 'Dict': # TODO: can respect alloc_dicts=True
77 return "cast('%s', None)" % mypy_type
78
79 raise AssertionError(type_name)
80
81 if isinstance(typ, ast.NamedType):
82 type_name = typ.name
83
84 if type_name == 'id': # hard-coded HACK
85 return '-1'
86
87 if type_name == 'int':
88 return '-1'
89
90 if type_name == 'BigInt':
91 return 'mops.BigInt(-1)'
92
93 if type_name == 'bool':
94 return 'False'
95
96 if type_name == 'float':
97 return '0.0' # or should it be NaN?
98
99 if type_name == 'string':
100 return "''"
101
102 if isinstance(typ.resolved, ast.SimpleSum):
103 sum_type = typ.resolved
104 # Just make it the first variant. We could define "Undef" for
105 # each enum, but it doesn't seem worth it.
106 return '%s_e.%s' % (type_name, sum_type.types[0].name)
107
108 # CompoundSum or Product type
109 return 'cast(%s, None)' % mypy_type
110
111 else:
112 raise AssertionError()
113
114
115def _HNodeExpr(abbrev, typ, var_name):
116 # type: (str, ast.TypeExpr, str) -> str
117 none_guard = False
118
119 if typ.IsOptional():
120 typ = typ.children[0] # descend one level
121
122 if isinstance(typ, ast.ParameterizedType):
123 code_str = '%s.%s()' % (var_name, abbrev)
124 none_guard = True
125
126 elif isinstance(typ, ast.NamedType):
127 type_name = typ.name
128
129 if type_name == 'bool':
130 code_str = "hnode.Leaf('T' if %s else 'F', color_e.OtherConst)" % var_name
131
132 elif type_name in ('int', 'uint16'):
133 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
134
135 elif type_name == 'BigInt':
136 code_str = 'hnode.Leaf(mops.ToStr(%s), color_e.OtherConst)' % var_name
137
138 elif type_name == 'float':
139 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
140
141 elif type_name == 'string':
142 code_str = 'NewLeaf(%s, color_e.StringConst)' % var_name
143
144 elif type_name == 'any': # TODO: Remove this. Used for value.Obj().
145 code_str = 'hnode.External(%s)' % var_name
146
147 elif type_name == 'id': # was meta.UserType
148 # This assumes it's Id, which is a simple SumType. TODO: Remove this.
149 code_str = 'hnode.Leaf(Id_str(%s), color_e.UserType)' % var_name
150
151 elif typ.resolved and isinstance(typ.resolved, ast.SimpleSum):
152 code_str = 'hnode.Leaf(%s_str(%s), color_e.TypeName)' % (type_name,
153 var_name)
154
155 else:
156 code_str = '%s.%s(trav=trav)' % (var_name, abbrev)
157 none_guard = True
158
159 else:
160 raise AssertionError()
161
162 return code_str, none_guard
163
164
165class GenMyPyVisitor(visitor.AsdlVisitor):
166 """Generate Python code with MyPy type annotations."""
167
168 def __init__(self,
169 f,
170 abbrev_mod_entries=None,
171 pretty_print_methods=True,
172 py_init_n=False,
173 simple_int_sums=None):
174
175 visitor.AsdlVisitor.__init__(self, f)
176 self.abbrev_mod_entries = abbrev_mod_entries or []
177 self.pretty_print_methods = pretty_print_methods
178 self.py_init_n = py_init_n
179
180 # For Id to use different code gen. It's used like an integer, not just
181 # like an enum.
182 self.simple_int_sums = simple_int_sums or []
183
184 self._shared_type_tags = {}
185 self._product_counter = 64 # matches asdl/gen_cpp.py
186
187 self._products = []
188 self._base_classes = defaultdict(list)
189
190 self._subtypes = []
191
192 def _EmitDict(self, name, d, depth):
193 self.Emit('_%s_str = {' % name, depth)
194 for k in sorted(d):
195 self.Emit('%d: %r,' % (k, d[k]), depth + 1)
196 self.Emit('}', depth)
197 self.Emit('', depth)
198
199 def VisitSimpleSum(self, sum, sum_name, depth):
200 int_to_str = {}
201 variants = []
202 for i, variant in enumerate(sum.types):
203 tag_num = i + 1
204 tag_str = '%s.%s' % (sum_name, variant.name)
205 int_to_str[tag_num] = tag_str
206 variants.append((variant, tag_num))
207
208 add_suffix = not ('no_namespace_suffix' in sum.generate)
209 gen_integers = 'integers' in sum.generate or 'uint16' in sum.generate
210
211 if gen_integers:
212 self.Emit('%s_t = int # type alias for integer' % sum_name)
213 self.Emit('')
214
215 i_name = ('%s_i' % sum_name) if add_suffix else sum_name
216
217 self.Emit('class %s(object):' % i_name, depth)
218
219 for variant, tag_num in variants:
220 line = ' %s = %d' % (variant.name, tag_num)
221 self.Emit(line, depth)
222
223 # Help in sizing array. Note that we're 1-based.
224 line = ' %s = %d' % ('ARRAY_SIZE', len(variants) + 1)
225 self.Emit(line, depth)
226
227 else:
228 # First emit a type
229 self.Emit('class %s_t(pybase.SimpleObj):' % sum_name, depth)
230 self.Emit(' pass', depth)
231 self.Emit('', depth)
232
233 # Now emit a namespace
234 e_name = ('%s_e' % sum_name) if add_suffix else sum_name
235 self.Emit('class %s(object):' % e_name, depth)
236
237 for variant, tag_num in variants:
238 line = ' %s = %s_t(%d)' % (variant.name, sum_name, tag_num)
239 self.Emit(line, depth)
240
241 self.Emit('', depth)
242
243 self._EmitDict(sum_name, int_to_str, depth)
244
245 self.Emit('def %s_str(val):' % sum_name, depth)
246 self.Emit(' # type: (%s_t) -> str' % sum_name, depth)
247 self.Emit(' return _%s_str[val]' % sum_name, depth)
248 self.Emit('', depth)
249
250 def _EmitCodeForField(self, abbrev, field, counter):
251 """Generate code that returns an hnode for a field."""
252 out_val_name = 'x%d' % counter
253
254 if field.typ.IsList():
255 iter_name = 'i%d' % counter
256
257 typ = field.typ
258 if typ.type_name == 'Optional': # descend one level
259 typ = typ.children[0]
260 item_type = typ.children[0]
261
262 self.Emit(' if self.%s is not None: # List' % field.name)
263 self.Emit(' %s = hnode.Array([])' % out_val_name)
264 self.Emit(' for %s in self.%s:' % (iter_name, field.name))
265 child_code_str, none_guard = _HNodeExpr(abbrev, item_type,
266 iter_name)
267
268 if none_guard: # e.g. for List[Optional[value_t]]
269 # TODO: could consolidate with asdl/runtime.py NewLeaf(), which
270 # also uses _ to mean None/nullptr
271 self.Emit(
272 ' h = (hnode.Leaf("_", color_e.OtherConst) if %s is None else %s)'
273 % (iter_name, child_code_str))
274 self.Emit(' %s.children.append(h)' % out_val_name)
275 else:
276 self.Emit(' %s.children.append(%s)' %
277 (out_val_name, child_code_str))
278
279 self.Emit(' L.append(Field(%r, %s))' %
280 (field.name, out_val_name))
281
282 elif field.typ.IsDict():
283 k = 'k%d' % counter
284 v = 'v%d' % counter
285
286 typ = field.typ
287 if typ.type_name == 'Optional': # descend one level
288 typ = typ.children[0]
289
290 k_typ = typ.children[0]
291 v_typ = typ.children[1]
292
293 k_code_str, _ = _HNodeExpr(abbrev, k_typ, k)
294 v_code_str, _ = _HNodeExpr(abbrev, v_typ, v)
295
296 self.Emit(' if self.%s is not None: # Dict' % field.name)
297 self.Emit(' m = hnode.Leaf("Dict", color_e.OtherConst)')
298 self.Emit(' %s = hnode.Array([m])' % out_val_name)
299 self.Emit(' for %s, %s in self.%s.iteritems():' %
300 (k, v, field.name))
301 self.Emit(' %s.children.append(%s)' %
302 (out_val_name, k_code_str))
303 self.Emit(' %s.children.append(%s)' %
304 (out_val_name, v_code_str))
305 self.Emit(' L.append(Field(%r, %s))' %
306 (field.name, out_val_name))
307
308 elif field.typ.IsOptional():
309 typ = field.typ.children[0]
310
311 self.Emit(' if self.%s is not None: # Optional' % field.name)
312 child_code_str, _ = _HNodeExpr(abbrev, typ, 'self.%s' % field.name)
313 self.Emit(' %s = %s' % (out_val_name, child_code_str))
314 self.Emit(' L.append(Field(%r, %s))' %
315 (field.name, out_val_name))
316
317 else:
318 var_name = 'self.%s' % field.name
319 code_str, obj_none_guard = _HNodeExpr(abbrev, field.typ, var_name)
320 depth = self.current_depth
321 if obj_none_guard: # to satisfy MyPy type system
322 self.Emit(' assert self.%s is not None' % field.name)
323 self.Emit(' %s = %s' % (out_val_name, code_str), depth)
324
325 self.Emit(' L.append(Field(%r, %s))' % (field.name, out_val_name),
326 depth)
327
328 def _GenClass(self,
329 fields,
330 class_name,
331 base_classes,
332 tag_num,
333 class_ns=''):
334 """Used for both Sum variants ("constructors") and Product types.
335
336 Args:
337 class_ns: for variants like value.Str
338 """
339 self.Emit('class %s(%s):' % (class_name, ', '.join(base_classes)))
340 self.Emit(' _type_tag = %d' % tag_num)
341
342 all_fields = fields
343
344 field_names = [f.name for f in all_fields]
345
346 quoted_fields = repr(tuple(field_names))
347 self.Emit(' __slots__ = %s' % quoted_fields)
348 self.Emit('')
349
350 #
351 # __init__
352 #
353
354 args = [f.name for f in fields]
355
356 self.Emit(' def __init__(self, %s):' % ', '.join(args))
357
358 arg_types = []
359 default_vals = []
360 for f in fields:
361 mypy_type = _MyPyType(f.typ)
362 arg_types.append(mypy_type)
363
364 d_str = _DefaultValue(f.typ, mypy_type)
365 default_vals.append(d_str)
366
367 self.Emit(' # type: (%s) -> None' % ', '.join(arg_types),
368 reflow=False)
369
370 if not all_fields:
371 self.Emit(' pass') # for types like NoOp
372
373 for f in fields:
374 # don't wrap the type comment
375 self.Emit(' self.%s = %s' % (f.name, f.name), reflow=False)
376
377 self.Emit('')
378
379 pretty_cls_name = '%s%s' % (class_ns, class_name)
380
381 if len(all_fields) and not self.py_init_n:
382 self.Emit(' @staticmethod')
383 self.Emit(' def CreateNull(alloc_lists=False):')
384 self.Emit(' # type: () -> %s%s' % (class_ns, class_name))
385 self.Emit(' return %s%s(%s)' %
386 (class_ns, class_name, ', '.join(default_vals)),
387 reflow=False)
388 self.Emit('')
389
390 if not self.pretty_print_methods:
391 return
392
393 is_list = any(b.startswith('List[') for b in base_classes)
394 is_dict = any(b.startswith('Dict[') for b in base_classes)
395 assert not (is_list and is_dict), base_classes
396
397 #
398 # PrettyTree
399 #
400
401 self.Emit(' def PrettyTree(self, trav=None):')
402 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
403 self.Emit(' trav = trav or TraversalState()')
404 self.Emit(' heap_id = id(self)')
405 self.Emit(' if heap_id in trav.seen:')
406 # cut off recursion
407 self.Emit(' return hnode.AlreadySeen(heap_id)')
408 self.Emit(' trav.seen[heap_id] = True')
409
410 if is_list:
411 #self.Emit(' out_node = hnode.Subtype(%r, [c.PrettyTree() for c in self])' % pretty_cls_name)
412 # TODO: emit hnode.Subtype
413 self.Emit(
414 ' out_node = hnode.Array([c.PrettyTree() for c in self])')
415 else:
416 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
417 self.Emit(' L = out_node.fields')
418 self.Emit('')
419
420 # Use the runtime type to be more like asdl/format.py
421 for local_id, field in enumerate(all_fields):
422 #log('%s :: %s', field_name, field_desc)
423 self.Indent()
424 self._EmitCodeForField('PrettyTree', field, local_id)
425 self.Dedent()
426 self.Emit('')
427 self.Emit(' return out_node')
428 self.Emit('')
429
430 #
431 # _AbbreviatedTree
432 #
433
434 self.Emit(' def _AbbreviatedTree(self, trav=None):')
435 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
436 self.Emit(' trav = trav or TraversalState()')
437 self.Emit(' heap_id = id(self)')
438 self.Emit(' if heap_id in trav.seen:')
439 # cut off recursion
440 self.Emit(' return hnode.AlreadySeen(heap_id)')
441 self.Emit(' trav.seen[heap_id] = True')
442
443 if is_list:
444 #self.Emit(' out_node = hnode.Subtype(%r, [c.PrettyTree() for c in self])' % pretty_cls_name)
445 # TODO: emit hnode.Subtype
446 self.Emit(
447 ' out_node = hnode.Array([c.PrettyTree() for c in self])')
448 else:
449 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
450 self.Emit(' L = out_node.fields')
451
452 for local_id, field in enumerate(fields):
453 self.Indent()
454 self._EmitCodeForField('AbbreviatedTree', field, local_id)
455 self.Dedent()
456 self.Emit('')
457
458 self.Emit(' return out_node')
459 self.Emit('')
460
461 self.Emit(' def AbbreviatedTree(self, trav=None):')
462 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
463 abbrev_name = '_%s' % class_name
464 if abbrev_name in self.abbrev_mod_entries:
465 self.Emit(' p = %s(self)' % abbrev_name)
466 # If the user function didn't return anything, fall back.
467 self.Emit(
468 ' return p if p else self._AbbreviatedTree(trav=trav)')
469 else:
470 self.Emit(' return self._AbbreviatedTree(trav=trav)')
471 self.Emit('')
472
473 def VisitCompoundSum(self, sum, sum_name, depth):
474 """Note that the following is_simple:
475
476 cflow = Break | Continue
477
478 But this is compound:
479
480 cflow = Break | Continue | Return(int val)
481
482 The generated code changes depending on which one it is.
483 """
484 #log('%d variants in %s', len(sum.types), sum_name)
485
486 # We emit THREE Python types for each meta.CompoundType:
487 #
488 # 1. enum for tag (cflow_e)
489 # 2. base class for inheritance (cflow_t)
490 # 3. namespace for classes (cflow) -- TODO: Get rid of this one.
491 #
492 # Should code use cflow_e.tag or isinstance()?
493 # isinstance() is better for MyPy I think. But tag is better for C++.
494 # int tag = static_cast<cflow>(node).tag;
495
496 int_to_str = {}
497
498 # enum for the tag
499 self.Emit('class %s_e(object):' % sum_name, depth)
500
501 for i, variant in enumerate(sum.types):
502 if variant.shared_type:
503 tag_num = self._shared_type_tags[variant.shared_type]
504 # e.g. DoubleQuoted may have base types expr_t, word_part_t
505 base_class = sum_name + '_t'
506 bases = self._base_classes[variant.shared_type]
507 if base_class in bases:
508 raise RuntimeError(
509 "Two tags in sum %r refer to product type %r" %
510 (sum_name, variant.shared_type))
511
512 else:
513 bases.append(base_class)
514 else:
515 tag_num = i + 1
516 self.Emit(' %s = %d' % (variant.name, tag_num), depth)
517 int_to_str[tag_num] = variant.name
518 self.Emit('', depth)
519
520 self._EmitDict(sum_name, int_to_str, depth)
521
522 self.Emit('def %s_str(tag, dot=True):' % sum_name, depth)
523 self.Emit(' # type: (int, bool) -> str', depth)
524 self.Emit(' v = _%s_str[tag]' % sum_name, depth)
525 self.Emit(' if dot:', depth)
526 self.Emit(' return "%s.%%s" %% v' % sum_name, depth)
527 self.Emit(' else:', depth)
528 self.Emit(' return v', depth)
529 self.Emit('', depth)
530
531 # the base class, e.g. 'oil_cmd'
532 self.Emit('class %s_t(pybase.CompoundObj):' % sum_name, depth)
533 self.Indent()
534 depth = self.current_depth
535
536 # To imitate C++ API
537 self.Emit('def tag(self):')
538 self.Emit(' # type: () -> int')
539 self.Emit(' return self._type_tag')
540
541 # This is what we would do in C++, but we don't need it in Python because
542 # every function is virtual.
543 if 0:
544 #if self.pretty_print_methods:
545 for abbrev in 'PrettyTree', '_AbbreviatedTree', 'AbbreviatedTree':
546 self.Emit('')
547 self.Emit('def %s(self):' % abbrev, depth)
548 self.Emit(' # type: () -> hnode_t', depth)
549 self.Indent()
550 depth = self.current_depth
551 self.Emit('UP_self = self', depth)
552 self.Emit('', depth)
553
554 for variant in sum.types:
555 if variant.shared_type:
556 subtype_name = variant.shared_type
557 else:
558 subtype_name = '%s__%s' % (sum_name, variant.name)
559
560 self.Emit(
561 'if self.tag() == %s_e.%s:' % (sum_name, variant.name),
562 depth)
563 self.Emit(' self = cast(%s, UP_self)' % subtype_name,
564 depth)
565 self.Emit(' return self.%s()' % abbrev, depth)
566
567 self.Emit('raise AssertionError()', depth)
568
569 self.Dedent()
570 depth = self.current_depth
571 else:
572 # Otherwise it's empty
573 self.Emit('pass', depth)
574
575 self.Dedent()
576 depth = self.current_depth
577 self.Emit('')
578
579 # Declare any zero argument singleton classes outside of the main
580 # "namespace" class.
581 for i, variant in enumerate(sum.types):
582 if variant.shared_type:
583 continue # Don't generate a class for shared types.
584 if len(variant.fields) == 0:
585 # We must use the old-style naming here, ie. command__NoOp, in order
586 # to support zero field variants as constants.
587 class_name = '%s__%s' % (sum_name, variant.name)
588 self._GenClass(variant.fields, class_name, (sum_name + '_t', ),
589 i + 1)
590
591 # Class that's just a NAMESPACE, e.g. for value.Str
592 self.Emit('class %s(object):' % sum_name, depth)
593
594 self.Indent()
595
596 for i, variant in enumerate(sum.types):
597 if variant.shared_type:
598 continue
599
600 if len(variant.fields) == 0:
601 self.Emit('%s = %s__%s()' %
602 (variant.name, sum_name, variant.name))
603 self.Emit('')
604 else:
605 # Use fully-qualified name, so we can have osh_cmd.Simple and
606 # oil_cmd.Simple.
607 fq_name = variant.name
608 self._GenClass(variant.fields,
609 fq_name, (sum_name + '_t', ),
610 i + 1,
611 class_ns=sum_name + '.')
612 self.Emit(' pass', depth) # in case every variant is first class
613
614 self.Dedent()
615 self.Emit('')
616
617 def VisitSubType(self, subtype):
618 self._shared_type_tags[subtype.name] = self._product_counter
619
620 # Also create these last. They may inherit from sum types that have yet
621 # to be defined.
622 self._subtypes.append((subtype, self._product_counter))
623 self._product_counter += 1
624
625 def VisitProduct(self, product, name, depth):
626 self._shared_type_tags[name] = self._product_counter
627 # Create a tuple of _GenClass args to create LAST. They may inherit from
628 # sum types that have yet to be defined.
629 self._products.append((product, name, depth, self._product_counter))
630 self._product_counter += 1
631
632 def EmitFooter(self):
633 # Now generate all the product types we deferred.
634 for args in self._products:
635 ast_node, name, depth, tag_num = args
636 # Figure out base classes AFTERWARD.
637 bases = self._base_classes[name]
638 if not bases:
639 bases = ('pybase.CompoundObj', )
640 self._GenClass(ast_node.fields, name, bases, tag_num)
641
642 for args in self._subtypes:
643 subtype, tag_num = args
644 # Figure out base classes AFTERWARD.
645 bases = self._base_classes[subtype.name]
646 if not bases:
647 bases = ('pybase.CompoundObj', )
648 bases.append(_MyPyType(subtype.base_class))
649 self._GenClass([], subtype.name, bases, tag_num)