diff --git a/tests/basic/basic.py b/tests/basic/basic.py index 84d607a1d5fb426caed6e1a5fe83e4f0b8673273..c407b959003b5834a00d8191e88223697492b50a 100644 --- a/tests/basic/basic.py +++ b/tests/basic/basic.py @@ -9,9 +9,21 @@ def basic_kv_1(conductor: ClusterConductor, dir, log: Logger): with KVSTestFixture(conductor, dir, log, node_count=4) as fx: c = KVSMultiClient(fx.clients, "client", log) conductor.add_shard("shard1", conductor.get_nodes(0, 2)) - conductor.add_shard("shard2", conductor.get_nodes(2, 4)) fx.broadcast_view(conductor.get_shard_view()) + # examples of modifying shards + # conductor.add_shard("shard2", conductor.get_nodes(2, 3)) + # fx.broadcast_view(conductor.get_shard_view()) + # + # conductor.add_node_to_shard("shard2", conductor.get_node(3)) + # fx.broadcast_view(conductor.get_shard_view()) + # + # conductor.remove_node_from_shard("shard1", conductor.get_node(0)) + # fx.broadcast_view(conductor.get_shard_view()) + # + # conductor.remove_shard("shard1") + # fx.broadcast_view(conductor.get_shard_view()) + r = c.put(0, "x", "1") assert r.ok, f"expected ok for new key, got {r.status_code}" diff --git a/utils/containers.py b/utils/containers.py index 9d8a960da80179adc03fb81dd981560bb1855166..95af9281a62cfc9f3e683c3aca579c8d421d9878 100644 --- a/utils/containers.py +++ b/utils/containers.py @@ -463,6 +463,7 @@ class ClusterConductor: def get_node(self, index): return self.nodes[index] + # from `start`` to `end`. Note that `end` is not inclusive def get_nodes(self, start=0, end=None): if start < 0: raise ValueError("Start index cannot be negative.") @@ -477,16 +478,23 @@ class ClusterConductor: return self.nodes[start:end] def add_shard(self, shard_name: str, nodes: List[ClusterNode]): - # add the nodes if shard_name not in self.shards: self.shards[shard_name] = [] self.shards[shard_name] = nodes + def remove_shard(self, shard_name: str): + del self.shards[shard_name] + def add_node_to_shard(self, shard_name: str, node: ClusterNode): if shard_name not in self.shards: self.shards[shard_name] = [] self.shards[shard_name].append(node) + def remove_node_from_shard(self, shard_name: str, node: ClusterNode): + if shard_name not in self.shards: + return + self.shards[shard_name].remove(node) + def get_shard_view(self) -> dict: return { shard: [node.get_view() for node in nodes]