diff --git a/__main__.py b/__main__.py index 40d59b944498dd44f873c58e5c61ecbaa36c9784..b26e498a373d0b61fe8f97013cc05e29838b91fe 100755 --- a/__main__.py +++ b/__main__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -import os +import os import sys import json import subprocess @@ -33,7 +33,7 @@ class TestRunner: group_id=TEST_GROUP_ID, base_image=CONTAINER_IMAGE_ID, external_port_base=9000, - log=global_logger() + log=global_logger(), ) def prepare_environment(self, build: bool = True) -> None: @@ -96,6 +96,7 @@ TEST_SET.extend(SHUFFLE_TESTS) # set to True to stop at the first failing test FAIL_FAST = True + def main(): parser = argparse.ArgumentParser() parser.add_argument( @@ -109,8 +110,12 @@ def main(): action="store_true", help="run all tests instead of stopping at first failure. Note: this is currently broken due to issues with cleanup code", ) - parser.add_argument("--num-threads", type=int, default=1, help="number of threads to run tests in") - parser.add_argument("--port-offset", type=int, default=1000, help="port offset for each test") + parser.add_argument( + "--num-threads", type=int, default=1, help="number of threads to run tests in" + ) + parser.add_argument( + "--port-offset", type=int, default=1000, help="port offset for each test" + ) parser.add_argument("filter", nargs="?", help="filter tests by name") args = parser.parse_args() @@ -130,9 +135,10 @@ def main(): log("\n== RUNNING TESTS ==") run_tests = [] + def run_test(test: TestCase, gid: str, port_offset: int): log(f"\n== TEST: [{test.name}] ==\n") - test_set_name = test.name.lower().split('_')[0] + test_set_name = test.name.lower().split("_")[0] test_dir = create_test_dir(DEBUG_OUTPUT_DIR, test_set_name, test.name) log_file_path = os.path.join(test_dir, f"{test.name}.log") @@ -144,7 +150,7 @@ def main(): group_id=gid, base_image=CONTAINER_IMAGE_ID, external_port_base=9000 + port_offset, - log=logger + log=logger, ) score, reason = test.execute(conductor, test_dir, log=logger) @@ -167,7 +173,12 @@ def main(): else: print("Running tests in a threadpool ({args.num_threads} threads)") pool = ThreadPool(processes=args.num_threads) - pool.map(lambda a: run_test(a[1], gid=f"{a[0]}", port_offset=a[0]*args.port_offset), enumerate(TEST_SET)) + pool.map( + lambda a: run_test( + a[1], gid=f"{a[0]}", port_offset=a[0] * args.port_offset + ), + enumerate(TEST_SET), + ) summary_log = os.path.join(DEBUG_OUTPUT_DIR, "summary.log") with open(summary_log, "w") as log_file: diff --git a/tests/helper.py b/tests/helper.py index a1b77c2ec82225b5a200f1e8b7fb0d497ac6fda5..6840234df42772e22722dfca42dc41b1c5b5d833 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -5,7 +5,8 @@ from ..utils.containers import ClusterConductor from ..utils.kvs_api import KVSClient from ..utils.util import Logger -class KVSTestFixture: + +class KVSTestFixture: def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int): conductor._parent = self self.conductor = conductor @@ -24,17 +25,16 @@ class KVSTestFixture: self.clients.append(KVSClient(ep)) r = self.clients[i].ping() - assert r.status_code == 200, f"expected 200 for ping, got { - r.status_code}" + assert r.status_code == 200, f"expected 200 for ping, got {r.status_code}" self.log(f" - node {i} is up: {r.text}") - def broadcast_view(self, view: Dict[str, List[Dict[str, Any]]]): + def broadcast_view(self, view: List[Dict[str, Any]]): self.log(f"\n> SEND VIEW: {view}") for i, client in enumerate(self.clients): r = client.send_view(view) - assert ( - r.status_code == 200 - ), f"expected 200 to ack view, got {r.status_code}" + assert r.status_code == 200, ( + f"expected 200 to ack view, got {r.status_code}" + ) self.log(f"view sent to node {i}: {r.status_code} {r.text}") def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]): @@ -47,8 +47,7 @@ class KVSTestFixture: def send_view(self, node_id: int, view: Dict[str, List[Dict[str, Any]]]): r = self.clients[node_id].send_view(view) - assert r.status_code == 200, f"expected 200 to ack view, got { - r.status_code}" + assert r.status_code == 200, f"expected 200 to ack view, got {r.status_code}" self.log(f"view sent to node {node_id}: {r.status_code} {r.text}") def destroy_cluster(self): @@ -80,11 +79,11 @@ class KVSMultiClient: self.metadata = None def put(self, node_id: int, key: str, value: str, timeout: float = DEFAULT_TIMEOUT): - self.log(f" {self.name} req_id:{self.req} > { - node_id} > kvs.put {key} <- {value}") + self.log( + f" {self.name} req_id:{self.req} > {node_id} > kvs.put {key} <- {value}" + ) - r = self.clients[node_id].put( - key, value, self.metadata, timeout=timeout) + r = self.clients[node_id].put(key, value, self.metadata, timeout=timeout) # update model if successful if r.status_code // 100 == 2: @@ -96,32 +95,48 @@ class KVSMultiClient: return r def get(self, node_id: int, key: str, timeout: float = DEFAULT_TIMEOUT): - self.log(f" {self.name} req_id:{self.req} > {node_id}> kvs.get { - key} request \"causal-metadata\": {self.metadata}") + self.log( + f' {self.name} req_id:{self.req} > {node_id}> kvs.get { + key + } request "causal-metadata": {self.metadata}' + ) r = self.clients[node_id].get(key, self.metadata, timeout=timeout) if r.status_code // 100 == 2: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get {key} -> {r.json()}") + self.log( + f" {self.name} req_id:{self.req} > {node_id}> kvs.get {key} -> { + r.json() + }" + ) self.metadata = r.json()["causal-metadata"] else: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get {key} -> HTTP ERROR {r.status_code}") + self.log( + f" {self.name} req_id:{self.req} > {node_id}> kvs.get { + key + } -> HTTP ERROR {r.status_code}" + ) self.req += 1 return r def get_all(self, node_id: int, timeout: float = DEFAULT_TIMEOUT): - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get_all request \"causal-metadata\": {self.metadata}") + self.log( + f' {self.name} req_id:{self.req} > { + node_id + }> kvs.get_all request "causal-metadata": {self.metadata}' + ) r = self.clients[node_id].get_all(self.metadata, timeout=timeout) if r.status_code // 100 == 2: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get_all -> {r.json()}") + self.log( + f" {self.name} req_id:{self.req} > {node_id}> kvs.get_all -> {r.json()}" + ) self.metadata = r.json()["causal-metadata"] else: - self.log(f" {self.name} req_id:{self.req} > { - node_id}> kvs.get_all -> HTTP ERROR {r.status_code}") + self.log( + f" {self.name} req_id:{self.req} > { + node_id + }> kvs.get_all -> HTTP ERROR {r.status_code}" + ) self.req += 1 return r diff --git a/utils/containers.py b/utils/containers.py index d62efd642177cfffb6b16d5e5840ad7a0e6738fa..ac3509c02f99cbe844dd8fc55f14bc1284360474 100644 --- a/utils/containers.py +++ b/utils/containers.py @@ -8,6 +8,8 @@ import re import requests +# from ..tests.helper import KVSTestFixture + from .util import run_cmd_bg, Logger CONTAINER_ENGINE = os.getenv("ENGINE", "docker") @@ -42,9 +44,6 @@ class ClusterNode: ) networks: List[str] # networks the container is attached to - def get_view(self) -> str: - return {"address": f"{self.ip}:{self.port}", "id": self.index} - def internal_endpoint(self) -> str: return f"http://{self.ip}:{self.port}" @@ -53,6 +52,8 @@ class ClusterNode: class ClusterConductor: + # _parent: KVSTestFixture + def __init__( self, group_id: str, @@ -114,7 +115,8 @@ class ClusterConductor: for container in containers: if container and container_regex.match(container): self._dump_container_logs(dir, container) - + def get_view(self) -> str: + return {"address": f"{self.ip}:{self.port}", "id": self.index} def _dump_container_logs(self, dir, name: str) -> None: log_file = os.path.join(dir, f"{name}.log") self.log(f"Dumping logs for container {name} to file {log_file}") @@ -153,7 +155,6 @@ class ClusterConductor: error_prefix=f"failed to remove containers {names}", ) - def _remove_network(self, name: str) -> None: # remove a single network self.log(f"removing network {name}") @@ -201,6 +202,7 @@ class ClusterConductor: if container and container_regex.match(container) ] self._remove_containers(containers_to_remove) + # cleanup networks self.log(f" cleaning up {'group' if group_only else 'all'} networks") networks = self._list_networks() @@ -273,7 +275,13 @@ class ClusterConductor: # attach container to base network self.log(f" attaching container {node_name} to base network") run_cmd_bg( - [CONTAINER_ENGINE, "network", "connect", self.base_net_name, node_name], + [ + CONTAINER_ENGINE, + "network", + "connect", + self.base_net_name, + node_name, + ], verbose=True, error_prefix=f"failed to attach container {node_name} to base network", log=self.log, @@ -360,10 +368,8 @@ class ClusterConductor: net_name = f"kvs_{self.group_id}_net_{partition_id}" self.log(f"creating partition {partition_id} with nodes {node_ids}") - # create partition network if it doesn't exist if not self._network_exists(net_name): - self.log(f"creating network {net_name}") self._create_network(net_name) # disconnect specified nodes from all other networks @@ -395,7 +401,13 @@ class ClusterConductor: self.log(f" connecting {node.name} to network {net_name}") run_cmd_bg( - [CONTAINER_ENGINE, "network", "connect", net_name, node.name], + [ + CONTAINER_ENGINE, + "network", + "connect", + net_name, + node.name, + ], verbose=True, error_prefix=f"failed to connect {node.name} to network {net_name}", ) @@ -423,7 +435,6 @@ class ClusterConductor: view_changed = True if view_changed and hasattr(self, "_parent"): self._parent.rebroadcast_view(self.get_shard_view()) - def create_partition(self, node_ids: List[int], partition_id: str) -> None: net_name = f"kvs_{self.group_id}_net_{partition_id}" @@ -456,7 +467,13 @@ class ClusterConductor: node = self.nodes[i] self.log(f" connecting {node.name} to network {net_name}") run_cmd_bg( - [CONTAINER_ENGINE, "network", "connect", net_name, node.name], + [ + CONTAINER_ENGINE, + "network", + "connect", + net_name, + node.name, + ], verbose=True, error_prefix=f"failed to connect {node.name} to network {net_name}", ) diff --git a/utils/kvs_api.py b/utils/kvs_api.py index 0baf91d84fbbf45f007725c8035939873ce6f2b8..838662c4a9f1a9f9233e0e5b7d519b759fce364b 100644 --- a/utils/kvs_api.py +++ b/utils/kvs_api.py @@ -18,14 +18,14 @@ def create_json(metadata, value=None): result["value"] = value return result + # client for kvs api class KVSClient: def __init__(self, base_url: str): # set base url without trailing slash - self.base_url = base_url if not base_url.endswith( - "/") else base_url[:-1] + self.base_url = base_url if not base_url.endswith("/") else base_url[:-1] def ping(self, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: if timeout is not None: @@ -38,33 +38,49 @@ class KVSClient: else: return requests.get(f"{self.base_url}/ping") - def get(self, key: str, metadata: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def get( + self, key: str, metadata: str, timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if not key: raise ValueError("key cannot be empty") if timeout is not None: try: - return requests.get(f"{self.base_url}/data/{key}", json=create_json(metadata), timeout=timeout) + return requests.get( + f"{self.base_url}/data/{key}", + json=create_json(metadata), + timeout=timeout, + ) except requests.exceptions.Timeout: r = requests.Response() r.status_code = REQUEST_TIMEOUT_STATUS_CODE return r else: - return requests.get(f"{self.base_url}/data/{key}", json=create_json(metadata)) + return requests.get( + f"{self.base_url}/data/{key}", json=create_json(metadata) + ) - def put(self, key: str, value: str, metadata: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def put( + self, key: str, value: str, metadata: str, timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if not key: raise ValueError("key cannot be empty") if timeout is not None: try: - return requests.put(f"{self.base_url}/data/{key}", json=create_json(metadata, value), timeout=timeout) + return requests.put( + f"{self.base_url}/data/{key}", + json=create_json(metadata, value), + timeout=timeout, + ) except requests.exceptions.Timeout: r = requests.Response() r.status_code = REQUEST_TIMEOUT_STATUS_CODE return r else: - return requests.put(f"{self.base_url}/data/{key}", json=create_json(metadata, value)) + return requests.put( + f"{self.base_url}/data/{key}", json=create_json(metadata, value) + ) def delete(self, key: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: if not key: @@ -80,10 +96,14 @@ class KVSClient: else: return requests.delete(f"{self.base_url}/data/{key}") - def get_all(self, metadata: str, timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def get_all( + self, metadata: str, timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if timeout is not None: try: - return requests.get(f"{self.base_url}/data", json=create_json(metadata), timeout=timeout) + return requests.get( + f"{self.base_url}/data", json=create_json(metadata), timeout=timeout + ) except requests.exceptions.Timeout: r = requests.Response() r.status_code = REQUEST_TIMEOUT_STATUS_CODE @@ -100,11 +120,12 @@ class KVSClient: delete_response = self.delete(key, timeout=timeout) if delete_response.status_code != 200: raise RuntimeError( - f"failed to delete key {key}: { - delete_response.status_code}" + f"failed to delete key {key}: {delete_response.status_code}" ) - def send_view(self, view: dict[str, List[Dict[str, Any]]], timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + def send_view( + self, view: dict[str, List[Dict[str, Any]]], timeout: float = DEFAULT_TIMEOUT + ) -> requests.Response: if not isinstance(view, dict): raise ValueError("view must be a dict") diff --git a/utils/testcase.py b/utils/testcase.py index 746b582767566b37b21b590b7ab02302da4aac97..a835939f18ef5a4ffd9bf866ffcdfe594be7c321 100644 --- a/utils/testcase.py +++ b/utils/testcase.py @@ -17,7 +17,7 @@ class TestCase: except Exception as e: self.score = False self.reason = f"FAIL: {e}" - + return self.score, self.reason def __str__(self): diff --git a/utils/util.py b/utils/util.py index e3da95eb62a5c79409d1feda9d8d83115588a7f1..0fa05f5133386ca76cad7c160be47f7bd34abc30 100644 --- a/utils/util.py +++ b/utils/util.py @@ -4,6 +4,7 @@ import subprocess from typing import TextIO, Optional from collections.abc import Collection + @dataclass class Logger: files: Collection[TextIO] @@ -17,16 +18,24 @@ class Logger: for file in self.files: print(*prefix, *args, file=file) + _GLOBAL_LOGGER = Logger(prefix=None, files=(sys.stderr,)) + def log(*args): _GLOBAL_LOGGER(*args) + def global_logger() -> Logger: return _GLOBAL_LOGGER + def run_cmd_bg( - cmd: list[str], log: Logger = _GLOBAL_LOGGER, verbose=False, error_prefix: str = "command failed", **kwargs + cmd: list[str], + log: Logger = _GLOBAL_LOGGER, + verbose=False, + error_prefix: str = "command failed", + **kwargs, ) -> subprocess.CompletedProcess: # default capture opts kwargs.setdefault("stdout", subprocess.PIPE)