OILS / asdl / gen_python.py View on Github | oilshell.org

631 lines, 408 significant
1#!/usr/bin/env python2
2"""gen_python.py: Generate Python code from an ASDL schema."""
3from __future__ import print_function
4
5from collections import defaultdict
6
7from asdl import ast
8from asdl import visitor
9from asdl.util import log
10
11_ = log # shut up lint
12
13_PRIMITIVES = {
14 'string': 'str',
15 'int': 'int',
16 'uint16': 'int',
17 'BigInt': 'mops.BigInt',
18 'float': 'float',
19 'bool': 'bool',
20 'any': 'Any',
21 # TODO: frontend/syntax.asdl should properly import id enum instead of
22 # hard-coding it here.
23 'id': 'Id_t',
24}
25
26
27def _MyPyType(typ):
28 """ASDL type to MyPy Type."""
29 if isinstance(typ, ast.ParameterizedType):
30
31 if typ.type_name == 'Dict':
32 k_type = _MyPyType(typ.children[0])
33 v_type = _MyPyType(typ.children[1])
34 return 'Dict[%s, %s]' % (k_type, v_type)
35
36 if typ.type_name == 'List':
37 return 'List[%s]' % _MyPyType(typ.children[0])
38
39 if typ.type_name == 'Optional':
40 return 'Optional[%s]' % _MyPyType(typ.children[0])
41
42 elif isinstance(typ, ast.NamedType):
43 if typ.resolved:
44 if isinstance(typ.resolved, ast.Sum): # includes SimpleSum
45 return '%s_t' % typ.name
46 if isinstance(typ.resolved, ast.Product):
47 return typ.name
48 if isinstance(typ.resolved, ast.Use):
49 return ast.TypeNameHeuristic(typ.name)
50
51 # 'id' falls through here
52 return _PRIMITIVES[typ.name]
53
54 else:
55 raise AssertionError()
56
57
58def _DefaultValue(typ, mypy_type):
59 """Values that the static CreateNull() constructor passes.
60
61 mypy_type is used to cast None, to maintain mypy --strict for ASDL.
62
63 We circumvent the type system on CreateNull(). Then the user is
64 responsible for filling in all the fields. If they do so, we can
65 rely on it when reading fields at runtime.
66 """
67 if isinstance(typ, ast.ParameterizedType):
68 type_name = typ.type_name
69
70 if type_name == 'Optional':
71 return "cast('%s', None)" % mypy_type
72
73 if type_name == 'List':
74 return "[] if alloc_lists else cast('%s', None)" % mypy_type
75
76 if type_name == 'Dict': # TODO: can respect alloc_dicts=True
77 return "cast('%s', None)" % mypy_type
78
79 raise AssertionError(type_name)
80
81 if isinstance(typ, ast.NamedType):
82 type_name = typ.name
83
84 if type_name == 'id': # hard-coded HACK
85 return '-1'
86
87 if type_name == 'int':
88 return '-1'
89
90 if type_name == 'BigInt':
91 return 'mops.BigInt(-1)'
92
93 if type_name == 'bool':
94 return 'False'
95
96 if type_name == 'float':
97 return '0.0' # or should it be NaN?
98
99 if type_name == 'string':
100 return "''"
101
102 if isinstance(typ.resolved, ast.SimpleSum):
103 sum_type = typ.resolved
104 # Just make it the first variant. We could define "Undef" for
105 # each enum, but it doesn't seem worth it.
106 return '%s_e.%s' % (type_name, sum_type.types[0].name)
107
108 # CompoundSum or Product type
109 return 'cast(%s, None)' % mypy_type
110
111 else:
112 raise AssertionError()
113
114
115def _HNodeExpr(abbrev, typ, var_name):
116 # type: (str, ast.TypeExpr, str) -> str
117 none_guard = False
118
119 if typ.IsOptional():
120 typ = typ.children[0] # descend one level
121
122 if isinstance(typ, ast.ParameterizedType):
123 code_str = '%s.%s()' % (var_name, abbrev)
124 none_guard = True
125
126 elif isinstance(typ, ast.NamedType):
127 type_name = typ.name
128
129 if type_name == 'bool':
130 code_str = "hnode.Leaf('T' if %s else 'F', color_e.OtherConst)" % var_name
131
132 elif type_name in ('int', 'uint16'):
133 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
134
135 elif type_name == 'BigInt':
136 code_str = 'hnode.Leaf(mops.ToStr(%s), color_e.OtherConst)' % var_name
137
138 elif type_name == 'float':
139 code_str = 'hnode.Leaf(str(%s), color_e.OtherConst)' % var_name
140
141 elif type_name == 'string':
142 code_str = 'NewLeaf(%s, color_e.StringConst)' % var_name
143
144 elif type_name == 'any': # TODO: Remove this. Used for value.Obj().
145 code_str = 'hnode.External(%s)' % var_name
146
147 elif type_name == 'id': # was meta.UserType
148 # This assumes it's Id, which is a simple SumType. TODO: Remove this.
149 code_str = 'hnode.Leaf(Id_str(%s), color_e.UserType)' % var_name
150
151 elif typ.resolved and isinstance(typ.resolved, ast.SimpleSum):
152 code_str = 'hnode.Leaf(%s_str(%s), color_e.TypeName)' % (type_name,
153 var_name)
154
155 else:
156 code_str = '%s.%s(trav=trav)' % (var_name, abbrev)
157 none_guard = True
158
159 else:
160 raise AssertionError()
161
162 return code_str, none_guard
163
164
165class GenMyPyVisitor(visitor.AsdlVisitor):
166 """Generate Python code with MyPy type annotations."""
167
168 def __init__(self,
169 f,
170 abbrev_mod_entries=None,
171 pretty_print_methods=True,
172 py_init_n=False,
173 simple_int_sums=None):
174
175 visitor.AsdlVisitor.__init__(self, f)
176 self.abbrev_mod_entries = abbrev_mod_entries or []
177 self.pretty_print_methods = pretty_print_methods
178 self.py_init_n = py_init_n
179
180 # For Id to use different code gen. It's used like an integer, not just
181 # like an enum.
182 self.simple_int_sums = simple_int_sums or []
183
184 self._shared_type_tags = {}
185 self._product_counter = 64 # matches asdl/gen_cpp.py
186
187 self._products = []
188 self._base_classes = defaultdict(list)
189
190 self._subtypes = []
191
192 def _EmitDict(self, name, d, depth):
193 self.Emit('_%s_str = {' % name, depth)
194 for k in sorted(d):
195 self.Emit('%d: %r,' % (k, d[k]), depth + 1)
196 self.Emit('}', depth)
197 self.Emit('', depth)
198
199 def VisitSimpleSum(self, sum, sum_name, depth):
200 int_to_str = {}
201 variants = []
202 for i, variant in enumerate(sum.types):
203 tag_num = i + 1
204 tag_str = '%s.%s' % (sum_name, variant.name)
205 int_to_str[tag_num] = tag_str
206 variants.append((variant, tag_num))
207
208 add_suffix = not ('no_namespace_suffix' in sum.generate)
209 gen_integers = 'integers' in sum.generate or 'uint16' in sum.generate
210
211 if gen_integers:
212 self.Emit('%s_t = int # type alias for integer' % sum_name)
213 self.Emit('')
214
215 i_name = ('%s_i' % sum_name) if add_suffix else sum_name
216
217 self.Emit('class %s(object):' % i_name, depth)
218
219 for variant, tag_num in variants:
220 line = ' %s = %d' % (variant.name, tag_num)
221 self.Emit(line, depth)
222
223 # Help in sizing array. Note that we're 1-based.
224 line = ' %s = %d' % ('ARRAY_SIZE', len(variants) + 1)
225 self.Emit(line, depth)
226
227 else:
228 # First emit a type
229 self.Emit('class %s_t(pybase.SimpleObj):' % sum_name, depth)
230 self.Emit(' pass', depth)
231 self.Emit('', depth)
232
233 # Now emit a namespace
234 e_name = ('%s_e' % sum_name) if add_suffix else sum_name
235 self.Emit('class %s(object):' % e_name, depth)
236
237 for variant, tag_num in variants:
238 line = ' %s = %s_t(%d)' % (variant.name, sum_name, tag_num)
239 self.Emit(line, depth)
240
241 self.Emit('', depth)
242
243 self._EmitDict(sum_name, int_to_str, depth)
244
245 self.Emit('def %s_str(val):' % sum_name, depth)
246 self.Emit(' # type: (%s_t) -> str' % sum_name, depth)
247 self.Emit(' return _%s_str[val]' % sum_name, depth)
248 self.Emit('', depth)
249
250 def _EmitCodeForField(self, abbrev, field, counter):
251 """Generate code that returns an hnode for a field."""
252 out_val_name = 'x%d' % counter
253
254 if field.typ.IsList():
255 iter_name = 'i%d' % counter
256
257 typ = field.typ
258 if typ.type_name == 'Optional': # descend one level
259 typ = typ.children[0]
260 item_type = typ.children[0]
261
262 self.Emit(' if self.%s is not None: # List' % field.name)
263 self.Emit(' %s = hnode.Array([])' % out_val_name)
264 self.Emit(' for %s in self.%s:' % (iter_name, field.name))
265 child_code_str, none_guard = _HNodeExpr(abbrev, item_type,
266 iter_name)
267
268 if none_guard: # e.g. for List[Optional[value_t]]
269 # TODO: could consolidate with asdl/runtime.py NewLeaf(), which
270 # also uses _ to mean None/nullptr
271 self.Emit(
272 ' h = (hnode.Leaf("_", color_e.OtherConst) if %s is None else %s)'
273 % (iter_name, child_code_str))
274 self.Emit(' %s.children.append(h)' % out_val_name)
275 else:
276 self.Emit(' %s.children.append(%s)' %
277 (out_val_name, child_code_str))
278
279 self.Emit(' L.append(Field(%r, %s))' %
280 (field.name, out_val_name))
281
282 elif field.typ.IsDict():
283 k = 'k%d' % counter
284 v = 'v%d' % counter
285
286 typ = field.typ
287 if typ.type_name == 'Optional': # descend one level
288 typ = typ.children[0]
289
290 k_typ = typ.children[0]
291 v_typ = typ.children[1]
292
293 k_code_str, _ = _HNodeExpr(abbrev, k_typ, k)
294 v_code_str, _ = _HNodeExpr(abbrev, v_typ, v)
295
296 self.Emit(' if self.%s is not None: # Dict' % field.name)
297 self.Emit(' m = hnode.Leaf("Dict", color_e.OtherConst)')
298 self.Emit(' %s = hnode.Array([m])' % out_val_name)
299 self.Emit(' for %s, %s in self.%s.iteritems():' %
300 (k, v, field.name))
301 self.Emit(' %s.children.append(%s)' %
302 (out_val_name, k_code_str))
303 self.Emit(' %s.children.append(%s)' %
304 (out_val_name, v_code_str))
305 self.Emit(' L.append(Field(%r, %s))' %
306 (field.name, out_val_name))
307
308 elif field.typ.IsOptional():
309 typ = field.typ.children[0]
310
311 self.Emit(' if self.%s is not None: # Optional' % field.name)
312 child_code_str, _ = _HNodeExpr(abbrev, typ, 'self.%s' % field.name)
313 self.Emit(' %s = %s' % (out_val_name, child_code_str))
314 self.Emit(' L.append(Field(%r, %s))' %
315 (field.name, out_val_name))
316
317 else:
318 var_name = 'self.%s' % field.name
319 code_str, obj_none_guard = _HNodeExpr(abbrev, field.typ, var_name)
320 depth = self.current_depth
321 if obj_none_guard: # to satisfy MyPy type system
322 self.Emit(' assert self.%s is not None' % field.name)
323 self.Emit(' %s = %s' % (out_val_name, code_str), depth)
324
325 self.Emit(' L.append(Field(%r, %s))' % (field.name, out_val_name),
326 depth)
327
328 def _GenClass(self,
329 fields,
330 class_name,
331 base_classes,
332 tag_num,
333 class_ns=''):
334 """Used for both Sum variants ("constructors") and Product types.
335
336 Args:
337 class_ns: for variants like value.Str
338 """
339 self.Emit('class %s(%s):' % (class_name, ', '.join(base_classes)))
340 self.Emit(' _type_tag = %d' % tag_num)
341
342 all_fields = fields
343
344 field_names = [f.name for f in all_fields]
345
346 quoted_fields = repr(tuple(field_names))
347 self.Emit(' __slots__ = %s' % quoted_fields)
348 self.Emit('')
349
350 #
351 # __init__
352 #
353
354 args = [f.name for f in fields]
355
356 self.Emit(' def __init__(self, %s):' % ', '.join(args))
357
358 arg_types = []
359 default_vals = []
360 for f in fields:
361 mypy_type = _MyPyType(f.typ)
362 arg_types.append(mypy_type)
363
364 d_str = _DefaultValue(f.typ, mypy_type)
365 default_vals.append(d_str)
366
367 self.Emit(' # type: (%s) -> None' % ', '.join(arg_types),
368 reflow=False)
369
370 if not all_fields:
371 self.Emit(' pass') # for types like NoOp
372
373 for f in fields:
374 # don't wrap the type comment
375 self.Emit(' self.%s = %s' % (f.name, f.name), reflow=False)
376
377 self.Emit('')
378
379 pretty_cls_name = '%s%s' % (class_ns, class_name)
380
381 if len(all_fields) and not self.py_init_n:
382 self.Emit(' @staticmethod')
383 self.Emit(' def CreateNull(alloc_lists=False):')
384 self.Emit(' # type: () -> %s%s' % (class_ns, class_name))
385 self.Emit(' return %s%s(%s)' %
386 (class_ns, class_name, ', '.join(default_vals)),
387 reflow=False)
388 self.Emit('')
389
390 if not self.pretty_print_methods:
391 return
392
393 #
394 # PrettyTree
395 #
396
397 self.Emit(' def PrettyTree(self, trav=None):')
398 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
399 self.Emit(' trav = trav or TraversalState()')
400 self.Emit(' heap_id = id(self)')
401 self.Emit(' if heap_id in trav.seen:')
402 # cut off recursion
403 self.Emit(' return hnode.AlreadySeen(heap_id)')
404 self.Emit(' trav.seen[heap_id] = True')
405
406 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
407 self.Emit(' L = out_node.fields')
408 self.Emit('')
409
410 # Use the runtime type to be more like asdl/format.py
411 for local_id, field in enumerate(all_fields):
412 #log('%s :: %s', field_name, field_desc)
413 self.Indent()
414 self._EmitCodeForField('PrettyTree', field, local_id)
415 self.Dedent()
416 self.Emit('')
417 self.Emit(' return out_node')
418 self.Emit('')
419
420 #
421 # _AbbreviatedTree
422 #
423
424 self.Emit(' def _AbbreviatedTree(self, trav=None):')
425 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
426 self.Emit(' trav = trav or TraversalState()')
427 self.Emit(' heap_id = id(self)')
428 self.Emit(' if heap_id in trav.seen:')
429 # cut off recursion
430 self.Emit(' return hnode.AlreadySeen(heap_id)')
431 self.Emit(' trav.seen[heap_id] = True')
432 self.Emit(' out_node = NewRecord(%r)' % pretty_cls_name)
433 self.Emit(' L = out_node.fields')
434
435 for local_id, field in enumerate(fields):
436 self.Indent()
437 self._EmitCodeForField('AbbreviatedTree', field, local_id)
438 self.Dedent()
439 self.Emit('')
440 self.Emit(' return out_node')
441 self.Emit('')
442
443 self.Emit(' def AbbreviatedTree(self, trav=None):')
444 self.Emit(' # type: (Optional[TraversalState]) -> hnode_t')
445 abbrev_name = '_%s' % class_name
446 if abbrev_name in self.abbrev_mod_entries:
447 self.Emit(' p = %s(self)' % abbrev_name)
448 # If the user function didn't return anything, fall back.
449 self.Emit(
450 ' return p if p else self._AbbreviatedTree(trav=trav)')
451 else:
452 self.Emit(' return self._AbbreviatedTree(trav=trav)')
453 self.Emit('')
454
455 def VisitCompoundSum(self, sum, sum_name, depth):
456 """Note that the following is_simple:
457
458 cflow = Break | Continue
459
460 But this is compound:
461
462 cflow = Break | Continue | Return(int val)
463
464 The generated code changes depending on which one it is.
465 """
466 #log('%d variants in %s', len(sum.types), sum_name)
467
468 # We emit THREE Python types for each meta.CompoundType:
469 #
470 # 1. enum for tag (cflow_e)
471 # 2. base class for inheritance (cflow_t)
472 # 3. namespace for classes (cflow) -- TODO: Get rid of this one.
473 #
474 # Should code use cflow_e.tag or isinstance()?
475 # isinstance() is better for MyPy I think. But tag is better for C++.
476 # int tag = static_cast<cflow>(node).tag;
477
478 int_to_str = {}
479
480 # enum for the tag
481 self.Emit('class %s_e(object):' % sum_name, depth)
482
483 for i, variant in enumerate(sum.types):
484 if variant.shared_type:
485 tag_num = self._shared_type_tags[variant.shared_type]
486 # e.g. DoubleQuoted may have base types expr_t, word_part_t
487 base_class = sum_name + '_t'
488 bases = self._base_classes[variant.shared_type]
489 if base_class in bases:
490 raise RuntimeError(
491 "Two tags in sum %r refer to product type %r" %
492 (sum_name, variant.shared_type))
493
494 else:
495 bases.append(base_class)
496 else:
497 tag_num = i + 1
498 self.Emit(' %s = %d' % (variant.name, tag_num), depth)
499 int_to_str[tag_num] = variant.name
500 self.Emit('', depth)
501
502 self._EmitDict(sum_name, int_to_str, depth)
503
504 self.Emit('def %s_str(tag, dot=True):' % sum_name, depth)
505 self.Emit(' # type: (int, bool) -> str', depth)
506 self.Emit(' v = _%s_str[tag]' % sum_name, depth)
507 self.Emit(' if dot:', depth)
508 self.Emit(' return "%s.%%s" %% v' % sum_name, depth)
509 self.Emit(' else:', depth)
510 self.Emit(' return v', depth)
511 self.Emit('', depth)
512
513 # the base class, e.g. 'oil_cmd'
514 self.Emit('class %s_t(pybase.CompoundObj):' % sum_name, depth)
515 self.Indent()
516 depth = self.current_depth
517
518 # To imitate C++ API
519 self.Emit('def tag(self):')
520 self.Emit(' # type: () -> int')
521 self.Emit(' return self._type_tag')
522
523 # This is what we would do in C++, but we don't need it in Python because
524 # every function is virtual.
525 if 0:
526 #if self.pretty_print_methods:
527 for abbrev in 'PrettyTree', '_AbbreviatedTree', 'AbbreviatedTree':
528 self.Emit('')
529 self.Emit('def %s(self):' % abbrev, depth)
530 self.Emit(' # type: () -> hnode_t', depth)
531 self.Indent()
532 depth = self.current_depth
533 self.Emit('UP_self = self', depth)
534 self.Emit('', depth)
535
536 for variant in sum.types:
537 if variant.shared_type:
538 subtype_name = variant.shared_type
539 else:
540 subtype_name = '%s__%s' % (sum_name, variant.name)
541
542 self.Emit(
543 'if self.tag() == %s_e.%s:' % (sum_name, variant.name),
544 depth)
545 self.Emit(' self = cast(%s, UP_self)' % subtype_name,
546 depth)
547 self.Emit(' return self.%s()' % abbrev, depth)
548
549 self.Emit('raise AssertionError()', depth)
550
551 self.Dedent()
552 depth = self.current_depth
553 else:
554 # Otherwise it's empty
555 self.Emit('pass', depth)
556
557 self.Dedent()
558 depth = self.current_depth
559 self.Emit('')
560
561 # Declare any zero argument singleton classes outside of the main
562 # "namespace" class.
563 for i, variant in enumerate(sum.types):
564 if variant.shared_type:
565 continue # Don't generate a class for shared types.
566 if len(variant.fields) == 0:
567 # We must use the old-style naming here, ie. command__NoOp, in order
568 # to support zero field variants as constants.
569 class_name = '%s__%s' % (sum_name, variant.name)
570 self._GenClass(variant.fields, class_name, (sum_name + '_t', ),
571 i + 1)
572
573 # Class that's just a NAMESPACE, e.g. for value.Str
574 self.Emit('class %s(object):' % sum_name, depth)
575
576 self.Indent()
577
578 for i, variant in enumerate(sum.types):
579 if variant.shared_type:
580 continue
581
582 if len(variant.fields) == 0:
583 self.Emit('%s = %s__%s()' %
584 (variant.name, sum_name, variant.name))
585 self.Emit('')
586 else:
587 # Use fully-qualified name, so we can have osh_cmd.Simple and
588 # oil_cmd.Simple.
589 fq_name = variant.name
590 self._GenClass(variant.fields,
591 fq_name, (sum_name + '_t', ),
592 i + 1,
593 class_ns=sum_name + '.')
594 self.Emit(' pass', depth) # in case every variant is first class
595
596 self.Dedent()
597 self.Emit('')
598
599 def VisitSubType(self, subtype):
600 self._shared_type_tags[subtype.name] = self._product_counter
601
602 # Also create these last. They may inherit from sum types that have yet
603 # to be defined.
604 self._subtypes.append((subtype, self._product_counter))
605 self._product_counter += 1
606
607 def VisitProduct(self, product, name, depth):
608 self._shared_type_tags[name] = self._product_counter
609 # Create a tuple of _GenClass args to create LAST. They may inherit from
610 # sum types that have yet to be defined.
611 self._products.append((product, name, depth, self._product_counter))
612 self._product_counter += 1
613
614 def EmitFooter(self):
615 # Now generate all the product types we deferred.
616 for args in self._products:
617 ast_node, name, depth, tag_num = args
618 # Figure out base classes AFTERWARD.
619 bases = self._base_classes[name]
620 if not bases:
621 bases = ('pybase.CompoundObj', )
622 self._GenClass(ast_node.fields, name, bases, tag_num)
623
624 for args in self._subtypes:
625 subtype, tag_num = args
626 # Figure out base classes AFTERWARD.
627 bases = self._base_classes[subtype.name]
628 if not bases:
629 bases = ('pybase.CompoundObj', )
630 bases.append(_MyPyType(subtype.base_class))
631 self._GenClass([], subtype.name, bases, tag_num)