OILS / mycpp / const_pass.py View on Github | oils.pub

290 lines, 174 significant
1"""
2const_pass.py - AST pass that collects string constants.
3
4Instead of emitting a dynamic allocation StrFromC("foo"), we emit a
5GLOBAL_STR(str99, "foo"), and then a reference to str99.
6"""
7import collections
8import json
9import hashlib
10import string
11
12from mypy.nodes import (Expression, StrExpr, CallExpr, FuncDef, ClassDef,
13 MypyFile)
14from mypy.types import Type
15
16from mycpp import format_strings
17from mycpp import util
18from mycpp.util import log
19from mycpp import visitor
20
21from typing import Dict, List, Tuple, Counter, TextIO, Union, Optional
22
23_ = log
24
25_ALPHABET = string.ascii_lowercase + string.ascii_uppercase
26_ALPHABET = _ALPHABET[:32]
27
28AllStrings = Dict[Union[int, StrExpr], str] # Node -> raw string
29UniqueStrings = Dict[bytes, str] # SHA1 digest -> raw string
30HashedStrings = Dict[str, List[str]] # short hash -> raw string
31VarNames = Dict[str, str] # raw string -> variable name
32
33MethodDefinitions = Dict[util.SymbolPath,
34 List[str]] # Class name -> List of method names
35
36ClassNamespaceDict = Dict[util.SymbolPath, str] # Class name -> Namespace name
37
38
39class GlobalStrings:
40
41 def __init__(self) -> None:
42 # SHA1 hash -> encoded bytes
43 self.all_strings: AllStrings = {}
44 self.var_names: VarNames = {}
45
46 # OLD
47 self.unique: Dict[bytes, bytes] = {}
48 self.int_id_lookup: Dict[Expression, str] = {}
49 self.pairs: List[Tuple[str, str]] = []
50
51 def Add(self, key: Union[int, StrExpr], s: str) -> None:
52 """
53 key: int for tests
54 StrExpr node for production
55 """
56 self.all_strings[key] = s
57
58 def ComputeStableVarNames(self) -> None:
59 unique = _MakeUniqueStrings(self.all_strings)
60 hash15 = _HashAndCollect(unique)
61 self.var_names = _HandleCollisions(hash15)
62
63 def GetVarName(self, node: StrExpr) -> str:
64 # StrExpr -> str -> variable names
65 return self.var_names[self.all_strings[node]]
66
67 def WriteConstants(self, out_f: TextIO) -> None:
68 if util.SMALL_STR:
69 macro_name = 'GLOBAL_STR2'
70 else:
71 macro_name = 'GLOBAL_STR'
72
73 # sort by the string value itself
74 for raw_string in sorted(self.var_names):
75 var_name = self.var_names[raw_string]
76 out_f.write('%s(%s, %s);\n' %
77 (macro_name, var_name, json.dumps(raw_string)))
78
79 out_f.write('\n')
80
81
82class MethodDefs:
83
84 def __init__(self) -> None:
85 self.method_defs: MethodDefinitions = {}
86
87 def Add(self, class_name: util.SymbolPath, method_name: str) -> None:
88 if class_name in self.method_defs:
89 self.method_defs[class_name].append(method_name)
90 else:
91 self.method_defs[class_name] = [method_name]
92
93 def ClassHasMethod(self, class_name: util.SymbolPath,
94 method_name: str) -> bool:
95 return (class_name in self.method_defs and
96 method_name in self.method_defs[class_name])
97
98
99class ClassNamespaces:
100
101 def __init__(self) -> None:
102 self.class_namespaces: ClassNamespaceDict = {}
103
104 def Set(self, class_name: util.SymbolPath, namespace_name: str) -> None:
105 self.class_namespaces[class_name] = namespace_name
106
107 def GetClassNamespace(self, class_name: util.SymbolPath) -> str:
108 return self.class_namespaces[class_name]
109
110
111class Collect(visitor.TypedVisitor):
112
113 def __init__(self, types: Dict[Expression, Type],
114 global_strings: GlobalStrings, method_defs: MethodDefs,
115 class_namespaces: ClassNamespaces) -> None:
116 visitor.TypedVisitor.__init__(self, types)
117 self.global_strings = global_strings
118
119 # Only generate unique strings.
120 # Before this optimization, _gen/bin/oils_for_unix.mycpp.cc went up to:
121 # "str2824"
122 # After:
123 # "str1789"
124 #
125 # So it saved over 1000 strings.
126 #
127 # The C++ compiler should also optimize it, but it's easy for us to
128 # generate less source code.
129
130 # unique string value -> id
131 self.unique: Dict[str, str] = {}
132 self.unique_id = 0
133
134 self.method_defs = method_defs
135 self.class_namespaces = class_namespaces
136
137 self.current_file_name: Optional[str] = None
138
139 def oils_visit_format_expr(self, left: Expression,
140 right: Expression) -> None:
141 if isinstance(left, StrExpr):
142 # Do NOT visit the left, because we write it literally
143 pass
144 else:
145 self.accept(left)
146 self.accept(right)
147
148 def visit_str_expr(self, o: StrExpr) -> None:
149 raw_string = format_strings.DecodeMyPyString(o.value)
150 self.global_strings.Add(o, raw_string)
151
152 def oils_visit_probe_call(self, o: CallExpr) -> None:
153 # Don't generate constants for DTRACE_PROBE()
154 pass
155
156 def oils_visit_log_call(self, fmt: StrExpr,
157 args: List[Expression]) -> None:
158 if len(args) == 0:
159 self.accept(fmt)
160 return
161
162 # Don't generate a string constant for the format string, which is an
163 # inlined C string, not a mycpp GC string
164 for i, arg in enumerate(args):
165 self.accept(arg)
166
167 def oils_visit_class_def(
168 self, o: ClassDef, base_class_sym: Optional[util.SymbolPath],
169 current_class_name: Optional[util.SymbolPath]) -> None:
170
171 for stmt in o.defs.body:
172 if isinstance(stmt, FuncDef):
173 self.method_defs.Add(current_class_name, stmt.name)
174 self.class_namespaces.Set(current_class_name,
175 self.current_file_name)
176 super().oils_visit_class_def(o, base_class_sym, current_class_name)
177
178 def oils_visit_mypy_file(self, o: MypyFile) -> None:
179 self.current_file_name = o.name
180 super().oils_visit_mypy_file(o)
181
182
183def _MakeUniqueStrings(all_strings: AllStrings) -> UniqueStrings:
184 """
185 Given all the strings, make a smaller set of unique strings.
186 """
187 unique: UniqueStrings = {}
188 for _, raw_string in all_strings.items():
189 b = raw_string.encode('utf-8')
190 h = hashlib.sha1(b).digest()
191 #print(repr(h))
192
193 if h in unique:
194 # extremely unlikely
195 assert unique[h] == raw_string, ("SHA1 hash collision! %r and %r" %
196 (unique[h], b))
197 unique[h] = raw_string
198 return unique
199
200
201def _ShortHash15(h: bytes) -> str:
202 """
203 Given a SHA1, create a 15 bit hash value.
204
205 We use three base-(2**5) aka base-32 digits, encoded as letters.
206 """
207 bits16 = h[0] | h[1] << 8
208
209 assert 0 <= bits16 < 2**16, bits16
210
211 # 5 least significant bits
212 d1 = bits16 & 0b11111
213 bits16 >>= 5
214 d2 = bits16 & 0b11111
215 bits16 >>= 5
216 d3 = bits16 & 0b11111
217 bits16 >>= 5
218
219 return _ALPHABET[d1] + _ALPHABET[d2] + _ALPHABET[d3]
220
221
222def _HashAndCollect(unique: UniqueStrings) -> HashedStrings:
223 """
224 Use the short hash.
225 """
226 hash15 = collections.defaultdict(list)
227 for sha1, b in unique.items():
228 short_hash = _ShortHash15(sha1)
229 hash15[short_hash].append(b)
230 return hash15
231
232
233def _SummarizeCollisions(hash15: HashedStrings) -> None:
234 collisions: Counter[int] = collections.Counter()
235 for short_hash, strs in hash15.items():
236 n = len(strs)
237 #if n > 1:
238 if 0:
239 print(short_hash)
240 print(strs)
241 collisions[n] += 1
242
243 log('%10s %s', 'COUNT', 'ITEM')
244 for item, count in collisions.most_common():
245 log('%10d %s', count, item)
246
247
248def _HandleCollisions(hash15: HashedStrings) -> VarNames:
249 var_names: VarNames = {}
250 for short_hash, bytes_list in hash15.items():
251 bytes_list.sort() # stable order, will bump some of the strings
252 for i, b in enumerate(bytes_list):
253 if i == 0:
254 var_names[b] = 'S_%s' % short_hash
255 else:
256 var_names[b] = 'S_%s_%d' % (short_hash, i)
257 return var_names
258
259
260def HashDemo() -> None:
261 import sys
262
263 # 5 bits
264 #_ALPHABET = _ALPHABET.replace('l', 'Z') # use a nicer one?
265 log('alpha %r', _ALPHABET)
266
267 global_strings = GlobalStrings()
268
269 all_lines = sys.stdin.readlines()
270 for i, line in enumerate(all_lines):
271 global_strings.Add(i, line.strip())
272
273 unique = _MakeUniqueStrings(global_strings.all_strings)
274 hash15 = _HashAndCollect(unique)
275 var_names = _HandleCollisions(hash15)
276
277 if 0:
278 for b, var_name in var_names.items():
279 if var_name[-1].isdigit():
280 log('%r %r', var_name, b)
281 #log('%r %r', var_name, b)
282
283 log('Unique %d' % len(unique))
284 log('hash15 %d' % len(hash15))
285
286 _SummarizeCollisions(hash15)
287
288
289if __name__ == '__main__':
290 HashDemo()