OILS / builtin / method_str.py View on Github | oils.pub

671 lines, 405 significant
1"""YSH Str methods"""
2
3from __future__ import print_function
4
5from _devbuild.gen.syntax_asdl import loc_t
6from _devbuild.gen.value_asdl import (value, value_e, value_t, eggex_ops,
7 eggex_ops_t, RegexMatch)
8from core import error
9from core import state
10from core import vm
11from frontend import typed_args
12from mycpp import mops
13from mycpp.mylib import log, tagswitch
14from osh import string_ops
15from ysh import expr_eval
16from ysh import regex_translate
17from ysh import val_ops
18
19import libc
20from libc import REG_NOTBOL
21
22from typing import cast, Dict, List, Optional, Tuple
23
24_ = log
25
26
27def _StrMatchStart(s, p):
28 # type: (str, str) -> Tuple[bool, int, int]
29 """Returns the range of bytes in 's' that match string pattern `p`. the
30 pattern matches if 's' starts with all the characters in 'p'.
31
32 The returned match result is the tuple "(matched, begin, end)". 'matched'
33 is true if the pattern matched. 'begin' and 'end' give the half-open range
34 "[begin, end)" of byte indices from 's' for the match, and are a valid but
35 empty range if 'match' is false.
36
37 Used for shell functions like 'trimStart' when trimming a prefix string.
38 """
39 if s.startswith(p):
40 return (True, 0, len(p))
41 else:
42 return (False, 0, 0)
43
44
45def _StrMatchEnd(s, p):
46 # type: (str, str) -> Tuple[bool, int, int]
47 """Returns a match result for the bytes in 's' that match string pattern
48 `p`. the pattern matches if 's' ends with all the characters in 'p'.
49
50 The returned match result is the tuple "(matched, begin, end)". 'matched'
51 is true if the pattern matched. 'begin' and 'end' give the half-open range
52 "[begin, end)" of byte indices from 's' for the match, and are a valid but
53 empty range if 'match' is false.
54
55 Used for shell functions like 'trimEnd' when trimming a suffix string.
56 """
57 len_s = len(s)
58 if s.endswith(p):
59 return (True, len_s - len(p), len_s)
60 else:
61 return (False, len_s, len_s)
62
63
64def _EggexMatchCommon(s, p, ere, empty_p):
65 # type: (str, value.Eggex, str, int) -> Tuple[bool, int, int]
66 cflags = regex_translate.LibcFlags(p.canonical_flags)
67 eflags = 0
68 indices = libc.regex_search(ere, cflags, s, eflags)
69 if indices is None:
70 return (False, empty_p, empty_p)
71
72 start = indices[0]
73 end = indices[1]
74
75 return (True, start, end)
76
77
78def _EggexMatchStart(s, p):
79 # type: (str, value.Eggex) -> Tuple[bool, int, int]
80 """Returns a match result for the bytes in 's' that match Eggex pattern
81 `p` when constrained to match at the start of the string.
82
83 Any capturing done by the Eggex pattern is ignored.
84
85 The returned match result is the tuple "(matched, begin, end)". 'matched'
86 is true if the pattern matched. 'begin' and 'end' give the half-open range
87 "[begin, end)" of byte indices from 's' for the match, and are a valid but
88 empty range if 'match' is false.
89
90 Used for shell functions like 'trimStart' when trimming with an Eggex
91 pattern.
92 """
93 ere = regex_translate.AsPosixEre(p)
94 if not ere.startswith('^'):
95 ere = '^' + ere
96 return _EggexMatchCommon(s, p, ere, 0)
97
98
99def _EggexMatchEnd(s, p):
100 # type: (str, value.Eggex) -> Tuple[bool, int, int]
101 """Like _EggexMatchStart, but matches against the end of the
102 string.
103 """
104 ere = regex_translate.AsPosixEre(p)
105 if not ere.endswith('$'):
106 ere = ere + '$'
107 return _EggexMatchCommon(s, p, ere, len(s))
108
109
110START = 0b01
111END = 0b10
112
113
114class HasAffix(vm._Callable):
115 """ Implements `startsWith()`, `endsWith()`. """
116
117 def __init__(self, anchor):
118 # type: (int) -> None
119 assert anchor in (START, END), ("Anchor must be START or END")
120 self.anchor = anchor
121
122 def Call(self, rd):
123 # type: (typed_args.Reader) -> value_t
124 """
125 string => startsWith(pattern_str) # => bool
126 string => startsWith(pattern_eggex) # => bool
127 string => endsWith(pattern_str) # => bool
128 string => endsWith(pattern_eggex) # => bool
129 """
130
131 string = rd.PosStr()
132 pattern_val = rd.PosValue()
133 pattern_str = None # type: Optional[str]
134 pattern_eggex = None # type: value.Eggex
135 with tagswitch(pattern_val) as case:
136 if case(value_e.Eggex):
137 pattern_eggex = cast(value.Eggex, pattern_val)
138 elif case(value_e.Str):
139 pattern_str = cast(value.Str, pattern_val).s
140 else:
141 raise error.TypeErr(pattern_val,
142 'expected pattern to be Eggex or Str',
143 rd.LeftParenToken())
144 rd.Done()
145
146 matched = False
147 try:
148 if pattern_str is not None:
149 if self.anchor & START:
150 matched, _, _ = _StrMatchStart(string, pattern_str)
151 else:
152 matched, _, _ = _StrMatchEnd(string, pattern_str)
153 else:
154 assert pattern_eggex is not None
155 if self.anchor & START:
156 matched, _, _ = _EggexMatchStart(string, pattern_eggex)
157 else:
158 matched, _, _ = _EggexMatchEnd(string, pattern_eggex)
159 except error.Strict as e:
160 raise error.Expr(e.msg, e.location)
161
162 return value.Bool(matched)
163
164
165class Trim(vm._Callable):
166 """ Implements `trimStart()`, `trimEnd()`, and `trim()` """
167
168 def __init__(self, anchor):
169 # type: (int) -> None
170 assert anchor in (START, END, START
171 | END), ("Anchor must be START, END, or START|END")
172 self.anchor = anchor
173
174 def Call(self, rd):
175 # type: (typed_args.Reader) -> value_t
176 """
177 string => trimStart() # => Str
178 string => trimEnd() # => Str
179 string => trim() # => Str
180 string => trimStart(pattern_str) # => Str
181 string => trimEnd(pattern_str) # => Str
182 string => trim(pattern_str) # => Str
183 string => trimStart(pattern_eggex) # => Str
184 string => trimEnd(pattern_eggex) # => Str
185 string => trim(pattern_eggex) # => Str
186 """
187
188 string = rd.PosStr()
189 pattern_val = rd.OptionalValue()
190 pattern_str = None # type: Optional[str]
191 pattern_eggex = None # type: value.Eggex
192 if pattern_val:
193 with tagswitch(pattern_val) as case:
194 if case(value_e.Eggex):
195 pattern_eggex = cast(value.Eggex, pattern_val)
196 elif case(value_e.Str):
197 pattern_str = cast(value.Str, pattern_val).s
198 else:
199 raise error.TypeErr(pattern_val,
200 'expected pattern to be Eggex or Str',
201 rd.LeftParenToken())
202 rd.Done()
203
204 start = 0
205 end = len(string)
206 try:
207 if pattern_str is not None:
208 if self.anchor & START:
209 _, _, start = _StrMatchStart(string, pattern_str)
210 if self.anchor & END:
211 _, end, _ = _StrMatchEnd(string, pattern_str)
212 elif pattern_eggex is not None:
213 if self.anchor & START:
214 _, _, start = _EggexMatchStart(string, pattern_eggex)
215 if self.anchor & END:
216 _, end, _ = _EggexMatchEnd(string, pattern_eggex)
217 else:
218 if self.anchor & START:
219 _, start = string_ops.StartsWithWhitespaceByteRange(string)
220 if self.anchor & END:
221 end, _ = string_ops.EndsWithWhitespaceByteRange(string)
222 except error.Strict as e:
223 raise error.Expr(e.msg, e.location)
224
225 res = string[start:end]
226 return value.Str(res)
227
228
229class Upper(vm._Callable):
230
231 def __init__(self):
232 # type: () -> None
233 pass
234
235 def Call(self, rd):
236 # type: (typed_args.Reader) -> value_t
237
238 s = rd.PosStr()
239 rd.Done()
240
241 # TODO: unicode support
242 return value.Str(s.upper())
243
244
245class Lower(vm._Callable):
246
247 def __init__(self):
248 # type: () -> None
249 pass
250
251 def Call(self, rd):
252 # type: (typed_args.Reader) -> value_t
253
254 s = rd.PosStr()
255 rd.Done()
256
257 # TODO: unicode support
258 return value.Str(s.lower())
259
260
261SEARCH = 0
262LEFT_MATCH = 1
263
264
265class SearchMatch(vm._Callable):
266
267 def __init__(self, which_method):
268 # type: (int) -> None
269 self.which_method = which_method
270
271 def Call(self, rd):
272 # type: (typed_args.Reader) -> value_t
273 """
274 s => search(eggex, pos=0)
275 """
276 string = rd.PosStr()
277
278 pattern = rd.PosValue() # Eggex or ERE Str
279 with tagswitch(pattern) as case:
280 if case(value_e.Eggex):
281 eggex_val = cast(value.Eggex, pattern)
282
283 # lazily converts to ERE
284 ere = regex_translate.AsPosixEre(eggex_val)
285 cflags = regex_translate.LibcFlags(eggex_val.canonical_flags)
286 capture = eggex_ops.Yes(
287 eggex_val.convert_funcs, eggex_val.convert_toks,
288 eggex_val.capture_names) # type: eggex_ops_t
289
290 elif case(value_e.Str):
291 ere = cast(value.Str, pattern).s
292 cflags = 0
293 capture = eggex_ops.No
294
295 else:
296 # TODO: add method name to this error
297 raise error.TypeErr(pattern, 'expected Eggex or Str',
298 rd.LeftParenToken())
299
300 # It's called 'pos', not 'start' like Python. Python has 2 kinds of
301 # 'start' in its regex API, which can be confusing.
302 pos = mops.BigTruncate(rd.NamedInt('pos', 0))
303 rd.Done()
304
305 # Make it anchored
306 if self.which_method == LEFT_MATCH:
307 ere = '^(%s)' % ere
308
309 if self.which_method == LEFT_MATCH:
310 eflags = 0 # ^ matches beginning even if pos=5
311 else:
312 eflags = 0 if pos == 0 else REG_NOTBOL # ^ only matches when pos=0
313
314 indices = libc.regex_search(ere, cflags, string, eflags, pos)
315
316 if indices is None:
317 return value.Null
318
319 if self.which_method == LEFT_MATCH:
320 # undo the ^() transformation
321 indices = indices[2:]
322
323 return RegexMatch(string, indices, capture)
324
325
326class Contains(vm._Callable):
327
328 def __init__(self):
329 # type: () -> None
330 pass
331
332 def Call(self, rd):
333 # type: (typed_args.Reader) -> value_t
334 string = rd.PosStr()
335
336 substr = rd.PosValue()
337 if substr.tag() == value_e.Str:
338 val = cast(value.Str, substr).s
339 else:
340 raise error.TypeErr(substr, "expected argument to 'contains' to be Str", rd.LeftParenToken())
341
342 rd.Done()
343
344 x = string.find(val)
345
346 return value.Bool(x != -1)
347
348class Find(vm._Callable):
349
350 def __init__(self, direction):
351 # type: (int) -> None
352 self.direction = direction
353
354 def Call(self, rd):
355 # type: (typed_args.Reader) -> value_t
356 string = rd.PosStr()
357
358 substr = rd.PosValue()
359 if substr.tag() == value_e.Str:
360 val = cast(value.Str, substr).s
361 else:
362 raise error.TypeErr(substr, "expected argument to 'find' to be Str", rd.LeftParenToken())
363
364 start = mops.BigTruncate(rd.NamedInt("start", 0))
365 end = mops.BigTruncate(rd.NamedInt("end", len(string)))
366 rd.Done()
367
368 if self.direction & START:
369 x = string.find(val, start, end)
370 else:
371 x = string.rfind(val, start, end)
372
373 return value.Int(mops.BigInt(x))
374
375class Replace(vm._Callable):
376
377 def __init__(self, mem, expr_ev):
378 # type: (state.Mem, expr_eval.ExprEvaluator) -> None
379 self.mem = mem
380 self.expr_ev = expr_ev
381
382 def EvalSubstExpr(self, expr, blame_loc):
383 # type: (value.Expr, loc_t) -> str
384 res = self.expr_ev.EvalExprClosure(expr, blame_loc)
385 if res.tag() == value_e.Str:
386 return cast(value.Str, res).s
387
388 raise error.TypeErr(res, "expected expr to eval to a Str", blame_loc)
389
390 def Call(self, rd):
391 # type: (typed_args.Reader) -> value_t
392 """
393 s => replace(string_val, subst_str, count=-1)
394 s => replace(string_val, subst_expr, count=-1)
395 s => replace(eggex_val, subst_str, count=-1)
396 s => replace(eggex_val, subst_expr, count=-1)
397
398 For count in [0, MAX_INT], there will be no more than count
399 replacements. Any negative count should read as unset, and replace will
400 replace all occurances of the pattern.
401 """
402 string = rd.PosStr()
403
404 string_val = None # type: value.Str
405 eggex_val = None # type: value.Eggex
406 subst_str = None # type: value.Str
407 subst_expr = None # type: value.Expr
408
409 pattern = rd.PosValue()
410 with tagswitch(pattern) as case:
411 if case(value_e.Eggex):
412 # HACK: mycpp will otherwise generate:
413 # value::Eggex* eggex_val ...
414 eggex_val_ = cast(value.Eggex, pattern)
415 eggex_val = eggex_val_
416
417 elif case(value_e.Str):
418 string_val_ = cast(value.Str, pattern)
419 string_val = string_val_
420
421 else:
422 raise error.TypeErr(pattern,
423 'expected pattern to be Eggex or Str',
424 rd.LeftParenToken())
425
426 subst = rd.PosValue()
427 with tagswitch(subst) as case:
428 if case(value_e.Str):
429 subst_str_ = cast(value.Str, subst)
430 subst_str = subst_str_
431
432 elif case(value_e.Expr):
433 subst_expr_ = cast(value.Expr, subst)
434 subst_expr = subst_expr_
435
436 else:
437 raise error.TypeErr(subst,
438 'expected substitution to be Str or Expr',
439 rd.LeftParenToken())
440
441 count = mops.BigTruncate(rd.NamedInt("count", -1))
442 rd.Done()
443
444 if count == 0:
445 return value.Str(string)
446
447 if string_val:
448 if subst_str:
449 s = subst_str.s
450 if subst_expr:
451 # Eval with $0 set to string_val (the matched substring)
452 with state.ctx_Eval(self.mem, string_val.s, None, None):
453 s = self.EvalSubstExpr(subst_expr, rd.LeftParenToken())
454 assert s is not None
455
456 result = string.replace(string_val.s, s, count)
457
458 return value.Str(result)
459
460 if eggex_val:
461 if '\0' in string:
462 raise error.Structured(
463 3, "cannot replace by eggex on a string with NUL bytes",
464 rd.LeftParenToken())
465
466 ere = regex_translate.AsPosixEre(eggex_val)
467 cflags = regex_translate.LibcFlags(eggex_val.canonical_flags)
468
469 # Walk through the string finding all matches of the compiled ere.
470 # Then, collect unmatched substrings and substitutions into the
471 # `parts` list.
472 pos = 0
473 parts = [] # type: List[str]
474 replace_count = 0
475 while pos < len(string):
476 indices = libc.regex_search(ere, cflags, string, 0, pos)
477 if indices is None:
478 break
479
480 # Collect captures
481 arg0 = None # type: Optional[str]
482 argv = [] # type: List[str]
483 named_vars = {} # type: Dict[str, value_t]
484 num_groups = len(indices) / 2
485 for group in xrange(num_groups):
486 start = indices[2 * group]
487 end = indices[2 * group + 1]
488 captured = string[start:end]
489 val = value.Str(captured) # type: value_t
490
491 if len(eggex_val.convert_funcs) and group != 0:
492 convert_func = eggex_val.convert_funcs[group - 1]
493 convert_tok = eggex_val.convert_toks[group - 1]
494
495 if convert_func:
496 val = self.expr_ev.CallConvertFunc(
497 convert_func, val, convert_tok,
498 rd.LeftParenToken())
499
500 # $0, $1, $2 variables are argv values, which must be
501 # strings. Furthermore, they can only be used in string
502 # contexts
503 # eg. "$[1]" != "$1".
504 val_str = val_ops.Stringify(val, rd.LeftParenToken(), '')
505 if group == 0:
506 arg0 = val_str
507 else:
508 argv.append(val_str)
509
510 # $0 cannot be named
511 if group != 0:
512 name = eggex_val.capture_names[group - 2]
513 if name is not None:
514 named_vars[name] = val
515
516 if subst_str:
517 s = subst_str.s
518 if subst_expr:
519 with state.ctx_Eval(self.mem, arg0, argv, named_vars):
520 s = self.EvalSubstExpr(subst_expr, rd.LeftParenToken())
521 assert s is not None
522
523 start = indices[0]
524 end = indices[1]
525 if pos == end:
526 raise error.Structured(
527 3, "eggex should never match the empty string",
528 rd.LeftParenToken())
529
530 parts.append(string[pos:start]) # Unmatched substring
531 parts.append(s) # Replacement
532 pos = end # Move to end of match
533
534 replace_count += 1
535 if count != -1 and replace_count == count:
536 break
537
538 parts.append(string[pos:]) # Remaining unmatched substring
539
540 return value.Str("".join(parts))
541
542 raise AssertionError()
543
544
545class Split(vm._Callable):
546
547 def __init__(self):
548 # type: () -> None
549 pass
550
551 def Call(self, rd):
552 # type: (typed_args.Reader) -> value_t
553 """
554 s.split(string_sep, count=-1)
555 s.split(eggex_sep, count=-1)
556
557 Count behaves like in replace() in that:
558 - `count` < 0 -> ignore
559 - `count` >= 0 -> there will be at most `count` splits
560 """
561 string = rd.PosStr()
562
563 string_sep = None # type: Optional[str]
564 eggex_sep = None # type: value.Eggex
565
566 sep = rd.PosValue()
567 with tagswitch(sep) as case:
568 if case(value_e.Eggex):
569 eggex_sep_ = cast(value.Eggex, sep)
570 eggex_sep = eggex_sep_
571
572 elif case(value_e.Str):
573 string_sep_ = cast(value.Str, sep)
574 string_sep = string_sep_.s
575
576 else:
577 raise error.TypeErr(sep,
578 'expected separator to be Eggex or Str',
579 rd.LeftParenToken())
580
581 count = mops.BigTruncate(rd.NamedInt("count", -1))
582 rd.Done()
583
584 if len(string) == 0:
585 return value.List([])
586
587 if string_sep is not None:
588 if len(string_sep) == 0:
589 raise error.Structured(3, "separator must be non-empty",
590 rd.LeftParenToken())
591
592 cursor = 0
593 chunks = [] # type: List[value_t]
594 length = len(string)
595 while cursor < length and count != 0:
596 next_pos = string.find(string_sep, cursor)
597 if next_pos == -1:
598 break
599
600 chunks.append(value.Str(string[cursor:next_pos]))
601 cursor = next_pos + len(string_sep)
602 count -= 1
603
604 chunks.append(value.Str(string[cursor:]))
605
606 return value.List(chunks)
607
608 if eggex_sep is not None:
609 if '\0' in string:
610 raise error.Structured(
611 3, "cannot split a string with a NUL byte",
612 rd.LeftParenToken())
613
614 regex = regex_translate.AsPosixEre(eggex_sep)
615 cflags = regex_translate.LibcFlags(eggex_sep.canonical_flags)
616
617 cursor = 0
618 chunks = []
619 while cursor < len(string) and count != 0:
620 m = libc.regex_search(regex, cflags, string, 0, cursor)
621 if m is None:
622 break
623
624 start = m[0]
625 end = m[1]
626 if start == end:
627 raise error.Structured(
628 3,
629 "eggex separators should never match the empty string",
630 rd.LeftParenToken())
631
632 chunks.append(value.Str(string[cursor:start]))
633 cursor = end
634
635 count -= 1
636
637 chunks.append(value.Str(string[cursor:]))
638
639 return value.List(chunks)
640
641 raise AssertionError()
642
643
644class Lines(vm._Callable):
645
646 def __init__(self):
647 # type: () -> None
648 pass
649
650 def Call(self, rd):
651 # type: (typed_args.Reader) -> value_t
652 string = rd.PosStr()
653 eol = rd.NamedStr('eol', '\n')
654
655 # Adapted from Str.split() above, except for the handling of the last item
656
657 cursor = 0
658 chunks = [] # type: List[value_t]
659 length = len(string)
660 while cursor < length:
661 next_pos = string.find(eol, cursor)
662 if next_pos == -1:
663 break
664
665 chunks.append(value.Str(string[cursor:next_pos]))
666 cursor = next_pos + len(eol)
667
668 if cursor < length:
669 chunks.append(value.Str(string[cursor:]))
670
671 return value.List(chunks)