OILS / mycpp / const_pass.py View on Github | oilshell.org

234 lines, 135 significant
1"""
2const_pass.py - AST pass that collects string constants.
3
4Instead of emitting a dynamic allocation StrFromC("foo"), we emit a
5GLOBAL_STR(str99, "foo"), and then a reference to str99.
6"""
7import collections
8import json
9import hashlib
10import string
11
12from mypy.nodes import (Expression, StrExpr, CallExpr)
13from mypy.types import Type
14
15from mycpp import format_strings
16from mycpp import util
17from mycpp.util import log
18from mycpp import visitor
19
20from typing import Dict, List, Tuple, Counter, TextIO, Union
21
22_ = log
23
24_ALPHABET = string.ascii_lowercase + string.ascii_uppercase
25_ALPHABET = _ALPHABET[:32]
26
27AllStrings = Dict[Union[int, StrExpr], str] # Node -> raw string
28UniqueStrings = Dict[bytes, str] # SHA1 digest -> raw string
29HashedStrings = Dict[str, List[str]] # short hash -> raw string
30VarNames = Dict[str, str] # raw string -> variable name
31
32
33class GlobalStrings:
34
35 def __init__(self) -> None:
36 # SHA1 hash -> encoded bytes
37 self.all_strings: AllStrings = {}
38 self.var_names: VarNames = {}
39
40 # OLD
41 self.unique: Dict[bytes, bytes] = {}
42 self.int_id_lookup: Dict[Expression, str] = {}
43 self.pairs: List[Tuple[str, str]] = []
44
45 def Add(self, key: Union[int, StrExpr], s: str) -> None:
46 """
47 key: int for tests
48 StrExpr node for production
49 """
50 self.all_strings[key] = s
51
52 def ComputeStableVarNames(self) -> None:
53 unique = _MakeUniqueStrings(self.all_strings)
54 hash15 = _HashAndCollect(unique)
55 self.var_names = _HandleCollisions(hash15)
56
57 def GetVarName(self, node: StrExpr) -> str:
58 # StrExpr -> str -> variable names
59 return self.var_names[self.all_strings[node]]
60
61 def WriteConstants(self, out_f: TextIO) -> None:
62 if util.SMALL_STR:
63 macro_name = 'GLOBAL_STR2'
64 else:
65 macro_name = 'GLOBAL_STR'
66
67 # sort by the string value itself
68 for raw_string in sorted(self.var_names):
69 var_name = self.var_names[raw_string]
70 out_f.write('%s(%s, %s);\n' %
71 (macro_name, var_name, json.dumps(raw_string)))
72
73 out_f.write('\n')
74
75
76class Collect(visitor.TypedVisitor):
77
78 def __init__(self, types: Dict[Expression, Type],
79 global_strings: GlobalStrings) -> None:
80 visitor.TypedVisitor.__init__(self, types)
81 self.global_strings = global_strings
82
83 # Only generate unique strings.
84 # Before this optimization, _gen/bin/oils_for_unix.mycpp.cc went up to:
85 # "str2824"
86 # After:
87 # "str1789"
88 #
89 # So it saved over 1000 strings.
90 #
91 # The C++ compiler should also optimize it, but it's easy for us to
92 # generate less source code.
93
94 # unique string value -> id
95 self.unique: Dict[str, str] = {}
96 self.unique_id = 0
97
98 def oils_visit_format_expr(self, left: Expression,
99 right: Expression) -> None:
100 if isinstance(left, StrExpr):
101 # Do NOT visit the left, because we write it literally
102 pass
103 else:
104 self.accept(left)
105 self.accept(right)
106
107 def visit_str_expr(self, o: StrExpr) -> None:
108 raw_string = format_strings.DecodeMyPyString(o.value)
109 self.global_strings.Add(o, raw_string)
110
111 def oils_visit_probe_call(self, o: CallExpr) -> None:
112 # Don't generate constants for DTRACE_PROBE()
113 pass
114
115 def oils_visit_log_call(self, fmt: StrExpr,
116 args: List[Expression]) -> None:
117 if len(args) == 0:
118 self.accept(fmt)
119 return
120
121 # Don't generate a string constant for the format string, which is an
122 # inlined C string, not a mycpp GC string
123 for i, arg in enumerate(args):
124 self.accept(arg)
125
126
127def _MakeUniqueStrings(all_strings: AllStrings) -> UniqueStrings:
128 """
129 Given all the strings, make a smaller set of unique strings.
130 """
131 unique: UniqueStrings = {}
132 for _, raw_string in all_strings.items():
133 b = raw_string.encode('utf-8')
134 h = hashlib.sha1(b).digest()
135 #print(repr(h))
136
137 if h in unique:
138 # extremely unlikely
139 assert unique[h] == raw_string, ("SHA1 hash collision! %r and %r" %
140 (unique[h], b))
141 unique[h] = raw_string
142 return unique
143
144
145def _ShortHash15(h: bytes) -> str:
146 """
147 Given a SHA1, create a 15 bit hash value.
148
149 We use three base-(2**5) aka base-32 digits, encoded as letters.
150 """
151 bits16 = h[0] | h[1] << 8
152
153 assert 0 <= bits16 < 2**16, bits16
154
155 # 5 least significant bits
156 d1 = bits16 & 0b11111
157 bits16 >>= 5
158 d2 = bits16 & 0b11111
159 bits16 >>= 5
160 d3 = bits16 & 0b11111
161 bits16 >>= 5
162
163 return _ALPHABET[d1] + _ALPHABET[d2] + _ALPHABET[d3]
164
165
166def _HashAndCollect(unique: UniqueStrings) -> HashedStrings:
167 """
168 Use the short hash.
169 """
170 hash15 = collections.defaultdict(list)
171 for sha1, b in unique.items():
172 short_hash = _ShortHash15(sha1)
173 hash15[short_hash].append(b)
174 return hash15
175
176
177def _SummarizeCollisions(hash15: HashedStrings) -> None:
178 collisions: Counter[int] = collections.Counter()
179 for short_hash, strs in hash15.items():
180 n = len(strs)
181 #if n > 1:
182 if 0:
183 print(short_hash)
184 print(strs)
185 collisions[n] += 1
186
187 log('%10s %s', 'COUNT', 'ITEM')
188 for item, count in collisions.most_common():
189 log('%10d %s', count, item)
190
191
192def _HandleCollisions(hash15: HashedStrings) -> VarNames:
193 var_names: VarNames = {}
194 for short_hash, bytes_list in hash15.items():
195 bytes_list.sort() # stable order, will bump some of the strings
196 for i, b in enumerate(bytes_list):
197 if i == 0:
198 var_names[b] = 'S_%s' % short_hash
199 else:
200 var_names[b] = 'S_%s_%d' % (short_hash, i)
201 return var_names
202
203
204def HashDemo() -> None:
205 import sys
206
207 # 5 bits
208 #_ALPHABET = _ALPHABET.replace('l', 'Z') # use a nicer one?
209 log('alpha %r', _ALPHABET)
210
211 global_strings = GlobalStrings()
212
213 all_lines = sys.stdin.readlines()
214 for i, line in enumerate(all_lines):
215 global_strings.Add(i, line.strip())
216
217 unique = _MakeUniqueStrings(global_strings.all_strings)
218 hash15 = _HashAndCollect(unique)
219 var_names = _HandleCollisions(hash15)
220
221 if 0:
222 for b, var_name in var_names.items():
223 if var_name[-1].isdigit():
224 log('%r %r', var_name, b)
225 #log('%r %r', var_name, b)
226
227 log('Unique %d' % len(unique))
228 log('hash15 %d' % len(hash15))
229
230 _SummarizeCollisions(hash15)
231
232
233if __name__ == '__main__':
234 HashDemo()