diff --git a/__main__.py b/__main__.py index e73bccfe693df2333447b924c880d00ce09aeaae..5837c275bcb48f7d97eca8be4bd7f5594a94100e 100755 --- a/__main__.py +++ b/__main__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -import os +import os import sys import json import subprocess @@ -33,7 +33,7 @@ class TestRunner: group_id=TEST_GROUP_ID, base_image=CONTAINER_IMAGE_ID, external_port_base=9000, - log=global_logger() + log=global_logger(), ) def prepare_environment(self, build: bool = True) -> None: @@ -84,6 +84,7 @@ TEST_SET.extend(BASIC_TESTS) # set to True to stop at the first failing test FAIL_FAST = True + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -97,8 +98,12 @@ def main(): action="store_true", help="run all tests instead of stopping at first failure. Note: this is currently broken due to issues with cleanup code", ) - parser.add_argument("--num-threads", type=int, default=1, help="number of threads to run tests in") - parser.add_argument("--port-offset", type=int, default=1000, help="port offset for each test") + parser.add_argument( + "--num-threads", type=int, default=1, help="number of threads to run tests in" + ) + parser.add_argument( + "--port-offset", type=int, default=1000, help="port offset for each test" + ) parser.add_argument("filter", nargs="?", help="filter tests by name") args = parser.parse_args() @@ -118,9 +123,10 @@ def main(): log("\n== RUNNING TESTS ==") run_tests = [] + def run_test(test: TestCase, gid: str, port_offset: int): log(f"\n== TEST: [{test.name}] ==\n") - test_set_name = test.name.lower().split('_')[0] + test_set_name = test.name.lower().split("_")[0] test_dir = create_test_dir(DEBUG_OUTPUT_DIR, test_set_name, test.name) log_file_path = os.path.join(test_dir, f"{test.name}.log") @@ -132,7 +138,7 @@ def main(): group_id=gid, base_image=CONTAINER_IMAGE_ID, external_port_base=9000 + port_offset, - log=logger + log=logger, ) score, reason = test.execute(conductor, test_dir, log=logger) @@ -155,7 +161,12 @@ def main(): else: print("Running tests in a threadpool ({args.num_threads} threads)") pool = ThreadPool(processes=args.num_threads) - pool.map(lambda a: run_test(a[1], gid=f"{a[0]}", port_offset=a[0]*args.port_offset), enumerate(TEST_SET)) + pool.map( + lambda a: run_test( + a[1], gid=f"{a[0]}", port_offset=a[0] * args.port_offset + ), + enumerate(TEST_SET), + ) summary_log = os.path.join(DEBUG_OUTPUT_DIR, "summary.log") with open(summary_log, "w") as log_file: diff --git a/tests/helper.py b/tests/helper.py index 18b31dbc6a88b3e062918ee66b7e887e45888bba..e6d3e6d549b2937f73ae6f5cc27801b9194fb150 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -5,12 +5,14 @@ from ..utils.containers import ClusterConductor from ..utils.kvs_api import KVSClient from ..utils.util import Logger -class KVSTestFixture: + +class KVSTestFixture: def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int): + conductor._parent = self self.conductor = conductor self.dir = dir self.node_count = node_count - self.clients = [] + self.clients: list[KVSClient] = [] self.log = log def spawn_cluster(self): @@ -22,23 +24,21 @@ class KVSTestFixture: self.clients.append(KVSClient(ep)) r = self.clients[i].ping() - assert r.status_code == 200, f"expected 200 for ping, got { - r.status_code}" + assert r.status_code == 200, f"expected 200 for ping, got {r.status_code}" self.log(f" - node {i} is up: {r.text}") - def broadcast_view(self, view: Dict[str, List[Dict[str, Any]]]): + def broadcast_view(self, view: List[Dict[str, Any]]): self.log(f"\n> SEND VIEW: {view}") for i, client in enumerate(self.clients): r = client.send_view(view) - assert ( - r.status_code == 200 - ), f"expected 200 to ack view, got {r.status_code}" + assert r.status_code == 200, ( + f"expected 200 to ack view, got {r.status_code}" + ) self.log(f"view sent to node {i}: {r.status_code} {r.text}") def send_view(self, node_id: int, view: Dict[str, List[Dict[str, Any]]]): r = self.clients[node_id].send_view(view) - assert r.status_code == 200, f"expected 200 to ack view, got { - r.status_code}" + assert r.status_code == 200, f"expected 200 to ack view, got {r.status_code}" self.log(f"view sent to node {node_id}: {r.status_code} {r.text}") def destroy_cluster(self): @@ -70,11 +70,11 @@ class KVSMultiClient: self.metadata = None def put(self, node_id: int, key: str, value: str, timeout: float = DEFAULT_TIMEOUT): - self.log(f" {self.name} req_id:{self.req} > { - node_id} > kvs.put {key} <- {value}") + self.log( + f" {self.name} req_id:{self.req} > {node_id} > kvs.put {key} <- {value}" + ) - r = self.clients[node_id].put( - key, value, self.metadata, timeout=timeout) + r = self.clients[node_id].put(key, value, self.metadata, timeout=timeout) # update model if successful if r.status_code // 100 == 2: @@ -86,32 +86,48 @@ class KVSMultiClient: return r def get(self, node_id: int, key: str, timeout: float = DEFAULT_TIMEOUT): - self.log(f" {self.name} req_id:{self.req} > {node_id}> kvs.get { - key} request \"causal-metadata\": {self.metadata}") + self.log( + f' {self.name} req_id:{self.req} > {node_id}> kvs.get { + key + } request "causal-metadata": {self.metadata}' + ) r = self.clients[node_id].get(key, self.metadata, timeout=timeout) if r.status_code // 100 == 2: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get {key} -> {r.json()}") + self.log( + f" {self.name} req_id:{self.req} > {node_id}> kvs.get {key} -> { + r.json() + }" + ) self.metadata = r.json()["causal-metadata"] else: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get {key} -> HTTP ERROR {r.status_code}") + self.log( + f" {self.name} req_id:{self.req} > {node_id}> kvs.get { + key + } -> HTTP ERROR {r.status_code}" + ) self.req += 1 return r def get_all(self, node_id: int, timeout: float = DEFAULT_TIMEOUT): - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get_all request \"causal-metadata\": {self.metadata}") + self.log( + f' {self.name} req_id:{self.req} > { + node_id + }> kvs.get_all request "causal-metadata": {self.metadata}' + ) r = self.clients[node_id].get_all(self.metadata, timeout=timeout) if r.status_code // 100 == 2: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get_all -> {r.json()}") + self.log( + f" {self.name} req_id:{self.req} > {node_id}> kvs.get_all -> {r.json()}" + ) self.metadata = r.json()["causal-metadata"] else: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get_all -> HTTP ERROR {r.status_code}") + self.log( + f" {self.name} req_id:{self.req} > { + node_id + }> kvs.get_all -> HTTP ERROR {r.status_code}" + ) self.req += 1 return r diff --git a/utils/containers.py b/utils/containers.py index dae84e411e8ef06906d47d7c8e99eea63d8a3ae8..5737800419df035d9aad76973da9f7302f7e3833 100644 --- a/utils/containers.py +++ b/utils/containers.py @@ -8,6 +8,8 @@ import re import requests +# from ..tests.helper import KVSTestFixture + from .util import run_cmd_bg, Logger CONTAINER_ENGINE = os.getenv("ENGINE", "docker") @@ -42,9 +44,6 @@ class ClusterNode: ) networks: List[str] # networks the container is attached to - def get_view(self) -> str: - return {"address": f"{self.ip}:{self.port}", "id": self.index} - def internal_endpoint(self) -> str: return f"http://{self.ip}:{self.port}" @@ -53,6 +52,8 @@ class ClusterNode: class ClusterConductor: + # _parent: KVSTestFixture + def __init__( self, group_id: str, @@ -107,14 +108,15 @@ class ClusterConductor: def dump_all_container_logs(self, dir): self.log("dumping logs of kvs containers") - container_pattern = "^kvs_.*" + container_pattern = f"^kvs_{self.group_id}.*" container_regex = re.compile(container_pattern) containers = self._list_containers() for container in containers: if container and container_regex.match(container): self._dump_container_logs(dir, container) - + def get_view(self) -> str: + return {"address": f"{self.ip}:{self.port}", "id": self.index} def _dump_container_logs(self, dir, name: str) -> None: log_file = os.path.join(dir, f"{name}.log") self.log(f"Dumping logs for container {name} to file {log_file}") @@ -143,6 +145,16 @@ class ClusterConductor: error_prefix=f"failed to remove container {name}", ) + def _remove_containers(self, names: list[str]) -> None: + if len(names) == 0: + return + self.log(f"removing containers {names}") + run_cmd_bg( + self._make_remove_cmd(names[0]) + names[1:], + verbose=True, + error_prefix=f"failed to remove containers {names}", + ) + def _remove_network(self, name: str) -> None: # remove a single network self.log(f"removing network {name}") @@ -184,9 +196,15 @@ class ClusterConductor: # cleanup containers self.log(f" cleaning up {'group' if group_only else 'all'} containers") containers = self._list_containers() - for container in containers: - if container and container_regex.match(container): - self._remove_container(container) + containers_to_remove = [ + container + for container in containers + if container and container_regex.match(container) + ] + self._remove_containers(containers_to_remove) + # for container in containers: + # if container and container_regex.match(container): + # self._remove_container(container) # cleanup networks self.log(f" cleaning up {'group' if group_only else 'all'} networks") @@ -260,7 +278,13 @@ class ClusterConductor: # attach container to base network self.log(f" attaching container {node_name} to base network") run_cmd_bg( - [CONTAINER_ENGINE, "network", "connect", self.base_net_name, node_name], + [ + CONTAINER_ENGINE, + "network", + "connect", + self.base_net_name, + node_name, + ], verbose=True, error_prefix=f"failed to attach container {node_name} to base network", log=self.log, @@ -347,10 +371,8 @@ class ClusterConductor: net_name = f"kvs_{self.group_id}_net_{partition_id}" self.log(f"creating partition {partition_id} with nodes {node_ids}") - # create partition network if it doesn't exist if not self._network_exists(net_name): - self.log(f"creating network {net_name}") self._create_network(net_name) # disconnect specified nodes from all other networks @@ -371,6 +393,7 @@ class ClusterConductor: # connect nodes to partition network, and update node ip self.log(f" connecting nodes to partition network {net_name}") + view_changed = False for i in node_ids: node = self.nodes[i] @@ -381,7 +404,13 @@ class ClusterConductor: self.log(f" connecting {node.name} to network {net_name}") run_cmd_bg( - [CONTAINER_ENGINE, "network", "connect", net_name, node.name], + [ + CONTAINER_ENGINE, + "network", + "connect", + net_name, + node.name, + ], verbose=True, error_prefix=f"failed to connect {node.name} to network {net_name}", ) @@ -399,7 +428,16 @@ class ClusterConductor: self.log(f" node {node.name} ip in network {net_name}: {container_ip}") # update node ip - node.ip = container_ip + + if container_ip != node.ip: + node.ip = container_ip + if hasattr(self, "_parent"): + self._parent.clients[ + node.index + ].base_url = self.node_external_endpoint(node.index) + view_changed = True + if view_changed and hasattr(self, "_parent"): + self._parent.broadcast_view(self.get_full_view()) def create_partition(self, node_ids: List[int], partition_id: str) -> None: net_name = f"kvs_{self.group_id}_net_{partition_id}" @@ -428,11 +466,18 @@ class ClusterConductor: # connect nodes to partition network, and update node ip self.log(f" connecting nodes to partition network {net_name}") + view_changed = False for i in node_ids: node = self.nodes[i] self.log(f" connecting {node.name} to network {net_name}") run_cmd_bg( - [CONTAINER_ENGINE, "network", "connect", net_name, node.name], + [ + CONTAINER_ENGINE, + "network", + "connect", + net_name, + node.name, + ], verbose=True, error_prefix=f"failed to connect {node.name} to network {net_name}", ) @@ -450,10 +495,17 @@ class ClusterConductor: self.log(f" node {node.name} ip in network {net_name}: {container_ip}") # update node ip - node.ip = container_ip - - DeprecationWarning("View is in updated format") - + if container_ip != node.ip: + node.ip = container_ip + if hasattr(self, "_parent"): + self._parent.clients[ + node.index + ].base_url = self.node_external_endpoint(node.index) + view_changed = True + if view_changed and hasattr(self, "_parent"): + self._parent.broadcast_view(self.get_full_view()) + def get_node(self, index): + return self.nodes[index] def get_full_view(self): view = [] for node in self.nodes: diff --git a/utils/kvs_api.py b/utils/kvs_api.py index 6f6de12ef37fbdcb1a59673ca23be83d5c5434d6..8353f1bba508cc2fce99b85e44c20cc0b01183d4 100644 --- a/utils/kvs_api.py +++ b/utils/kvs_api.py @@ -18,14 +18,14 @@ def create_json(metadata, value=None): result["value"] = value return result + # client for kvs api class KVSClient: def __init__(self, base_url: str): # set base url without trailing slash - self.base_url = base_url if not base_url.endswith( - "/") else base_url[:-1] + self.base_url = base_url if not base_url.endswith("/") else base_url[:-1] def ping(self, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: if timeout is not None: @@ -38,33 +38,49 @@ class KVSClient: else: return requests.get(f"{self.base_url}/ping") - def get(self, key: str, metadata: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def get( + self, key: str, metadata: str, timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if not key: raise ValueError("key cannot be empty") if timeout is not None: try: - return requests.get(f"{self.base_url}/data/{key}", json=create_json(metadata), timeout=timeout) + return requests.get( + f"{self.base_url}/data/{key}", + json=create_json(metadata), + timeout=timeout, + ) except requests.exceptions.Timeout: r = requests.Response() r.status_code = REQUEST_TIMEOUT_STATUS_CODE return r else: - return requests.get(f"{self.base_url}/data/{key}", json=create_json(metadata)) + return requests.get( + f"{self.base_url}/data/{key}", json=create_json(metadata) + ) - def put(self, key: str, value: str, metadata: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def put( + self, key: str, value: str, metadata: str, timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if not key: raise ValueError("key cannot be empty") if timeout is not None: try: - return requests.put(f"{self.base_url}/data/{key}", json=create_json(metadata, value), timeout=timeout) + return requests.put( + f"{self.base_url}/data/{key}", + json=create_json(metadata, value), + timeout=timeout, + ) except requests.exceptions.Timeout: r = requests.Response() r.status_code = REQUEST_TIMEOUT_STATUS_CODE return r else: - return requests.put(f"{self.base_url}/data/{key}", json=create_json(metadata, value)) + return requests.put( + f"{self.base_url}/data/{key}", json=create_json(metadata, value) + ) def delete(self, key: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: if not key: @@ -80,10 +96,14 @@ class KVSClient: else: return requests.delete(f"{self.base_url}/data/{key}") - def get_all(self, metadata: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def get_all( + self, metadata: str, timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if timeout is not None: try: - return requests.get(f"{self.base_url}/data", json=create_json(metadata), timeout=timeout) + return requests.get( + f"{self.base_url}/data", json=create_json(metadata), timeout=timeout + ) except requests.exceptions.Timeout: r = requests.Response() r.status_code = REQUEST_TIMEOUT_STATUS_CODE @@ -100,11 +120,12 @@ class KVSClient: delete_response = self.delete(key, timeout=timeout) if delete_response.status_code != 200: raise RuntimeError( - f"failed to delete key {key}: { - delete_response.status_code}" + f"failed to delete key {key}: {delete_response.status_code}" ) - def send_view(self, view: dict[str, List[Dict[str, Any]]], timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def send_view( + self, view: dict[str, List[Dict[str, Any]]], timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if not isinstance(view, dict): raise ValueError("view must be a dict") diff --git a/utils/testcase.py b/utils/testcase.py index 746b582767566b37b21b590b7ab02302da4aac97..a835939f18ef5a4ffd9bf866ffcdfe594be7c321 100644 --- a/utils/testcase.py +++ b/utils/testcase.py @@ -17,7 +17,7 @@ class TestCase: except Exception as e: self.score = False self.reason = f"FAIL: {e}" - + return self.score, self.reason def __str__(self): diff --git a/utils/util.py b/utils/util.py index e3da95eb62a5c79409d1feda9d8d83115588a7f1..0fa05f5133386ca76cad7c160be47f7bd34abc30 100644 --- a/utils/util.py +++ b/utils/util.py @@ -4,6 +4,7 @@ import subprocess from typing import TextIO, Optional from collections.abc import Collection + @dataclass class Logger: files: Collection[TextIO] @@ -17,16 +18,24 @@ class Logger: for file in self.files: print(*prefix, *args, file=file) + _GLOBAL_LOGGER = Logger(prefix=None, files=(sys.stderr,)) + def log(*args): _GLOBAL_LOGGER(*args) + def global_logger() -> Logger: return _GLOBAL_LOGGER + def run_cmd_bg( - cmd: list[str], log: Logger = _GLOBAL_LOGGER, verbose=False, error_prefix: str = "command failed", **kwargs + cmd: list[str], + log: Logger = _GLOBAL_LOGGER, + verbose=False, + error_prefix: str = "command failed", + **kwargs, ) -> subprocess.CompletedProcess: # default capture opts kwargs.setdefault("stdout", subprocess.PIPE)