OILS / builtin / method_str.py View on Github | oilshell.org

588 lines, 354 significant
1"""YSH Str methods"""
2
3from __future__ import print_function
4
5from _devbuild.gen.syntax_asdl import loc_t
6from _devbuild.gen.value_asdl import (value, value_e, value_t, eggex_ops,
7 eggex_ops_t, RegexMatch)
8from core import error
9from core import state
10from core import vm
11from frontend import typed_args
12from mycpp import mops
13from mycpp.mylib import log, tagswitch
14from osh import string_ops
15from ysh import expr_eval
16from ysh import regex_translate
17from ysh import val_ops
18
19import libc
20from libc import REG_NOTBOL
21
22from typing import cast, Dict, List, Tuple
23
24_ = log
25
26
27def _StrMatchStart(s, p):
28 # type: (str, str) -> Tuple[bool, int, int]
29 """Returns the range of bytes in 's' that match string pattern `p`. the
30 pattern matches if 's' starts with all the characters in 'p'.
31
32 The returned match result is the tuple "(matched, begin, end)". 'matched'
33 is true if the pattern matched. 'begin' and 'end' give the half-open range
34 "[begin, end)" of byte indices from 's' for the match, and are a valid but
35 empty range if 'match' is false.
36
37 Used for shell functions like 'trimStart' when trimming a prefix string.
38 """
39 if s.startswith(p):
40 return (True, 0, len(p))
41 else:
42 return (False, 0, 0)
43
44
45def _StrMatchEnd(s, p):
46 # type: (str, str) -> Tuple[bool, int, int]
47 """Returns a match result for the bytes in 's' that match string pattern
48 `p`. the pattern matches if 's' ends with all the characters in 'p'.
49
50 The returned match result is the tuple "(matched, begin, end)". 'matched'
51 is true if the pattern matched. 'begin' and 'end' give the half-open range
52 "[begin, end)" of byte indices from 's' for the match, and are a valid but
53 empty range if 'match' is false.
54
55 Used for shell functions like 'trimEnd' when trimming a suffix string.
56 """
57 len_s = len(s)
58 if s.endswith(p):
59 return (True, len_s - len(p), len_s)
60 else:
61 return (False, len_s, len_s)
62
63
64def _EggexMatchCommon(s, p, ere, empty_p):
65 # type: (str, value.Eggex, str, int) -> Tuple[bool, int, int]
66 cflags = regex_translate.LibcFlags(p.canonical_flags)
67 eflags = 0
68 indices = libc.regex_search(ere, cflags, s, eflags)
69 if indices is None:
70 return (False, empty_p, empty_p)
71
72 start = indices[0]
73 end = indices[1]
74
75 return (True, start, end)
76
77
78def _EggexMatchStart(s, p):
79 # type: (str, value.Eggex) -> Tuple[bool, int, int]
80 """Returns a match result for the bytes in 's' that match Eggex pattern
81 `p` when constrained to match at the start of the string.
82
83 Any capturing done by the Eggex pattern is ignored.
84
85 The returned match result is the tuple "(matched, begin, end)". 'matched'
86 is true if the pattern matched. 'begin' and 'end' give the half-open range
87 "[begin, end)" of byte indices from 's' for the match, and are a valid but
88 empty range if 'match' is false.
89
90 Used for shell functions like 'trimStart' when trimming with an Eggex
91 pattern.
92 """
93 ere = regex_translate.AsPosixEre(p)
94 if not ere.startswith('^'):
95 ere = '^' + ere
96 return _EggexMatchCommon(s, p, ere, 0)
97
98
99def _EggexMatchEnd(s, p):
100 # type: (str, value.Eggex) -> Tuple[bool, int, int]
101 """Like _EggexMatchStart, but matches against the end of the
102 string.
103 """
104 ere = regex_translate.AsPosixEre(p)
105 if not ere.endswith('$'):
106 ere = ere + '$'
107 return _EggexMatchCommon(s, p, ere, len(s))
108
109
110START = 0b01
111END = 0b10
112
113
114class HasAffix(vm._Callable):
115 """ Implements `startsWith()`, `endsWith()`. """
116
117 def __init__(self, anchor):
118 # type: (int) -> None
119 assert anchor in (START, END), ("Anchor must be START or END")
120 self.anchor = anchor
121
122 def Call(self, rd):
123 # type: (typed_args.Reader) -> value_t
124 """
125 string => startsWith(pattern_str) # => bool
126 string => startsWith(pattern_eggex) # => bool
127 string => endsWith(pattern_str) # => bool
128 string => endsWith(pattern_eggex) # => bool
129 """
130
131 string = rd.PosStr()
132 pattern_val = rd.PosValue()
133 pattern_str = None # type: str
134 pattern_eggex = None # type: value.Eggex
135 with tagswitch(pattern_val) as case:
136 if case(value_e.Eggex):
137 pattern_eggex = cast(value.Eggex, pattern_val)
138 elif case(value_e.Str):
139 pattern_str = cast(value.Str, pattern_val).s
140 else:
141 raise error.TypeErr(pattern_val,
142 'expected pattern to be Eggex or Str',
143 rd.LeftParenToken())
144 rd.Done()
145
146 matched = False
147 try:
148 if pattern_str is not None:
149 if self.anchor & START:
150 matched, _, _ = _StrMatchStart(string, pattern_str)
151 else:
152 matched, _, _ = _StrMatchEnd(string, pattern_str)
153 else:
154 assert pattern_eggex is not None
155 if self.anchor & START:
156 matched, _, _ = _EggexMatchStart(string, pattern_eggex)
157 else:
158 matched, _, _ = _EggexMatchEnd(string, pattern_eggex)
159 except error.Strict as e:
160 raise error.Expr(e.msg, e.location)
161
162 return value.Bool(matched)
163
164
165class Trim(vm._Callable):
166 """ Implements `trimStart()`, `trimEnd()`, and `trim()` """
167
168 def __init__(self, anchor):
169 # type: (int) -> None
170 assert anchor in (START, END, START
171 | END), ("Anchor must be START, END, or START|END")
172 self.anchor = anchor
173
174 def Call(self, rd):
175 # type: (typed_args.Reader) -> value_t
176 """
177 string => trimStart() # => Str
178 string => trimEnd() # => Str
179 string => trim() # => Str
180 string => trimStart(pattern_str) # => Str
181 string => trimEnd(pattern_str) # => Str
182 string => trim(pattern_str) # => Str
183 string => trimStart(pattern_eggex) # => Str
184 string => trimEnd(pattern_eggex) # => Str
185 string => trim(pattern_eggex) # => Str
186 """
187
188 string = rd.PosStr()
189 pattern_val = rd.OptionalValue()
190 pattern_str = None # type: str
191 pattern_eggex = None # type: value.Eggex
192 if pattern_val:
193 with tagswitch(pattern_val) as case:
194 if case(value_e.Eggex):
195 pattern_eggex = cast(value.Eggex, pattern_val)
196 elif case(value_e.Str):
197 pattern_str = cast(value.Str, pattern_val).s
198 else:
199 raise error.TypeErr(pattern_val,
200 'expected pattern to be Eggex or Str',
201 rd.LeftParenToken())
202 rd.Done()
203
204 start = 0
205 end = len(string)
206 try:
207 if pattern_str is not None:
208 if self.anchor & START:
209 _, _, start = _StrMatchStart(string, pattern_str)
210 if self.anchor & END:
211 _, end, _ = _StrMatchEnd(string, pattern_str)
212 elif pattern_eggex is not None:
213 if self.anchor & START:
214 _, _, start = _EggexMatchStart(string, pattern_eggex)
215 if self.anchor & END:
216 _, end, _ = _EggexMatchEnd(string, pattern_eggex)
217 else:
218 if self.anchor & START:
219 _, start = string_ops.StartsWithWhitespaceByteRange(string)
220 if self.anchor & END:
221 end, _ = string_ops.EndsWithWhitespaceByteRange(string)
222 except error.Strict as e:
223 raise error.Expr(e.msg, e.location)
224
225 res = string[start:end]
226 return value.Str(res)
227
228
229class Upper(vm._Callable):
230
231 def __init__(self):
232 # type: () -> None
233 pass
234
235 def Call(self, rd):
236 # type: (typed_args.Reader) -> value_t
237
238 s = rd.PosStr()
239 rd.Done()
240
241 # TODO: unicode support
242 return value.Str(s.upper())
243
244
245class Lower(vm._Callable):
246
247 def __init__(self):
248 # type: () -> None
249 pass
250
251 def Call(self, rd):
252 # type: (typed_args.Reader) -> value_t
253
254 s = rd.PosStr()
255 rd.Done()
256
257 # TODO: unicode support
258 return value.Str(s.lower())
259
260
261SEARCH = 0
262LEFT_MATCH = 1
263
264
265class SearchMatch(vm._Callable):
266
267 def __init__(self, which_method):
268 # type: (int) -> None
269 self.which_method = which_method
270
271 def Call(self, rd):
272 # type: (typed_args.Reader) -> value_t
273 """
274 s => search(eggex, pos=0)
275 """
276 string = rd.PosStr()
277
278 pattern = rd.PosValue() # Eggex or ERE Str
279 with tagswitch(pattern) as case:
280 if case(value_e.Eggex):
281 eggex_val = cast(value.Eggex, pattern)
282
283 # lazily converts to ERE
284 ere = regex_translate.AsPosixEre(eggex_val)
285 cflags = regex_translate.LibcFlags(eggex_val.canonical_flags)
286 capture = eggex_ops.Yes(
287 eggex_val.convert_funcs, eggex_val.convert_toks,
288 eggex_val.capture_names) # type: eggex_ops_t
289
290 elif case(value_e.Str):
291 ere = cast(value.Str, pattern).s
292 cflags = 0
293 capture = eggex_ops.No
294
295 else:
296 # TODO: add method name to this error
297 raise error.TypeErr(pattern, 'expected Eggex or Str',
298 rd.LeftParenToken())
299
300 # It's called 'pos', not 'start' like Python. Python has 2 kinds of
301 # 'start' in its regex API, which can be confusing.
302 pos = mops.BigTruncate(rd.NamedInt('pos', 0))
303 rd.Done()
304
305 # Make it anchored
306 if self.which_method == LEFT_MATCH and not ere.startswith('^'):
307 ere = '^' + ere
308
309 if self.which_method == LEFT_MATCH:
310 eflags = 0 # ^ matches beginning even if pos=5
311 else:
312 eflags = 0 if pos == 0 else REG_NOTBOL # ^ only matches when pos=0
313
314 indices = libc.regex_search(ere, cflags, string, eflags, pos)
315
316 if indices is None:
317 return value.Null
318
319 return RegexMatch(string, indices, capture)
320
321
322class Replace(vm._Callable):
323
324 def __init__(self, mem, expr_ev):
325 # type: (state.Mem, expr_eval.ExprEvaluator) -> None
326 self.mem = mem
327 self.expr_ev = expr_ev
328
329 def EvalSubstExpr(self, expr, blame_loc):
330 # type: (value.Expr, loc_t) -> str
331 res = self.expr_ev.EvalExpr(expr.e, blame_loc)
332 if res.tag() == value_e.Str:
333 return cast(value.Str, res).s
334
335 raise error.TypeErr(res, "expected expr to eval to a Str", blame_loc)
336
337 def Call(self, rd):
338 # type: (typed_args.Reader) -> value_t
339 """
340 s => replace(string_val, subst_str, count=-1)
341 s => replace(string_val, subst_expr, count=-1)
342 s => replace(eggex_val, subst_str, count=-1)
343 s => replace(eggex_val, subst_expr, count=-1)
344
345 For count in [0, MAX_INT], there will be no more than count
346 replacements. Any negative count should read as unset, and replace will
347 replace all occurances of the pattern.
348 """
349 string = rd.PosStr()
350
351 string_val = None # type: value.Str
352 eggex_val = None # type: value.Eggex
353 subst_str = None # type: value.Str
354 subst_expr = None # type: value.Expr
355
356 pattern = rd.PosValue()
357 with tagswitch(pattern) as case:
358 if case(value_e.Eggex):
359 # HACK: mycpp will otherwise generate:
360 # value::Eggex* eggex_val ...
361 eggex_val_ = cast(value.Eggex, pattern)
362 eggex_val = eggex_val_
363
364 elif case(value_e.Str):
365 string_val_ = cast(value.Str, pattern)
366 string_val = string_val_
367
368 else:
369 raise error.TypeErr(pattern,
370 'expected pattern to be Eggex or Str',
371 rd.LeftParenToken())
372
373 subst = rd.PosValue()
374 with tagswitch(subst) as case:
375 if case(value_e.Str):
376 subst_str_ = cast(value.Str, subst)
377 subst_str = subst_str_
378
379 elif case(value_e.Expr):
380 subst_expr_ = cast(value.Expr, subst)
381 subst_expr = subst_expr_
382
383 else:
384 raise error.TypeErr(subst,
385 'expected substitution to be Str or Expr',
386 rd.LeftParenToken())
387
388 count = mops.BigTruncate(rd.NamedInt("count", -1))
389 rd.Done()
390
391 if count == 0:
392 return value.Str(string)
393
394 if string_val:
395 if subst_str:
396 s = subst_str.s
397 if subst_expr:
398 # Eval with $0 set to string_val (the matched substring)
399 with state.ctx_Eval(self.mem, string_val.s, None, None):
400 s = self.EvalSubstExpr(subst_expr, rd.LeftParenToken())
401 assert s is not None
402
403 result = string.replace(string_val.s, s, count)
404
405 return value.Str(result)
406
407 if eggex_val:
408 if '\0' in string:
409 raise error.Structured(
410 3,
411 "cannot replace by eggex on a string with NUL bytes",
412 rd.LeftParenToken())
413
414 ere = regex_translate.AsPosixEre(eggex_val)
415 cflags = regex_translate.LibcFlags(eggex_val.canonical_flags)
416
417 # Walk through the string finding all matches of the compiled ere.
418 # Then, collect unmatched substrings and substitutions into the
419 # `parts` list.
420 pos = 0
421 parts = [] # type: List[str]
422 replace_count = 0
423 while pos < len(string):
424 indices = libc.regex_search(ere, cflags, string, 0, pos)
425 if indices is None:
426 break
427
428 # Collect captures
429 arg0 = None # type: str
430 argv = [] # type: List[str]
431 named_vars = {} # type: Dict[str, value_t]
432 num_groups = len(indices) / 2
433 for group in xrange(num_groups):
434 start = indices[2 * group]
435 end = indices[2 * group + 1]
436 captured = string[start:end]
437 val = value.Str(captured) # type: value_t
438
439 if len(eggex_val.convert_funcs) and group != 0:
440 convert_func = eggex_val.convert_funcs[group - 1]
441 convert_tok = eggex_val.convert_toks[group - 1]
442
443 if convert_func:
444 val = self.expr_ev.CallConvertFunc(
445 convert_func, val, convert_tok,
446 rd.LeftParenToken())
447
448 # $0, $1, $2 variables are argv values, which must be
449 # strings. Furthermore, they can only be used in string
450 # contexts
451 # eg. "$[1]" != "$1".
452 val_str = val_ops.Stringify(val, rd.LeftParenToken())
453 if group == 0:
454 arg0 = val_str
455 else:
456 argv.append(val_str)
457
458 # $0 cannot be named
459 if group != 0:
460 name = eggex_val.capture_names[group - 2]
461 if name is not None:
462 named_vars[name] = val
463
464 if subst_str:
465 s = subst_str.s
466 if subst_expr:
467 with state.ctx_Eval(self.mem, arg0, argv, named_vars):
468 s = self.EvalSubstExpr(subst_expr, rd.LeftParenToken())
469 assert s is not None
470
471 start = indices[0]
472 end = indices[1]
473 if pos == end:
474 raise error.Structured(
475 3,
476 "eggex should never match the empty string",
477 rd.LeftParenToken())
478
479 parts.append(string[pos:start]) # Unmatched substring
480 parts.append(s) # Replacement
481 pos = end # Move to end of match
482
483 replace_count += 1
484 if count != -1 and replace_count == count:
485 break
486
487 parts.append(string[pos:]) # Remaining unmatched substring
488
489 return value.Str("".join(parts))
490
491 raise AssertionError()
492
493
494class Split(vm._Callable):
495
496 def __init__(self):
497 # type: () -> None
498 pass
499
500 def Call(self, rd):
501 # type: (typed_args.Reader) -> value_t
502 """
503 s.split(string_sep, count=-1)
504 s.split(eggex_sep, count=-1)
505
506 Count behaves like in replace() in that:
507 - `count` < 0 -> ignore
508 - `count` >= 0 -> there will be at most `count` splits
509 """
510 string = rd.PosStr()
511
512 string_sep = None # type: str
513 eggex_sep = None # type: value.Eggex
514
515 sep = rd.PosValue()
516 with tagswitch(sep) as case:
517 if case(value_e.Eggex):
518 eggex_sep_ = cast(value.Eggex, sep)
519 eggex_sep = eggex_sep_
520
521 elif case(value_e.Str):
522 string_sep_ = cast(value.Str, sep)
523 string_sep = string_sep_.s
524
525 else:
526 raise error.TypeErr(sep, 'expected separator to be Eggex or Str',
527 rd.LeftParenToken())
528
529 count = mops.BigTruncate(rd.NamedInt("count", -1))
530 rd.Done()
531
532 if len(string) == 0:
533 return value.List([])
534
535 if string_sep is not None:
536 if len(string_sep) == 0:
537 raise error.Structured(3, "separator must be non-empty",
538 rd.LeftParenToken())
539
540 cursor = 0
541 chunks = [] # type: List[value_t]
542 while cursor < len(string) and count != 0:
543 next = string.find(string_sep, cursor)
544 if next == -1:
545 break
546
547 chunks.append(value.Str(string[cursor:next]))
548 cursor = next + len(string_sep)
549 count -= 1
550
551 chunks.append(value.Str(string[cursor:]))
552
553 return value.List(chunks)
554
555 if eggex_sep is not None:
556 if '\0' in string:
557 raise error.Structured(
558 3, "cannot split a string with a NUL byte",
559 rd.LeftParenToken())
560
561 regex = regex_translate.AsPosixEre(eggex_sep)
562 cflags = regex_translate.LibcFlags(eggex_sep.canonical_flags)
563
564 cursor = 0
565 chunks = []
566 while cursor < len(string) and count != 0:
567 m = libc.regex_search(regex, cflags, string, 0, cursor)
568 if m is None:
569 break
570
571 start = m[0]
572 end = m[1]
573 if start == end:
574 raise error.Structured(
575 3,
576 "eggex separators should never match the empty string",
577 rd.LeftParenToken())
578
579 chunks.append(value.Str(string[cursor:start]))
580 cursor = end
581
582 count -= 1
583
584 chunks.append(value.Str(string[cursor:]))
585
586 return value.List(chunks)
587
588 raise AssertionError()