OILS / mycpp / control_flow_pass.py View on Github | oils.pub

587 lines, 362 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5
6import mypy
7from mypy.nodes import (Block, Expression, Statement, CallExpr, FuncDef,
8 IfStmt, NameExpr, MemberExpr, IndexExpr, TupleExpr,
9 IntExpr)
10
11from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
12
13from mycpp.crash import catch_errors
14from mycpp.util import SymbolToString, SplitPyName
15from mycpp import visitor
16from mycpp import util
17from mycpp.util import SymbolPath
18from mycpp import pass_state
19
20from typing import Dict, List, Union, Optional, overload, TYPE_CHECKING
21
22if TYPE_CHECKING:
23 from mycpp import conversion_pass
24 from mycpp import cppgen_pass
25
26
27def GetObjectTypeName(t: Type) -> SymbolPath:
28 if isinstance(t, Instance):
29 return SplitPyName(t.type.fullname)
30
31 elif isinstance(t, UnionType):
32 assert len(t.items) == 2
33 if isinstance(t.items[0], NoneTyp):
34 return GetObjectTypeName(t.items[1])
35
36 return GetObjectTypeName(t.items[0])
37
38 assert False, t
39
40
41INVALID_ID = -99 # statement IDs are positive
42
43
44class Build(visitor.SimpleVisitor):
45
46 def __init__(self, types: Dict[Expression,
47 Type], virtual: pass_state.Virtual,
48 local_vars: 'cppgen_pass.AllLocalVars',
49 dot_exprs: 'conversion_pass.DotExprs') -> None:
50 visitor.SimpleVisitor.__init__(self)
51
52 self.types = types
53 self.cflow_graphs: Dict[
54 SymbolPath, pass_state.ControlFlowGraph] = collections.defaultdict(
55 pass_state.ControlFlowGraph)
56 self.current_statement_id = INVALID_ID
57 self.current_func_node: Optional[FuncDef] = None
58 self.loop_stack: List[pass_state.CfgLoopContext] = []
59 self.virtual = virtual
60 self.local_vars = local_vars
61 self.dot_exprs = dot_exprs
62 self.heap_counter = 0
63 # statement object -> SymbolPath of the callee
64 self.callees: Dict['mypy.nodes.CallExpr', SymbolPath] = {}
65 self.current_lval: Optional[Expression] = None
66
67 self.inside_switch = False
68 self.inside_loop = False
69
70 def current_cfg(self) -> pass_state.ControlFlowGraph:
71 if not self.current_func_node:
72 return None
73
74 return self.cflow_graphs[SplitPyName(self.current_func_node.fullname)]
75
76 def resolve_callee(
77 self, o: CallExpr, current_class_name: Optional[util.SymbolPath]
78 ) -> Optional[util.SymbolPath]:
79 """
80 Returns the fully qualified name of the callee in the given call
81 expression.
82
83 Member functions are prefixed by the names of the classes that contain
84 them. For example, the name of the callee in the last statement of the
85 snippet below is `module.SomeObject.Foo`.
86
87 x = module.SomeObject()
88 x.Foo()
89
90 Free-functions defined in the local module are referred to by their
91 normal fully qualified names. The function `foo` in a module called
92 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
93 in imported modules are named the same way.
94 """
95
96 if isinstance(o.callee, NameExpr):
97 return SplitPyName(o.callee.fullname)
98
99 elif isinstance(o.callee, MemberExpr):
100 if isinstance(o.callee.expr, NameExpr):
101 is_module = isinstance(self.dot_exprs.get(o.callee),
102 pass_state.ModuleMember)
103 if is_module:
104 return (SplitPyName(o.callee.expr.fullname) +
105 (o.callee.name, ))
106
107 elif o.callee.expr.name == 'self':
108 assert current_class_name
109 return current_class_name + (o.callee.name, )
110
111 else:
112 local_type = None
113 for name, t in self.local_vars.get(self.current_func_node,
114 []):
115 if name == o.callee.expr.name:
116 local_type = t
117 break
118
119 if local_type:
120 if isinstance(local_type, str):
121 return (SplitPyName(local_type) +
122 (o.callee.name, ))
123
124 elif isinstance(local_type, Instance):
125 return (SplitPyName(local_type.type.fullname) +
126 (o.callee.name, ))
127
128 elif isinstance(local_type, UnionType):
129 assert len(local_type.items) == 2
130 return (SplitPyName(
131 local_type.items[0].type.fullname) +
132 (o.callee.expr.name, ))
133
134 else:
135 assert not isinstance(local_type, CallableType)
136 # primitive type or string. don't care.
137 return None
138
139 else:
140 # context or exception handler. probably safe to ignore.
141 return None
142
143 else:
144 t = self.types.get(o.callee.expr)
145 if isinstance(t, Instance):
146 return SplitPyName(t.type.fullname) + (o.callee.name, )
147
148 elif isinstance(t, UnionType):
149 assert len(t.items) == 2
150 return (SplitPyName(t.items[0].type.fullname) +
151 (o.callee.name, ))
152
153 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
154 None):
155 return (SplitPyName(o.callee.expr.fullname) +
156 (o.callee.name, ))
157
158 else:
159 # constructors of things that we don't care about.
160 return None
161
162 # Don't currently get here
163 raise AssertionError()
164
165 def get_ref_name(self, expr: Expression) -> Optional[util.SymbolPath]:
166 """
167 To do dataflow analysis we need to track changes to objects, which
168 requires naming them. This function returns the name of the object
169 referred to by the given expression. If the expression doesn't refer to
170 an object or variable it returns None.
171
172 Objects are named slightly differently than they appear in the source
173 code.
174
175 Objects referenced by local variables are referred to by the name of the
176 local. For example, the name of the object in both statements below is
177 `x`.
178
179 x = module.SomeObject()
180 x = None
181
182 Member expressions are named after the parent object's type. For
183 example, the names of the objects in the member assignment statements
184 below are both `module.SomeObject.member_a`. This makes it possible to
185 track data flow across object members without having to track individual
186 heap objects, which would increase the search space for analyses and
187 slow things down.
188
189 x = module.SomeObject()
190 y = module.SomeObject()
191 x.member_a = 'foo'
192 y.member_a = 'bar'
193
194 Index expressions are named after their bases, for the same reasons as
195 member expressions. The coarse-grained precision should lead to an
196 over-approximation of where objects are in use, but should not miss any
197 references. This should be fine for our purposes. In the snippet below
198 the last two assignments are named `x` and `module.SomeObject.a_list`.
199
200 x = [None] # list[Thing]
201 y = module.SomeObject()
202 x[0] = Thing()
203 y.a_list[1] = Blah()
204
205 Index expressions over tuples are treated differently, though. Tuples
206 have a fixed size, tend to be small, and their elements have distinct
207 types. So, each element can be (and probably needs to be) individually
208 named. In the snippet below, the name of the RHS in the second
209 assignment is `t.0`.
210
211 t = (1, 2, 3, 4)
212 x = t[0]
213
214 The examples above all deal with assignments, but these rules apply to
215 any expression that uses an object or variable.
216 """
217 if isinstance(expr,
218 NameExpr) and expr.name not in {'True', 'False', 'None'}:
219 return (expr.name, )
220
221 elif isinstance(expr, MemberExpr):
222 dot_expr = self.dot_exprs[expr]
223 if isinstance(dot_expr, pass_state.ModuleMember):
224 return dot_expr.module_path + (dot_expr.member, )
225
226 elif isinstance(dot_expr, pass_state.HeapObjectMember):
227 obj_name = self.get_ref_name(dot_expr.object_expr)
228 if obj_name:
229 # XXX: add a new case like pass_state.ExpressionMember for
230 # cases when the LHS of . isn't a reference (e.g.
231 # builtin/assign_osh.py:54)
232 return obj_name + (dot_expr.member, )
233
234 elif isinstance(dot_expr, pass_state.StackObjectMember):
235 return (self.get_ref_name(dot_expr.object_expr) +
236 (dot_expr.member, ))
237
238 elif isinstance(expr, IndexExpr):
239 if isinstance(self.types[expr.base], TupleType):
240 assert isinstance(expr.index, IntExpr)
241 return self.get_ref_name(expr.base) + (str(expr.index.value), )
242
243 return self.get_ref_name(expr.base)
244
245 return None
246
247 #
248 # COPIED from IRBuilder
249 #
250
251 @overload
252 def accept(self, node: Expression) -> None:
253 ...
254
255 @overload
256 def accept(self, node: Statement) -> None:
257 ...
258
259 def accept(self, node: Union[Statement, Expression]) -> None:
260 with catch_errors(self.module_path, node.line):
261 if isinstance(node, Expression):
262 node.accept(self)
263 else:
264 cfg = self.current_cfg()
265 # Most statements have empty visitors because they don't
266 # require any special logic. Create statements for them
267 # here. Don't create statements from blocks to avoid
268 # stuttering.
269 if cfg and not isinstance(node, Block):
270 self.current_statement_id = cfg.AddStatement()
271
272 node.accept(self)
273
274 # Statements
275
276 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> None:
277 cfg = self.current_cfg()
278 was_inside_loop = self.inside_loop
279 self.inside_loop = True
280 with pass_state.CfgLoopContext(
281 cfg, entry=self.current_statement_id) as loop:
282 self.accept(o.expr)
283 self.loop_stack.append(loop)
284 self.accept(o.body)
285 self.loop_stack.pop()
286
287 # Let the outer loop toggle this off
288 if not was_inside_loop:
289 self.inside_loop = False
290
291 def _handle_switch(self, expr: Expression, o: 'mypy.nodes.WithStmt',
292 cfg: pass_state.ControlFlowGraph) -> None:
293 assert len(o.body.body) == 1, o.body.body
294 if_node = o.body.body[0]
295 assert isinstance(if_node, IfStmt), if_node
296 cases: util.CaseList = []
297 default_block = util.CollectSwitchCases(self.module_path, if_node,
298 cases)
299 with pass_state.CfgBranchContext(
300 cfg, self.current_statement_id) as branch_ctx:
301 for expr, body in cases:
302 self.accept(expr)
303 assert expr is not None, expr
304 with branch_ctx.AddBranch():
305 self.accept(body)
306
307 if not isinstance(default_block, int):
308 with branch_ctx.AddBranch():
309 self.accept(default_block)
310
311 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> None:
312 cfg = self.current_cfg()
313
314 assert len(o.expr) == 1, o.expr
315 expr = o.expr[0]
316
317 assert isinstance(expr, CallExpr), expr
318 self.accept(expr)
319
320 # Note: we have 'with alloc.ctx_SourceCode'
321 #assert isinstance(expr.callee, NameExpr), expr.callee
322 callee_name = expr.callee.name
323
324 was_inside_loop = self.inside_loop
325 if callee_name == 'switch':
326 self.inside_switch = True
327 self.inside_loop = False
328 self._handle_switch(expr, o, cfg)
329 elif callee_name == 'str_switch':
330 self.inside_switch = True
331 self.inside_loop = False
332 self._handle_switch(expr, o, cfg)
333 elif callee_name == 'tagswitch':
334 self.inside_switch = True
335 self.inside_loop = False
336 self._handle_switch(expr, o, cfg)
337 else:
338 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
339 self.accept(o.body)
340
341 self.inside_switch = False
342 # Restore if we were inside a loop
343 if was_inside_loop:
344 self.inside_loop = True
345
346 def oils_visit_func_def(self, o: 'mypy.nodes.FuncDef',
347 current_class_name: Optional[util.SymbolPath],
348 current_method_name: Optional[str]) -> None:
349 # For virtual methods, pretend that the method on the base class calls
350 # the same method on every subclass. This way call sites using the
351 # abstract base class will over-approximate the set of call paths they
352 # can take when checking if they can reach MaybeCollect().
353 if current_class_name and self.virtual.IsVirtual(
354 current_class_name, o.name):
355 key = (current_class_name, o.name)
356 base = self.virtual.virtuals[key]
357 if base:
358 sub = SymbolToString(current_class_name + (o.name, ),
359 delim='.')
360 base_key = base[0] + (base[1], )
361 cfg = self.cflow_graphs[base_key]
362 cfg.AddFact(0, pass_state.FunctionCall(sub))
363
364 self.current_func_node = o
365 cfg = self.current_cfg()
366 for arg in o.arguments:
367 cfg.AddFact(0,
368 pass_state.Definition((arg.variable.name, ), '$Empty'))
369
370 self.accept(o.body)
371 self.current_func_node = None
372 self.current_statement_id = INVALID_ID
373
374 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> None:
375 cfg = self.current_cfg()
376 was_inside_loop = self.inside_loop
377 self.inside_loop = True
378 with pass_state.CfgLoopContext(
379 cfg, entry=self.current_statement_id) as loop:
380 self.accept(o.expr)
381 self.loop_stack.append(loop)
382 self.accept(o.body)
383 self.loop_stack.pop()
384
385 # Let the outer loop toggle this off
386 if not was_inside_loop:
387 self.inside_loop = False
388
389 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> None:
390 cfg = self.current_cfg()
391 if cfg:
392 cfg.AddDeadend(self.current_statement_id)
393
394 if o.expr:
395 self.accept(o.expr)
396
397 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> None:
398 cfg = self.current_cfg()
399
400 if util.ShouldVisitIfExpr(o):
401 for expr in o.expr:
402 self.accept(expr)
403
404 with pass_state.CfgBranchContext(
405 cfg, self.current_statement_id) as branch_ctx:
406 if util.ShouldVisitIfBody(o):
407 with branch_ctx.AddBranch():
408 for node in o.body:
409 self.accept(node)
410
411 if util.ShouldVisitElseBody(o):
412 with branch_ctx.AddBranch():
413 self.accept(o.else_body)
414
415 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> None:
416 if self.inside_switch and not self.inside_loop:
417 # since it will break out of the loop in Python, but only leave
418 # the switch case in C++, # disallow break inside a switch
419 self.report_error(
420 o, "'break' is not allowed to be used inside a switch")
421 if len(self.loop_stack):
422 self.loop_stack[-1].AddBreak(self.current_statement_id)
423
424 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> None:
425 if len(self.loop_stack):
426 self.loop_stack[-1].AddContinue(self.current_statement_id)
427
428 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> None:
429 cfg = self.current_cfg()
430 if cfg:
431 cfg.AddDeadend(self.current_statement_id)
432
433 if o.expr:
434 self.accept(o.expr)
435
436 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> None:
437 cfg = self.current_cfg()
438 with pass_state.CfgBranchContext(cfg,
439 self.current_statement_id) as try_ctx:
440 with try_ctx.AddBranch() as try_block:
441 self.accept(o.body)
442
443 for t, v, handler in zip(o.types, o.vars, o.handlers):
444 with try_ctx.AddBranch(try_block.exit):
445 self.accept(handler)
446
447 # 2024-12(andy): SimpleVisitor now has a special case for:
448 #
449 # myvar = [x for x in other]
450 #
451 # This seems like it should affect the control flow graph, since we are no
452 # longer calling oils_visit_assignment_stmt, and are instead calling
453 # oils_visit_assign_to_listcomp.
454 #
455 # We may need more test coverage?
456 # List comprehensions are arguably a weird/legacy part of the mycpp IR that
457 # should be cleaned up.
458 #
459 # We do NOT allow:
460 #
461 # myfunc([x for x in other])
462 #
463 # See mycpp/visitor.py and mycpp/cppgen_pass.py for how this is used.
464
465 #def oils_visit_assign_to_listcomp(self, o: 'mypy.nodes.AssignmentStmt', lval: NameExpr) -> None:
466
467 def oils_visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt',
468 lval: Expression, rval: Expression,
469 current_method_name: Optional[str],
470 at_global_scope: bool) -> None:
471 cfg = self.current_cfg()
472 if cfg:
473 lval_names = []
474 if isinstance(lval, TupleExpr):
475 lval_names.extend(
476 [self.get_ref_name(item) for item in lval.items])
477
478 else:
479 lval_names.append(self.get_ref_name(lval))
480
481 assert len(lval_names), lval
482
483 rval_type = self.types[rval]
484 rval_names: List[Optional[util.SymbolPath]] = []
485 if isinstance(rval, CallExpr):
486 # The RHS is either an object constructor or something that
487 # returns a primitive type (e.g. Tuple[int, int] or str).
488 # XXX: When we add inter-procedural analysis we should treat
489 # these not as definitions but as some new kind of assignment.
490 rval_names = [None for _ in lval_names]
491
492 elif isinstance(rval, TupleExpr) and len(lval_names) == 1:
493 # We're constructing a tuple. Since tuples have have a fixed
494 # (and usually small) size, we can name each of the
495 # elements.
496 base = lval_names[0]
497 lval_names = [
498 base + (str(i), ) for i in range(len(rval.items))
499 ]
500 rval_names = [self.get_ref_name(item) for item in rval.items]
501
502 elif isinstance(rval_type, TupleType):
503 # We're unpacking a tuple. Like the tuple construction case,
504 # give each element a name.
505 rval_name = self.get_ref_name(rval)
506 assert rval_name, rval
507 rval_names = [
508 rval_name + (str(i), ) for i in range(len(lval_names))
509 ]
510
511 else:
512 rval_names = [self.get_ref_name(rval)]
513
514 assert len(rval_names) == len(lval_names)
515
516 for lhs, rhs in zip(lval_names, rval_names):
517 assert lhs, lval
518 if rhs:
519 # In this case rhe RHS is another variable. Record the
520 # assignment so we can keep track of aliases.
521 cfg.AddFact(self.current_statement_id,
522 pass_state.Assignment(lhs, rhs))
523 else:
524 # In this case the RHS is either some kind of literal (e.g.
525 # [] or 'foo') or a call to an object constructor. Mark this
526 # statement as an (re-)definition of a variable.
527 cfg.AddFact(
528 self.current_statement_id,
529 pass_state.Definition(
530 lhs, '$HeapObject(h{})'.format(self.heap_counter)),
531 )
532 self.heap_counter += 1
533
534 # TODO: Could simplify this
535 self.current_lval = lval
536 self.accept(lval)
537 self.current_lval = None
538
539 self.accept(rval)
540
541 # Expressions
542
543 def oils_visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> None:
544 self.accept(o.expr)
545 cfg = self.current_cfg()
546 if (cfg and
547 not isinstance(self.dot_exprs[o], pass_state.ModuleMember) and
548 o != self.current_lval):
549 ref = self.get_ref_name(o)
550 if ref:
551 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
552
553 def oils_visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> None:
554 cfg = self.current_cfg()
555 if cfg and o != self.current_lval:
556 is_local = False
557 for name, t in self.local_vars.get(self.current_func_node, []):
558 if name == o.name:
559 is_local = True
560 break
561
562 ref = self.get_ref_name(o)
563 if ref and is_local:
564 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
565
566 def oils_visit_call_expr(
567 self, o: 'mypy.nodes.CallExpr',
568 current_class_name: Optional[util.SymbolPath]) -> None:
569 cfg = self.current_cfg()
570 if self.current_func_node:
571 full_callee = self.resolve_callee(o, current_class_name)
572 if full_callee:
573 self.callees[o] = full_callee
574 cfg.AddFact(
575 self.current_statement_id,
576 pass_state.FunctionCall(
577 SymbolToString(full_callee, delim='.')))
578
579 for i, arg in enumerate(o.args):
580 arg_ref = self.get_ref_name(arg)
581 if arg_ref:
582 cfg.AddFact(self.current_statement_id,
583 pass_state.Bind(arg_ref, full_callee, i))
584
585 self.accept(o.callee)
586 for arg in o.args:
587 self.accept(arg)