OILS / mycpp / control_flow_pass.py View on Github | oilshell.org

559 lines, 354 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5from typing import overload, Union, Optional, Dict
6
7import mypy
8from mypy.nodes import (Block, Expression, Statement, ExpressionStmt, StrExpr,
9 CallExpr, FuncDef, IfStmt, NameExpr, MemberExpr,
10 IndexExpr, TupleExpr, IntExpr)
11
12from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
13
14from mycpp.crash import catch_errors
15from mycpp.util import join_name, split_py_name
16from mycpp.visitor import SimpleVisitor, T
17from mycpp import util
18from mycpp import pass_state
19
20
21class UnsupportedException(Exception):
22 pass
23
24
25def GetObjectTypeName(t: Type) -> util.SymbolPath:
26 if isinstance(t, Instance):
27 return split_py_name(t.type.fullname)
28
29 elif isinstance(t, UnionType):
30 assert len(t.items) == 2
31 if isinstance(t.items[0], NoneTyp):
32 return GetObjectTypeName(t.items[1])
33
34 return GetObjectTypeName(t.items[0])
35
36 assert False, t
37
38
39class Build(SimpleVisitor):
40
41 def __init__(self, types: Dict[Expression, Type], virtual, local_vars,
42 dot_exprs):
43
44 self.types = types
45 self.cfgs = collections.defaultdict(pass_state.ControlFlowGraph)
46 self.current_statement_id = None
47 self.current_class_name = None
48 self.current_func_node = None
49 self.loop_stack = []
50 self.virtual = virtual
51 self.local_vars = local_vars
52 self.dot_exprs = dot_exprs
53 self.heap_counter = 0
54 self.callees = {} # statement object -> SymbolPath of the callee
55 self.current_lval = None
56
57 def current_cfg(self):
58 if not self.current_func_node:
59 return None
60
61 return self.cfgs[split_py_name(self.current_func_node.fullname)]
62
63 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
64 """
65 Returns the fully qualified name of the callee in the given call
66 expression.
67
68 Member functions are prefixed by the names of the classes that contain
69 them. For example, the name of the callee in the last statement of the
70 snippet below is `module.SomeObject.Foo`.
71
72 x = module.SomeObject()
73 x.Foo()
74
75 Free-functions defined in the local module are referred to by their
76 normal fully qualified names. The function `foo` in a module called
77 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
78 in imported modules are named the same way.
79 """
80
81 if isinstance(o.callee, NameExpr):
82 return split_py_name(o.callee.fullname)
83
84 elif isinstance(o.callee, MemberExpr):
85 if isinstance(o.callee.expr, NameExpr):
86 is_module = isinstance(self.dot_exprs.get(o.callee),
87 pass_state.ModuleMember)
88 if is_module:
89 return split_py_name(
90 o.callee.expr.fullname) + (o.callee.name, )
91
92 elif o.callee.expr.name == 'self':
93 assert self.current_class_name
94 return self.current_class_name + (o.callee.name, )
95
96 else:
97 local_type = None
98 for name, t in self.local_vars.get(self.current_func_node,
99 []):
100 if name == o.callee.expr.name:
101 local_type = t
102 break
103
104 if local_type:
105 if isinstance(local_type, str):
106 return split_py_name(local_type) + (
107 o.callee.name, )
108
109 elif isinstance(local_type, Instance):
110 return split_py_name(
111 local_type.type.fullname) + (o.callee.name, )
112
113 elif isinstance(local_type, UnionType):
114 assert len(local_type.items) == 2
115 return split_py_name(
116 local_type.items[0].type.fullname) + (
117 o.callee.expr.name, )
118
119 else:
120 assert not isinstance(local_type, CallableType)
121 # primitive type or string. don't care.
122 return None
123
124 else:
125 # context or exception handler. probably safe to ignore.
126 return None
127
128 else:
129 t = self.types.get(o.callee.expr)
130 if isinstance(t, Instance):
131 return split_py_name(t.type.fullname) + (o.callee.name, )
132
133 elif isinstance(t, UnionType):
134 assert len(t.items) == 2
135 return split_py_name(
136 t.items[0].type.fullname) + (o.callee.name, )
137
138 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
139 None):
140 return split_py_name(
141 o.callee.expr.fullname) + (o.callee.name, )
142
143 else:
144 # constructors of things that we don't care about.
145 return None
146
147 # Don't currently get here
148 raise AssertionError()
149
150 def get_ref_name(self, expr: Expression) -> Optional[util.SymbolPath]:
151 """
152 To do dataflow analysis we need to track changes to objects, which
153 requires naming them. This function returns the name of the object
154 referred to by the given expression. If the expression doesn't refer to
155 an object or variable it returns None.
156
157 Objects are named slightly differently than they appear in the source
158 code.
159
160 Objects referenced by local variables are referred to by the name of the
161 local. For example, the name of the object in both statements below is
162 `x`.
163
164 x = module.SomeObject()
165 x = None
166
167 Member expressions are named after the parent object's type. For
168 example, the names of the objects in the member assignment statements
169 below are both `module.SomeObject.member_a`. This makes it possible to
170 track data flow across object members without having to track individual
171 heap objects, which would increase the search space for analyses and
172 slow things down.
173
174 x = module.SomeObject()
175 y = module.SomeObject()
176 x.member_a = 'foo'
177 y.member_a = 'bar'
178
179 Index expressions are named after their bases, for the same reasons as
180 member expressions. The coarse-grained precision should lead to an
181 over-approximation of where objects are in use, but should not miss any
182 references. This should be fine for our purposes. In the snippet below
183 the last two assignments are named `x` and `module.SomeObject.a_list`.
184
185 x = [None] # list[Thing]
186 y = module.SomeObject()
187 x[0] = Thing()
188 y.a_list[1] = Blah()
189
190 Index expressions over tuples are treated differently, though. Tuples
191 have a fixed size, tend to be small, and their elements have distinct
192 types. So, each element can be (and probably needs to be) individually
193 named. In the snippet below, the name of the RHS in the second
194 assignment is `t.0`.
195
196 t = (1, 2, 3, 4)
197 x = t[0]
198
199 The examples above all deal with assignments, but these rules apply to
200 any expression that uses an object or variable.
201 """
202 if isinstance(expr,
203 NameExpr) and expr.name not in {'True', 'False', 'None'}:
204 return (expr.name, )
205
206 elif isinstance(expr, MemberExpr):
207 dot_expr = self.dot_exprs[expr]
208 if isinstance(dot_expr, pass_state.ModuleMember):
209 return dot_expr.module_path + (dot_expr.member, )
210
211 elif isinstance(dot_expr, pass_state.HeapObjectMember):
212 obj_name = self.get_ref_name(dot_expr.object_expr)
213 if obj_name:
214 # XXX: add a new case like pass_state.ExpressionMember for
215 # cases when the LHS of . isn't a reference (e.g.
216 # builtin/assign_osh.py:54)
217 return obj_name + (dot_expr.member, )
218
219 elif isinstance(dot_expr, pass_state.StackObjectMember):
220 return self.get_ref_name(
221 dot_expr.object_expr) + (dot_expr.member, )
222
223 elif isinstance(expr, IndexExpr):
224 if isinstance(self.types[expr.base], TupleType):
225 assert isinstance(expr.index, IntExpr)
226 return self.get_ref_name(expr.base) + (str(expr.index.value), )
227
228 return self.get_ref_name(expr.base)
229
230 return None
231
232 #
233 # COPIED from IRBuilder
234 #
235
236 @overload
237 def accept(self, node: Expression) -> T:
238 ...
239
240 @overload
241 def accept(self, node: Statement) -> None:
242 ...
243
244 def accept(self, node: Union[Statement, Expression]) -> Optional[T]:
245 with catch_errors(self.module_path, node.line):
246 if isinstance(node, Expression):
247 try:
248 res = node.accept(self)
249 #res = self.coerce(res, self.node_type(node), node.line)
250
251 # If we hit an error during compilation, we want to
252 # keep trying, so we can produce more error
253 # messages. Generate a temp of the right type to keep
254 # from causing more downstream trouble.
255 except UnsupportedException:
256 res = self.alloc_temp(self.node_type(node))
257 return res
258 else:
259 try:
260 cfg = self.current_cfg()
261 # Most statements have empty visitors because they don't
262 # require any special logic. Create statements for them
263 # here. Don't create statements from blocks to avoid
264 # stuttering.
265 if cfg and not isinstance(node, Block):
266 self.current_statement_id = cfg.AddStatement()
267
268 node.accept(self)
269 except UnsupportedException:
270 pass
271 return None
272
273 # Not in superclasses:
274
275 def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T:
276 if util.ShouldSkipPyFile(o):
277 return
278
279 self.module_path = o.path
280
281 for node in o.defs:
282 # skip module docstring
283 if isinstance(node, ExpressionStmt) and isinstance(
284 node.expr, StrExpr):
285 continue
286 self.accept(node)
287
288 # Statements
289
290 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> T:
291 cfg = self.current_cfg()
292 with pass_state.CfgLoopContext(
293 cfg, entry=self.current_statement_id) as loop:
294 self.accept(o.expr)
295 self.loop_stack.append(loop)
296 self.accept(o.body)
297 self.loop_stack.pop()
298
299 def _handle_switch(self, expr, o, cfg):
300 assert len(o.body.body) == 1, o.body.body
301 if_node = o.body.body[0]
302 assert isinstance(if_node, IfStmt), if_node
303 cases = []
304 default_block = util._collect_cases(self.module_path, if_node, cases)
305 with pass_state.CfgBranchContext(
306 cfg, self.current_statement_id) as branch_ctx:
307 for expr, body in cases:
308 self.accept(expr)
309 assert expr is not None, expr
310 with branch_ctx.AddBranch():
311 self.accept(body)
312
313 if default_block:
314 with branch_ctx.AddBranch():
315 self.accept(default_block)
316
317 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> T:
318 cfg = self.current_cfg()
319 assert len(o.expr) == 1, o.expr
320 expr = o.expr[0]
321 assert isinstance(expr, CallExpr), expr
322 self.accept(expr)
323
324 callee_name = expr.callee.name
325 if callee_name == 'switch':
326 self._handle_switch(expr, o, cfg)
327 elif callee_name == 'str_switch':
328 self._handle_switch(expr, o, cfg)
329 elif callee_name == 'tagswitch':
330 self._handle_switch(expr, o, cfg)
331 else:
332 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
333 self.accept(o.body)
334
335 def visit_func_def(self, o: 'mypy.nodes.FuncDef') -> T:
336 if o.name == '__repr__': # Don't translate
337 return
338
339 # For virtual methods, pretend that the method on the base class calls
340 # the same method on every subclass. This way call sites using the
341 # abstract base class will over-approximate the set of call paths they
342 # can take when checking if they can reach MaybeCollect().
343 if self.current_class_name and self.virtual.IsVirtual(
344 self.current_class_name, o.name):
345 key = (self.current_class_name, o.name)
346 base = self.virtual.virtuals[key]
347 if base:
348 sub = join_name(self.current_class_name + (o.name, ),
349 delim='.')
350 base_key = base[0] + (base[1], )
351 cfg = self.cfgs[base_key]
352 cfg.AddFact(0, pass_state.FunctionCall(sub))
353
354 self.current_func_node = o
355 cfg = self.current_cfg()
356 for arg in o.arguments:
357 cfg.AddFact(0,
358 pass_state.Definition((arg.variable.name, ), '$Empty'))
359
360 self.accept(o.body)
361 self.current_func_node = None
362 self.current_statement_id = None
363
364 def visit_class_def(self, o: 'mypy.nodes.ClassDef') -> T:
365 self.current_class_name = split_py_name(o.fullname)
366 for stmt in o.defs.body:
367 # Ignore things that look like docstrings
368 if (isinstance(stmt, ExpressionStmt) and
369 isinstance(stmt.expr, StrExpr)):
370 continue
371
372 if isinstance(stmt, FuncDef) and stmt.name == '__repr__':
373 continue
374
375 self.accept(stmt)
376
377 self.current_class_name = None
378
379 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> T:
380 cfg = self.current_cfg()
381 with pass_state.CfgLoopContext(
382 cfg, entry=self.current_statement_id) as loop:
383 self.accept(o.expr)
384 self.loop_stack.append(loop)
385 self.accept(o.body)
386 self.loop_stack.pop()
387
388 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> T:
389 cfg = self.current_cfg()
390 if cfg:
391 cfg.AddDeadend(self.current_statement_id)
392
393 if o.expr:
394 self.accept(o.expr)
395
396 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> T:
397 cfg = self.current_cfg()
398
399 if util.ShouldVisitIfExpr(o):
400 for expr in o.expr:
401 self.accept(expr)
402
403 with pass_state.CfgBranchContext(
404 cfg, self.current_statement_id) as branch_ctx:
405 if util.ShouldVisitIfBody(o):
406 with branch_ctx.AddBranch():
407 for node in o.body:
408 self.accept(node)
409
410 if util.ShouldVisitElseBody(o):
411 with branch_ctx.AddBranch():
412 self.accept(o.else_body)
413
414 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> T:
415 if len(self.loop_stack):
416 self.loop_stack[-1].AddBreak(self.current_statement_id)
417
418 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> T:
419 if len(self.loop_stack):
420 self.loop_stack[-1].AddContinue(self.current_statement_id)
421
422 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> T:
423 cfg = self.current_cfg()
424 if cfg:
425 cfg.AddDeadend(self.current_statement_id)
426
427 if o.expr:
428 self.accept(o.expr)
429
430 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> T:
431 cfg = self.current_cfg()
432 with pass_state.CfgBranchContext(cfg,
433 self.current_statement_id) as try_ctx:
434 with try_ctx.AddBranch() as try_block:
435 self.accept(o.body)
436
437 for t, v, handler in zip(o.types, o.vars, o.handlers):
438 with try_ctx.AddBranch(try_block.exit):
439 self.accept(handler)
440
441 def visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt') -> T:
442 cfg = self.current_cfg()
443 if cfg:
444 assert len(o.lvalues) == 1
445 lval = o.lvalues[0]
446 lval_names = []
447 if isinstance(lval, TupleExpr):
448 lval_names.extend(
449 [self.get_ref_name(item) for item in lval.items])
450
451 else:
452 lval_names.append(self.get_ref_name(lval))
453
454 assert lval_names, o
455
456 rval_type = self.types[o.rvalue]
457 rval_names = []
458 if isinstance(o.rvalue, CallExpr):
459 # The RHS is either an object constructor or something that
460 # returns a primitive type (e.g. Tuple[int, int] or str).
461 # XXX: When we add inter-procedural analysis we should treat
462 # these not as definitions but as some new kind of assignment.
463 rval_names = [None for _ in lval_names]
464
465 elif isinstance(o.rvalue, TupleExpr) and len(lval_names) == 1:
466 # We're constructing a tuple. Since tuples have have a fixed
467 # (and usually small) size, we can name each of the
468 # elements.
469 base = lval_names[0]
470 lval_names = [
471 base + (str(i), ) for i in range(len(o.rvalue.items))
472 ]
473 rval_names = [
474 self.get_ref_name(item) for item in o.rvalue.items
475 ]
476
477 elif isinstance(rval_type, TupleType):
478 # We're unpacking a tuple. Like the tuple construction case,
479 # give each element a name.
480 rval_name = self.get_ref_name(o.rvalue)
481 assert rval_name, o.rvalue
482 rval_names = [
483 rval_name + (str(i), ) for i in range(len(lval_names))
484 ]
485
486 else:
487 rval_names = [self.get_ref_name(o.rvalue)]
488
489 assert len(rval_names) == len(lval_names)
490
491 for lhs, rhs in zip(lval_names, rval_names):
492 assert lhs, lval
493 if rhs:
494 # In this case rhe RHS is another variable. Record the
495 # assignment so we can keep track of aliases.
496 cfg.AddFact(self.current_statement_id,
497 pass_state.Assignment(lhs, rhs))
498 else:
499 # In this case the RHS is either some kind of literal (e.g.
500 # [] or 'foo') or a call to an object constructor. Mark this
501 # statement as an (re-)definition of a variable.
502 cfg.AddFact(
503 self.current_statement_id,
504 pass_state.Definition(
505 lhs, '$HeapObject(h{})'.format(self.heap_counter)),
506 )
507 self.heap_counter += 1
508
509 for lval in o.lvalues:
510 self.current_lval = lval
511 self.accept(lval)
512 self.current_lval = None
513
514 self.accept(o.rvalue)
515
516 # Expressions
517
518 def visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> T:
519 self.accept(o.expr)
520 cfg = self.current_cfg()
521 if (cfg and
522 not isinstance(self.dot_exprs[o], pass_state.ModuleMember) and
523 o != self.current_lval):
524 ref = self.get_ref_name(o)
525 if ref:
526 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
527
528 def visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> T:
529 cfg = self.current_cfg()
530 if cfg and o != self.current_lval:
531 is_local = False
532 for name, t in self.local_vars.get(self.current_func_node, []):
533 if name == o.name:
534 is_local = True
535 break
536
537 ref = self.get_ref_name(o)
538 if ref and is_local:
539 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
540
541 def visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> T:
542 cfg = self.current_cfg()
543 if self.current_func_node:
544 full_callee = self.resolve_callee(o)
545 if full_callee:
546 self.callees[o] = full_callee
547 cfg.AddFact(
548 self.current_statement_id,
549 pass_state.FunctionCall(join_name(full_callee, delim='.')))
550
551 for i, arg in enumerate(o.args):
552 arg_ref = self.get_ref_name(arg)
553 if arg_ref:
554 cfg.AddFact(self.current_statement_id,
555 pass_state.Bind(arg_ref, full_callee, i))
556
557 self.accept(o.callee)
558 for arg in o.args:
559 self.accept(arg)