Skip to content

Commit

Permalink
Better handler registration
Browse files Browse the repository at this point in the history
  • Loading branch information
weiiwang01 committed Aug 22, 2023
1 parent 32f5b19 commit 9be8ab9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
14 changes: 11 additions & 3 deletions ops/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2503,12 +2503,19 @@ def __init__(self, backend: _TestingModelBackend, container_root: pathlib.Path):

def _handle_exec(self, command_prefix: List[str], handler: ExecHandler):
prefix = tuple(command_prefix)
for idx, registered_handler in enumerate(self._exec_handlers):
inserted = False
for idx in range(len(self._exec_handlers)):
if inserted:
idx = idx + 1
registered_handler = self._exec_handlers[idx]
if prefix == registered_handler[0]:
self._exec_handlers[idx] = (prefix, handler)
return
self._exec_handlers.append((prefix, handler))
self._exec_handlers.sort(key=lambda pair: len(pair[0]), reverse=True)
if not inserted and len(prefix) > len(registered_handler[0]):
self._exec_handlers.insert(idx, (prefix, handler))
inserted = True
if not inserted:
self._exec_handlers.append((prefix, handler))

def _check_connection(self):
if not self._backend._can_connect(self):
Expand Down Expand Up @@ -2871,6 +2878,7 @@ def remove_path(self, path: str, *, recursive: bool = False):
file_path.unlink()

def _find_exec_handler(self, command: List[str]) -> Optional[ExecHandler]:
print(self._exec_handlers)
for command_prefix, handler in self._exec_handlers:
if tuple(command[:len(command_prefix)]) == command_prefix:
return handler
Expand Down
17 changes: 17 additions & 0 deletions test/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import inspect
import io
import ipaddress
import itertools
import os
import pathlib
import platform
Expand Down Expand Up @@ -5092,3 +5093,19 @@ def handler(args):
self.assertEqual(args_history[-1].group, "test_group")
self.assertEqual(args_history[-1].group_id, 4)
self.assertDictEqual(args_history[-1].environment, {"foo": "hello", "foobar": "barfoo"})

def test_registration_order(self):
for n in range(7):
pebble = self.harness._backend._pebble_clients[self.container.name]
for prefix_lengths in itertools.product(range(1, n + 1), repeat=n):
pebble._exec_handlers = []
for idx, prefix_len in enumerate(prefix_lengths):
self.harness.handle_exec(self.container, [str(idx)] * prefix_len, result=0)
handlers = pebble._exec_handlers
self.assertTrue(all(
len(handlers[i][0]) >= len(handlers[i + 1][0])
for i in range(len(handlers) - 1)))
self.assertEqual(len(handlers), len(prefix_lengths))
for idx, handler in enumerate(handlers):
self.harness.handle_exec(self.container, handler[0], result=idx)
self.assertEqual(handlers[idx][1](None), idx)

0 comments on commit 9be8ab9

Please sign in to comment.