OILS / builtin / printf_osh.py View on Github | oils.pub

600 lines, 388 significant
1#!/usr/bin/env python2
2from __future__ import print_function
3
4import time as time_ # avoid name conflict
5
6from _devbuild.gen import arg_types
7from _devbuild.gen.id_kind_asdl import Id, Id_t, Id_str, Kind, Kind_t
8from _devbuild.gen.runtime_asdl import cmd_value
9from _devbuild.gen.syntax_asdl import (
10 loc,
11 loc_e,
12 loc_t,
13 source,
14 Token,
15 CompoundWord,
16 printf_part,
17 printf_part_e,
18 printf_part_t,
19)
20from _devbuild.gen.types_asdl import lex_mode_e, lex_mode_t
21from _devbuild.gen.value_asdl import (value, value_e)
22
23from core import alloc
24from core import error
25from core.error import p_die
26from core import state
27from core import vm
28from frontend import flag_util
29from frontend import consts
30from frontend import lexer
31from frontend import match
32from frontend import reader
33from mycpp import mops
34from mycpp import mylib
35from mycpp.mylib import log
36from osh import sh_expr_eval
37from osh import string_ops
38from osh import word_compile
39from data_lang import j8_lite
40
41import posix_ as posix
42
43from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, cast
44
45if TYPE_CHECKING:
46 from display import ui
47 from frontend import parse_lib
48
49_ = log
50
51
52def _ParsePrintfInteger(s):
53 # type: (str) -> Tuple[bool, mops.BigInt]
54 """
55 Returns:
56 (True, value) when the string looks like an integer
57 (False, ...) when it doesn't, or when there is overflow
58
59 Grammar:
60 Whitespace? ('-' | '+')? (Dec | Oct | Hex)
61
62 Note: trailing space isn't accepted.
63 """
64 # shells ignore space on the left, but not on the right!
65 s = s.lstrip()
66
67 # Handle +/- sign separately since the shell number lexer doesn't
68 # recognize signed numbers as a single token
69 negative = False
70 if s.startswith('-'):
71 negative = True
72 s = s[1:]
73 elif s.startswith('+'):
74 # Positive sign is optional but allowed
75 s = s[1:]
76
77 # Borrow the lexer for $(( )), but handle it a bit differently.
78 id_, pos = match.MatchShNumberToken(s, 0)
79 if pos != len(s): # no trailing data
80 return (False, mops.BigInt(0))
81
82 big_int = mops.ZERO
83 if id_ == Id.ShNumber_Dec:
84 ok, big_int = mops.FromStr2(s)
85
86 elif id_ == Id.ShNumber_Oct:
87 ok, big_int = mops.FromStr2(s[1:], 8)
88
89 elif id_ == Id.ShNumber_Hex:
90 ok, big_int = mops.FromStr2(s[2:], 16)
91
92 else: # Id.ShNumber_BaseN or Id.Unknown_Tok
93 # Unlike $(( )), printf doesn't support 64#a
94 return (False, mops.BigInt(0))
95
96 if not ok:
97 return (False, mops.BigInt(0))
98 if negative:
99 big_int = mops.Negate(big_int)
100 return (True, big_int)
101
102
103class _FormatStringParser(object):
104 """
105 Grammar:
106
107 width = Num | Star
108 precision = Dot (Num | Star | Zero)?
109 fmt = Percent (Flag | Zero)* width? precision? (Type | Time)
110 part = Char_* | Format_EscapedPercent | fmt
111 printf_format = part* Eof_Real # we're using the main lexer
112
113 Maybe: bash also supports %(strftime)T
114 """
115
116 def __init__(self, lexer):
117 # type: (lexer.Lexer) -> None
118 self.lexer = lexer
119
120 # uninitialized values
121 self.cur_token = None # type: Token
122 self.token_type = Id.Undefined_Tok # type: Id_t
123 self.token_kind = Kind.Undefined # type: Kind_t
124
125 def _Next(self, lex_mode):
126 # type: (lex_mode_t) -> None
127 """Advance a token."""
128 self.cur_token = self.lexer.Read(lex_mode)
129 self.token_type = self.cur_token.id
130 self.token_kind = consts.GetKind(self.token_type)
131
132 def _ParseFormatStr(self):
133 # type: () -> printf_part_t
134 """fmt = ..."""
135 self._Next(lex_mode_e.PrintfPercent) # move past %
136
137 part = printf_part.Percent.CreateNull(alloc_lists=True)
138 while self.token_type in (Id.Format_Flag, Id.Format_Zero):
139 # space and + could be implemented
140 flag = lexer.TokenVal(self.cur_token) # allocation will be cached
141 if flag in '# +':
142 p_die("osh printf doesn't support the %r flag" % flag,
143 self.cur_token)
144
145 part.flags.append(self.cur_token)
146 self._Next(lex_mode_e.PrintfPercent)
147
148 if self.token_type in (Id.Format_Num, Id.Format_Star):
149 part.width = self.cur_token
150 self._Next(lex_mode_e.PrintfPercent)
151
152 if self.token_type == Id.Format_Dot:
153 part.precision = self.cur_token
154 self._Next(lex_mode_e.PrintfPercent) # past dot
155 if self.token_type in (Id.Format_Num, Id.Format_Star,
156 Id.Format_Zero):
157 part.precision = self.cur_token
158 self._Next(lex_mode_e.PrintfPercent)
159
160 if self.token_type in (Id.Format_Type, Id.Format_Time):
161 part.type = self.cur_token
162
163 # ADDITIONAL VALIDATION outside the "grammar".
164 type_val = lexer.TokenVal(part.type) # allocation will be cached
165 if type_val in 'eEfFgG':
166 p_die("osh printf doesn't support floating point", part.type)
167
168 #if type_val == 'c':
169 # TODO: printf %c should not be a YSH operation, since it doesn't
170 # support Unicode
171
172 elif self.token_type == Id.Unknown_Tok:
173 p_die('Invalid printf format character', self.cur_token)
174
175 else:
176 p_die('Expected a printf format character', self.cur_token)
177
178 return part
179
180 def Parse(self):
181 # type: () -> List[printf_part_t]
182 self._Next(lex_mode_e.PrintfOuter)
183 parts = [] # type: List[printf_part_t]
184 while True:
185 if (self.token_kind in (Kind.Lit, Kind.Char) or self.token_type
186 in (Id.Format_EscapedPercent, Id.Unknown_Backslash)):
187
188 # Note: like in echo -e, we don't fail with Unknown_Backslash here
189 # when shopt --set no_parse_backslash because it's at runtime
190 # rather than parse time.
191 # Users should use $'' or the future static printf ${x %.3f}.
192
193 parts.append(self.cur_token)
194
195 elif self.token_type == Id.Format_Percent:
196 parts.append(self._ParseFormatStr())
197
198 elif self.token_type in (Id.Eof_Real, Id.Eol_Tok):
199 # Id.Eol_Tok: special case for format string of '\x00'.
200 break
201
202 else:
203 raise AssertionError(Id_str(self.token_type))
204
205 self._Next(lex_mode_e.PrintfOuter)
206
207 return parts
208
209
210class _PrintfState(object):
211
212 def __init__(self):
213 # type: () -> None
214 self.arg_index = 0
215 self.backslash_c = False
216 self.status = 0 # set to 1 before returning
217
218
219class Printf(vm._Builtin):
220
221 def __init__(
222 self,
223 mem, # type: state.Mem
224 parse_ctx, # type: parse_lib.ParseContext
225 unsafe_arith, # type: sh_expr_eval.UnsafeArith
226 errfmt, # type: ui.ErrorFormatter
227 ):
228 # type: (...) -> None
229 self.mem = mem
230 self.parse_ctx = parse_ctx
231 self.unsafe_arith = unsafe_arith
232 self.errfmt = errfmt
233 self.parse_cache = {} # type: Dict[str, List[printf_part_t]]
234
235 # this object initialized in main()
236 self.shell_start_time = time_.time()
237
238 def _Percent(
239 self,
240 pr, # type: _PrintfState
241 part, # type: printf_part.Percent
242 varargs, # type: List[str]
243 locs, # type: List[CompoundWord]
244 ):
245 # type: (...) -> Optional[str]
246
247 num_args = len(varargs)
248
249 # TODO: Cache this?
250 flags = [] # type: List[str]
251 if len(part.flags) > 0:
252 for flag_token in part.flags:
253 flags.append(lexer.TokenVal(flag_token))
254
255 width = -1 # nonexistent
256 if part.width:
257 if part.width.id in (Id.Format_Num, Id.Format_Zero):
258 width_str = lexer.TokenVal(part.width)
259 width_loc = part.width # type: loc_t
260 elif part.width.id == Id.Format_Star: # depends on data
261 if pr.arg_index < num_args:
262 width_str = varargs[pr.arg_index]
263 width_loc = locs[pr.arg_index]
264 pr.arg_index += 1
265 else:
266 width_str = '' # invalid
267 width_loc = loc.Missing
268 else:
269 raise AssertionError()
270
271 try:
272 width = int(width_str)
273 except ValueError:
274 if width_loc.tag() == loc_e.Missing:
275 width_loc = part.width
276 self.errfmt.Print_("printf got invalid width %r" % width_str,
277 blame_loc=width_loc)
278 pr.status = 1
279 return None
280
281 precision = -1 # nonexistent
282 if part.precision:
283 if part.precision.id == Id.Format_Dot:
284 precision_str = '0'
285 precision_loc = part.precision # type: loc_t
286 elif part.precision.id in (Id.Format_Num, Id.Format_Zero):
287 precision_str = lexer.TokenVal(part.precision)
288 precision_loc = part.precision
289 elif part.precision.id == Id.Format_Star:
290 if pr.arg_index < num_args:
291 precision_str = varargs[pr.arg_index]
292 precision_loc = locs[pr.arg_index]
293 pr.arg_index += 1
294 else:
295 precision_str = ''
296 precision_loc = loc.Missing
297 else:
298 raise AssertionError()
299
300 try:
301 precision = int(precision_str)
302 except ValueError:
303 if precision_loc.tag() == loc_e.Missing:
304 precision_loc = part.precision
305 self.errfmt.Print_('printf got invalid precision %r' %
306 precision_str,
307 blame_loc=precision_loc)
308 pr.status = 1
309 return None
310
311 if pr.arg_index < num_args:
312 s = varargs[pr.arg_index]
313 word_loc = locs[pr.arg_index] # type: loc_t
314 pr.arg_index += 1
315 has_arg = True
316 else:
317 s = ''
318 word_loc = loc.Missing
319 has_arg = False
320
321 # Note: %s could be lexed into Id.Percent_S. Although small string
322 # optimization would remove the allocation as well.
323 typ = lexer.TokenVal(part.type)
324 if typ == 's':
325 if precision >= 0:
326 s = s[:precision] # truncate
327
328 elif typ == 'q':
329 # Most shells give \' for single quote, while OSH gives
330 # $'\'' this could matter when SSH'ing.
331 # Ditto for $'\\' vs. '\'
332
333 s = j8_lite.MaybeShellEncode(s)
334
335 elif typ == 'b':
336 # Process just like echo -e, except \c handling is simpler.
337
338 c_parts = [] # type: List[str]
339 lex = match.PrintfBLexer(s)
340 while True:
341 id_, tok_val = lex.Next()
342 if id_ == Id.Eol_Tok: # Note: This is really a NUL terminator
343 break
344
345 p = word_compile.EvalCStringToken(id_, tok_val)
346
347 # Unusual behavior: '\c' aborts processing!
348 if p is None:
349 pr.backslash_c = True
350 break
351
352 c_parts.append(p)
353 s = ''.join(c_parts)
354
355 elif typ == 'c':
356 # printf %c simply prints the first BYTE. It doesn't decode UTF-8.
357 s = s[0]
358
359 elif part.type.id == Id.Format_Time or typ in 'diouxX':
360 # %(...)T and %d share this complex integer conversion logic
361
362 ok, d = _ParsePrintfInteger(s)
363 if not ok:
364 # Check for 'a and "a
365 # These are interpreted as the numeric ASCII value of 'a'
366 num_bytes = len(s)
367 if num_bytes > 0 and s[0] in '\'"':
368 if num_bytes == 1:
369 # NUL after quote
370 d = mops.ZERO
371 elif num_bytes == 2:
372 # Allow invalid UTF-8, because all shells do
373 d = mops.IntWiden(ord(s[1]))
374 else:
375 try:
376 small_i = string_ops.DecodeUtf8Char(s, 1)
377 except error.Expr as e:
378 # Take the numeric value of first char, ignoring
379 # the rest of the bytes.
380 # Something like strict_arith or strict_printf
381 # could throw an error in this case.
382 self.errfmt.Print_(
383 'Warning: %s' % e.UserErrorString(), word_loc)
384 small_i = ord(s[1])
385
386 d = mops.IntWiden(small_i)
387
388 # No argument means -1 for %(...)T as in Bash Reference Manual
389 # 4.2 - "If no argument is specified, conversion behaves as if
390 # -1 had been given."
391 elif not has_arg and part.type.id == Id.Format_Time:
392 d = mops.MINUS_ONE
393
394 else:
395 if has_arg:
396 blame_loc = word_loc # type: loc_t
397 else:
398 blame_loc = part.type
399 self.errfmt.Print_(
400 'printf expected an integer, got %r' % s, blame_loc)
401 pr.status = 1
402 return None
403
404 if part.type.id == Id.Format_Time:
405 # Initialize timezone:
406 # `localtime' uses the current timezone information initialized
407 # by `tzset'. The function `tzset' refers to the environment
408 # variable `TZ'. When the exported variable `TZ' is present,
409 # its value should be reflected in the real environment
410 # variable `TZ' before call of `tzset'.
411 #
412 # Note: unlike LANG, TZ doesn't seem to change behavior if it's
413 # not exported.
414 #
415 # TODO: In YSH, provide an API that doesn't rely on libc's global
416 # state.
417
418 tzcell = self.mem.GetCell('TZ')
419 if (tzcell and tzcell.exported and
420 tzcell.val.tag() == value_e.Str):
421 tzval = cast(value.Str, tzcell.val)
422 posix.putenv('TZ', tzval.s)
423
424 time_.tzset()
425
426 # Handle special values:
427 # User can specify two special values -1 and -2 as in Bash
428 # Reference Manual 4.2: "Two special argument values may be
429 # used: -1 represents the current time, and -2 represents the
430 # time the shell was invoked." from
431 # https://www.gnu.org/software/bash/manual/html_node/Bash-Builtins.html#index-printf
432 if mops.Equal(d, mops.MINUS_ONE): # -1 is current time
433 # TODO: 2038 problem
434 ts = time_.time()
435 elif mops.Equal(d, mops.MINUS_TWO): # -2 is shell start time
436 ts = self.shell_start_time
437 else:
438 ts = mops.BigTruncate(d)
439
440 s = time_.strftime(typ[1:-2], time_.localtime(ts))
441 if precision >= 0:
442 s = s[:precision] # truncate
443
444 else: # typ in 'diouxX'
445 # Disallowed because it depends on 32- or 64- bit
446 if mops.Greater(mops.ZERO, d) and typ in 'ouxX':
447 # TODO: Don't truncate it
448 self.errfmt.Print_(
449 "Can't format negative number with %%%s: %d" %
450 (typ, mops.BigTruncate(d)), part.type)
451 pr.status = 1
452 return None
453
454 if typ == 'o':
455 s = mops.ToOctal(d)
456 elif typ == 'x':
457 s = mops.ToHexLower(d)
458 elif typ == 'X':
459 s = mops.ToHexUpper(d)
460 else: # diu
461 s = mops.ToStr(d) # without spaces like ' -42 '
462
463 # There are TWO different ways to ZERO PAD, and they differ on
464 # the negative sign! See spec/builtin-printf
465
466 zero_pad = 0 # no zero padding
467 if width >= 0 and '0' in flags:
468 zero_pad = 1 # style 1
469 elif precision > 0 and len(s) < precision:
470 zero_pad = 2 # style 2
471
472 if zero_pad:
473 negative = (s[0] == '-')
474 if negative:
475 digits = s[1:]
476 sign = '-'
477 if zero_pad == 1:
478 # [%06d] -42 becomes [-00042] (6 TOTAL)
479 n = width - 1
480 else:
481 # [%6.6d] -42 becomes [-000042] (1 for '-' + 6)
482 n = precision
483 else:
484 digits = s
485 sign = ''
486 if zero_pad == 1:
487 n = width
488 else:
489 n = precision
490 s = sign + digits.rjust(n, '0')
491
492 else:
493 raise AssertionError()
494
495 if width >= 0:
496 if '-' in flags:
497 s = s.ljust(width, ' ')
498 else:
499 s = s.rjust(width, ' ')
500 return s
501
502 def _Format(self, parts, varargs, locs, out):
503 # type: (List[printf_part_t], List[str], List[CompoundWord], List[str]) -> int
504 """Hairy printf formatting logic."""
505
506 pr = _PrintfState()
507 num_args = len(varargs)
508
509 while True: # loop over arguments
510 for part in parts: # loop over parsed format string
511 UP_part = part
512 if part.tag() == printf_part_e.Literal:
513 part = cast(Token, UP_part)
514 if part.id == Id.Format_EscapedPercent:
515 s = '%'
516 else:
517 s = word_compile.EvalCStringToken(
518 part.id, lexer.LazyStr(part))
519
520 elif part.tag() == printf_part_e.Percent:
521 part = cast(printf_part.Percent, UP_part)
522
523 s = self._Percent(pr, part, varargs, locs)
524 if pr.status != 0:
525 return pr.status
526
527 else:
528 raise AssertionError()
529
530 out.append(s)
531
532 if pr.backslash_c: # 'printf %b a\cb xx' - \c terminates processing!
533 break
534
535 if pr.arg_index == 0:
536 # We went through ALL parts and didn't consume ANY arg.
537 # Example: print x y
538 break
539 if pr.arg_index >= num_args:
540 # We printed all args
541 break
542 # If there are more args, keep going. This implement 'arg recycling'
543 # behavior
544 # printf '%s ' 1 2 3 => 1 2 3
545
546 return 0
547
548 def Run(self, cmd_val):
549 # type: (cmd_value.Argv) -> int
550 """
551 printf: printf [-v var] format [argument ...]
552 """
553 attrs, arg_r = flag_util.ParseCmdVal('printf', cmd_val)
554 arg = arg_types.printf(attrs.attrs)
555
556 fmt, fmt_loc = arg_r.ReadRequired2('requires a format string')
557 varargs, locs = arg_r.Rest2()
558
559 #log('fmt %s', fmt)
560 #log('vals %s', vals)
561
562 arena = self.parse_ctx.arena
563 if fmt in self.parse_cache:
564 parts = self.parse_cache[fmt]
565 else:
566 line_reader = reader.StringLineReader(fmt, arena)
567 # TODO: Make public
568 lexer = self.parse_ctx.MakeLexer(line_reader)
569 parser = _FormatStringParser(lexer)
570
571 with alloc.ctx_SourceCode(arena,
572 source.Dynamic('printf arg', fmt_loc)):
573 try:
574 parts = parser.Parse()
575 except error.Parse as e:
576 self.errfmt.PrettyPrintError(e)
577 return 2 # parse error
578
579 self.parse_cache[fmt] = parts
580
581 if 0:
582 print()
583 for part in parts:
584 part.PrettyPrint()
585 print()
586
587 out = [] # type: List[str]
588 status = self._Format(parts, varargs, locs, out)
589 if status != 0:
590 return status # failure
591
592 result = ''.join(out)
593 if arg.v is not None:
594 # TODO: get the location for arg.v!
595 v_loc = loc.Missing
596 lval = self.unsafe_arith.ParseLValue(arg.v, v_loc)
597 state.BuiltinSetValue(self.mem, lval, value.Str(result))
598 else:
599 mylib.Stdout().write(result)
600 return 0