Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle timeouts more gracefully #6

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 48 additions & 36 deletions src/shellinspector/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
LOGGER = logging.getLogger(Path(__file__).name)


class TimeoutException(Exception):
def __init__(self, output_so_far: str):
self.output_so_far = output_so_far
super().__init__()


@dataclasses.dataclass
class ShellinspectorPyContext:
applied_example: dict
Expand Down Expand Up @@ -90,21 +96,18 @@ def run_command(self, line):
actual_output = actual_output.replace("\r\n", "\n")

if found_prompt:
return True, actual_output
return actual_output
else:
self.close()
return False, actual_output
raise TimeoutException(actual_output)

def set_environment(self, context):
for k, v in context.items():
self.sendline(f"export {k}={shlex.quote(str(v))}")
assert self.prompt()

def get_environment(self):
success, output = self.run_command("export")

if not success:
raise NotImplementedError()
output = self.run_command("export")

env = {}

Expand Down Expand Up @@ -137,7 +140,7 @@ def push_state(self):

def pop_state(self):
if self.closed:
raise Exception("Session is closed")
return

self.sendline("echo $SHELLINSPECTOR_PROMPT_STATE")
assert self.prompt()
Expand Down Expand Up @@ -264,6 +267,27 @@ def _close_session(self, cmd):
f"Session could not be closed, because it doesn't exist, command: {cmd}"
)

def _make_session(self, key, cmd, timeout_seconds):
LOGGER.debug("creating session: %s", key)
if cmd.host == "local":
LOGGER.debug("new local shell session")
session = self.sessions[key] = get_localshell(timeout_seconds)
else:
ssh_config = {
**self.ssh_config,
"username": cmd.user,
"server": self.ssh_config["server"],
"port": self.ssh_config["port"],
}
LOGGER.debug("connecting via SSH: %s", ssh_config)
session = get_ssh_session(ssh_config, timeout_seconds)

if logging.root.level == logging.DEBUG:
# use .buffer here, because pexpect wants to write bytes, not strs
session.logfile = sys.stdout.buffer

return session

def _get_session(self, cmd, timeout_seconds):
"""
Create or reuse a shell session used to run the given command.
Expand All @@ -289,31 +313,17 @@ def _get_session(self, cmd, timeout_seconds):

if key not in self.sessions:
# connect, if there is no session
LOGGER.debug("creating session: %s", key)
if cmd.host == "local":
LOGGER.debug("new local shell session")
session = self.sessions[key] = get_localshell(timeout_seconds)
else:
ssh_config = {
**self.ssh_config,
"username": cmd.user,
"server": self.ssh_config["server"],
"port": self.ssh_config["port"],
}
LOGGER.debug("connecting via SSH: %s", ssh_config)
session = self.sessions[key] = get_ssh_session(
ssh_config, timeout_seconds
)

if logging.root.level == logging.DEBUG:
# use .buffer here, because pexpect wants to write bytes, not strs
session.logfile = sys.stdout.buffer
self.sessions[key] = self._make_session(key, cmd, timeout_seconds)
elif self.sessions[key].closed:
# destroy and reconnect, if there is a broken session
LOGGER.debug("closing failed session: %s", key)
self._close_session(cmd)
self.sessions[key] = self._make_session(key, cmd, timeout_seconds)
else:
# reuse, if we're already connected
LOGGER.debug("reusing session: %s", key)
session = self.sessions[key]

return session
return self.sessions[key]

def add_reporter(self, reporter):
self.reporters.append(reporter)
Expand Down Expand Up @@ -364,26 +374,28 @@ def _check_result(self, cmd, command_output, returncode):
return False

def _run_command(self, session, cmd):
success, command_output = session.run_command(cmd.command)
if not success:
try:
command_output = session.run_command(cmd.command)
except TimeoutException as ex:
self.report(
RunnerEvent.ERROR,
cmd,
{
"message": "could not find prompt for command",
"actual": command_output,
"message": "timeout, could not find prompt for command",
"actual": ex.output_so_far,
},
)
return False

success, rc_output = session.run_command("echo $?")
if not success:
try:
rc_output = session.run_command("echo $?")
except TimeoutException as ex:
self.report(
RunnerEvent.ERROR,
cmd,
{
"message": "could not find prompt for return code",
"actual": rc_output,
"message": "timeout, could not find prompt for return code",
"actual": ex.output_so_far,
},
)
return False
Expand Down
10 changes: 8 additions & 2 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ def set_environment(self, env):
[
(
RunnerEvent.ERROR,
{"message": "could not find prompt for command", "actual": "a"},
{
"message": "timeout, could not find prompt for command",
"actual": "a",
},
),
],
),
Expand All @@ -806,7 +809,10 @@ def set_environment(self, env):
[
(
RunnerEvent.ERROR,
{"message": "could not find prompt for return code", "actual": "0"},
{
"message": "timeout, could not find prompt for return code",
"actual": "0",
},
)
],
),
Expand Down
Loading