1 | """
|
2 | const_pass.py - AST pass that collects string constants.
|
3 |
|
4 | Instead of emitting a dynamic allocation StrFromC("foo"), we emit a
|
5 | GLOBAL_STR(str99, "foo"), and then a reference to str99.
|
6 | """
|
7 | import collections
|
8 | import json
|
9 | import hashlib
|
10 | import string
|
11 |
|
12 | from mypy.nodes import (Expression, StrExpr, CallExpr)
|
13 | from mypy.types import Type
|
14 |
|
15 | from mycpp import format_strings
|
16 | from mycpp import util
|
17 | from mycpp.util import log
|
18 | from mycpp import visitor
|
19 |
|
20 | from typing import Dict, List, Tuple, Counter, TextIO, Union
|
21 |
|
22 | _ = log
|
23 |
|
24 | _ALPHABET = string.ascii_lowercase + string.ascii_uppercase
|
25 | _ALPHABET = _ALPHABET[:32]
|
26 |
|
27 | AllStrings = Dict[Union[int, StrExpr], str] # Node -> raw string
|
28 | UniqueStrings = Dict[bytes, str] # SHA1 digest -> raw string
|
29 | HashedStrings = Dict[str, List[str]] # short hash -> raw string
|
30 | VarNames = Dict[str, str] # raw string -> variable name
|
31 |
|
32 |
|
33 | class GlobalStrings:
|
34 |
|
35 | def __init__(self) -> None:
|
36 | # SHA1 hash -> encoded bytes
|
37 | self.all_strings: AllStrings = {}
|
38 | self.var_names: VarNames = {}
|
39 |
|
40 | # OLD
|
41 | self.unique: Dict[bytes, bytes] = {}
|
42 | self.int_id_lookup: Dict[Expression, str] = {}
|
43 | self.pairs: List[Tuple[str, str]] = []
|
44 |
|
45 | def Add(self, key: Union[int, StrExpr], s: str) -> None:
|
46 | """
|
47 | key: int for tests
|
48 | StrExpr node for production
|
49 | """
|
50 | self.all_strings[key] = s
|
51 |
|
52 | def ComputeStableVarNames(self) -> None:
|
53 | unique = _MakeUniqueStrings(self.all_strings)
|
54 | hash15 = _HashAndCollect(unique)
|
55 | self.var_names = _HandleCollisions(hash15)
|
56 |
|
57 | def GetVarName(self, node: StrExpr) -> str:
|
58 | # StrExpr -> str -> variable names
|
59 | return self.var_names[self.all_strings[node]]
|
60 |
|
61 | def WriteConstants(self, out_f: TextIO) -> None:
|
62 | if util.SMALL_STR:
|
63 | macro_name = 'GLOBAL_STR2'
|
64 | else:
|
65 | macro_name = 'GLOBAL_STR'
|
66 |
|
67 | # sort by the string value itself
|
68 | for raw_string in sorted(self.var_names):
|
69 | var_name = self.var_names[raw_string]
|
70 | out_f.write('%s(%s, %s);\n' %
|
71 | (macro_name, var_name, json.dumps(raw_string)))
|
72 |
|
73 | out_f.write('\n')
|
74 |
|
75 |
|
76 | class Collect(visitor.TypedVisitor):
|
77 |
|
78 | def __init__(self, types: Dict[Expression, Type],
|
79 | global_strings: GlobalStrings) -> None:
|
80 | visitor.TypedVisitor.__init__(self, types)
|
81 | self.global_strings = global_strings
|
82 |
|
83 | # Only generate unique strings.
|
84 | # Before this optimization, _gen/bin/oils_for_unix.mycpp.cc went up to:
|
85 | # "str2824"
|
86 | # After:
|
87 | # "str1789"
|
88 | #
|
89 | # So it saved over 1000 strings.
|
90 | #
|
91 | # The C++ compiler should also optimize it, but it's easy for us to
|
92 | # generate less source code.
|
93 |
|
94 | # unique string value -> id
|
95 | self.unique: Dict[str, str] = {}
|
96 | self.unique_id = 0
|
97 |
|
98 | def oils_visit_format_expr(self, left: Expression,
|
99 | right: Expression) -> None:
|
100 | if isinstance(left, StrExpr):
|
101 | # Do NOT visit the left, because we write it literally
|
102 | pass
|
103 | else:
|
104 | self.accept(left)
|
105 | self.accept(right)
|
106 |
|
107 | def visit_str_expr(self, o: StrExpr) -> None:
|
108 | raw_string = format_strings.DecodeMyPyString(o.value)
|
109 | self.global_strings.Add(o, raw_string)
|
110 |
|
111 | def oils_visit_probe_call(self, o: CallExpr) -> None:
|
112 | # Don't generate constants for DTRACE_PROBE()
|
113 | pass
|
114 |
|
115 | def oils_visit_log_call(self, fmt: StrExpr,
|
116 | args: List[Expression]) -> None:
|
117 | if len(args) == 0:
|
118 | self.accept(fmt)
|
119 | return
|
120 |
|
121 | # Don't generate a string constant for the format string, which is an
|
122 | # inlined C string, not a mycpp GC string
|
123 | for i, arg in enumerate(args):
|
124 | self.accept(arg)
|
125 |
|
126 |
|
127 | def _MakeUniqueStrings(all_strings: AllStrings) -> UniqueStrings:
|
128 | """
|
129 | Given all the strings, make a smaller set of unique strings.
|
130 | """
|
131 | unique: UniqueStrings = {}
|
132 | for _, raw_string in all_strings.items():
|
133 | b = raw_string.encode('utf-8')
|
134 | h = hashlib.sha1(b).digest()
|
135 | #print(repr(h))
|
136 |
|
137 | if h in unique:
|
138 | # extremely unlikely
|
139 | assert unique[h] == raw_string, ("SHA1 hash collision! %r and %r" %
|
140 | (unique[h], b))
|
141 | unique[h] = raw_string
|
142 | return unique
|
143 |
|
144 |
|
145 | def _ShortHash15(h: bytes) -> str:
|
146 | """
|
147 | Given a SHA1, create a 15 bit hash value.
|
148 |
|
149 | We use three base-(2**5) aka base-32 digits, encoded as letters.
|
150 | """
|
151 | bits16 = h[0] | h[1] << 8
|
152 |
|
153 | assert 0 <= bits16 < 2**16, bits16
|
154 |
|
155 | # 5 least significant bits
|
156 | d1 = bits16 & 0b11111
|
157 | bits16 >>= 5
|
158 | d2 = bits16 & 0b11111
|
159 | bits16 >>= 5
|
160 | d3 = bits16 & 0b11111
|
161 | bits16 >>= 5
|
162 |
|
163 | return _ALPHABET[d1] + _ALPHABET[d2] + _ALPHABET[d3]
|
164 |
|
165 |
|
166 | def _HashAndCollect(unique: UniqueStrings) -> HashedStrings:
|
167 | """
|
168 | Use the short hash.
|
169 | """
|
170 | hash15 = collections.defaultdict(list)
|
171 | for sha1, b in unique.items():
|
172 | short_hash = _ShortHash15(sha1)
|
173 | hash15[short_hash].append(b)
|
174 | return hash15
|
175 |
|
176 |
|
177 | def _SummarizeCollisions(hash15: HashedStrings) -> None:
|
178 | collisions: Counter[int] = collections.Counter()
|
179 | for short_hash, strs in hash15.items():
|
180 | n = len(strs)
|
181 | #if n > 1:
|
182 | if 0:
|
183 | print(short_hash)
|
184 | print(strs)
|
185 | collisions[n] += 1
|
186 |
|
187 | log('%10s %s', 'COUNT', 'ITEM')
|
188 | for item, count in collisions.most_common():
|
189 | log('%10d %s', count, item)
|
190 |
|
191 |
|
192 | def _HandleCollisions(hash15: HashedStrings) -> VarNames:
|
193 | var_names: VarNames = {}
|
194 | for short_hash, bytes_list in hash15.items():
|
195 | bytes_list.sort() # stable order, will bump some of the strings
|
196 | for i, b in enumerate(bytes_list):
|
197 | if i == 0:
|
198 | var_names[b] = 'S_%s' % short_hash
|
199 | else:
|
200 | var_names[b] = 'S_%s_%d' % (short_hash, i)
|
201 | return var_names
|
202 |
|
203 |
|
204 | def HashDemo() -> None:
|
205 | import sys
|
206 |
|
207 | # 5 bits
|
208 | #_ALPHABET = _ALPHABET.replace('l', 'Z') # use a nicer one?
|
209 | log('alpha %r', _ALPHABET)
|
210 |
|
211 | global_strings = GlobalStrings()
|
212 |
|
213 | all_lines = sys.stdin.readlines()
|
214 | for i, line in enumerate(all_lines):
|
215 | global_strings.Add(i, line.strip())
|
216 |
|
217 | unique = _MakeUniqueStrings(global_strings.all_strings)
|
218 | hash15 = _HashAndCollect(unique)
|
219 | var_names = _HandleCollisions(hash15)
|
220 |
|
221 | if 0:
|
222 | for b, var_name in var_names.items():
|
223 | if var_name[-1].isdigit():
|
224 | log('%r %r', var_name, b)
|
225 | #log('%r %r', var_name, b)
|
226 |
|
227 | log('Unique %d' % len(unique))
|
228 | log('hash15 %d' % len(hash15))
|
229 |
|
230 | _SummarizeCollisions(hash15)
|
231 |
|
232 |
|
233 | if __name__ == '__main__':
|
234 | HashDemo()
|