From 0d17d66985d49f72b368ed7af28824ddadc74ed5 Mon Sep 17 00:00:00 2001
From: Christian Knab <christiantknab@gmail.com>
Date: Tue, 11 Mar 2025 01:19:55 -0700
Subject: [PATCH] view change helper functions to modify shards

---
 tests/basic/basic.py | 14 +++++++++++++-
 utils/containers.py  | 10 +++++++++-
 2 files changed, 22 insertions(+), 2 deletions(-)

diff --git a/tests/basic/basic.py b/tests/basic/basic.py
index 84d607a..c407b95 100644
--- a/tests/basic/basic.py
+++ b/tests/basic/basic.py
@@ -9,9 +9,21 @@ def basic_kv_1(conductor: ClusterConductor, dir, log: Logger):
     with KVSTestFixture(conductor, dir, log, node_count=4) as fx:
         c = KVSMultiClient(fx.clients, "client", log)
         conductor.add_shard("shard1", conductor.get_nodes(0, 2))
-        conductor.add_shard("shard2", conductor.get_nodes(2, 4))
         fx.broadcast_view(conductor.get_shard_view())
 
+        # examples of modifying shards
+        # conductor.add_shard("shard2", conductor.get_nodes(2, 3))
+        # fx.broadcast_view(conductor.get_shard_view())
+        #
+        # conductor.add_node_to_shard("shard2", conductor.get_node(3))
+        # fx.broadcast_view(conductor.get_shard_view())
+        #
+        # conductor.remove_node_from_shard("shard1", conductor.get_node(0))
+        # fx.broadcast_view(conductor.get_shard_view())
+        #
+        # conductor.remove_shard("shard1")
+        # fx.broadcast_view(conductor.get_shard_view())
+
         r = c.put(0, "x", "1")
         assert r.ok, f"expected ok for new key, got {r.status_code}"
 
diff --git a/utils/containers.py b/utils/containers.py
index 9d8a960..95af928 100644
--- a/utils/containers.py
+++ b/utils/containers.py
@@ -463,6 +463,7 @@ class ClusterConductor:
     def get_node(self, index):
         return self.nodes[index]
 
+    # from `start`` to `end`. Note that `end` is not inclusive
     def get_nodes(self, start=0, end=None):
         if start < 0:
             raise ValueError("Start index cannot be negative.")
@@ -477,16 +478,23 @@ class ClusterConductor:
         return self.nodes[start:end]
 
     def add_shard(self, shard_name: str, nodes: List[ClusterNode]):
-        # add the nodes
         if shard_name not in self.shards:
             self.shards[shard_name] = []
         self.shards[shard_name] = nodes
 
+    def remove_shard(self, shard_name: str):
+        del self.shards[shard_name]
+
     def add_node_to_shard(self, shard_name: str, node: ClusterNode):
         if shard_name not in self.shards:
             self.shards[shard_name] = []
         self.shards[shard_name].append(node)
 
+    def remove_node_from_shard(self, shard_name: str, node: ClusterNode):
+        if shard_name not in self.shards:
+            return
+        self.shards[shard_name].remove(node)
+
     def get_shard_view(self) -> dict:
         return {
             shard: [node.get_view() for node in nodes]
-- 
GitLab