1 | """
2 | const_pass.py - AST pass that collects string constants.
3 |
4 | Instead of emitting a dynamic allocation StrFromC("foo"), we emit a
5 | GLOBAL_STR(str99, "foo"), and then a reference to str99.
6 | """
7 | import collections
8 | import json
9 | import hashlib
10 | import string
11 |
12 | from mypy.nodes import (Expression, StrExpr, CallExpr)
13 | from mypy.types import Type
14 |
15 | from mycpp import format_strings
16 | from mycpp import util
17 | from mycpp.util import log
18 | from mycpp import visitor
19 |
20 | from typing import Dict, List, Tuple, Counter, TextIO, Union
21 |
22 | _ = log
23 |
24 | _ALPHABET = string.ascii_lowercase + string.ascii_uppercase
26 |
27 | AllStrings = Dict[Union[int, StrExpr], str] # Node -> raw string
28 | UniqueStrings = Dict[bytes, str] # SHA1 digest -> raw string
29 | HashedStrings = Dict[str, List[str]] # short hash -> raw string
30 | VarNames = Dict[str, str] # raw string -> variable name
31 |
32 |
33 | class GlobalStrings:
34 |
35 | def __init__(self) -> None:
36 | # SHA1 hash -> encoded bytes
37 | self.all_strings: AllStrings = {}
38 | self.var_names: VarNames = {}
39 |
40 | # OLD
41 | self.unique: Dict[bytes, bytes] = {}
42 | self.int_id_lookup: Dict[Expression, str] = {}
43 | self.pairs: List[Tuple[str, str]] = []
44 |
45 | def Add(self, key: Union[int, StrExpr], s: str) -> None:
46 | """
47 | key: int for tests
48 | StrExpr node for production
49 | """
50 | self.all_strings[key] = s
51 |
52 | def ComputeStableVarNames(self) -> None:
53 | unique = _MakeUniqueStrings(self.all_strings)
54 | hash15 = _HashAndCollect(unique)
55 | self.var_names = _HandleCollisions(hash15)
56 |
57 | def GetVarName(self, node: StrExpr) -> str:
58 | # StrExpr -> str -> variable names
59 | return self.var_names[self.all_strings[node]]
60 |
61 | def WriteConstants(self, out_f: TextIO) -> None:
62 | if util.SMALL_STR:
63 | macro_name = 'GLOBAL_STR2'
64 | else:
65 | macro_name = 'GLOBAL_STR'
66 |
67 | # sort by the string value itself
68 | for raw_string in sorted(self.var_names):
69 | var_name = self.var_names[raw_string]
70 | out_f.write('%s(%s, %s);\n' %
71 | (macro_name, var_name, json.dumps(raw_string)))
72 |
73 | out_f.write('\n')
74 |
75 |
76 | class Collect(visitor.TypedVisitor):
77 |
78 | def __init__(self, types: Dict[Expression, Type],
79 | global_strings: GlobalStrings) -> None:
80 | visitor.TypedVisitor.__init__(self, types)
81 | self.global_strings = global_strings
82 |
83 | # Only generate unique strings.
84 | # Before this optimization, _gen/bin/oils_for_unix.mycpp.cc went up to:
85 | # "str2824"
86 | # After:
87 | # "str1789"
88 | #
89 | # So it saved over 1000 strings.
90 | #
91 | # The C++ compiler should also optimize it, but it's easy for us to
92 | # generate less source code.
93 |
94 | # unique string value -> id
95 | self.unique: Dict[str, str] = {}
96 | self.unique_id = 0
97 |
98 | def oils_visit_format_expr(self, left: Expression,
99 | right: Expression) -> None:
100 | if isinstance(left, StrExpr):
101 | # Do NOT visit the left, because we write it literally
102 | pass
103 | else:
104 | self.accept(left)
105 | self.accept(right)
106 |
107 | def visit_str_expr(self, o: StrExpr) -> None:
108 | raw_string = format_strings.DecodeMyPyString(o.value)
109 | self.global_strings.Add(o, raw_string)
110 |
111 | def oils_visit_probe_call(self, o: CallExpr) -> None:
112 | # Don't generate constants for DTRACE_PROBE()
113 | pass
114 |
115 | def oils_visit_log_call(self, fmt: StrExpr,
116 | args: List[Expression]) -> None:
117 | if len(args) == 0:
118 | self.accept(fmt)
119 | return
120 |
121 | # Don't generate a string constant for the format string, which is an
122 | # inlined C string, not a mycpp GC string
123 | for i, arg in enumerate(args):
124 | self.accept(arg)
125 |
126 |
127 | def _MakeUniqueStrings(all_strings: AllStrings) -> UniqueStrings:
128 | """
129 | Given all the strings, make a smaller set of unique strings.
130 | """
131 | unique: UniqueStrings = {}
132 | for _, raw_string in all_strings.items():
133 | b = raw_string.encode('utf-8')
134 | h = hashlib.sha1(b).digest()
135 | #print(repr(h))
136 |
137 | if h in unique:
138 | # extremely unlikely
139 | assert unique[h] == raw_string, ("SHA1 hash collision! %r and %r" %
140 | (unique[h], b))
141 | unique[h] = raw_string
142 | return unique
143 |
144 |
145 | def _ShortHash15(h: bytes) -> str:
146 | """
147 | Given a SHA1, create a 15 bit hash value.
148 |
149 | We use three base-(2**5) aka base-32 digits, encoded as letters.
150 | """
151 | bits16 = h[0] | h[1] << 8
152 |
153 | assert 0 <= bits16 < 2**16, bits16
154 |
155 | # 5 least significant bits
156 | d1 = bits16 & 0b11111
157 | bits16 >>= 5
158 | d2 = bits16 & 0b11111
159 | bits16 >>= 5
160 | d3 = bits16 & 0b11111
161 | bits16 >>= 5
162 |
163 | return _ALPHABET[d1] + _ALPHABET[d2] + _ALPHABET[d3]
164 |
165 |
166 | def _HashAndCollect(unique: UniqueStrings) -> HashedStrings:
167 | """
168 | Use the short hash.
169 | """
170 | hash15 = collections.defaultdict(list)
171 | for sha1, b in unique.items():
172 | short_hash = _ShortHash15(sha1)
173 | hash15[short_hash].append(b)
174 | return hash15
175 |
176 |
177 | def _SummarizeCollisions(hash15: HashedStrings) -> None:
178 | collisions: Counter[int] = collections.Counter()
179 | for short_hash, strs in hash15.items():
180 | n = len(strs)
181 | #if n > 1:
182 | if 0:
183 | print(short_hash)
184 | print(strs)
185 | collisions[n] += 1
186 |
187 | log('%10s %s', 'COUNT', 'ITEM')
188 | for item, count in collisions.most_common():
189 | log('%10d %s', count, item)
190 |
191 |
192 | def _HandleCollisions(hash15: HashedStrings) -> VarNames:
193 | var_names: VarNames = {}
194 | for short_hash, bytes_list in hash15.items():
195 | bytes_list.sort() # stable order, will bump some of the strings
196 | for i, b in enumerate(bytes_list):
197 | if i == 0:
198 | var_names[b] = 'S_%s' % short_hash
199 | else:
200 | var_names[b] = 'S_%s_%d' % (short_hash, i)
201 | return var_names
202 |
203 |
204 | def HashDemo() -> None:
205 | import sys
206 |
207 | # 5 bits
208 | #_ALPHABET = _ALPHABET.replace('l', 'Z') # use a nicer one?
209 | log('alpha %r', _ALPHABET)
210 |
211 | global_strings = GlobalStrings()
212 |
213 | all_lines = sys.stdin.readlines()
214 | for i, line in enumerate(all_lines):
215 | global_strings.Add(i, line.strip())
216 |
217 | unique = _MakeUniqueStrings(global_strings.all_strings)
218 | hash15 = _HashAndCollect(unique)
219 | var_names = _HandleCollisions(hash15)
220 |
221 | if 0:
222 | for b, var_name in var_names.items():
223 | if var_name[-1].isdigit():
224 | log('%r %r', var_name, b)
225 | #log('%r %r', var_name, b)
226 |
227 | log('Unique %d' % len(unique))
228 | log('hash15 %d' % len(hash15))
229 |
230 | _SummarizeCollisions(hash15)
231 |
232 |
233 | if __name__ == '__main__':
234 | HashDemo()