OILS / builtin / method_str.py View on Github | oils.pub

661 lines, 397 significant
1"""YSH Str methods"""
2
3from __future__ import print_function
4
5from _devbuild.gen.syntax_asdl import loc_t
6from _devbuild.gen.value_asdl import (value, value_e, value_t, eggex_ops,
7 eggex_ops_t, RegexMatch)
8from core import error
9from core import state
10from core import vm
11from frontend import typed_args
12from mycpp import mops
13from mycpp.mylib import log, tagswitch
14from osh import string_ops
15from ysh import expr_eval
16from ysh import regex_translate
17from ysh import val_ops
18
19import libc
20from libc import REG_NOTBOL
21
22from typing import cast, Dict, List, Optional, Tuple
23
24_ = log
25
26
27def _StrMatchStart(s, p):
28 # type: (str, str) -> Tuple[bool, int, int]
29 """Returns the range of bytes in 's' that match string pattern `p`. the
30 pattern matches if 's' starts with all the characters in 'p'.
31
32 The returned match result is the tuple "(matched, begin, end)". 'matched'
33 is true if the pattern matched. 'begin' and 'end' give the half-open range
34 "[begin, end)" of byte indices from 's' for the match, and are a valid but
35 empty range if 'match' is false.
36
37 Used for shell functions like 'trimStart' when trimming a prefix string.
38 """
39 if s.startswith(p):
40 return (True, 0, len(p))
41 else:
42 return (False, 0, 0)
43
44
45def _StrMatchEnd(s, p):
46 # type: (str, str) -> Tuple[bool, int, int]
47 """Returns a match result for the bytes in 's' that match string pattern
48 `p`. the pattern matches if 's' ends with all the characters in 'p'.
49
50 The returned match result is the tuple "(matched, begin, end)". 'matched'
51 is true if the pattern matched. 'begin' and 'end' give the half-open range
52 "[begin, end)" of byte indices from 's' for the match, and are a valid but
53 empty range if 'match' is false.
54
55 Used for shell functions like 'trimEnd' when trimming a suffix string.
56 """
57 len_s = len(s)
58 if s.endswith(p):
59 return (True, len_s - len(p), len_s)
60 else:
61 return (False, len_s, len_s)
62
63
64def _EggexMatchCommon(s, p, ere, empty_p):
65 # type: (str, value.Eggex, str, int) -> Tuple[bool, int, int]
66 cflags = regex_translate.LibcFlags(p.canonical_flags)
67 eflags = 0
68 indices = libc.regex_search(ere, cflags, s, eflags)
69 if indices is None:
70 return (False, empty_p, empty_p)
71
72 start = indices[0]
73 end = indices[1]
74
75 return (True, start, end)
76
77
78def _EggexMatchStart(s, p):
79 # type: (str, value.Eggex) -> Tuple[bool, int, int]
80 """Returns a match result for the bytes in 's' that match Eggex pattern
81 `p` when constrained to match at the start of the string.
82
83 Any capturing done by the Eggex pattern is ignored.
84
85 The returned match result is the tuple "(matched, begin, end)". 'matched'
86 is true if the pattern matched. 'begin' and 'end' give the half-open range
87 "[begin, end)" of byte indices from 's' for the match, and are a valid but
88 empty range if 'match' is false.
89
90 Used for shell functions like 'trimStart' when trimming with an Eggex
91 pattern.
92 """
93 ere = regex_translate.AsPosixEre(p)
94 if not ere.startswith('^'):
95 ere = '^' + ere
96 return _EggexMatchCommon(s, p, ere, 0)
97
98
99def _EggexMatchEnd(s, p):
100 # type: (str, value.Eggex) -> Tuple[bool, int, int]
101 """Like _EggexMatchStart, but matches against the end of the
102 string.
103 """
104 ere = regex_translate.AsPosixEre(p)
105 if not ere.endswith('$'):
106 ere = ere + '$'
107 return _EggexMatchCommon(s, p, ere, len(s))
108
109
110START = 0b01
111END = 0b10
112
113
114class HasAffix(vm._Callable):
115 """ Implements `startsWith()`, `endsWith()`. """
116
117 def __init__(self, anchor):
118 # type: (int) -> None
119 assert anchor in (START, END), ("Anchor must be START or END")
120 self.anchor = anchor
121
122 def Call(self, rd):
123 # type: (typed_args.Reader) -> value_t
124 """
125 string => startsWith(pattern_str) # => bool
126 string => startsWith(pattern_eggex) # => bool
127 string => endsWith(pattern_str) # => bool
128 string => endsWith(pattern_eggex) # => bool
129 """
130
131 string = rd.PosStr()
132 pattern_val = rd.PosValue()
133 pattern_str = None # type: Optional[str]
134 pattern_eggex = None # type: value.Eggex
135 with tagswitch(pattern_val) as case:
136 if case(value_e.Eggex):
137 pattern_eggex = cast(value.Eggex, pattern_val)
138 elif case(value_e.Str):
139 pattern_str = cast(value.Str, pattern_val).s
140 else:
141 raise error.TypeErr(pattern_val,
142 'expected pattern to be Eggex or Str',
143 rd.LeftParenToken())
144 rd.Done()
145
146 matched = False
147 try:
148 if pattern_str is not None:
149 if self.anchor & START:
150 matched, _, _ = _StrMatchStart(string, pattern_str)
151 else:
152 matched, _, _ = _StrMatchEnd(string, pattern_str)
153 else:
154 assert pattern_eggex is not None
155 if self.anchor & START:
156 matched, _, _ = _EggexMatchStart(string, pattern_eggex)
157 else:
158 matched, _, _ = _EggexMatchEnd(string, pattern_eggex)
159 except error.Strict as e:
160 raise error.Expr(e.msg, e.location)
161
162 return value.Bool(matched)
163
164
165class Trim(vm._Callable):
166 """ Implements `trimStart()`, `trimEnd()`, and `trim()` """
167
168 def __init__(self, anchor):
169 # type: (int) -> None
170 assert anchor in (START, END, START
171 | END), ("Anchor must be START, END, or START|END")
172 self.anchor = anchor
173
174 def Call(self, rd):
175 # type: (typed_args.Reader) -> value_t
176 """
177 string => trimStart() # => Str
178 string => trimEnd() # => Str
179 string => trim() # => Str
180 string => trimStart(pattern_str) # => Str
181 string => trimEnd(pattern_str) # => Str
182 string => trim(pattern_str) # => Str
183 string => trimStart(pattern_eggex) # => Str
184 string => trimEnd(pattern_eggex) # => Str
185 string => trim(pattern_eggex) # => Str
186 """
187
188 string = rd.PosStr()
189 pattern_val = rd.OptionalValue()
190 pattern_str = None # type: Optional[str]
191 pattern_eggex = None # type: value.Eggex
192 if pattern_val:
193 with tagswitch(pattern_val) as case:
194 if case(value_e.Eggex):
195 pattern_eggex = cast(value.Eggex, pattern_val)
196 elif case(value_e.Str):
197 pattern_str = cast(value.Str, pattern_val).s
198 else:
199 raise error.TypeErr(pattern_val,
200 'expected pattern to be Eggex or Str',
201 rd.LeftParenToken())
202 rd.Done()
203
204 start = 0
205 end = len(string)
206 try:
207 if pattern_str is not None:
208 if self.anchor & START:
209 _, _, start = _StrMatchStart(string, pattern_str)
210 if self.anchor & END:
211 _, end, _ = _StrMatchEnd(string, pattern_str)
212 elif pattern_eggex is not None:
213 if self.anchor & START:
214 _, _, start = _EggexMatchStart(string, pattern_eggex)
215 if self.anchor & END:
216 _, end, _ = _EggexMatchEnd(string, pattern_eggex)
217 else:
218 if self.anchor & START:
219 _, start = string_ops.StartsWithWhitespaceByteRange(string)
220 if self.anchor & END:
221 end, _ = string_ops.EndsWithWhitespaceByteRange(string)
222 except error.Strict as e:
223 raise error.Expr(e.msg, e.location)
224
225 res = string[start:end]
226 return value.Str(res)
227
228
229class Upper(vm._Callable):
230
231 def __init__(self):
232 # type: () -> None
233 pass
234
235 def Call(self, rd):
236 # type: (typed_args.Reader) -> value_t
237
238 s = rd.PosStr()
239 rd.Done()
240
241 # TODO: unicode support
242 return value.Str(s.upper())
243
244
245class Lower(vm._Callable):
246
247 def __init__(self):
248 # type: () -> None
249 pass
250
251 def Call(self, rd):
252 # type: (typed_args.Reader) -> value_t
253
254 s = rd.PosStr()
255 rd.Done()
256
257 # TODO: unicode support
258 return value.Str(s.lower())
259
260
261SEARCH = 0
262LEFT_MATCH = 1
263
264
265class SearchMatch(vm._Callable):
266
267 def __init__(self, which_method):
268 # type: (int) -> None
269 self.which_method = which_method
270
271 def Call(self, rd):
272 # type: (typed_args.Reader) -> value_t
273 """
274 s => search(eggex, pos=0)
275 """
276 string = rd.PosStr()
277
278 pattern = rd.PosValue() # Eggex or ERE Str
279 with tagswitch(pattern) as case:
280 if case(value_e.Eggex):
281 eggex_val = cast(value.Eggex, pattern)
282
283 # lazily converts to ERE
284 ere = regex_translate.AsPosixEre(eggex_val)
285 cflags = regex_translate.LibcFlags(eggex_val.canonical_flags)
286 capture = eggex_ops.Yes(
287 eggex_val.convert_funcs, eggex_val.convert_toks,
288 eggex_val.capture_names) # type: eggex_ops_t
289
290 elif case(value_e.Str):
291 ere = cast(value.Str, pattern).s
292 cflags = 0
293 capture = eggex_ops.No
294
295 else:
296 # TODO: add method name to this error
297 raise error.TypeErr(pattern, 'expected Eggex or Str',
298 rd.LeftParenToken())
299
300 # It's called 'pos', not 'start' like Python. Python has 2 kinds of
301 # 'start' in its regex API, which can be confusing.
302 pos = mops.BigTruncate(rd.NamedInt('pos', 0))
303 rd.Done()
304
305 # Make it anchored
306 if self.which_method == LEFT_MATCH:
307 ere = '^(%s)' % ere
308
309 if self.which_method == LEFT_MATCH:
310 eflags = 0 # ^ matches beginning even if pos=5
311 else:
312 eflags = 0 if pos == 0 else REG_NOTBOL # ^ only matches when pos=0
313
314 indices = libc.regex_search(ere, cflags, string, eflags, pos)
315
316 if indices is None:
317 return value.Null
318
319 if self.which_method == LEFT_MATCH:
320 # undo the ^() transformation
321 indices = indices[2:]
322
323 return RegexMatch(string, indices, capture)
324
325
326class Contains(vm._Callable):
327
328 def __init__(self):
329 # type: () -> None
330 pass
331
332 def Call(self, rd):
333 # type: (typed_args.Reader) -> value_t
334 string = rd.PosStr()
335 substr = rd.PosStr()
336
337 rd.Done()
338
339 x = string.find(substr)
340
341 return value.Bool(x != -1)
342
343class Find(vm._Callable):
344
345 def __init__(self, direction):
346 # type: (int) -> None
347 self.direction = direction
348
349 def Call(self, rd):
350 # type: (typed_args.Reader) -> value_t
351 string = rd.PosStr()
352 substr = rd.PosStr()
353
354 start = mops.BigTruncate(rd.NamedInt("start", 0))
355 end = mops.BigTruncate(rd.NamedInt("end", len(string)))
356 rd.Done()
357
358 if self.direction & START:
359 x = string.find(substr, start, end)
360 else:
361 x = string.rfind(substr, start, end)
362
363 return value.Int(mops.BigInt(x))
364
365class Replace(vm._Callable):
366
367 def __init__(self, mem, expr_ev):
368 # type: (state.Mem, expr_eval.ExprEvaluator) -> None
369 self.mem = mem
370 self.expr_ev = expr_ev
371
372 def EvalSubstExpr(self, expr, blame_loc):
373 # type: (value.Expr, loc_t) -> str
374 res = self.expr_ev.EvalExprClosure(expr, blame_loc)
375 if res.tag() == value_e.Str:
376 return cast(value.Str, res).s
377
378 raise error.TypeErr(res, "expected expr to eval to a Str", blame_loc)
379
380 def Call(self, rd):
381 # type: (typed_args.Reader) -> value_t
382 """
383 s => replace(string_val, subst_str, count=-1)
384 s => replace(string_val, subst_expr, count=-1)
385 s => replace(eggex_val, subst_str, count=-1)
386 s => replace(eggex_val, subst_expr, count=-1)
387
388 For count in [0, MAX_INT], there will be no more than count
389 replacements. Any negative count should read as unset, and replace will
390 replace all occurances of the pattern.
391 """
392 string = rd.PosStr()
393
394 string_val = None # type: value.Str
395 eggex_val = None # type: value.Eggex
396 subst_str = None # type: value.Str
397 subst_expr = None # type: value.Expr
398
399 pattern = rd.PosValue()
400 with tagswitch(pattern) as case:
401 if case(value_e.Eggex):
402 # HACK: mycpp will otherwise generate:
403 # value::Eggex* eggex_val ...
404 eggex_val_ = cast(value.Eggex, pattern)
405 eggex_val = eggex_val_
406
407 elif case(value_e.Str):
408 string_val_ = cast(value.Str, pattern)
409 string_val = string_val_
410
411 else:
412 raise error.TypeErr(pattern,
413 'expected pattern to be Eggex or Str',
414 rd.LeftParenToken())
415
416 subst = rd.PosValue()
417 with tagswitch(subst) as case:
418 if case(value_e.Str):
419 subst_str_ = cast(value.Str, subst)
420 subst_str = subst_str_
421
422 elif case(value_e.Expr):
423 subst_expr_ = cast(value.Expr, subst)
424 subst_expr = subst_expr_
425
426 else:
427 raise error.TypeErr(subst,
428 'expected substitution to be Str or Expr',
429 rd.LeftParenToken())
430
431 count = mops.BigTruncate(rd.NamedInt("count", -1))
432 rd.Done()
433
434 if count == 0:
435 return value.Str(string)
436
437 if string_val:
438 if subst_str:
439 s = subst_str.s
440 if subst_expr:
441 # Eval with $0 set to string_val (the matched substring)
442 with state.ctx_Eval(self.mem, string_val.s, None, None):
443 s = self.EvalSubstExpr(subst_expr, rd.LeftParenToken())
444 assert s is not None
445
446 result = string.replace(string_val.s, s, count)
447
448 return value.Str(result)
449
450 if eggex_val:
451 if '\0' in string:
452 raise error.Structured(
453 3, "cannot replace by eggex on a string with NUL bytes",
454 rd.LeftParenToken())
455
456 ere = regex_translate.AsPosixEre(eggex_val)
457 cflags = regex_translate.LibcFlags(eggex_val.canonical_flags)
458
459 # Walk through the string finding all matches of the compiled ere.
460 # Then, collect unmatched substrings and substitutions into the
461 # `parts` list.
462 pos = 0
463 parts = [] # type: List[str]
464 replace_count = 0
465 while pos < len(string):
466 indices = libc.regex_search(ere, cflags, string, 0, pos)
467 if indices is None:
468 break
469
470 # Collect captures
471 arg0 = None # type: Optional[str]
472 argv = [] # type: List[str]
473 named_vars = {} # type: Dict[str, value_t]
474 num_groups = len(indices) / 2
475 for group in xrange(num_groups):
476 start = indices[2 * group]
477 end = indices[2 * group + 1]
478 captured = string[start:end]
479 val = value.Str(captured) # type: value_t
480
481 if len(eggex_val.convert_funcs) and group != 0:
482 convert_func = eggex_val.convert_funcs[group - 1]
483 convert_tok = eggex_val.convert_toks[group - 1]
484
485 if convert_func:
486 val = self.expr_ev.CallConvertFunc(
487 convert_func, val, convert_tok,
488 rd.LeftParenToken())
489
490 # $0, $1, $2 variables are argv values, which must be
491 # strings. Furthermore, they can only be used in string
492 # contexts
493 # eg. "$[1]" != "$1".
494 val_str = val_ops.Stringify(val, rd.LeftParenToken(), '')
495 if group == 0:
496 arg0 = val_str
497 else:
498 argv.append(val_str)
499
500 # $0 cannot be named
501 if group != 0:
502 name = eggex_val.capture_names[group - 2]
503 if name is not None:
504 named_vars[name] = val
505
506 if subst_str:
507 s = subst_str.s
508 if subst_expr:
509 with state.ctx_Eval(self.mem, arg0, argv, named_vars):
510 s = self.EvalSubstExpr(subst_expr, rd.LeftParenToken())
511 assert s is not None
512
513 start = indices[0]
514 end = indices[1]
515 if pos == end:
516 raise error.Structured(
517 3, "eggex should never match the empty string",
518 rd.LeftParenToken())
519
520 parts.append(string[pos:start]) # Unmatched substring
521 parts.append(s) # Replacement
522 pos = end # Move to end of match
523
524 replace_count += 1
525 if count != -1 and replace_count == count:
526 break
527
528 parts.append(string[pos:]) # Remaining unmatched substring
529
530 return value.Str("".join(parts))
531
532 raise AssertionError()
533
534
535class Split(vm._Callable):
536
537 def __init__(self):
538 # type: () -> None
539 pass
540
541 def Call(self, rd):
542 # type: (typed_args.Reader) -> value_t
543 """
544 s.split(string_sep, count=-1)
545 s.split(eggex_sep, count=-1)
546
547 Count behaves like in replace() in that:
548 - `count` < 0 -> ignore
549 - `count` >= 0 -> there will be at most `count` splits
550 """
551 string = rd.PosStr()
552
553 string_sep = None # type: Optional[str]
554 eggex_sep = None # type: value.Eggex
555
556 sep = rd.PosValue()
557 with tagswitch(sep) as case:
558 if case(value_e.Eggex):
559 eggex_sep_ = cast(value.Eggex, sep)
560 eggex_sep = eggex_sep_
561
562 elif case(value_e.Str):
563 string_sep_ = cast(value.Str, sep)
564 string_sep = string_sep_.s
565
566 else:
567 raise error.TypeErr(sep,
568 'expected separator to be Eggex or Str',
569 rd.LeftParenToken())
570
571 count = mops.BigTruncate(rd.NamedInt("count", -1))
572 rd.Done()
573
574 if len(string) == 0:
575 return value.List([])
576
577 if string_sep is not None:
578 if len(string_sep) == 0:
579 raise error.Structured(3, "separator must be non-empty",
580 rd.LeftParenToken())
581
582 cursor = 0
583 chunks = [] # type: List[value_t]
584 length = len(string)
585 while cursor < length and count != 0:
586 next_pos = string.find(string_sep, cursor)
587 if next_pos == -1:
588 break
589
590 chunks.append(value.Str(string[cursor:next_pos]))
591 cursor = next_pos + len(string_sep)
592 count -= 1
593
594 chunks.append(value.Str(string[cursor:]))
595
596 return value.List(chunks)
597
598 if eggex_sep is not None:
599 if '\0' in string:
600 raise error.Structured(
601 3, "cannot split a string with a NUL byte",
602 rd.LeftParenToken())
603
604 regex = regex_translate.AsPosixEre(eggex_sep)
605 cflags = regex_translate.LibcFlags(eggex_sep.canonical_flags)
606
607 cursor = 0
608 chunks = []
609 while cursor < len(string) and count != 0:
610 m = libc.regex_search(regex, cflags, string, 0, cursor)
611 if m is None:
612 break
613
614 start = m[0]
615 end = m[1]
616 if start == end:
617 raise error.Structured(
618 3,
619 "eggex separators should never match the empty string",
620 rd.LeftParenToken())
621
622 chunks.append(value.Str(string[cursor:start]))
623 cursor = end
624
625 count -= 1
626
627 chunks.append(value.Str(string[cursor:]))
628
629 return value.List(chunks)
630
631 raise AssertionError()
632
633
634class Lines(vm._Callable):
635
636 def __init__(self):
637 # type: () -> None
638 pass
639
640 def Call(self, rd):
641 # type: (typed_args.Reader) -> value_t
642 string = rd.PosStr()
643 eol = rd.NamedStr('eol', '\n')
644
645 # Adapted from Str.split() above, except for the handling of the last item
646
647 cursor = 0
648 chunks = [] # type: List[value_t]
649 length = len(string)
650 while cursor < length:
651 next_pos = string.find(eol, cursor)
652 if next_pos == -1:
653 break
654
655 chunks.append(value.Str(string[cursor:next_pos]))
656 cursor = next_pos + len(eol)
657
658 if cursor < length:
659 chunks.append(value.Str(string[cursor:]))
660
661 return value.List(chunks)