OILS / builtin / trap_osh.py View on Github | oils.pub

459 lines, 269 significant
1#!/usr/bin/env python2
2from __future__ import print_function
3
4from signal import SIG_DFL, SIG_IGN, SIGINT, SIGWINCH
5
6from _devbuild.gen import arg_types
7from _devbuild.gen.runtime_asdl import cmd_value, trap_action, trap_action_e, trap_action_t
8from _devbuild.gen.syntax_asdl import loc, loc_t, source, command_e, command
9from core import alloc
10from core import dev
11from core import error
12from core import main_loop
13from core import vm
14from frontend import flag_util
15from frontend import reader
16from frontend import signal_def
17from frontend import typed_args
18from builtin import process_osh # for PrintSignals()
19from data_lang import j8_lite
20from mycpp import iolib
21from mycpp import mylib
22from mycpp.mylib import iteritems, print_stderr, log, tagswitch
23
24from typing import Dict, List, Optional, TYPE_CHECKING, cast
25if TYPE_CHECKING:
26 from _devbuild.gen.syntax_asdl import command_t
27 from core import optview
28 from display import ui
29 from frontend import args
30 from frontend.parse_lib import ParseContext
31
32_ = log
33
34
35class TrapState(object):
36 """Traps are shell callbacks that the user wants to run on certain events.
37
38 There are 2 catogires:
39 1. Signals like SIGUSR1
40 2. Hooks like EXIT
41
42 Signal handlers execute in the main loop, and within blocking syscalls.
43
44 EXIT, DEBUG, ERR, RETURN execute in specific places in the interpreter.
45 """
46
47 def __init__(self, signal_safe):
48 # type: (iolib.SignalSafe) -> None
49 self.signal_safe = signal_safe
50 self.hooks = {} # type: Dict[str, trap_action_t]
51 self.traps = {} # type: Dict[int, trap_action_t]
52
53 def ClearForSubProgram(self, inherit_errtrace):
54 # type: (bool) -> None
55 """SubProgramThunk uses this because traps aren't inherited."""
56
57 # bash clears hooks like DEBUG in subshells.
58 # The ERR can be preserved if set -o errtrace
59 hook_err = self.hooks.get('ERR')
60 self.hooks.clear()
61 if hook_err is not None and inherit_errtrace:
62 self.hooks['ERR'] = hook_err
63
64 self.traps.clear()
65
66 def _GetCommand(self, action):
67 # type: (Optional[trap_action_t]) -> Optional[command_t]
68 if action is None:
69 return None
70 if action.tag() == trap_action_e.Ignored:
71 return None
72 assert action.tag() == trap_action_e.Command
73 return cast(trap_action.Command, action).c
74
75 def GetHook(self, hook_name):
76 # type: (str) -> Optional[command_t]
77 """ e.g. EXIT hook. """
78 action = self.hooks.get(hook_name)
79 return self._GetCommand(action)
80
81 def GetTrap(self, sig_num):
82 # type: (int) -> Optional[command_t]
83 action = self.traps.get(sig_num)
84 return self._GetCommand(action)
85
86 def _AddUserTrap(self, sig_num, handler):
87 # type: (int, trap_action_t) -> None
88 """ e.g. SIGUSR1 """
89 self.traps[sig_num] = handler
90
91 if handler.tag() == trap_action_e.Ignored:
92 # This is the case:
93 # trap '' SIGINT SIGWINCH
94 # It's handled the same as removing a trap:
95 # trap - SIGINT SIGWINCH
96 #
97 # That is, the signal_safe calls are the same. This seems right
98 # because the shell interpreter itself cares about SIGINT and
99 # SIGWINCH too -- not just user traps.
100 if sig_num == SIGINT:
101 self.signal_safe.SetSigIntTrapped(False)
102 elif sig_num == SIGWINCH:
103 self.signal_safe.SetSigWinchCode(iolib.UNTRAPPED_SIGWINCH)
104 else:
105 iolib.sigaction(sig_num, SIG_IGN)
106 else:
107 if sig_num == SIGINT:
108 # Don't disturb the underlying runtime's SIGINT handllers
109 # 1. CPython has one for KeyboardInterrupt
110 # 2. mycpp runtime simulates KeyboardInterrupt:
111 # pyos::InitSignalSafe() calls RegisterSignalInterest(SIGINT),
112 # then we PollSigInt() in the osh/cmd_eval.py main loop
113 self.signal_safe.SetSigIntTrapped(True)
114 elif sig_num == SIGWINCH:
115 self.signal_safe.SetSigWinchCode(SIGWINCH)
116 else:
117 iolib.RegisterSignalInterest(sig_num)
118
119 def _RemoveUserTrap(self, sig_num):
120 # type: (int) -> None
121
122 mylib.dict_erase(self.traps, sig_num)
123
124 if sig_num == SIGINT:
125 self.signal_safe.SetSigIntTrapped(False)
126 elif sig_num == SIGWINCH:
127 self.signal_safe.SetSigWinchCode(iolib.UNTRAPPED_SIGWINCH)
128 else:
129 # TODO: In process.InitInteractiveShell(), 4 signals are set to
130 # SIG_IGN, not SIG_DFL:
131 #
132 # SIGQUIT SIGTSTP SIGTTOU SIGTTIN
133 #
134 # Should we restore them? It's rare that you type 'trap' in
135 # interactive shells, but it might be more correct. See what other
136 # shells do.
137 iolib.sigaction(sig_num, SIG_DFL)
138
139 def AddItem(self, parsed_id, handler):
140 # type: (str, trap_action_t) -> None
141 """Add trap or hook, parsed to EXIT or INT (not 0 or SIGINT)"""
142 if parsed_id in _HOOK_NAMES:
143 self.hooks[parsed_id] = handler
144 else:
145 sig_num = signal_def.GetNumber(parsed_id)
146 # Should have already been validated
147 assert sig_num is not signal_def.NO_SIGNAL
148
149 self._AddUserTrap(sig_num, handler)
150
151 def RemoveItem(self, parsed_id):
152 # type: (str) -> None
153 """Remove trap or hook, parsed to EXIT or INT (not 0 or SIGINT)"""
154 if parsed_id in _HOOK_NAMES:
155 mylib.dict_erase(self.hooks, parsed_id)
156 else:
157 sig_num = signal_def.GetNumber(parsed_id)
158 # Should have already been validated
159 assert sig_num is not signal_def.NO_SIGNAL
160
161 self._RemoveUserTrap(sig_num)
162
163 def GetPendingTraps(self):
164 # type: () -> Optional[List[command_t]]
165 """Transfer ownership of queue of pending trap handlers to caller."""
166 signals = self.signal_safe.TakePendingSignals()
167 if 0:
168 log('*** GetPendingTraps')
169 for si in signals:
170 log('SIGNAL %d', si)
171 #import traceback
172 #traceback.print_stack()
173
174 # Optimization for the common case: do not allocate a list. This function
175 # is called in the interpreter loop.
176 if len(signals) == 0:
177 self.signal_safe.ReuseEmptyList(signals)
178 return None
179
180 run_list = [] # type: List[command_t]
181 for sig_num in signals:
182 action = self.traps.get(sig_num, None)
183 if action is None:
184 continue
185 if action.tag() == trap_action_e.Ignored:
186 continue
187 a = cast(trap_action.Command, action)
188 run_list.append(a.c)
189
190 # Optimization to avoid allocation in the main loop.
191 del signals[:]
192 self.signal_safe.ReuseEmptyList(signals)
193
194 return run_list
195
196 def ThisProcessHasTraps(self):
197 # type: () -> bool
198 """
199 noforklast optimizations are not enabled when the process has code to
200 run after fork!
201 """
202 if 0:
203 log('traps %d', len(self.traps))
204 log('hooks %d', len(self.hooks))
205 return len(self.traps) != 0 or len(self.hooks) != 0
206
207
208_HOOK_NAMES = ['EXIT', 'ERR', 'RETURN', 'DEBUG']
209
210
211def _ParseSignalOrHook(user_str, blame_loc, allow_legacy=True):
212 # type: (str, loc_t, bool) -> str
213 """Convert user string to a parsed/normalized string.
214
215 These can be passed to AddItem() and RemoveItem()
216
217 See unit tests in builtin/trap_osh_test.py
218 '0' -> 'EXIT'
219 'EXIT' -> 'EXIT'
220 'eXIT' -> 'EXIT'
221
222 '2' -> 'INT'
223 'iNT' -> 'INT'
224 'sIGINT' -> 'INT'
225
226 'zz' -> error
227 '-150' -> error
228 '10000' -> error
229 """
230 if allow_legacy and user_str.isdigit():
231 try:
232 sig_num = int(user_str)
233 except ValueError:
234 raise error.Usage("got overflowing integer: %s" % user_str,
235 blame_loc)
236
237 if sig_num == 0: # Special case
238 return 'EXIT'
239
240 name = signal_def.GetName(sig_num)
241 if name is None:
242 return None
243 return name[3:] # Remove SIG
244
245 user_str = user_str.upper() # Ignore case
246
247 if user_str in _HOOK_NAMES:
248 return user_str
249
250 if user_str.startswith('SIG'):
251 user_str = user_str[3:]
252
253 n = signal_def.GetNumber(user_str)
254 if n == signal_def.NO_SIGNAL:
255 return None
256
257 return user_str
258
259
260def ParseSignalOrHook(user_str, blame_loc, allow_legacy=True):
261 # type: (str, loc_t, bool) -> str
262 """Convenience wrapper"""
263 parsed_id = _ParseSignalOrHook(user_str,
264 blame_loc,
265 allow_legacy=allow_legacy)
266 if parsed_id is None:
267 raise error.Usage('expected signal or hook, got %r' % user_str,
268 blame_loc)
269 return parsed_id
270
271
272class Trap(vm._Builtin):
273
274 def __init__(self, trap_state, parse_ctx, exec_opts, tracer, errfmt):
275 # type: (TrapState, ParseContext, optview.Exec, dev.Tracer, ui.ErrorFormatter) -> None
276 self.trap_state = trap_state
277 self.parse_ctx = parse_ctx
278 self.arena = parse_ctx.arena
279 self.exec_opts = exec_opts
280 self.tracer = tracer
281 self.errfmt = errfmt
282
283 def _ParseTrapCode(self, code_str):
284 # type: (str) -> command_t
285 """
286 Returns:
287 A node, or None if the code is invalid.
288 """
289 line_reader = reader.StringLineReader(code_str, self.arena)
290 c_parser = self.parse_ctx.MakeOshParser(line_reader)
291
292 # TODO: the SPID should be passed through argv.
293 src = source.Dynamic('trap arg', loc.Missing)
294 with alloc.ctx_SourceCode(self.arena, src):
295 try:
296 node = main_loop.ParseWholeFile(c_parser)
297 except error.Parse as e:
298 self.errfmt.PrettyPrintError(e)
299 return None
300
301 return node
302
303 def _GetCommandSourceCode(self, body):
304 # type: (command_t) -> str
305
306 # TODO: Print ANY command_t variant
307 handler_string = '<unknown>' # type: str
308
309 if body.tag() == command_e.Simple:
310 simple_cmd = cast(command.Simple, body)
311 if simple_cmd.blame_tok:
312 handler_string = simple_cmd.blame_tok.line.content
313 return handler_string
314
315 def _PrintTrapEntry(self, handler, name):
316 # type: (trap_action_t, str) -> None
317 with tagswitch(handler) as case:
318 if case(trap_action_e.Ignored):
319 print("trap -- '' %s" % name)
320 elif case(trap_action_e.Command):
321 c = cast(trap_action.Command, handler).c
322 code = self._GetCommandSourceCode(c)
323 print("trap -- %s %s" % (j8_lite.ShellEncode(code), name))
324 else:
325 raise AssertionError()
326
327 def _PrintState(self):
328 # type: () -> None
329 for name, handler in iteritems(self.trap_state.hooks):
330 self._PrintTrapEntry(handler, name)
331
332 # Print in order of signal number
333 n = signal_def.MaxSigNumber() + 1
334 for sig_num in xrange(n):
335 action = self.trap_state.traps.get(sig_num)
336 if action is None:
337 continue
338
339 sig_name = signal_def.GetName(sig_num)
340 assert sig_name is not None
341
342 self._PrintTrapEntry(action, sig_name)
343
344 def _PrintNames(self):
345 # type: () -> None
346 for hook_name in _HOOK_NAMES:
347 # EXIT is 0, but we hide that
348 print(' %s' % hook_name)
349
350 process_osh.PrintSignals()
351
352 def _AddTheRest(self, arg_r, node, allow_legacy=True):
353 # type: (args.Reader, trap_action_t, bool) -> int
354 """Add a handler for all args"""
355 while not arg_r.AtEnd():
356 arg_str, arg_loc = arg_r.Peek2()
357 parsed_id = ParseSignalOrHook(arg_str,
358 arg_loc,
359 allow_legacy=allow_legacy)
360
361 if parsed_id == 'RETURN':
362 print_stderr("osh warning: The %r hook isn't implemented" %
363 arg_str)
364 if parsed_id == 'STOP' or parsed_id == 'KILL':
365 self.errfmt.Print_("Signal %r can't be handled" % arg_str,
366 blame_loc=arg_loc)
367 # Other shells return 0, but this seems like an obvious error
368 return 2
369
370 self.trap_state.AddItem(parsed_id, node)
371
372 arg_r.Next()
373 return 0
374
375 def _RemoveTheRest(self, arg_r, allow_legacy=True):
376 # type: (args.Reader, bool) -> None
377 """Remove handlers for all args"""
378 while not arg_r.AtEnd():
379 arg_str, arg_loc = arg_r.Peek2()
380 parsed_id = ParseSignalOrHook(arg_str,
381 arg_loc,
382 allow_legacy=allow_legacy)
383 self.trap_state.RemoveItem(parsed_id)
384 arg_r.Next()
385
386 def Run(self, cmd_val):
387 # type: (cmd_value.Argv) -> int
388 attrs, arg_r = flag_util.ParseCmdVal('trap',
389 cmd_val,
390 accept_typed_args=True)
391 arg = arg_types.trap(attrs.attrs)
392
393 if arg.add: # trap --add
394 cmd_frag = typed_args.RequiredBlockAsFrag(cmd_val)
395 return self._AddTheRest(arg_r,
396 trap_action.Command(cmd_frag),
397 allow_legacy=False)
398
399 if arg.ignore: # trap --ignore
400 return self._AddTheRest(arg_r,
401 trap_action.Ignored,
402 allow_legacy=False)
403
404 if arg.remove: # trap --remove
405 self._RemoveTheRest(arg_r, allow_legacy=False)
406 return 0
407
408 if arg.p: # trap -p prints handlers
409 self._PrintState()
410 return 0
411
412 if arg.l: # List valid signals and hooks
413 self._PrintNames()
414 return 0
415
416 # Anything other than the above is not supported in YSH pass
417 if self.exec_opts.simple_trap_builtin():
418 raise error.Usage(
419 'expected --add, --remove, -l, or -p (simple_trap_builtin)',
420 cmd_val.arg_locs[0])
421
422 # 'trap' with no arguments is equivalent to 'trap -p'
423 if arg_r.AtEnd():
424 self._PrintState()
425 return 0
426
427 first_arg, first_loc = arg_r.Peek2()
428
429 # If the first arg is '-' or an unsigned integer, then remove the
430 # handlers. For example, 'trap 0 2' or 'trap 0 SIGINT'
431 #
432 # https://pubs.opengroup.org/onlinepubs/9699919799.2018edition/utilities/V3_chap02.html#tag_18_28
433 first_is_dash = (first_arg == '-')
434 if first_is_dash or first_arg.isdigit():
435 if first_is_dash:
436 arg_r.Next()
437
438 self._RemoveTheRest(arg_r)
439 return 0
440
441 arg_r.Next()
442
443 # If first arg is empty string '', ignore the specified signals
444 if len(first_arg) == 0:
445 return self._AddTheRest(arg_r, trap_action.Ignored)
446
447 # Legacy behavior for only one arg: 'trap SIGNAL' removes the handler
448 if arg_r.AtEnd():
449 parsed_id = ParseSignalOrHook(first_arg, first_loc)
450 self.trap_state.RemoveItem(parsed_id)
451 return 0
452
453 # Unlike other shells, we parse the code upon registration
454 node = self._ParseTrapCode(first_arg)
455 if node is None:
456 return 1 # _ParseTrapCode() prints an error for us.
457
458 # trap COMMAND SIGNAL+
459 return self._AddTheRest(arg_r, trap_action.Command(node))