OILS / mycpp / control_flow_pass.py View on Github | oils.pub

555 lines, 339 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5
6import mypy
7from mypy.nodes import (Block, Expression, Statement, CallExpr, FuncDef,
8 IfStmt, NameExpr, MemberExpr, IndexExpr, TupleExpr,
9 IntExpr)
10
11from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
12
13from mycpp.crash import catch_errors
14from mycpp.util import SymbolToString, SplitPyName
15from mycpp import visitor
16from mycpp import util
17from mycpp.util import SymbolPath
18from mycpp import pass_state
19
20from typing import Dict, List, Union, Optional, overload, TYPE_CHECKING
21
22if TYPE_CHECKING:
23 from mycpp import conversion_pass
24 from mycpp import cppgen_pass
25
26
27def GetObjectTypeName(t: Type) -> SymbolPath:
28 if isinstance(t, Instance):
29 return SplitPyName(t.type.fullname)
30
31 elif isinstance(t, UnionType):
32 assert len(t.items) == 2
33 if isinstance(t.items[0], NoneTyp):
34 return GetObjectTypeName(t.items[1])
35
36 return GetObjectTypeName(t.items[0])
37
38 assert False, t
39
40
41INVALID_ID = -99 # statement IDs are positive
42
43
44class Build(visitor.SimpleVisitor):
45
46 def __init__(self, types: Dict[Expression,
47 Type], virtual: pass_state.Virtual,
48 local_vars: 'cppgen_pass.AllLocalVars',
49 dot_exprs: 'conversion_pass.DotExprs') -> None:
50 visitor.SimpleVisitor.__init__(self)
51
52 self.types = types
53 self.cflow_graphs: Dict[
54 SymbolPath, pass_state.ControlFlowGraph] = collections.defaultdict(
55 pass_state.ControlFlowGraph)
56 self.current_statement_id = INVALID_ID
57 self.current_func_node: Optional[FuncDef] = None
58 self.loop_stack: List[pass_state.CfgLoopContext] = []
59 self.virtual = virtual
60 self.local_vars = local_vars
61 self.dot_exprs = dot_exprs
62 self.heap_counter = 0
63 # statement object -> SymbolPath of the callee
64 self.callees: Dict['mypy.nodes.CallExpr', SymbolPath] = {}
65 self.current_lval: Optional[Expression] = None
66
67 def current_cfg(self) -> pass_state.ControlFlowGraph:
68 if not self.current_func_node:
69 return None
70
71 return self.cflow_graphs[SplitPyName(self.current_func_node.fullname)]
72
73 def resolve_callee(
74 self, o: CallExpr, current_class_name: Optional[util.SymbolPath]
75 ) -> Optional[util.SymbolPath]:
76 """
77 Returns the fully qualified name of the callee in the given call
78 expression.
79
80 Member functions are prefixed by the names of the classes that contain
81 them. For example, the name of the callee in the last statement of the
82 snippet below is `module.SomeObject.Foo`.
83
84 x = module.SomeObject()
85 x.Foo()
86
87 Free-functions defined in the local module are referred to by their
88 normal fully qualified names. The function `foo` in a module called
89 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
90 in imported modules are named the same way.
91 """
92
93 if isinstance(o.callee, NameExpr):
94 return SplitPyName(o.callee.fullname)
95
96 elif isinstance(o.callee, MemberExpr):
97 if isinstance(o.callee.expr, NameExpr):
98 is_module = isinstance(self.dot_exprs.get(o.callee),
99 pass_state.ModuleMember)
100 if is_module:
101 return (SplitPyName(o.callee.expr.fullname) +
102 (o.callee.name, ))
103
104 elif o.callee.expr.name == 'self':
105 assert current_class_name
106 return current_class_name + (o.callee.name, )
107
108 else:
109 local_type = None
110 for name, t in self.local_vars.get(self.current_func_node,
111 []):
112 if name == o.callee.expr.name:
113 local_type = t
114 break
115
116 if local_type:
117 if isinstance(local_type, str):
118 return (SplitPyName(local_type) +
119 (o.callee.name, ))
120
121 elif isinstance(local_type, Instance):
122 return (SplitPyName(local_type.type.fullname) +
123 (o.callee.name, ))
124
125 elif isinstance(local_type, UnionType):
126 assert len(local_type.items) == 2
127 return (SplitPyName(
128 local_type.items[0].type.fullname) +
129 (o.callee.expr.name, ))
130
131 else:
132 assert not isinstance(local_type, CallableType)
133 # primitive type or string. don't care.
134 return None
135
136 else:
137 # context or exception handler. probably safe to ignore.
138 return None
139
140 else:
141 t = self.types.get(o.callee.expr)
142 if isinstance(t, Instance):
143 return SplitPyName(t.type.fullname) + (o.callee.name, )
144
145 elif isinstance(t, UnionType):
146 assert len(t.items) == 2
147 return (SplitPyName(t.items[0].type.fullname) +
148 (o.callee.name, ))
149
150 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
151 None):
152 return (SplitPyName(o.callee.expr.fullname) +
153 (o.callee.name, ))
154
155 else:
156 # constructors of things that we don't care about.
157 return None
158
159 # Don't currently get here
160 raise AssertionError()
161
162 def get_ref_name(self, expr: Expression) -> Optional[util.SymbolPath]:
163 """
164 To do dataflow analysis we need to track changes to objects, which
165 requires naming them. This function returns the name of the object
166 referred to by the given expression. If the expression doesn't refer to
167 an object or variable it returns None.
168
169 Objects are named slightly differently than they appear in the source
170 code.
171
172 Objects referenced by local variables are referred to by the name of the
173 local. For example, the name of the object in both statements below is
174 `x`.
175
176 x = module.SomeObject()
177 x = None
178
179 Member expressions are named after the parent object's type. For
180 example, the names of the objects in the member assignment statements
181 below are both `module.SomeObject.member_a`. This makes it possible to
182 track data flow across object members without having to track individual
183 heap objects, which would increase the search space for analyses and
184 slow things down.
185
186 x = module.SomeObject()
187 y = module.SomeObject()
188 x.member_a = 'foo'
189 y.member_a = 'bar'
190
191 Index expressions are named after their bases, for the same reasons as
192 member expressions. The coarse-grained precision should lead to an
193 over-approximation of where objects are in use, but should not miss any
194 references. This should be fine for our purposes. In the snippet below
195 the last two assignments are named `x` and `module.SomeObject.a_list`.
196
197 x = [None] # list[Thing]
198 y = module.SomeObject()
199 x[0] = Thing()
200 y.a_list[1] = Blah()
201
202 Index expressions over tuples are treated differently, though. Tuples
203 have a fixed size, tend to be small, and their elements have distinct
204 types. So, each element can be (and probably needs to be) individually
205 named. In the snippet below, the name of the RHS in the second
206 assignment is `t.0`.
207
208 t = (1, 2, 3, 4)
209 x = t[0]
210
211 The examples above all deal with assignments, but these rules apply to
212 any expression that uses an object or variable.
213 """
214 if isinstance(expr,
215 NameExpr) and expr.name not in {'True', 'False', 'None'}:
216 return (expr.name, )
217
218 elif isinstance(expr, MemberExpr):
219 dot_expr = self.dot_exprs[expr]
220 if isinstance(dot_expr, pass_state.ModuleMember):
221 return dot_expr.module_path + (dot_expr.member, )
222
223 elif isinstance(dot_expr, pass_state.HeapObjectMember):
224 obj_name = self.get_ref_name(dot_expr.object_expr)
225 if obj_name:
226 # XXX: add a new case like pass_state.ExpressionMember for
227 # cases when the LHS of . isn't a reference (e.g.
228 # builtin/assign_osh.py:54)
229 return obj_name + (dot_expr.member, )
230
231 elif isinstance(dot_expr, pass_state.StackObjectMember):
232 return (self.get_ref_name(dot_expr.object_expr) +
233 (dot_expr.member, ))
234
235 elif isinstance(expr, IndexExpr):
236 if isinstance(self.types[expr.base], TupleType):
237 assert isinstance(expr.index, IntExpr)
238 return self.get_ref_name(expr.base) + (str(expr.index.value), )
239
240 return self.get_ref_name(expr.base)
241
242 return None
243
244 #
245 # COPIED from IRBuilder
246 #
247
248 @overload
249 def accept(self, node: Expression) -> None:
250 ...
251
252 @overload
253 def accept(self, node: Statement) -> None:
254 ...
255
256 def accept(self, node: Union[Statement, Expression]) -> None:
257 with catch_errors(self.module_path, node.line):
258 if isinstance(node, Expression):
259 node.accept(self)
260 else:
261 cfg = self.current_cfg()
262 # Most statements have empty visitors because they don't
263 # require any special logic. Create statements for them
264 # here. Don't create statements from blocks to avoid
265 # stuttering.
266 if cfg and not isinstance(node, Block):
267 self.current_statement_id = cfg.AddStatement()
268
269 node.accept(self)
270
271 # Statements
272
273 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> None:
274 cfg = self.current_cfg()
275 with pass_state.CfgLoopContext(
276 cfg, entry=self.current_statement_id) as loop:
277 self.accept(o.expr)
278 self.loop_stack.append(loop)
279 self.accept(o.body)
280 self.loop_stack.pop()
281
282 def _handle_switch(self, expr: Expression, o: 'mypy.nodes.WithStmt',
283 cfg: pass_state.ControlFlowGraph) -> None:
284 assert len(o.body.body) == 1, o.body.body
285 if_node = o.body.body[0]
286 assert isinstance(if_node, IfStmt), if_node
287 cases: util.CaseList = []
288 default_block = util.CollectSwitchCases(self.module_path, if_node,
289 cases)
290 with pass_state.CfgBranchContext(
291 cfg, self.current_statement_id) as branch_ctx:
292 for expr, body in cases:
293 self.accept(expr)
294 assert expr is not None, expr
295 with branch_ctx.AddBranch():
296 self.accept(body)
297
298 if not isinstance(default_block, int):
299 with branch_ctx.AddBranch():
300 self.accept(default_block)
301
302 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> None:
303 cfg = self.current_cfg()
304
305 assert len(o.expr) == 1, o.expr
306 expr = o.expr[0]
307
308 assert isinstance(expr, CallExpr), expr
309 self.accept(expr)
310
311 # Note: we have 'with alloc.ctx_SourceCode'
312 #assert isinstance(expr.callee, NameExpr), expr.callee
313 callee_name = expr.callee.name
314
315 if callee_name == 'switch':
316 self._handle_switch(expr, o, cfg)
317 elif callee_name == 'str_switch':
318 self._handle_switch(expr, o, cfg)
319 elif callee_name == 'tagswitch':
320 self._handle_switch(expr, o, cfg)
321 else:
322 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
323 self.accept(o.body)
324
325 def oils_visit_func_def(self, o: 'mypy.nodes.FuncDef',
326 current_class_name: Optional[util.SymbolPath],
327 current_method_name: Optional[str]) -> None:
328 # For virtual methods, pretend that the method on the base class calls
329 # the same method on every subclass. This way call sites using the
330 # abstract base class will over-approximate the set of call paths they
331 # can take when checking if they can reach MaybeCollect().
332 if current_class_name and self.virtual.IsVirtual(
333 current_class_name, o.name):
334 key = (current_class_name, o.name)
335 base = self.virtual.virtuals[key]
336 if base:
337 sub = SymbolToString(current_class_name + (o.name, ),
338 delim='.')
339 base_key = base[0] + (base[1], )
340 cfg = self.cflow_graphs[base_key]
341 cfg.AddFact(0, pass_state.FunctionCall(sub))
342
343 self.current_func_node = o
344 cfg = self.current_cfg()
345 for arg in o.arguments:
346 cfg.AddFact(0,
347 pass_state.Definition((arg.variable.name, ), '$Empty'))
348
349 self.accept(o.body)
350 self.current_func_node = None
351 self.current_statement_id = INVALID_ID
352
353 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> None:
354 cfg = self.current_cfg()
355 with pass_state.CfgLoopContext(
356 cfg, entry=self.current_statement_id) as loop:
357 self.accept(o.expr)
358 self.loop_stack.append(loop)
359 self.accept(o.body)
360 self.loop_stack.pop()
361
362 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> None:
363 cfg = self.current_cfg()
364 if cfg:
365 cfg.AddDeadend(self.current_statement_id)
366
367 if o.expr:
368 self.accept(o.expr)
369
370 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> None:
371 cfg = self.current_cfg()
372
373 if util.ShouldVisitIfExpr(o):
374 for expr in o.expr:
375 self.accept(expr)
376
377 with pass_state.CfgBranchContext(
378 cfg, self.current_statement_id) as branch_ctx:
379 if util.ShouldVisitIfBody(o):
380 with branch_ctx.AddBranch():
381 for node in o.body:
382 self.accept(node)
383
384 if util.ShouldVisitElseBody(o):
385 with branch_ctx.AddBranch():
386 self.accept(o.else_body)
387
388 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> None:
389 if len(self.loop_stack):
390 self.loop_stack[-1].AddBreak(self.current_statement_id)
391
392 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> None:
393 if len(self.loop_stack):
394 self.loop_stack[-1].AddContinue(self.current_statement_id)
395
396 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> None:
397 cfg = self.current_cfg()
398 if cfg:
399 cfg.AddDeadend(self.current_statement_id)
400
401 if o.expr:
402 self.accept(o.expr)
403
404 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> None:
405 cfg = self.current_cfg()
406 with pass_state.CfgBranchContext(cfg,
407 self.current_statement_id) as try_ctx:
408 with try_ctx.AddBranch() as try_block:
409 self.accept(o.body)
410
411 for t, v, handler in zip(o.types, o.vars, o.handlers):
412 with try_ctx.AddBranch(try_block.exit):
413 self.accept(handler)
414
415 # 2024-12(andy): SimpleVisitor now has a special case for:
416 #
417 # myvar = [x for x in other]
418 #
419 # This seems like it should affect the control flow graph, since we are no
420 # longer calling oils_visit_assignment_stmt, and are instead calling
421 # oils_visit_assign_to_listcomp.
422 #
423 # We may need more test coverage?
424 # List comprehensions are arguably a weird/legacy part of the mycpp IR that
425 # should be cleaned up.
426 #
427 # We do NOT allow:
428 #
429 # myfunc([x for x in other])
430 #
431 # See mycpp/visitor.py and mycpp/cppgen_pass.py for how this is used.
432
433 #def oils_visit_assign_to_listcomp(self, o: 'mypy.nodes.AssignmentStmt', lval: NameExpr) -> None:
434
435 def oils_visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt',
436 lval: Expression, rval: Expression,
437 current_method_name: Optional[str],
438 at_global_scope: bool) -> None:
439 cfg = self.current_cfg()
440 if cfg:
441 lval_names = []
442 if isinstance(lval, TupleExpr):
443 lval_names.extend(
444 [self.get_ref_name(item) for item in lval.items])
445
446 else:
447 lval_names.append(self.get_ref_name(lval))
448
449 assert len(lval_names), lval
450
451 rval_type = self.types[rval]
452 rval_names: List[Optional[util.SymbolPath]] = []
453 if isinstance(rval, CallExpr):
454 # The RHS is either an object constructor or something that
455 # returns a primitive type (e.g. Tuple[int, int] or str).
456 # XXX: When we add inter-procedural analysis we should treat
457 # these not as definitions but as some new kind of assignment.
458 rval_names = [None for _ in lval_names]
459
460 elif isinstance(rval, TupleExpr) and len(lval_names) == 1:
461 # We're constructing a tuple. Since tuples have have a fixed
462 # (and usually small) size, we can name each of the
463 # elements.
464 base = lval_names[0]
465 lval_names = [
466 base + (str(i), ) for i in range(len(rval.items))
467 ]
468 rval_names = [self.get_ref_name(item) for item in rval.items]
469
470 elif isinstance(rval_type, TupleType):
471 # We're unpacking a tuple. Like the tuple construction case,
472 # give each element a name.
473 rval_name = self.get_ref_name(rval)
474 assert rval_name, rval
475 rval_names = [
476 rval_name + (str(i), ) for i in range(len(lval_names))
477 ]
478
479 else:
480 rval_names = [self.get_ref_name(rval)]
481
482 assert len(rval_names) == len(lval_names)
483
484 for lhs, rhs in zip(lval_names, rval_names):
485 assert lhs, lval
486 if rhs:
487 # In this case rhe RHS is another variable. Record the
488 # assignment so we can keep track of aliases.
489 cfg.AddFact(self.current_statement_id,
490 pass_state.Assignment(lhs, rhs))
491 else:
492 # In this case the RHS is either some kind of literal (e.g.
493 # [] or 'foo') or a call to an object constructor. Mark this
494 # statement as an (re-)definition of a variable.
495 cfg.AddFact(
496 self.current_statement_id,
497 pass_state.Definition(
498 lhs, '$HeapObject(h{})'.format(self.heap_counter)),
499 )
500 self.heap_counter += 1
501
502 # TODO: Could simplify this
503 self.current_lval = lval
504 self.accept(lval)
505 self.current_lval = None
506
507 self.accept(rval)
508
509 # Expressions
510
511 def oils_visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> None:
512 self.accept(o.expr)
513 cfg = self.current_cfg()
514 if (cfg and
515 not isinstance(self.dot_exprs[o], pass_state.ModuleMember) and
516 o != self.current_lval):
517 ref = self.get_ref_name(o)
518 if ref:
519 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
520
521 def oils_visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> None:
522 cfg = self.current_cfg()
523 if cfg and o != self.current_lval:
524 is_local = False
525 for name, t in self.local_vars.get(self.current_func_node, []):
526 if name == o.name:
527 is_local = True
528 break
529
530 ref = self.get_ref_name(o)
531 if ref and is_local:
532 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
533
534 def oils_visit_call_expr(
535 self, o: 'mypy.nodes.CallExpr',
536 current_class_name: Optional[util.SymbolPath]) -> None:
537 cfg = self.current_cfg()
538 if self.current_func_node:
539 full_callee = self.resolve_callee(o, current_class_name)
540 if full_callee:
541 self.callees[o] = full_callee
542 cfg.AddFact(
543 self.current_statement_id,
544 pass_state.FunctionCall(
545 SymbolToString(full_callee, delim='.')))
546
547 for i, arg in enumerate(o.args):
548 arg_ref = self.get_ref_name(arg)
549 if arg_ref:
550 cfg.AddFact(self.current_statement_id,
551 pass_state.Bind(arg_ref, full_callee, i))
552
553 self.accept(o.callee)
554 for arg in o.args:
555 self.accept(arg)