OILS / mycpp / const_pass.py View on Github | oils.pub

298 lines, 181 significant
1"""
2const_pass.py - AST pass that collects string constants.
3
4Instead of emitting a dynamic allocation StrFromC("foo"), we emit a
5GLOBAL_STR(str99, "foo"), and then a reference to str99.
6"""
7import collections
8import json
9import hashlib
10import string
11
12from mypy.nodes import (Expression, StrExpr, CallExpr, FuncDef, ClassDef,
13 MypyFile)
14from mypy.types import Type
15
16from mycpp import format_strings
17from mycpp import util
18from mycpp.util import log
19from mycpp import visitor
20
21from typing import Dict, List, Tuple, Counter, TextIO, Union, Optional
22
23_ = log
24
25_ALPHABET = string.ascii_lowercase + string.ascii_uppercase
26_ALPHABET = _ALPHABET[:32]
27
28AllStrings = Dict[Union[int, StrExpr], str] # Node -> raw string
29UniqueStrings = Dict[bytes, str] # SHA1 digest -> raw string
30HashedStrings = Dict[str, List[str]] # short hash -> raw string
31VarNames = Dict[str, str] # raw string -> variable name
32
33MethodDefinitions = Dict[util.SymbolPath,
34 List[str]] # Class name -> List of method names
35
36ClassNamespaceDict = Dict[util.SymbolPath, str] # Class name -> Namespace name
37
38
39class GlobalStrings:
40
41 def __init__(self) -> None:
42 # SHA1 hash -> encoded bytes
43 self.all_strings: AllStrings = {}
44 self.var_names: VarNames = {}
45
46 # OLD
47 self.unique: Dict[bytes, bytes] = {}
48 self.int_id_lookup: Dict[Expression, str] = {}
49 self.pairs: List[Tuple[str, str]] = []
50
51 def Add(self, key: Union[int, StrExpr], s: str) -> None:
52 """
53 key: int for tests
54 StrExpr node for production
55 """
56 self.all_strings[key] = s
57
58 def ComputeStableVarNames(self) -> None:
59 unique = _MakeUniqueStrings(self.all_strings)
60 hash15 = _HashAndCollect(unique)
61 self.var_names = _HandleCollisions(hash15)
62
63 def GetVarName(self, node: StrExpr) -> str:
64 # StrExpr -> str -> variable names
65 return self.var_names[self.all_strings[node]]
66
67 def WriteConstants(self, out_f: TextIO) -> None:
68 if util.SMALL_STR:
69 macro_name = 'GLOBAL_STR2'
70 else:
71 macro_name = 'GLOBAL_STR'
72
73 # sort by the string value itself
74 for raw_string in sorted(self.var_names):
75 var_name = self.var_names[raw_string]
76 out_f.write('%s(%s, %s);\n' %
77 (macro_name, var_name, json.dumps(raw_string)))
78
79 out_f.write('\n')
80
81
82class MethodDefs:
83
84 def __init__(self) -> None:
85 self.method_defs: MethodDefinitions = {}
86
87 def Add(self, class_name: util.SymbolPath, method_name: str) -> None:
88 if class_name in self.method_defs:
89 self.method_defs[class_name].append(method_name)
90 else:
91 self.method_defs[class_name] = [method_name]
92
93 def ClassHasMethod(self, class_name: util.SymbolPath,
94 method_name: str) -> bool:
95 return (class_name in self.method_defs and
96 method_name in self.method_defs[class_name])
97
98
99class ClassNamespaces:
100
101 def __init__(self) -> None:
102 self.class_namespaces: ClassNamespaceDict = {}
103
104 def Set(self, class_name: util.SymbolPath, namespace_name: str) -> None:
105 self.class_namespaces[class_name] = namespace_name
106
107 def GetClassNamespace(self, class_name: util.SymbolPath) -> str:
108 return self.class_namespaces[class_name]
109
110
111class Collect(visitor.TypedVisitor):
112
113 def __init__(self, types: Dict[Expression, Type],
114 global_strings: GlobalStrings, method_defs: MethodDefs,
115 class_namespaces: ClassNamespaces) -> None:
116 visitor.TypedVisitor.__init__(self, types)
117 self.global_strings = global_strings
118
119 # Only generate unique strings.
120 # Before this optimization, _gen/bin/oils_for_unix.mycpp.cc went up to:
121 # "str2824"
122 # After:
123 # "str1789"
124 #
125 # So it saved over 1000 strings.
126 #
127 # The C++ compiler should also optimize it, but it's easy for us to
128 # generate less source code.
129
130 # unique string value -> id
131 self.unique: Dict[str, str] = {}
132 self.unique_id = 0
133
134 self.method_defs = method_defs
135 self.class_namespaces = class_namespaces
136
137 self.current_file_name: Optional[str] = None
138
139 def verify_format_string(self, fmt: StrExpr) -> None:
140 try:
141 format_strings.Parse(fmt.value)
142 except RuntimeError as e:
143 self.report_error(fmt, str(e))
144
145 def oils_visit_format_expr(self, left: Expression,
146 right: Expression) -> None:
147 if isinstance(left, StrExpr):
148 self.verify_format_string(left)
149 # Do NOT visit the left, because we write it literally
150 pass
151 else:
152 self.accept(left)
153 self.accept(right)
154
155 def visit_str_expr(self, o: StrExpr) -> None:
156 raw_string = format_strings.DecodeMyPyString(o.value)
157 self.global_strings.Add(o, raw_string)
158
159 def oils_visit_probe_call(self, o: CallExpr) -> None:
160 # Don't generate constants for DTRACE_PROBE()
161 pass
162
163 def oils_visit_log_call(self, fmt: StrExpr,
164 args: List[Expression]) -> None:
165 self.verify_format_string(fmt)
166 if len(args) == 0:
167 self.accept(fmt)
168 return
169
170 # Don't generate a string constant for the format string, which is an
171 # inlined C string, not a mycpp GC string
172 for i, arg in enumerate(args):
173 self.accept(arg)
174
175 def oils_visit_class_def(
176 self, o: ClassDef, base_class_sym: Optional[util.SymbolPath],
177 current_class_name: Optional[util.SymbolPath]) -> None:
178
179 for stmt in o.defs.body:
180 if isinstance(stmt, FuncDef):
181 self.method_defs.Add(current_class_name, stmt.name)
182 self.class_namespaces.Set(current_class_name,
183 self.current_file_name)
184 super().oils_visit_class_def(o, base_class_sym, current_class_name)
185
186 def oils_visit_mypy_file(self, o: MypyFile) -> None:
187 self.current_file_name = o.name
188 super().oils_visit_mypy_file(o)
189
190
191def _MakeUniqueStrings(all_strings: AllStrings) -> UniqueStrings:
192 """
193 Given all the strings, make a smaller set of unique strings.
194 """
195 unique: UniqueStrings = {}
196 for _, raw_string in all_strings.items():
197 b = raw_string.encode('utf-8')
198 h = hashlib.sha1(b).digest()
199 #print(repr(h))
200
201 if h in unique:
202 # extremely unlikely
203 assert unique[h] == raw_string, ("SHA1 hash collision! %r and %r" %
204 (unique[h], b))
205 unique[h] = raw_string
206 return unique
207
208
209def _ShortHash15(h: bytes) -> str:
210 """
211 Given a SHA1, create a 15 bit hash value.
212
213 We use three base-(2**5) aka base-32 digits, encoded as letters.
214 """
215 bits16 = h[0] | h[1] << 8
216
217 assert 0 <= bits16 < 2**16, bits16
218
219 # 5 least significant bits
220 d1 = bits16 & 0b11111
221 bits16 >>= 5
222 d2 = bits16 & 0b11111
223 bits16 >>= 5
224 d3 = bits16 & 0b11111
225 bits16 >>= 5
226
227 return _ALPHABET[d1] + _ALPHABET[d2] + _ALPHABET[d3]
228
229
230def _HashAndCollect(unique: UniqueStrings) -> HashedStrings:
231 """
232 Use the short hash.
233 """
234 hash15 = collections.defaultdict(list)
235 for sha1, b in unique.items():
236 short_hash = _ShortHash15(sha1)
237 hash15[short_hash].append(b)
238 return hash15
239
240
241def _SummarizeCollisions(hash15: HashedStrings) -> None:
242 collisions: Counter[int] = collections.Counter()
243 for short_hash, strs in hash15.items():
244 n = len(strs)
245 #if n > 1:
246 if 0:
247 print(short_hash)
248 print(strs)
249 collisions[n] += 1
250
251 log('%10s %s', 'COUNT', 'ITEM')
252 for item, count in collisions.most_common():
253 log('%10d %s', count, item)
254
255
256def _HandleCollisions(hash15: HashedStrings) -> VarNames:
257 var_names: VarNames = {}
258 for short_hash, bytes_list in hash15.items():
259 bytes_list.sort() # stable order, will bump some of the strings
260 for i, b in enumerate(bytes_list):
261 if i == 0:
262 var_names[b] = 'S_%s' % short_hash
263 else:
264 var_names[b] = 'S_%s_%d' % (short_hash, i)
265 return var_names
266
267
268def HashDemo() -> None:
269 import sys
270
271 # 5 bits
272 #_ALPHABET = _ALPHABET.replace('l', 'Z') # use a nicer one?
273 log('alpha %r', _ALPHABET)
274
275 global_strings = GlobalStrings()
276
277 all_lines = sys.stdin.readlines()
278 for i, line in enumerate(all_lines):
279 global_strings.Add(i, line.strip())
280
281 unique = _MakeUniqueStrings(global_strings.all_strings)
282 hash15 = _HashAndCollect(unique)
283 var_names = _HandleCollisions(hash15)
284
285 if 0:
286 for b, var_name in var_names.items():
287 if var_name[-1].isdigit():
288 log('%r %r', var_name, b)
289 #log('%r %r', var_name, b)
290
291 log('Unique %d' % len(unique))
292 log('hash15 %d' % len(hash15))
293
294 _SummarizeCollisions(hash15)
295
296
297if __name__ == '__main__':
298 HashDemo()