| 1 | """
|
| 2 | const_pass.py - AST pass that collects string constants.
|
| 3 |
|
| 4 | Instead of emitting a dynamic allocation StrFromC("foo"), we emit a
|
| 5 | GLOBAL_STR(str99, "foo"), and then a reference to str99.
|
| 6 | """
|
| 7 | import collections
|
| 8 | import json
|
| 9 | import hashlib
|
| 10 | import string
|
| 11 |
|
| 12 | from mypy.nodes import (Expression, StrExpr, CallExpr, FuncDef, ClassDef,
|
| 13 | MypyFile)
|
| 14 | from mypy.types import Type
|
| 15 |
|
| 16 | from mycpp import format_strings
|
| 17 | from mycpp import util
|
| 18 | from mycpp.util import log
|
| 19 | from mycpp import visitor
|
| 20 |
|
| 21 | from typing import Dict, List, Tuple, Counter, TextIO, Union, Optional
|
| 22 |
|
| 23 | _ = log
|
| 24 |
|
| 25 | _ALPHABET = string.ascii_lowercase + string.ascii_uppercase
|
| 26 | _ALPHABET = _ALPHABET[:32]
|
| 27 |
|
| 28 | AllStrings = Dict[Union[int, StrExpr], str] # Node -> raw string
|
| 29 | UniqueStrings = Dict[bytes, str] # SHA1 digest -> raw string
|
| 30 | HashedStrings = Dict[str, List[str]] # short hash -> raw string
|
| 31 | VarNames = Dict[str, str] # raw string -> variable name
|
| 32 |
|
| 33 | MethodDefinitions = Dict[util.SymbolPath,
|
| 34 | List[str]] # Class name -> List of method names
|
| 35 |
|
| 36 | ClassNamespaceDict = Dict[util.SymbolPath, str] # Class name -> Namespace name
|
| 37 |
|
| 38 |
|
| 39 | class GlobalStrings:
|
| 40 |
|
| 41 | def __init__(self) -> None:
|
| 42 | # SHA1 hash -> encoded bytes
|
| 43 | self.all_strings: AllStrings = {}
|
| 44 | self.var_names: VarNames = {}
|
| 45 |
|
| 46 | # OLD
|
| 47 | self.unique: Dict[bytes, bytes] = {}
|
| 48 | self.int_id_lookup: Dict[Expression, str] = {}
|
| 49 | self.pairs: List[Tuple[str, str]] = []
|
| 50 |
|
| 51 | def Add(self, key: Union[int, StrExpr], s: str) -> None:
|
| 52 | """
|
| 53 | key: int for tests
|
| 54 | StrExpr node for production
|
| 55 | """
|
| 56 | self.all_strings[key] = s
|
| 57 |
|
| 58 | def ComputeStableVarNames(self) -> None:
|
| 59 | unique = _MakeUniqueStrings(self.all_strings)
|
| 60 | hash15 = _HashAndCollect(unique)
|
| 61 | self.var_names = _HandleCollisions(hash15)
|
| 62 |
|
| 63 | def GetVarName(self, node: StrExpr) -> str:
|
| 64 | # StrExpr -> str -> variable names
|
| 65 | return self.var_names[self.all_strings[node]]
|
| 66 |
|
| 67 | def WriteConstants(self, out_f: TextIO) -> None:
|
| 68 | if util.SMALL_STR:
|
| 69 | macro_name = 'GLOBAL_STR2'
|
| 70 | else:
|
| 71 | macro_name = 'GLOBAL_STR'
|
| 72 |
|
| 73 | # sort by the string value itself
|
| 74 | for raw_string in sorted(self.var_names):
|
| 75 | var_name = self.var_names[raw_string]
|
| 76 | out_f.write('%s(%s, %s);\n' %
|
| 77 | (macro_name, var_name, json.dumps(raw_string)))
|
| 78 |
|
| 79 | out_f.write('\n')
|
| 80 |
|
| 81 |
|
| 82 | class MethodDefs:
|
| 83 |
|
| 84 | def __init__(self) -> None:
|
| 85 | self.method_defs: MethodDefinitions = {}
|
| 86 |
|
| 87 | def Add(self, class_name: util.SymbolPath, method_name: str) -> None:
|
| 88 | if class_name in self.method_defs:
|
| 89 | self.method_defs[class_name].append(method_name)
|
| 90 | else:
|
| 91 | self.method_defs[class_name] = [method_name]
|
| 92 |
|
| 93 | def ClassHasMethod(self, class_name: util.SymbolPath,
|
| 94 | method_name: str) -> bool:
|
| 95 | return (class_name in self.method_defs and
|
| 96 | method_name in self.method_defs[class_name])
|
| 97 |
|
| 98 |
|
| 99 | class ClassNamespaces:
|
| 100 |
|
| 101 | def __init__(self) -> None:
|
| 102 | self.class_namespaces: ClassNamespaceDict = {}
|
| 103 |
|
| 104 | def Set(self, class_name: util.SymbolPath, namespace_name: str) -> None:
|
| 105 | self.class_namespaces[class_name] = namespace_name
|
| 106 |
|
| 107 | def GetClassNamespace(self, class_name: util.SymbolPath) -> str:
|
| 108 | return self.class_namespaces[class_name]
|
| 109 |
|
| 110 |
|
| 111 | class Collect(visitor.TypedVisitor):
|
| 112 |
|
| 113 | def __init__(self, types: Dict[Expression, Type],
|
| 114 | global_strings: GlobalStrings, method_defs: MethodDefs,
|
| 115 | class_namespaces: ClassNamespaces) -> None:
|
| 116 | visitor.TypedVisitor.__init__(self, types)
|
| 117 | self.global_strings = global_strings
|
| 118 |
|
| 119 | # Only generate unique strings.
|
| 120 | # Before this optimization, _gen/bin/oils_for_unix.mycpp.cc went up to:
|
| 121 | # "str2824"
|
| 122 | # After:
|
| 123 | # "str1789"
|
| 124 | #
|
| 125 | # So it saved over 1000 strings.
|
| 126 | #
|
| 127 | # The C++ compiler should also optimize it, but it's easy for us to
|
| 128 | # generate less source code.
|
| 129 |
|
| 130 | # unique string value -> id
|
| 131 | self.unique: Dict[str, str] = {}
|
| 132 | self.unique_id = 0
|
| 133 |
|
| 134 | self.method_defs = method_defs
|
| 135 | self.class_namespaces = class_namespaces
|
| 136 |
|
| 137 | self.current_file_name: Optional[str] = None
|
| 138 |
|
| 139 | def verify_format_string(self, fmt: StrExpr) -> None:
|
| 140 | try:
|
| 141 | format_strings.Parse(fmt.value)
|
| 142 | except RuntimeError as e:
|
| 143 | self.report_error(fmt, str(e))
|
| 144 |
|
| 145 | def oils_visit_format_expr(self, left: Expression,
|
| 146 | right: Expression) -> None:
|
| 147 | if isinstance(left, StrExpr):
|
| 148 | self.verify_format_string(left)
|
| 149 | # Do NOT visit the left, because we write it literally
|
| 150 | pass
|
| 151 | else:
|
| 152 | self.accept(left)
|
| 153 | self.accept(right)
|
| 154 |
|
| 155 | def visit_str_expr(self, o: StrExpr) -> None:
|
| 156 | raw_string = format_strings.DecodeMyPyString(o.value)
|
| 157 | self.global_strings.Add(o, raw_string)
|
| 158 |
|
| 159 | def oils_visit_probe_call(self, o: CallExpr) -> None:
|
| 160 | # Don't generate constants for DTRACE_PROBE()
|
| 161 | pass
|
| 162 |
|
| 163 | def oils_visit_log_call(self, fmt: StrExpr,
|
| 164 | args: List[Expression]) -> None:
|
| 165 | self.verify_format_string(fmt)
|
| 166 | if len(args) == 0:
|
| 167 | self.accept(fmt)
|
| 168 | return
|
| 169 |
|
| 170 | # Don't generate a string constant for the format string, which is an
|
| 171 | # inlined C string, not a mycpp GC string
|
| 172 | for i, arg in enumerate(args):
|
| 173 | self.accept(arg)
|
| 174 |
|
| 175 | def oils_visit_class_def(
|
| 176 | self, o: ClassDef, base_class_sym: Optional[util.SymbolPath],
|
| 177 | current_class_name: Optional[util.SymbolPath]) -> None:
|
| 178 |
|
| 179 | for stmt in o.defs.body:
|
| 180 | if isinstance(stmt, FuncDef):
|
| 181 | self.method_defs.Add(current_class_name, stmt.name)
|
| 182 | self.class_namespaces.Set(current_class_name,
|
| 183 | self.current_file_name)
|
| 184 | super().oils_visit_class_def(o, base_class_sym, current_class_name)
|
| 185 |
|
| 186 | def oils_visit_mypy_file(self, o: MypyFile) -> None:
|
| 187 | self.current_file_name = o.name
|
| 188 | super().oils_visit_mypy_file(o)
|
| 189 |
|
| 190 |
|
| 191 | def _MakeUniqueStrings(all_strings: AllStrings) -> UniqueStrings:
|
| 192 | """
|
| 193 | Given all the strings, make a smaller set of unique strings.
|
| 194 | """
|
| 195 | unique: UniqueStrings = {}
|
| 196 | for _, raw_string in all_strings.items():
|
| 197 | b = raw_string.encode('utf-8')
|
| 198 | h = hashlib.sha1(b).digest()
|
| 199 | #print(repr(h))
|
| 200 |
|
| 201 | if h in unique:
|
| 202 | # extremely unlikely
|
| 203 | assert unique[h] == raw_string, ("SHA1 hash collision! %r and %r" %
|
| 204 | (unique[h], b))
|
| 205 | unique[h] = raw_string
|
| 206 | return unique
|
| 207 |
|
| 208 |
|
| 209 | def _ShortHash15(h: bytes) -> str:
|
| 210 | """
|
| 211 | Given a SHA1, create a 15 bit hash value.
|
| 212 |
|
| 213 | We use three base-(2**5) aka base-32 digits, encoded as letters.
|
| 214 | """
|
| 215 | bits16 = h[0] | h[1] << 8
|
| 216 |
|
| 217 | assert 0 <= bits16 < 2**16, bits16
|
| 218 |
|
| 219 | # 5 least significant bits
|
| 220 | d1 = bits16 & 0b11111
|
| 221 | bits16 >>= 5
|
| 222 | d2 = bits16 & 0b11111
|
| 223 | bits16 >>= 5
|
| 224 | d3 = bits16 & 0b11111
|
| 225 | bits16 >>= 5
|
| 226 |
|
| 227 | return _ALPHABET[d1] + _ALPHABET[d2] + _ALPHABET[d3]
|
| 228 |
|
| 229 |
|
| 230 | def _HashAndCollect(unique: UniqueStrings) -> HashedStrings:
|
| 231 | """
|
| 232 | Use the short hash.
|
| 233 | """
|
| 234 | hash15 = collections.defaultdict(list)
|
| 235 | for sha1, b in unique.items():
|
| 236 | short_hash = _ShortHash15(sha1)
|
| 237 | hash15[short_hash].append(b)
|
| 238 | return hash15
|
| 239 |
|
| 240 |
|
| 241 | def _SummarizeCollisions(hash15: HashedStrings) -> None:
|
| 242 | collisions: Counter[int] = collections.Counter()
|
| 243 | for short_hash, strs in hash15.items():
|
| 244 | n = len(strs)
|
| 245 | #if n > 1:
|
| 246 | if 0:
|
| 247 | print(short_hash)
|
| 248 | print(strs)
|
| 249 | collisions[n] += 1
|
| 250 |
|
| 251 | log('%10s %s', 'COUNT', 'ITEM')
|
| 252 | for item, count in collisions.most_common():
|
| 253 | log('%10d %s', count, item)
|
| 254 |
|
| 255 |
|
| 256 | def _HandleCollisions(hash15: HashedStrings) -> VarNames:
|
| 257 | var_names: VarNames = {}
|
| 258 | for short_hash, bytes_list in hash15.items():
|
| 259 | bytes_list.sort() # stable order, will bump some of the strings
|
| 260 | for i, b in enumerate(bytes_list):
|
| 261 | if i == 0:
|
| 262 | var_names[b] = 'S_%s' % short_hash
|
| 263 | else:
|
| 264 | var_names[b] = 'S_%s_%d' % (short_hash, i)
|
| 265 | return var_names
|
| 266 |
|
| 267 |
|
| 268 | def HashDemo() -> None:
|
| 269 | import sys
|
| 270 |
|
| 271 | # 5 bits
|
| 272 | #_ALPHABET = _ALPHABET.replace('l', 'Z') # use a nicer one?
|
| 273 | log('alpha %r', _ALPHABET)
|
| 274 |
|
| 275 | global_strings = GlobalStrings()
|
| 276 |
|
| 277 | all_lines = sys.stdin.readlines()
|
| 278 | for i, line in enumerate(all_lines):
|
| 279 | global_strings.Add(i, line.strip())
|
| 280 |
|
| 281 | unique = _MakeUniqueStrings(global_strings.all_strings)
|
| 282 | hash15 = _HashAndCollect(unique)
|
| 283 | var_names = _HandleCollisions(hash15)
|
| 284 |
|
| 285 | if 0:
|
| 286 | for b, var_name in var_names.items():
|
| 287 | if var_name[-1].isdigit():
|
| 288 | log('%r %r', var_name, b)
|
| 289 | #log('%r %r', var_name, b)
|
| 290 |
|
| 291 | log('Unique %d' % len(unique))
|
| 292 | log('hash15 %d' % len(hash15))
|
| 293 |
|
| 294 | _SummarizeCollisions(hash15)
|
| 295 |
|
| 296 |
|
| 297 | if __name__ == '__main__':
|
| 298 | HashDemo()
|