Skip to content

Commit

Permalink
Add new public APIs for header and command output (#383)
Browse files Browse the repository at this point in the history
* Add new public APIs for header and command output

Adds new APIs on the protocol class to expose a publicly supported way
to build the WSMan SOAP headers used internally and to get a single
output value. The original methods have been kept around for backwards
compatibility as while they were not set as public they have been used
in various libraries so removing them now will cause breakages.

* Use parametrized tests for compat check
  • Loading branch information
jborean93 authored Jun 6, 2024
1 parent e8bb574 commit 739df6f
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 22 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
- Added `WSManFaultError` which contains WSManFault specific information when receiving a 500 WSMan fault response
- This contains pre-parsed values like the code, subcode, wsman fault code, wmi error code, and raw response
- It can be used by the caller to implement fallback behaviour based on specific error codes
- Added public API `protocol.build_wsman_header` that can create the standard WSMan header used by the protocol
- This can be used to craft custom WSMan messages that are not supported in the existing actions
- Added public API `protocol.get_command_output_raw`
- This can be used to send a single WSMan receive request and get the output
- Unlike `protocol.get_command_output`, it will not loop until the command is done and will not catch a timeout exception

### Version 0.4.3
- Fix invalid regex escape sequences.
Expand Down
85 changes: 63 additions & 22 deletions winrm/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def open_shell(
@rtype string
"""
req = {
"env:Envelope": self._get_soap_header(
"env:Envelope": self.build_wsman_header(
resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA
action="http://schemas.xmlsoap.org/ws/2004/09/transfer/Create",
)
Expand Down Expand Up @@ -198,13 +198,26 @@ def open_shell(
return t.cast(str, next(node for node in root.findall(".//*") if node.get("Name") == "ShellId").text)

# Helper method for building SOAP Header
def _get_soap_header(
def build_wsman_header(
self,
action: str | None = None,
resource_uri: str | None = None,
action: str,
resource_uri: str,
shell_id: str | None = None,
message_id: uuid.UUID | None = None,
message_id: str | uuid.UUID | None = None,
) -> dict[str, t.Any]:
"""
Builds the standard header needed for WSMan operations. The return
value is a dictionary that can be used by xmltodict to generate the
WSMan envelope when sending custom requests.
@param string action: The WSMan action to perform.
@param string resource_uri: The WSMan resource URI the request is for.
@param string shell_id: The optional shell UUID the request is for.
@param string message_id: A unique message UUID, if unset a random UUID
is used.
@returns The WSMan header as a dictionary.
@rtype dict[str, t.Any]
"""
if not message_id:
message_id = uuid.uuid4()
header: dict[str, t.Any] = {
Expand Down Expand Up @@ -238,6 +251,11 @@ def _get_soap_header(
header["env:Header"]["w:SelectorSet"] = {"w:Selector": {"@Name": "ShellId", "#text": shell_id}}
return header

# For backwards compatibility with Ansible. This should not be removed
# until all supported releases of Ansible has been updated to use the new
# method.
_get_soap_header = build_wsman_header

def send_message(self, message: str) -> bytes:
# TODO add message_id vs relates_to checking
# TODO port error handling code
Expand Down Expand Up @@ -311,7 +329,7 @@ def close_shell(self, shell_id: str, close_session: bool = True) -> None:
try:
message_id = uuid.uuid4()
req = {
"env:Envelope": self._get_soap_header(
"env:Envelope": self.build_wsman_header(
resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA
action="http://schemas.xmlsoap.org/ws/2004/09/transfer/Delete",
shell_id=shell_id,
Expand Down Expand Up @@ -356,7 +374,7 @@ def run_command(
@rtype string
"""
req = {
"env:Envelope": self._get_soap_header(
"env:Envelope": self.build_wsman_header(
resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA
action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Command", # NOQA
shell_id=shell_id,
Expand Down Expand Up @@ -393,7 +411,7 @@ def cleanup_command(self, shell_id: str, command_id: str) -> None:
"""
message_id = uuid.uuid4()
req = {
"env:Envelope": self._get_soap_header(
"env:Envelope": self.build_wsman_header(
resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA
action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Signal", # NOQA
shell_id=shell_id,
Expand Down Expand Up @@ -430,7 +448,7 @@ def send_command_input(self, shell_id: str, command_id: str, stdin_input: str |
if isinstance(stdin_input, str):
stdin_input = stdin_input.encode("437")
req = {
"env:Envelope": self._get_soap_header(
"env:Envelope": self.build_wsman_header(
resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA
action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Send", # NOQA
shell_id=shell_id,
Expand All @@ -449,32 +467,48 @@ def send_command_input(self, shell_id: str, command_id: str, stdin_input: str |

def get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int]:
"""
Get the Output of the given shell and command
Get the Output of the given shell and command. This will wait until the
command is finished before returning the output.
@param string shell_id: The shell id on the remote machine.
See #open_shell
@param string command_id: The command id on the remote machine.
See #run_command
#@return [Hash] Returns a Hash with a key :exitcode and :data.
Data is an Array of Hashes where the corresponding key
# is either :stdout or :stderr. The reason it is in an Array so so
we can get the output in the order it occurs on
# the console.
@return tuple[bytes, bytes, int]: Returns a tuple with the stdout,
stderr, and the return code of the command. The stdout and stderr
value is a byte string and not a normal string.
"""
stdout_buffer, stderr_buffer = [], []
command_done = False
while not command_done:
try:
stdout, stderr, return_code, command_done = self._raw_get_command_output(shell_id, command_id)
stdout, stderr, return_code, command_done = self.get_command_output_raw(shell_id, command_id)
stdout_buffer.append(stdout)
stderr_buffer.append(stderr)
except WinRMOperationTimeoutError:
# this is an expected error when waiting for a long-running process, just silently retry
pass
return b"".join(stdout_buffer), b"".join(stderr_buffer), return_code

def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int, bool]:
def get_command_output_raw(self, shell_id: str, command_id: str) -> tuple[bytes, bytes, int, bool]:
"""
Get the next available output of the given shell and command. This
will wait until the issued WSMan Receive action returns data or times
out with WinRMOperationTimeoutError.
@param string shell_id: The shell id on the remote machine.
See #open_shell
@param string command_id: The command id on the remote machine.
See #run_command
@return tuple[bytes, bytes, int, bool]: Returns a tuple with the stdout,
stderr, the return code of the command, and whether it has finished
or not. The stdout and stderr value is a byte string and not a
normal string.
@raises WinRMOperationTimeoutError: Raised when there has been no
output from the command
"""
req = {
"env:Envelope": self._get_soap_header(
"env:Envelope": self.build_wsman_header(
resource_uri="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/cmd", # NOQA
action="http://schemas.microsoft.com/wbem/wsman/1/windows/shell/Receive", # NOQA
shell_id=shell_id,
Expand All @@ -488,15 +522,16 @@ def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes
res = self.send_message(xmltodict.unparse(req))
root = ET.fromstring(res)
stream_nodes = [node for node in root.findall(".//*") if node.tag.endswith("Stream")]
stdout = stderr = b""
stdout = []
stderr = []
return_code = -1
for stream_node in stream_nodes:
if not stream_node.text:
continue
if stream_node.attrib["Name"] == "stdout":
stdout += base64.b64decode(stream_node.text.encode("ascii"))
stdout.append(base64.b64decode(stream_node.text.encode("ascii")))
elif stream_node.attrib["Name"] == "stderr":
stderr += base64.b64decode(stream_node.text.encode("ascii"))
stderr.append(base64.b64decode(stream_node.text.encode("ascii")))

# We may need to get additional output if the stream has not finished.
# The CommandState will change from Running to Done like so:
Expand All @@ -511,4 +546,10 @@ def _raw_get_command_output(self, shell_id: str, command_id: str) -> tuple[bytes
if command_done:
return_code = int(next(node for node in root.findall(".//*") if node.tag.endswith("ExitCode")).text or -1)

return stdout, stderr, return_code, command_done
return b"".join(stdout), b"".join(stderr), return_code, command_done

# While it was meant to be private it has been treated as a public API.
# This might be removed in a future version but for now keep it as an
# alias for the now public API method 'get_command_output_raw'.
# https://github.com/search?q=_raw_get_command_output+language%3APython&type=code&l=Python
_raw_get_command_output = get_command_output_raw
27 changes: 27 additions & 0 deletions winrm/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@
from winrm.protocol import Protocol


@pytest.mark.parametrize("func_name", ["build_wsman_header", "_get_soap_header"])
def test_build_wsman_header(func_name, protocol_fake):
func = getattr(protocol_fake, func_name)
actual = func("my action", "resource uri", "shell id", "message id")

assert actual["env:Header"]["a:Action"]["#text"] == "my action"
assert actual["env:Header"]["w:ResourceURI"]["#text"] == "resource uri"
assert actual["env:Header"]["a:MessageID"] == "uuid:message id"
assert actual["env:Header"]["w:SelectorSet"]["w:Selector"]["#text"] == "shell id"


def test_open_shell_and_close_shell(protocol_fake):
shell_id = protocol_fake.open_shell()
assert shell_id == "11111111-1111-1111-1111-111111111113"
Expand Down Expand Up @@ -40,6 +51,22 @@ def test_get_command_output(protocol_fake):
protocol_fake.close_shell(shell_id)


@pytest.mark.parametrize("func_name", ["get_command_output_raw", "_raw_get_command_output"])
def test_get_command_output_raw(func_name, protocol_fake):
func = getattr(protocol_fake, func_name)
shell_id = protocol_fake.open_shell()
command_id = protocol_fake.run_command(shell_id, "ipconfig", ["/all"])

std_out, std_err, status_code, done = func(shell_id, command_id)
assert status_code == 0
assert b"Windows IP Configuration" in std_out
assert len(std_err) == 0
assert done is True

protocol_fake.cleanup_command(shell_id, command_id)
protocol_fake.close_shell(shell_id)


def test_send_command_input(protocol_fake):
shell_id = protocol_fake.open_shell()
command_id = protocol_fake.run_command(shell_id, "cmd")
Expand Down

0 comments on commit 739df6f

Please sign in to comment.