diff --git a/tests/basic/basic.py b/tests/basic/basic.py index 52cb706c89cf822f2fd02166f6a2ef107f03bd24..c407b959003b5834a00d8191e88223697492b50a 100644 --- a/tests/basic/basic.py +++ b/tests/basic/basic.py @@ -1,5 +1,3 @@ -from time import sleep - from ...utils.containers import ClusterConductor from ...utils.testcase import TestCase from ...utils.util import Logger @@ -8,13 +6,23 @@ from ...utils.kvs_api import DEFAULT_TIMEOUT def basic_kv_1(conductor: ClusterConductor, dir, log: Logger): - with KVSTestFixture(conductor, dir, log, node_count=2) as fx: + with KVSTestFixture(conductor, dir, log, node_count=4) as fx: c = KVSMultiClient(fx.clients, "client", log) - view = { - "shard1": [conductor.get_node(0).get_view()], - "shard2": [conductor.get_node(1).get_view()], - } - fx.broadcast_view(view) + conductor.add_shard("shard1", conductor.get_nodes(0, 2)) + fx.broadcast_view(conductor.get_shard_view()) + + # examples of modifying shards + # conductor.add_shard("shard2", conductor.get_nodes(2, 3)) + # fx.broadcast_view(conductor.get_shard_view()) + # + # conductor.add_node_to_shard("shard2", conductor.get_node(3)) + # fx.broadcast_view(conductor.get_shard_view()) + # + # conductor.remove_node_from_shard("shard1", conductor.get_node(0)) + # fx.broadcast_view(conductor.get_shard_view()) + # + # conductor.remove_shard("shard1") + # fx.broadcast_view(conductor.get_shard_view()) r = c.put(0, "x", "1") assert r.ok, f"expected ok for new key, got {r.status_code}" diff --git a/utils/containers.py b/utils/containers.py index d598d445c484d2bff27b20b183f9238f18dee722..a1a26d05c7b29607f61b50c663594df718532753 100644 --- a/utils/containers.py +++ b/utils/containers.py @@ -65,6 +65,7 @@ class ClusterConductor: self.base_image = base_image self.base_port = external_port_base self.nodes: List[ClusterNode] = [] + self.shards: dict[str, List[ClusterNode]] = {} # naming patterns self.group_ctr_prefix = f"kvs_{group_id}_node" @@ -116,13 +117,6 @@ class ClusterConductor: self._dump_container_logs(dir, container) def get_view(self) -> str: return {"address": f"{self.ip}:{self.port}", "id": self.index} - - def get_nodes(self): - return self.nodes - - def get_node(self, index): - return self.nodes[index] - def _dump_container_logs(self, dir, name: str) -> None: log_file = os.path.join(dir, f"{name}.log") self.log(f"Dumping logs for container {name} to file {log_file}") @@ -373,8 +367,7 @@ class ClusterConductor: part_name = net[len(self.group_net_prefix) + 1 :] self.log(f" {part_name}: {nodes}") - # returns true if a view change is needed - def my_partition(self, node_ids: List[int], partition_id: str): + def my_partition(self, node_ids: List[int], partition_id: str) -> None: net_name = f"kvs_{self.group_id}_net_{partition_id}" self.log(f"creating partition {partition_id} with nodes {node_ids}") @@ -519,6 +512,47 @@ class ClusterConductor: view.append({"address": f"{node.ip}:{node.port}", "id": node.index}) return view + def get_node(self, index): + return self.nodes[index] + + # from `start`` to `end`. Note that `end` is not inclusive + def get_nodes(self, start=0, end=None): + if start < 0: + raise ValueError("Start index cannot be negative.") + if end is not None: + if end < 0 or end > len(self.nodes): + raise ValueError(f"End index must be between 0 and {len(self.nodes)}.") + if start > end: + raise ValueError("Start index cannot be greater than end index.") + + if end is None: + return self.nodes[start:] + return self.nodes[start:end] + + def add_shard(self, shard_name: str, nodes: List[ClusterNode]): + if shard_name not in self.shards: + self.shards[shard_name] = [] + self.shards[shard_name] = nodes + + def remove_shard(self, shard_name: str): + del self.shards[shard_name] + + def add_node_to_shard(self, shard_name: str, node: ClusterNode): + if shard_name not in self.shards: + self.shards[shard_name] = [] + self.shards[shard_name].append(node) + + def remove_node_from_shard(self, shard_name: str, node: ClusterNode): + if shard_name not in self.shards: + return + self.shards[shard_name].remove(node) + + def get_shard_view(self) -> dict: + return { + shard: [node.get_view() for node in nodes] + for shard, nodes in self.shards.items() + } + def get_partition_view(self, partition_id: str): net_name = f"kvs_{self.group_id}_net_{partition_id}" view = []