OILS / mycpp / control_flow_pass.py View on Github | oils.pub

548 lines, 332 significant
1"""
2control_flow_pass.py - AST pass that builds a control flow graph.
3"""
4import collections
5
6import mypy
7from mypy.nodes import (Block, Expression, Statement, CallExpr, FuncDef,
8 IfStmt, NameExpr, MemberExpr, IndexExpr, TupleExpr,
9 IntExpr)
10
11from mypy.types import CallableType, Instance, Type, UnionType, NoneTyp, TupleType
12
13from mycpp.crash import catch_errors
14from mycpp.util import SymbolToString, SplitPyName
15from mycpp import visitor
16from mycpp import util
17from mycpp.util import SymbolPath
18from mycpp import pass_state
19
20from typing import Dict, List, Union, Optional, overload, TYPE_CHECKING
21
22if TYPE_CHECKING:
23 from mycpp import conversion_pass
24 from mycpp import cppgen_pass
25
26
27def GetObjectTypeName(t: Type) -> SymbolPath:
28 if isinstance(t, Instance):
29 return SplitPyName(t.type.fullname)
30
31 elif isinstance(t, UnionType):
32 assert len(t.items) == 2
33 if isinstance(t.items[0], NoneTyp):
34 return GetObjectTypeName(t.items[1])
35
36 return GetObjectTypeName(t.items[0])
37
38 assert False, t
39
40
41INVALID_ID = -99 # statement IDs are positive
42
43
44class Build(visitor.SimpleVisitor):
45
46 def __init__(self, types: Dict[Expression,
47 Type], virtual: pass_state.Virtual,
48 local_vars: 'cppgen_pass.AllLocalVars',
49 dot_exprs: 'conversion_pass.DotExprs') -> None:
50 visitor.SimpleVisitor.__init__(self)
51
52 self.types = types
53 self.cflow_graphs: Dict[
54 SymbolPath, pass_state.ControlFlowGraph] = collections.defaultdict(
55 pass_state.ControlFlowGraph)
56 self.current_statement_id = INVALID_ID
57 self.current_class_name: Optional[SymbolPath] = None
58 self.current_func_node: Optional[FuncDef] = None
59 self.loop_stack: List[pass_state.CfgLoopContext] = []
60 self.virtual = virtual
61 self.local_vars = local_vars
62 self.dot_exprs = dot_exprs
63 self.heap_counter = 0
64 # statement object -> SymbolPath of the callee
65 self.callees: Dict['mypy.nodes.CallExpr', SymbolPath] = {}
66 self.current_lval: Optional[Expression] = None
67
68 def current_cfg(self) -> pass_state.ControlFlowGraph:
69 if not self.current_func_node:
70 return None
71
72 return self.cflow_graphs[SplitPyName(self.current_func_node.fullname)]
73
74 def resolve_callee(self, o: CallExpr) -> Optional[util.SymbolPath]:
75 """
76 Returns the fully qualified name of the callee in the given call
77 expression.
78
79 Member functions are prefixed by the names of the classes that contain
80 them. For example, the name of the callee in the last statement of the
81 snippet below is `module.SomeObject.Foo`.
82
83 x = module.SomeObject()
84 x.Foo()
85
86 Free-functions defined in the local module are referred to by their
87 normal fully qualified names. The function `foo` in a module called
88 `moduleA` would is named `moduleA.foo`. Calls to free-functions defined
89 in imported modules are named the same way.
90 """
91
92 if isinstance(o.callee, NameExpr):
93 return SplitPyName(o.callee.fullname)
94
95 elif isinstance(o.callee, MemberExpr):
96 if isinstance(o.callee.expr, NameExpr):
97 is_module = isinstance(self.dot_exprs.get(o.callee),
98 pass_state.ModuleMember)
99 if is_module:
100 return (SplitPyName(o.callee.expr.fullname) +
101 (o.callee.name, ))
102
103 elif o.callee.expr.name == 'self':
104 assert self.current_class_name
105 return self.current_class_name + (o.callee.name, )
106
107 else:
108 local_type = None
109 for name, t in self.local_vars.get(self.current_func_node,
110 []):
111 if name == o.callee.expr.name:
112 local_type = t
113 break
114
115 if local_type:
116 if isinstance(local_type, str):
117 return (SplitPyName(local_type) +
118 (o.callee.name, ))
119
120 elif isinstance(local_type, Instance):
121 return (SplitPyName(local_type.type.fullname) +
122 (o.callee.name, ))
123
124 elif isinstance(local_type, UnionType):
125 assert len(local_type.items) == 2
126 return (SplitPyName(
127 local_type.items[0].type.fullname) +
128 (o.callee.expr.name, ))
129
130 else:
131 assert not isinstance(local_type, CallableType)
132 # primitive type or string. don't care.
133 return None
134
135 else:
136 # context or exception handler. probably safe to ignore.
137 return None
138
139 else:
140 t = self.types.get(o.callee.expr)
141 if isinstance(t, Instance):
142 return SplitPyName(t.type.fullname) + (o.callee.name, )
143
144 elif isinstance(t, UnionType):
145 assert len(t.items) == 2
146 return (SplitPyName(t.items[0].type.fullname) +
147 (o.callee.name, ))
148
149 elif o.callee.expr and getattr(o.callee.expr, 'fullname',
150 None):
151 return (SplitPyName(o.callee.expr.fullname) +
152 (o.callee.name, ))
153
154 else:
155 # constructors of things that we don't care about.
156 return None
157
158 # Don't currently get here
159 raise AssertionError()
160
161 def get_ref_name(self, expr: Expression) -> Optional[util.SymbolPath]:
162 """
163 To do dataflow analysis we need to track changes to objects, which
164 requires naming them. This function returns the name of the object
165 referred to by the given expression. If the expression doesn't refer to
166 an object or variable it returns None.
167
168 Objects are named slightly differently than they appear in the source
169 code.
170
171 Objects referenced by local variables are referred to by the name of the
172 local. For example, the name of the object in both statements below is
173 `x`.
174
175 x = module.SomeObject()
176 x = None
177
178 Member expressions are named after the parent object's type. For
179 example, the names of the objects in the member assignment statements
180 below are both `module.SomeObject.member_a`. This makes it possible to
181 track data flow across object members without having to track individual
182 heap objects, which would increase the search space for analyses and
183 slow things down.
184
185 x = module.SomeObject()
186 y = module.SomeObject()
187 x.member_a = 'foo'
188 y.member_a = 'bar'
189
190 Index expressions are named after their bases, for the same reasons as
191 member expressions. The coarse-grained precision should lead to an
192 over-approximation of where objects are in use, but should not miss any
193 references. This should be fine for our purposes. In the snippet below
194 the last two assignments are named `x` and `module.SomeObject.a_list`.
195
196 x = [None] # list[Thing]
197 y = module.SomeObject()
198 x[0] = Thing()
199 y.a_list[1] = Blah()
200
201 Index expressions over tuples are treated differently, though. Tuples
202 have a fixed size, tend to be small, and their elements have distinct
203 types. So, each element can be (and probably needs to be) individually
204 named. In the snippet below, the name of the RHS in the second
205 assignment is `t.0`.
206
207 t = (1, 2, 3, 4)
208 x = t[0]
209
210 The examples above all deal with assignments, but these rules apply to
211 any expression that uses an object or variable.
212 """
213 if isinstance(expr,
214 NameExpr) and expr.name not in {'True', 'False', 'None'}:
215 return (expr.name, )
216
217 elif isinstance(expr, MemberExpr):
218 dot_expr = self.dot_exprs[expr]
219 if isinstance(dot_expr, pass_state.ModuleMember):
220 return dot_expr.module_path + (dot_expr.member, )
221
222 elif isinstance(dot_expr, pass_state.HeapObjectMember):
223 obj_name = self.get_ref_name(dot_expr.object_expr)
224 if obj_name:
225 # XXX: add a new case like pass_state.ExpressionMember for
226 # cases when the LHS of . isn't a reference (e.g.
227 # builtin/assign_osh.py:54)
228 return obj_name + (dot_expr.member, )
229
230 elif isinstance(dot_expr, pass_state.StackObjectMember):
231 return (self.get_ref_name(dot_expr.object_expr) +
232 (dot_expr.member, ))
233
234 elif isinstance(expr, IndexExpr):
235 if isinstance(self.types[expr.base], TupleType):
236 assert isinstance(expr.index, IntExpr)
237 return self.get_ref_name(expr.base) + (str(expr.index.value), )
238
239 return self.get_ref_name(expr.base)
240
241 return None
242
243 #
244 # COPIED from IRBuilder
245 #
246
247 @overload
248 def accept(self, node: Expression) -> None:
249 ...
250
251 @overload
252 def accept(self, node: Statement) -> None:
253 ...
254
255 def accept(self, node: Union[Statement, Expression]) -> None:
256 with catch_errors(self.module_path, node.line):
257 if isinstance(node, Expression):
258 node.accept(self)
259 else:
260 cfg = self.current_cfg()
261 # Most statements have empty visitors because they don't
262 # require any special logic. Create statements for them
263 # here. Don't create statements from blocks to avoid
264 # stuttering.
265 if cfg and not isinstance(node, Block):
266 self.current_statement_id = cfg.AddStatement()
267
268 node.accept(self)
269
270 # Statements
271
272 def visit_for_stmt(self, o: 'mypy.nodes.ForStmt') -> None:
273 cfg = self.current_cfg()
274 with pass_state.CfgLoopContext(
275 cfg, entry=self.current_statement_id) as loop:
276 self.accept(o.expr)
277 self.loop_stack.append(loop)
278 self.accept(o.body)
279 self.loop_stack.pop()
280
281 def _handle_switch(self, expr: Expression, o: 'mypy.nodes.WithStmt',
282 cfg: pass_state.ControlFlowGraph) -> None:
283 assert len(o.body.body) == 1, o.body.body
284 if_node = o.body.body[0]
285 assert isinstance(if_node, IfStmt), if_node
286 cases: util.CaseList = []
287 default_block = util.CollectSwitchCases(self.module_path, if_node,
288 cases)
289 with pass_state.CfgBranchContext(
290 cfg, self.current_statement_id) as branch_ctx:
291 for expr, body in cases:
292 self.accept(expr)
293 assert expr is not None, expr
294 with branch_ctx.AddBranch():
295 self.accept(body)
296
297 if not isinstance(default_block, int):
298 with branch_ctx.AddBranch():
299 self.accept(default_block)
300
301 def visit_with_stmt(self, o: 'mypy.nodes.WithStmt') -> None:
302 cfg = self.current_cfg()
303
304 assert len(o.expr) == 1, o.expr
305 expr = o.expr[0]
306
307 assert isinstance(expr, CallExpr), expr
308 self.accept(expr)
309
310 # Note: we have 'with alloc.ctx_SourceCode'
311 #assert isinstance(expr.callee, NameExpr), expr.callee
312 callee_name = expr.callee.name
313
314 if callee_name == 'switch':
315 self._handle_switch(expr, o, cfg)
316 elif callee_name == 'str_switch':
317 self._handle_switch(expr, o, cfg)
318 elif callee_name == 'tagswitch':
319 self._handle_switch(expr, o, cfg)
320 else:
321 with pass_state.CfgBlockContext(cfg, self.current_statement_id):
322 self.accept(o.body)
323
324 def oils_visit_func_def(self, o: 'mypy.nodes.FuncDef') -> None:
325 # For virtual methods, pretend that the method on the base class calls
326 # the same method on every subclass. This way call sites using the
327 # abstract base class will over-approximate the set of call paths they
328 # can take when checking if they can reach MaybeCollect().
329 if self.current_class_name and self.virtual.IsVirtual(
330 self.current_class_name, o.name):
331 key = (self.current_class_name, o.name)
332 base = self.virtual.virtuals[key]
333 if base:
334 sub = SymbolToString(self.current_class_name + (o.name, ),
335 delim='.')
336 base_key = base[0] + (base[1], )
337 cfg = self.cflow_graphs[base_key]
338 cfg.AddFact(0, pass_state.FunctionCall(sub))
339
340 self.current_func_node = o
341 cfg = self.current_cfg()
342 for arg in o.arguments:
343 cfg.AddFact(0,
344 pass_state.Definition((arg.variable.name, ), '$Empty'))
345
346 self.accept(o.body)
347 self.current_func_node = None
348 self.current_statement_id = INVALID_ID
349
350 def visit_while_stmt(self, o: 'mypy.nodes.WhileStmt') -> None:
351 cfg = self.current_cfg()
352 with pass_state.CfgLoopContext(
353 cfg, entry=self.current_statement_id) as loop:
354 self.accept(o.expr)
355 self.loop_stack.append(loop)
356 self.accept(o.body)
357 self.loop_stack.pop()
358
359 def visit_return_stmt(self, o: 'mypy.nodes.ReturnStmt') -> None:
360 cfg = self.current_cfg()
361 if cfg:
362 cfg.AddDeadend(self.current_statement_id)
363
364 if o.expr:
365 self.accept(o.expr)
366
367 def visit_if_stmt(self, o: 'mypy.nodes.IfStmt') -> None:
368 cfg = self.current_cfg()
369
370 if util.ShouldVisitIfExpr(o):
371 for expr in o.expr:
372 self.accept(expr)
373
374 with pass_state.CfgBranchContext(
375 cfg, self.current_statement_id) as branch_ctx:
376 if util.ShouldVisitIfBody(o):
377 with branch_ctx.AddBranch():
378 for node in o.body:
379 self.accept(node)
380
381 if util.ShouldVisitElseBody(o):
382 with branch_ctx.AddBranch():
383 self.accept(o.else_body)
384
385 def visit_break_stmt(self, o: 'mypy.nodes.BreakStmt') -> None:
386 if len(self.loop_stack):
387 self.loop_stack[-1].AddBreak(self.current_statement_id)
388
389 def visit_continue_stmt(self, o: 'mypy.nodes.ContinueStmt') -> None:
390 if len(self.loop_stack):
391 self.loop_stack[-1].AddContinue(self.current_statement_id)
392
393 def visit_raise_stmt(self, o: 'mypy.nodes.RaiseStmt') -> None:
394 cfg = self.current_cfg()
395 if cfg:
396 cfg.AddDeadend(self.current_statement_id)
397
398 if o.expr:
399 self.accept(o.expr)
400
401 def visit_try_stmt(self, o: 'mypy.nodes.TryStmt') -> None:
402 cfg = self.current_cfg()
403 with pass_state.CfgBranchContext(cfg,
404 self.current_statement_id) as try_ctx:
405 with try_ctx.AddBranch() as try_block:
406 self.accept(o.body)
407
408 for t, v, handler in zip(o.types, o.vars, o.handlers):
409 with try_ctx.AddBranch(try_block.exit):
410 self.accept(handler)
411
412 # 2024-12(andy): SimpleVisitor now has a special case for:
413 #
414 # myvar = [x for x in other]
415 #
416 # This seems like it should affect the control flow graph, since we are no
417 # longer calling oils_visit_assignment_stmt, and are instead calling
418 # oils_visit_assign_to_listcomp.
419 #
420 # We may need more test coverage?
421 # List comprehensions are arguably a weird/legacy part of the mycpp IR that
422 # should be cleaned up.
423 #
424 # We do NOT allow:
425 #
426 # myfunc([x for x in other])
427 #
428 # See mycpp/visitor.py and mycpp/cppgen_pass.py for how this is used.
429
430 #def oils_visit_assign_to_listcomp(self, o: 'mypy.nodes.AssignmentStmt', lval: NameExpr) -> None:
431
432 def oils_visit_assignment_stmt(self, o: 'mypy.nodes.AssignmentStmt',
433 lval: Expression, rval: Expression) -> None:
434 cfg = self.current_cfg()
435 if cfg:
436 lval_names = []
437 if isinstance(lval, TupleExpr):
438 lval_names.extend(
439 [self.get_ref_name(item) for item in lval.items])
440
441 else:
442 lval_names.append(self.get_ref_name(lval))
443
444 assert len(lval_names), lval
445
446 rval_type = self.types[rval]
447 rval_names: List[Optional[util.SymbolPath]] = []
448 if isinstance(rval, CallExpr):
449 # The RHS is either an object constructor or something that
450 # returns a primitive type (e.g. Tuple[int, int] or str).
451 # XXX: When we add inter-procedural analysis we should treat
452 # these not as definitions but as some new kind of assignment.
453 rval_names = [None for _ in lval_names]
454
455 elif isinstance(rval, TupleExpr) and len(lval_names) == 1:
456 # We're constructing a tuple. Since tuples have have a fixed
457 # (and usually small) size, we can name each of the
458 # elements.
459 base = lval_names[0]
460 lval_names = [
461 base + (str(i), ) for i in range(len(rval.items))
462 ]
463 rval_names = [self.get_ref_name(item) for item in rval.items]
464
465 elif isinstance(rval_type, TupleType):
466 # We're unpacking a tuple. Like the tuple construction case,
467 # give each element a name.
468 rval_name = self.get_ref_name(rval)
469 assert rval_name, rval
470 rval_names = [
471 rval_name + (str(i), ) for i in range(len(lval_names))
472 ]
473
474 else:
475 rval_names = [self.get_ref_name(rval)]
476
477 assert len(rval_names) == len(lval_names)
478
479 for lhs, rhs in zip(lval_names, rval_names):
480 assert lhs, lval
481 if rhs:
482 # In this case rhe RHS is another variable. Record the
483 # assignment so we can keep track of aliases.
484 cfg.AddFact(self.current_statement_id,
485 pass_state.Assignment(lhs, rhs))
486 else:
487 # In this case the RHS is either some kind of literal (e.g.
488 # [] or 'foo') or a call to an object constructor. Mark this
489 # statement as an (re-)definition of a variable.
490 cfg.AddFact(
491 self.current_statement_id,
492 pass_state.Definition(
493 lhs, '$HeapObject(h{})'.format(self.heap_counter)),
494 )
495 self.heap_counter += 1
496
497 # TODO: Could simplify this
498 self.current_lval = lval
499 self.accept(lval)
500 self.current_lval = None
501
502 self.accept(rval)
503
504 # Expressions
505
506 def oils_visit_member_expr(self, o: 'mypy.nodes.MemberExpr') -> None:
507 self.accept(o.expr)
508 cfg = self.current_cfg()
509 if (cfg and
510 not isinstance(self.dot_exprs[o], pass_state.ModuleMember) and
511 o != self.current_lval):
512 ref = self.get_ref_name(o)
513 if ref:
514 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
515
516 def oils_visit_name_expr(self, o: 'mypy.nodes.NameExpr') -> None:
517 cfg = self.current_cfg()
518 if cfg and o != self.current_lval:
519 is_local = False
520 for name, t in self.local_vars.get(self.current_func_node, []):
521 if name == o.name:
522 is_local = True
523 break
524
525 ref = self.get_ref_name(o)
526 if ref and is_local:
527 cfg.AddFact(self.current_statement_id, pass_state.Use(ref))
528
529 def oils_visit_call_expr(self, o: 'mypy.nodes.CallExpr') -> None:
530 cfg = self.current_cfg()
531 if self.current_func_node:
532 full_callee = self.resolve_callee(o)
533 if full_callee:
534 self.callees[o] = full_callee
535 cfg.AddFact(
536 self.current_statement_id,
537 pass_state.FunctionCall(
538 SymbolToString(full_callee, delim='.')))
539
540 for i, arg in enumerate(o.args):
541 arg_ref = self.get_ref_name(arg)
542 if arg_ref:
543 cfg.AddFact(self.current_statement_id,
544 pass_state.Bind(arg_ref, full_callee, i))
545
546 self.accept(o.callee)
547 for arg in o.args:
548 self.accept(arg)