OILS / mycpp / control_flow_pass.py View on Github | oils.pub

586 lines, 360 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5
6import mypy
7from mypy.nodes import (Block, Expression, Statement, CallExpr, FuncDef,
8 IfStmt, NameExpr, MemberExpr, IndexExpr, TupleExpr,
9 IntExpr)
10
11from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
12
13from mycpp.crash import catch_errors
14from mycpp.util import SymbolToString, SplitPyName
15from mycpp import visitor
16from mycpp import util
17from mycpp.util import SymbolPath
18from mycpp import pass_state
19
20from typing import Dict, List, Union, Optional, overload, TYPE_CHECKING
21
22if TYPE_CHECKING:
23 from mycpp import conversion_pass
24 from mycpp import cppgen_pass
25
26
27def GetObjectTypeName(t: Type) -> SymbolPath:
28 if isinstance(t, Instance):
29 return SplitPyName(t.type.fullname)
30
31 elif isinstance(t, UnionType):
32 assert len(t.items) == 2
33 if isinstance(t.items[0], NoneTyp):
34 return GetObjectTypeName(t.items[1])
35
36 return GetObjectTypeName(t.items[0])
37
38 assert False, t
39
40
41INVALID_ID = -99 # statement IDs are positive
42
43
44class Build(visitor.SimpleVisitor):
45
46 def __init__(self, types: Dict[Expression,
47 Type], virtual: pass_state.Virtual,
48 local_vars: 'cppgen_pass.AllLocalVars',
49 dot_exprs: 'conversion_pass.DotExprs') -> None:
50 visitor.SimpleVisitor.__init__(self)
51
52 self.types = types
53 self.cflow_graphs: Dict[
54 SymbolPath, pass_state.ControlFlowGraph] = collections.defaultdict(
55 pass_state.ControlFlowGraph)
56 self.current_statement_id = INVALID_ID
57 self.current_func_node: Optional[FuncDef] = None
58 self.loop_stack: List[pass_state.CfgLoopContext] = []
59 self.virtual = virtual
60 self.local_vars = local_vars
61 self.dot_exprs = dot_exprs
62 self.heap_counter = 0
63 # statement object -> SymbolPath of the callee
64 self.callees: Dict['mypy.nodes.CallExpr', SymbolPath] = {}
65 self.current_lval: Optional[Expression] = None
66
67 self.inside_switch = False
68 self.inside_loop = False
69
70 def current_cfg(self) -> pass_state.ControlFlowGraph:
71 if not self.current_func_node:
72 return None
73
74 return self.cflow_graphs[SplitPyName(self.current_func_node.fullname)]
75
76 def resolve_callee(
77 self, o: CallExpr, current_class_name: Optional[util.SymbolPath]
78 ) -> Optional[util.SymbolPath]:
79 """
80 Returns the fully qualified name of the callee in the given call
81 expression.
82
83 Member functions are prefixed by the names of the classes that contain
84 them. For example, the name of the callee in the last statement of the
85 snippet below is `module.SomeObject.Foo`.
86
87 x = module.SomeObject()
88 x.Foo()
89
90 Free-functions defined in the local module are referred to by their
91 normal fully qualified names. The function `foo` in a module called
92 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
93 in imported modules are named the same way.
94 """
95
96 if isinstance(o.callee, NameExpr):
97 return SplitPyName(o.callee.fullname)
98
99 elif isinstance(o.callee, MemberExpr):
100 if isinstance(o.callee.expr, NameExpr):
101 is_module = isinstance(self.dot_exprs.get(o.callee),
102 pass_state.ModuleMember)
103 if is_module:
104 return (SplitPyName(o.callee.expr.fullname) +
105 (o.callee.name, ))
106
107 elif o.callee.expr.name == 'self':
108 assert current_class_name
109 return current_class_name + (o.callee.name, )
110
111 else:
112 local_type = None
113 for name, t in self.local_vars.get(self.current_func_node,
114 []):
115 if name == o.callee.expr.name:
116 local_type = t
117 break
118
119 if local_type:
120 if isinstance(local_type, str):
121 return (SplitPyName(local_type) +
122 (o.callee.name, ))
123
124 elif isinstance(local_type, Instance):
125 return (SplitPyName(local_type.type.fullname) +
126 (o.callee.name, ))
127
128 elif isinstance(local_type, UnionType):
129 assert len(local_type.items) == 2
130 return (SplitPyName(
131 local_type.items[0].type.fullname) +
132 (o.callee.expr.name, ))
133
134 else:
135 assert not isinstance(local_type, CallableType)
136 # primitive type or string. don't care.
137 return None
138
139 else:
140 # context or exception handler. probably safe to ignore.
141 return None
142
143 else:
144 t = self.types.get(o.callee.expr)
145 if isinstance(t, Instance):
146 return SplitPyName(t.type.fullname) + (o.callee.name, )
147
148 elif isinstance(t, UnionType):
149 assert len(t.items) == 2
150 return (SplitPyName(t.items[0].type.fullname) +
151 (o.callee.name, ))
152
153 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
154 None):
155 return (SplitPyName(o.callee.expr.fullname) +
156 (o.callee.name, ))
157
158 else:
159 # constructors of things that we don't care about.
160 return None
161
162 # Don't currently get here
163 raise AssertionError()
164
165 def get_ref_name(self, expr: Expression) -> Optional[util.SymbolPath]:
166 """
167 To do dataflow analysis we need to track changes to objects, which
168 requires naming them. This function returns the name of the object
169 referred to by the given expression. If the expression doesn't refer to
170 an object or variable it returns None.
171
172 Objects are named slightly differently than they appear in the source
173 code.
174
175 Objects referenced by local variables are referred to by the name of the
176 local. For example, the name of the object in both statements below is
177 `x`.
178
179 x = module.SomeObject()
180 x = None
181
182 Member expressions are named after the parent object's type. For
183 example, the names of the objects in the member assignment statements
184 below are both `module.SomeObject.member_a`. This makes it possible to
185 track data flow across object members without having to track individual
186 heap objects, which would increase the search space for analyses and
187 slow things down.
188
189 x = module.SomeObject()
190 y = module.SomeObject()
191 x.member_a = 'foo'
192 y.member_a = 'bar'
193
194 Index expressions are named after their bases, for the same reasons as
195 member expressions. The coarse-grained precision should lead to an
196 over-approximation of where objects are in use, but should not miss any
197 references. This should be fine for our purposes. In the snippet below
198 the last two assignments are named `x` and `module.SomeObject.a_list`.
199
200 x = [None] # list[Thing]
201 y = module.SomeObject()
202 x[0] = Thing()
203 y.a_list[1] = Blah()
204
205 Index expressions over tuples are treated differently, though. Tuples
206 have a fixed size, tend to be small, and their elements have distinct
207 types. So, each element can be (and probably needs to be) individually
208 named. In the snippet below, the name of the RHS in the second
209 assignment is `t.0`.
210
211 t = (1, 2, 3, 4)
212 x = t[0]
213
214 The examples above all deal with assignments, but these rules apply to
215 any expression that uses an object or variable.
216 """
217 if isinstance(expr,
218 NameExpr) and expr.name not in {'True', 'False', 'None'}:
219 return (expr.name, )
220
221 elif isinstance(expr, MemberExpr):
222 dot_expr = self.dot_exprs[expr]
223 if isinstance(dot_expr, pass_state.ModuleMember):
224 return dot_expr.module_path + (dot_expr.member, )
225
226 elif isinstance(dot_expr, pass_state.HeapObjectMember):
227 obj_name = self.get_ref_name(dot_expr.object_expr)
228 if obj_name:
229 # XXX: add a new case like pass_state.ExpressionMember for
230 # cases when the LHS of . isn't a reference (e.g.
231 # builtin/assign_osh.py:54)
232 return obj_name + (dot_expr.member, )
233
234 elif isinstance(dot_expr, pass_state.StackObjectMember):
235 return (self.get_ref_name(dot_expr.object_expr) +
236 (dot_expr.member, ))
237
238 elif isinstance(expr, IndexExpr):
239 if isinstance(self.types[expr.base], TupleType):
240 assert isinstance(expr.index, IntExpr)
241 return self.get_ref_name(expr.base) + (str(expr.index.value), )
242
243 return self.get_ref_name(expr.base)
244
245 return None
246
247 #
248 # COPIED from IRBuilder
249 #
250
251 @overload
252 def accept(self, node: Expression) -> None:
253 ...
254
255 @overload
256 def accept(self, node: Statement) -> None:
257 ...
258
259 def accept(self, node: Union[Statement, Expression]) -> None:
260 with catch_errors(self.module_path, node.line):
261 if isinstance(node, Expression):
262 node.accept(self)
263 else:
264 cfg = self.current_cfg()
265 # Most statements have empty visitors because they don't
266 # require any special logic. Create statements for them
267 # here. Don't create statements from blocks to avoid
268 # stuttering.
269 if cfg and not isinstance(node, Block):
270 self.current_statement_id = cfg.AddStatement()
271
272 node.accept(self)
273
274 # Statements
275
276 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> None:
277 cfg = self.current_cfg()
278 was_inside_loop = self.inside_loop
279 self.inside_loop = True
280 with pass_state.CfgLoopContext(
281 cfg, entry=self.current_statement_id) as loop:
282 self.accept(o.expr)
283 self.loop_stack.append(loop)
284 self.accept(o.body)
285 self.loop_stack.pop()
286
287 # Let the outer loop toggle this off
288 if not was_inside_loop:
289 self.inside_loop = False
290
291 def _handle_switch(self, expr: Expression, o: 'mypy.nodes.WithStmt',
292 cfg: pass_state.ControlFlowGraph) -> None:
293 assert len(o.body.body) == 1, o.body.body
294 if_node = o.body.body[0]
295 assert isinstance(if_node, IfStmt), if_node
296 cases: util.CaseList = []
297 default_block = util.CollectSwitchCases(self.module_path, if_node,
298 cases)
299 with pass_state.CfgBranchContext(
300 cfg, self.current_statement_id) as branch_ctx:
301 for expr, body in cases:
302 self.accept(expr)
303 assert expr is not None, expr
304 with branch_ctx.AddBranch():
305 self.accept(body)
306
307 if not isinstance(default_block, int):
308 with branch_ctx.AddBranch():
309 self.accept(default_block)
310
311 def oils_visit_switch(self, expr: 'mypy.nodes.CallExpr',
312 o: 'mypy.nodes.WithStmt',
313 switch_type: str) -> None:
314 """Build control flow graph for switch statements."""
315 cfg = self.current_cfg()
316 was_inside_loop = self.inside_loop
317 self.inside_switch = True
318 self.inside_loop = False
319 self._handle_switch(expr, o, cfg)
320 self.inside_switch = False
321 # Restore if we were inside a loop
322 if was_inside_loop:
323 self.inside_loop = True
324
325 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> None:
326 cfg = self.current_cfg()
327
328 assert len(o.expr) == 1, o.expr
329 expr = o.expr[0]
330
331 assert isinstance(expr, CallExpr), expr
332 self.accept(expr)
333
334 # Note: we have 'with alloc.ctx_SourceCode'
335 #assert isinstance(expr.callee, NameExpr), expr.callee
336 if isinstance(expr.callee, NameExpr):
337 callee_name = expr.callee.name
338 if callee_name in ('switch', 'str_switch', 'tagswitch'):
339 self.oils_visit_switch(expr, o, callee_name)
340 return
341
342 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
343 self.accept(o.body)
344
345 def oils_visit_func_def(self, o: 'mypy.nodes.FuncDef',
346 current_class_name: Optional[util.SymbolPath],
347 current_method_name: Optional[str]) -> None:
348 # For virtual methods, pretend that the method on the base class calls
349 # the same method on every subclass. This way call sites using the
350 # abstract base class will over-approximate the set of call paths they
351 # can take when checking if they can reach MaybeCollect().
352 if current_class_name and self.virtual.IsVirtual(
353 current_class_name, o.name):
354 key = (current_class_name, o.name)
355 base = self.virtual.virtuals[key]
356 if base:
357 sub = SymbolToString(current_class_name + (o.name, ),
358 delim='.')
359 base_key = base[0] + (base[1], )
360 cfg = self.cflow_graphs[base_key]
361 cfg.AddFact(0, pass_state.FunctionCall(sub))
362
363 self.current_func_node = o
364 cfg = self.current_cfg()
365 for arg in o.arguments:
366 cfg.AddFact(0,
367 pass_state.Definition((arg.variable.name, ), '$Empty'))
368
369 self.accept(o.body)
370 self.current_func_node = None
371 self.current_statement_id = INVALID_ID
372
373 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> None:
374 cfg = self.current_cfg()
375 was_inside_loop = self.inside_loop
376 self.inside_loop = True
377 with pass_state.CfgLoopContext(
378 cfg, entry=self.current_statement_id) as loop:
379 self.accept(o.expr)
380 self.loop_stack.append(loop)
381 self.accept(o.body)
382 self.loop_stack.pop()
383
384 # Let the outer loop toggle this off
385 if not was_inside_loop:
386 self.inside_loop = False
387
388 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> None:
389 cfg = self.current_cfg()
390 if cfg:
391 cfg.AddDeadend(self.current_statement_id)
392
393 if o.expr:
394 self.accept(o.expr)
395
396 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> None:
397 cfg = self.current_cfg()
398
399 if util.ShouldVisitIfExpr(o):
400 for expr in o.expr:
401 self.accept(expr)
402
403 with pass_state.CfgBranchContext(
404 cfg, self.current_statement_id) as branch_ctx:
405 if util.ShouldVisitIfBody(o):
406 with branch_ctx.AddBranch():
407 for node in o.body:
408 self.accept(node)
409
410 if util.ShouldVisitElseBody(o):
411 with branch_ctx.AddBranch():
412 self.accept(o.else_body)
413
414 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> None:
415 if self.inside_switch and not self.inside_loop:
416 # since it will break out of the loop in Python, but only leave
417 # the switch case in C++, # disallow break inside a switch
418 self.report_error(
419 o, "'break' is not allowed to be used inside a switch")
420 if len(self.loop_stack):
421 self.loop_stack[-1].AddBreak(self.current_statement_id)
422
423 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> None:
424 if len(self.loop_stack):
425 self.loop_stack[-1].AddContinue(self.current_statement_id)
426
427 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> None:
428 cfg = self.current_cfg()
429 if cfg:
430 cfg.AddDeadend(self.current_statement_id)
431
432 if o.expr:
433 self.accept(o.expr)
434
435 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> None:
436 cfg = self.current_cfg()
437 with pass_state.CfgBranchContext(cfg,
438 self.current_statement_id) as try_ctx:
439 with try_ctx.AddBranch() as try_block:
440 self.accept(o.body)
441
442 for t, v, handler in zip(o.types, o.vars, o.handlers):
443 with try_ctx.AddBranch(try_block.exit):
444 self.accept(handler)
445
446 # 2024-12(andy): SimpleVisitor now has a special case for:
447 #
448 # myvar = [x for x in other]
449 #
450 # This seems like it should affect the control flow graph, since we are no
451 # longer calling oils_visit_assignment_stmt, and are instead calling
452 # oils_visit_assign_to_listcomp.
453 #
454 # We may need more test coverage?
455 # List comprehensions are arguably a weird/legacy part of the mycpp IR that
456 # should be cleaned up.
457 #
458 # We do NOT allow:
459 #
460 # myfunc([x for x in other])
461 #
462 # See mycpp/visitor.py and mycpp/cppgen_pass.py for how this is used.
463
464 #def oils_visit_assign_to_listcomp(self, o: 'mypy.nodes.AssignmentStmt', lval: NameExpr) -> None:
465
466 def oils_visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt',
467 lval: Expression, rval: Expression,
468 current_method_name: Optional[str],
469 at_global_scope: bool) -> None:
470 cfg = self.current_cfg()
471 if cfg:
472 lval_names = []
473 if isinstance(lval, TupleExpr):
474 lval_names.extend(
475 [self.get_ref_name(item) for item in lval.items])
476
477 else:
478 lval_names.append(self.get_ref_name(lval))
479
480 assert len(lval_names), lval
481
482 rval_type = self.types[rval]
483 rval_names: List[Optional[util.SymbolPath]] = []
484 if isinstance(rval, CallExpr):
485 # The RHS is either an object constructor or something that
486 # returns a primitive type (e.g. Tuple[int, int] or str).
487 # XXX: When we add inter-procedural analysis we should treat
488 # these not as definitions but as some new kind of assignment.
489 rval_names = [None for _ in lval_names]
490
491 elif isinstance(rval, TupleExpr) and len(lval_names) == 1:
492 # We're constructing a tuple. Since tuples have have a fixed
493 # (and usually small) size, we can name each of the
494 # elements.
495 base = lval_names[0]
496 lval_names = [
497 base + (str(i), ) for i in range(len(rval.items))
498 ]
499 rval_names = [self.get_ref_name(item) for item in rval.items]
500
501 elif isinstance(rval_type, TupleType):
502 # We're unpacking a tuple. Like the tuple construction case,
503 # give each element a name.
504 rval_name = self.get_ref_name(rval)
505 assert rval_name, rval
506 rval_names = [
507 rval_name + (str(i), ) for i in range(len(lval_names))
508 ]
509
510 else:
511 rval_names = [self.get_ref_name(rval)]
512
513 assert len(rval_names) == len(lval_names)
514
515 for lhs, rhs in zip(lval_names, rval_names):
516 assert lhs, lval
517 if rhs:
518 # In this case rhe RHS is another variable. Record the
519 # assignment so we can keep track of aliases.
520 cfg.AddFact(self.current_statement_id,
521 pass_state.Assignment(lhs, rhs))
522 else:
523 # In this case the RHS is either some kind of literal (e.g.
524 # [] or 'foo') or a call to an object constructor. Mark this
525 # statement as an (re-)definition of a variable.
526 cfg.AddFact(
527 self.current_statement_id,
528 pass_state.Definition(
529 lhs, '$HeapObject(h{})'.format(self.heap_counter)),
530 )
531 self.heap_counter += 1
532
533 # TODO: Could simplify this
534 self.current_lval = lval
535 self.accept(lval)
536 self.current_lval = None
537
538 self.accept(rval)
539
540 # Expressions
541
542 def oils_visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> None:
543 self.accept(o.expr)
544 cfg = self.current_cfg()
545 if (cfg and
546 not isinstance(self.dot_exprs[o], pass_state.ModuleMember) and
547 o != self.current_lval):
548 ref = self.get_ref_name(o)
549 if ref:
550 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
551
552 def oils_visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> None:
553 cfg = self.current_cfg()
554 if cfg and o != self.current_lval:
555 is_local = False
556 for name, t in self.local_vars.get(self.current_func_node, []):
557 if name == o.name:
558 is_local = True
559 break
560
561 ref = self.get_ref_name(o)
562 if ref and is_local:
563 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
564
565 def oils_visit_call_expr(
566 self, o: 'mypy.nodes.CallExpr',
567 current_class_name: Optional[util.SymbolPath]) -> None:
568 cfg = self.current_cfg()
569 if self.current_func_node:
570 full_callee = self.resolve_callee(o, current_class_name)
571 if full_callee:
572 self.callees[o] = full_callee
573 cfg.AddFact(
574 self.current_statement_id,
575 pass_state.FunctionCall(
576 SymbolToString(full_callee, delim='.')))
577
578 for i, arg in enumerate(o.args):
579 arg_ref = self.get_ref_name(arg)
580 if arg_ref:
581 cfg.AddFact(self.current_statement_id,
582 pass_state.Bind(arg_ref, full_callee, i))
583
584 self.accept(o.callee)
585 for arg in o.args:
586 self.accept(arg)