OILS / mycpp / control_flow_pass.py View on Github | oils.pub

586 lines, 361 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5
6import mypy
7from mypy.nodes import (Block, Expression, Statement, CallExpr, FuncDef,
8 IfStmt, NameExpr, MemberExpr, IndexExpr, TupleExpr,
9 IntExpr)
10
11from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
12
13from mycpp.crash import catch_errors
14from mycpp.util import SymbolToString, SplitPyName
15from mycpp import visitor
16from mycpp import util
17from mycpp.util import SymbolPath
18from mycpp import pass_state
19
20from typing import Dict, List, Union, Optional, overload, TYPE_CHECKING
21
22if TYPE_CHECKING:
23 from mycpp import conversion_pass
24 from mycpp import cppgen_pass
25
26
27def GetObjectTypeName(t: Type) -> SymbolPath:
28 if isinstance(t, Instance):
29 return SplitPyName(t.type.fullname)
30
31 elif isinstance(t, UnionType):
32 assert len(t.items) == 2
33 if isinstance(t.items[0], NoneTyp):
34 return GetObjectTypeName(t.items[1])
35
36 return GetObjectTypeName(t.items[0])
37
38 assert False, t
39
40
41INVALID_ID = -99 # statement IDs are positive
42
43
44class Build(visitor.TypedVisitor):
45
46 def __init__(self, types: Dict[Expression,
47 Type], virtual: pass_state.Virtual,
48 local_vars: 'cppgen_pass.AllLocalVars',
49 dot_exprs: 'conversion_pass.DotExprs') -> None:
50 visitor.TypedVisitor.__init__(self, types)
51
52 self.cflow_graphs: Dict[
53 SymbolPath, pass_state.ControlFlowGraph] = collections.defaultdict(
54 pass_state.ControlFlowGraph)
55 self.current_statement_id = INVALID_ID
56 self.current_func_node: Optional[FuncDef] = None
57 self.loop_stack: List[pass_state.CfgLoopContext] = []
58 self.virtual = virtual
59 self.local_vars = local_vars
60 self.dot_exprs = dot_exprs
61 self.heap_counter = 0
62 # statement object -> SymbolPath of the callee
63 self.callees: Dict['mypy.nodes.CallExpr', SymbolPath] = {}
64 self.current_lval: Optional[Expression] = None
65
66 self.inside_switch = False
67 self.inside_loop = False
68
69 def current_cfg(self) -> pass_state.ControlFlowGraph:
70 if not self.current_func_node:
71 return None
72
73 return self.cflow_graphs[SplitPyName(self.current_func_node.fullname)]
74
75 def resolve_callee(
76 self, o: CallExpr, current_class_name: Optional[util.SymbolPath]
77 ) -> Optional[util.SymbolPath]:
78 """
79 Returns the fully qualified name of the callee in the given call
80 expression.
81
82 Member functions are prefixed by the names of the classes that contain
83 them. For example, the name of the callee in the last statement of the
84 snippet below is `module.SomeObject.Foo`.
85
86 x = module.SomeObject()
87 x.Foo()
88
89 Free-functions defined in the local module are referred to by their
90 normal fully qualified names. The function `foo` in a module called
91 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
92 in imported modules are named the same way.
93 """
94
95 if isinstance(o.callee, NameExpr):
96 return SplitPyName(o.callee.fullname)
97
98 elif isinstance(o.callee, MemberExpr):
99 if isinstance(o.callee.expr, NameExpr):
100 is_module = isinstance(self.dot_exprs.get(o.callee),
101 pass_state.ModuleMember)
102 if is_module:
103 return (SplitPyName(o.callee.expr.fullname) +
104 (o.callee.name, ))
105
106 elif o.callee.expr.name == 'self':
107 assert current_class_name
108 return current_class_name + (o.callee.name, )
109
110 else:
111 local_type = None
112 for name, t in self.local_vars.get(self.current_func_node,
113 []):
114 if name == o.callee.expr.name:
115 local_type = t
116 break
117
118 if local_type:
119 if isinstance(local_type, str):
120 return (SplitPyName(local_type) +
121 (o.callee.name, ))
122
123 elif isinstance(local_type, Instance):
124 return (SplitPyName(local_type.type.fullname) +
125 (o.callee.name, ))
126
127 elif isinstance(local_type, UnionType):
128 assert len(local_type.items) == 2
129 return (SplitPyName(
130 local_type.items[0].type.fullname) +
131 (o.callee.expr.name, ))
132
133 else:
134 assert not isinstance(local_type, CallableType)
135 # primitive type or string. don't care.
136 return None
137
138 else:
139 # context or exception handler. probably safe to ignore.
140 return None
141
142 else:
143 t = self._GetType(o.callee.expr)
144 if isinstance(t, Instance):
145 return SplitPyName(t.type.fullname) + (o.callee.name, )
146
147 elif isinstance(t, UnionType):
148 assert len(t.items) == 2
149 return (SplitPyName(t.items[0].type.fullname) +
150 (o.callee.name, ))
151
152 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
153 None):
154 return (SplitPyName(o.callee.expr.fullname) +
155 (o.callee.name, ))
156
157 else:
158 # constructors of things that we don't care about.
159 return None
160
161 # Don't currently get here
162 raise AssertionError()
163
164 def get_ref_name(self, expr: Expression) -> Optional[util.SymbolPath]:
165 """
166 To do dataflow analysis we need to track changes to objects, which
167 requires naming them. This function returns the name of the object
168 referred to by the given expression. If the expression doesn't refer to
169 an object or variable it returns None.
170
171 Objects are named slightly differently than they appear in the source
172 code.
173
174 Objects referenced by local variables are referred to by the name of the
175 local. For example, the name of the object in both statements below is
176 `x`.
177
178 x = module.SomeObject()
179 x = None
180
181 Member expressions are named after the parent object's type. For
182 example, the names of the objects in the member assignment statements
183 below are both `module.SomeObject.member_a`. This makes it possible to
184 track data flow across object members without having to track individual
185 heap objects, which would increase the search space for analyses and
186 slow things down.
187
188 x = module.SomeObject()
189 y = module.SomeObject()
190 x.member_a = 'foo'
191 y.member_a = 'bar'
192
193 Index expressions are named after their bases, for the same reasons as
194 member expressions. The coarse-grained precision should lead to an
195 over-approximation of where objects are in use, but should not miss any
196 references. This should be fine for our purposes. In the snippet below
197 the last two assignments are named `x` and `module.SomeObject.a_list`.
198
199 x = [None] # list[Thing]
200 y = module.SomeObject()
201 x[0] = Thing()
202 y.a_list[1] = Blah()
203
204 Index expressions over tuples are treated differently, though. Tuples
205 have a fixed size, tend to be small, and their elements have distinct
206 types. So, each element can be (and probably needs to be) individually
207 named. In the snippet below, the name of the RHS in the second
208 assignment is `t.0`.
209
210 t = (1, 2, 3, 4)
211 x = t[0]
212
213 The examples above all deal with assignments, but these rules apply to
214 any expression that uses an object or variable.
215 """
216 if isinstance(expr,
217 NameExpr) and expr.name not in {'True', 'False', 'None'}:
218 return (expr.name, )
219
220 elif isinstance(expr, MemberExpr):
221 dot_expr = self.dot_exprs[expr]
222 if isinstance(dot_expr, pass_state.ModuleMember):
223 return dot_expr.module_path + (dot_expr.member, )
224
225 elif isinstance(dot_expr, pass_state.HeapObjectMember):
226 obj_name = self.get_ref_name(dot_expr.object_expr)
227 if obj_name:
228 # XXX: add a new case like pass_state.ExpressionMember for
229 # cases when the LHS of . isn't a reference (e.g.
230 # builtin/assign_osh.py:54)
231 return obj_name + (dot_expr.member, )
232
233 elif isinstance(dot_expr, pass_state.StackObjectMember):
234 return (self.get_ref_name(dot_expr.object_expr) +
235 (dot_expr.member, ))
236
237 elif isinstance(expr, IndexExpr):
238 if isinstance(self._GetType(expr.base), TupleType):
239 assert isinstance(expr.index, IntExpr)
240 return self.get_ref_name(expr.base) + (str(expr.index.value), )
241
242 return self.get_ref_name(expr.base)
243
244 return None
245
246 #
247 # COPIED from IRBuilder
248 #
249
250 @overload
251 def accept(self, node: Expression) -> None:
252 ...
253
254 @overload
255 def accept(self, node: Statement) -> None:
256 ...
257
258 def accept(self, node: Union[Statement, Expression]) -> None:
259 with catch_errors(self.module_path, node.line):
260 if isinstance(node, Expression):
261 node.accept(self)
262 else:
263 cfg = self.current_cfg()
264 # Most statements have empty visitors because they don't
265 # require any special logic. Create statements for them
266 # here. Don't create statements from blocks to avoid
267 # stuttering.
268 if cfg and not isinstance(node, Block):
269 self.current_statement_id = cfg.AddStatement()
270
271 node.accept(self)
272
273 # Statements
274
275 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> None:
276 cfg = self.current_cfg()
277 was_inside_loop = self.inside_loop
278 self.inside_loop = True
279 with pass_state.CfgLoopContext(
280 cfg, entry=self.current_statement_id) as loop:
281 self.accept(o.expr)
282 self.loop_stack.append(loop)
283 self.accept(o.body)
284 self.loop_stack.pop()
285
286 # Let the outer loop toggle this off
287 if not was_inside_loop:
288 self.inside_loop = False
289
290 def _handle_switch(self, expr: Expression, o: 'mypy.nodes.WithStmt',
291 cfg: pass_state.ControlFlowGraph) -> None:
292 assert len(o.body.body) == 1, o.body.body
293 if_node = o.body.body[0]
294 assert isinstance(if_node, IfStmt), if_node
295 cases: util.CaseList = []
296 default_block = util.CollectSwitchCases(self.module_path, if_node,
297 cases)
298 with pass_state.CfgBranchContext(
299 cfg, self.current_statement_id) as branch_ctx:
300 for expr, body in cases:
301 self.accept(expr)
302 assert expr is not None, expr
303 with branch_ctx.AddBranch():
304 self.accept(body)
305
306 if not isinstance(default_block, int):
307 with branch_ctx.AddBranch():
308 self.accept(default_block)
309
310 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> None:
311 cfg = self.current_cfg()
312
313 assert len(o.expr) == 1, o.expr
314 expr = o.expr[0]
315
316 assert isinstance(expr, CallExpr), expr
317 self.accept(expr)
318
319 # Note: we have 'with alloc.ctx_SourceCode'
320 #assert isinstance(expr.callee, NameExpr), expr.callee
321 callee_name = expr.callee.name
322
323 was_inside_loop = self.inside_loop
324 if callee_name == 'switch':
325 self.inside_switch = True
326 self.inside_loop = False
327 self._handle_switch(expr, o, cfg)
328 elif callee_name == 'str_switch':
329 self.inside_switch = True
330 self.inside_loop = False
331 self._handle_switch(expr, o, cfg)
332 elif callee_name == 'tagswitch':
333 self.inside_switch = True
334 self.inside_loop = False
335 self._handle_switch(expr, o, cfg)
336 else:
337 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
338 self.accept(o.body)
339
340 self.inside_switch = False
341 # Restore if we were inside a loop
342 if was_inside_loop:
343 self.inside_loop = True
344
345 def oils_visit_func_def(self, o: 'mypy.nodes.FuncDef',
346 current_class_name: Optional[util.SymbolPath],
347 current_method_name: Optional[str]) -> None:
348 # For virtual methods, pretend that the method on the base class calls
349 # the same method on every subclass. This way call sites using the
350 # abstract base class will over-approximate the set of call paths they
351 # can take when checking if they can reach MaybeCollect().
352 if current_class_name and self.virtual.IsVirtual(
353 current_class_name, o.name):
354 key = (current_class_name, o.name)
355 base = self.virtual.virtuals[key]
356 if base:
357 sub = SymbolToString(current_class_name + (o.name, ),
358 delim='.')
359 base_key = base[0] + (base[1], )
360 cfg = self.cflow_graphs[base_key]
361 cfg.AddFact(0, pass_state.FunctionCall(sub))
362
363 self.current_func_node = o
364 cfg = self.current_cfg()
365 for arg in o.arguments:
366 cfg.AddFact(0,
367 pass_state.Definition((arg.variable.name, ), '$Empty'))
368
369 self.accept(o.body)
370 self.current_func_node = None
371 self.current_statement_id = INVALID_ID
372
373 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> None:
374 cfg = self.current_cfg()
375 was_inside_loop = self.inside_loop
376 self.inside_loop = True
377 with pass_state.CfgLoopContext(
378 cfg, entry=self.current_statement_id) as loop:
379 self.accept(o.expr)
380 self.loop_stack.append(loop)
381 self.accept(o.body)
382 self.loop_stack.pop()
383
384 # Let the outer loop toggle this off
385 if not was_inside_loop:
386 self.inside_loop = False
387
388 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> None:
389 cfg = self.current_cfg()
390 if cfg:
391 cfg.AddDeadend(self.current_statement_id)
392
393 if o.expr:
394 self.accept(o.expr)
395
396 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> None:
397 cfg = self.current_cfg()
398
399 if util.ShouldVisitIfExpr(o):
400 for expr in o.expr:
401 self.accept(expr)
402
403 with pass_state.CfgBranchContext(
404 cfg, self.current_statement_id) as branch_ctx:
405 if util.ShouldVisitIfBody(o):
406 with branch_ctx.AddBranch():
407 for node in o.body:
408 self.accept(node)
409
410 if util.ShouldVisitElseBody(o):
411 with branch_ctx.AddBranch():
412 self.accept(o.else_body)
413
414 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> None:
415 if self.inside_switch and not self.inside_loop:
416 # since it will break out of the loop in Python, but only leave
417 # the switch case in C++, # disallow break inside a switch
418 self.report_error(
419 o, "'break' is not allowed to be used inside a switch")
420 if len(self.loop_stack):
421 self.loop_stack[-1].AddBreak(self.current_statement_id)
422
423 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> None:
424 if len(self.loop_stack):
425 self.loop_stack[-1].AddContinue(self.current_statement_id)
426
427 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> None:
428 cfg = self.current_cfg()
429 if cfg:
430 cfg.AddDeadend(self.current_statement_id)
431
432 if o.expr:
433 self.accept(o.expr)
434
435 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> None:
436 cfg = self.current_cfg()
437 with pass_state.CfgBranchContext(cfg,
438 self.current_statement_id) as try_ctx:
439 with try_ctx.AddBranch() as try_block:
440 self.accept(o.body)
441
442 for t, v, handler in zip(o.types, o.vars, o.handlers):
443 with try_ctx.AddBranch(try_block.exit):
444 self.accept(handler)
445
446 # 2024-12(andy): SimpleVisitor now has a special case for:
447 #
448 # myvar = [x for x in other]
449 #
450 # This seems like it should affect the control flow graph, since we are no
451 # longer calling oils_visit_assignment_stmt, and are instead calling
452 # oils_visit_assign_to_listcomp.
453 #
454 # We may need more test coverage?
455 # List comprehensions are arguably a weird/legacy part of the mycpp IR that
456 # should be cleaned up.
457 #
458 # We do NOT allow:
459 #
460 # myfunc([x for x in other])
461 #
462 # See mycpp/visitor.py and mycpp/cppgen_pass.py for how this is used.
463
464 #def oils_visit_assign_to_listcomp(self, o: 'mypy.nodes.AssignmentStmt', lval: NameExpr) -> None:
465
466 def oils_visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt',
467 lval: Expression, rval: Expression,
468 current_method_name: Optional[str],
469 at_global_scope: bool) -> None:
470 cfg = self.current_cfg()
471 if cfg:
472 lval_names = []
473 if isinstance(lval, TupleExpr):
474 lval_names.extend(
475 [self.get_ref_name(item) for item in lval.items])
476
477 else:
478 lval_names.append(self.get_ref_name(lval))
479
480 assert len(lval_names), lval
481
482 rval_type = self._GetType(rval)
483 rval_names: List[Optional[util.SymbolPath]] = []
484 if isinstance(rval, CallExpr):
485 # The RHS is either an object constructor or something that
486 # returns a primitive type (e.g. Tuple[int, int] or str).
487 # XXX: When we add inter-procedural analysis we should treat
488 # these not as definitions but as some new kind of assignment.
489 rval_names = [None for _ in lval_names]
490
491 elif isinstance(rval, TupleExpr) and len(lval_names) == 1:
492 # We're constructing a tuple. Since tuples have have a fixed
493 # (and usually small) size, we can name each of the
494 # elements.
495 base = lval_names[0]
496 lval_names = [
497 base + (str(i), ) for i in range(len(rval.items))
498 ]
499 rval_names = [self.get_ref_name(item) for item in rval.items]
500
501 elif isinstance(rval_type, TupleType):
502 # We're unpacking a tuple. Like the tuple construction case,
503 # give each element a name.
504 rval_name = self.get_ref_name(rval)
505 assert rval_name, rval
506 rval_names = [
507 rval_name + (str(i), ) for i in range(len(lval_names))
508 ]
509
510 else:
511 rval_names = [self.get_ref_name(rval)]
512
513 assert len(rval_names) == len(lval_names)
514
515 for lhs, rhs in zip(lval_names, rval_names):
516 assert lhs, lval
517 if rhs:
518 # In this case rhe RHS is another variable. Record the
519 # assignment so we can keep track of aliases.
520 cfg.AddFact(self.current_statement_id,
521 pass_state.Assignment(lhs, rhs))
522 else:
523 # In this case the RHS is either some kind of literal (e.g.
524 # [] or 'foo') or a call to an object constructor. Mark this
525 # statement as an (re-)definition of a variable.
526 cfg.AddFact(
527 self.current_statement_id,
528 pass_state.Definition(
529 lhs, '$HeapObject(h{})'.format(self.heap_counter)),
530 )
531 self.heap_counter += 1
532
533 # TODO: Could simplify this
534 self.current_lval = lval
535 self.accept(lval)
536 self.current_lval = None
537
538 self.accept(rval)
539
540 # Expressions
541
542 def oils_visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> None:
543 self.accept(o.expr)
544 cfg = self.current_cfg()
545 if (cfg and
546 not isinstance(self.dot_exprs[o], pass_state.ModuleMember) and
547 o != self.current_lval):
548 ref = self.get_ref_name(o)
549 if ref:
550 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
551
552 def oils_visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> None:
553 cfg = self.current_cfg()
554 if cfg and o != self.current_lval:
555 is_local = False
556 for name, t in self.local_vars.get(self.current_func_node, []):
557 if name == o.name:
558 is_local = True
559 break
560
561 ref = self.get_ref_name(o)
562 if ref and is_local:
563 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
564
565 def oils_visit_call_expr(
566 self, o: 'mypy.nodes.CallExpr',
567 current_class_name: Optional[util.SymbolPath]) -> None:
568 cfg = self.current_cfg()
569 if self.current_func_node:
570 full_callee = self.resolve_callee(o, current_class_name)
571 if full_callee:
572 self.callees[o] = full_callee
573 cfg.AddFact(
574 self.current_statement_id,
575 pass_state.FunctionCall(
576 SymbolToString(full_callee, delim='.')))
577
578 for i, arg in enumerate(o.args):
579 arg_ref = self.get_ref_name(arg)
580 if arg_ref:
581 cfg.AddFact(self.current_statement_id,
582 pass_state.Bind(arg_ref, full_callee, i))
583
584 self.accept(o.callee)
585 for arg in o.args:
586 self.accept(arg)