From 93fddce21c6e63b05f0cc75c83fbc58a85380397 Mon Sep 17 00:00:00 2001
From: Christian Knab <christiantknab@gmail.com>
Date: Mon, 10 Mar 2025 23:09:48 -0700
Subject: [PATCH] changing views

---
 tests/basic/basic.py |  14 ++---
 utils/containers.py  | 138 +++++++++++++++++++++++++------------------
 2 files changed, 85 insertions(+), 67 deletions(-)

diff --git a/tests/basic/basic.py b/tests/basic/basic.py
index 4bae246..84d607a 100644
--- a/tests/basic/basic.py
+++ b/tests/basic/basic.py
@@ -1,16 +1,16 @@
-from time import sleep
-
 from ...utils.containers import ClusterConductor
 from ...utils.testcase import TestCase
 from ...utils.util import Logger
 from ..helper import KVSMultiClient, KVSTestFixture
 from ...utils.kvs_api import DEFAULT_TIMEOUT
 
+
 def basic_kv_1(conductor: ClusterConductor, dir, log: Logger):
-    with KVSTestFixture(conductor, dir, log, node_count=2) as fx:
+    with KVSTestFixture(conductor, dir, log, node_count=4) as fx:
         c = KVSMultiClient(fx.clients, "client", log)
-        view = {'shard1': [conductor.get_node(0).get_view()], 'shard2': [conductor.get_node(1).get_view()]}
-        fx.broadcast_view(view)
+        conductor.add_shard("shard1", conductor.get_nodes(0, 2))
+        conductor.add_shard("shard2", conductor.get_nodes(2, 4))
+        fx.broadcast_view(conductor.get_shard_view())
 
         r = c.put(0, "x", "1")
         assert r.ok, f"expected ok for new key, got {r.status_code}"
@@ -37,6 +37,4 @@ def basic_kv_1(conductor: ClusterConductor, dir, log: Logger):
         return True, 0
 
 
-BASIC_TESTS = [
-        TestCase("basic_kv_1", basic_kv_1)
-]
+BASIC_TESTS = [TestCase("basic_kv_1", basic_kv_1)]
diff --git a/utils/containers.py b/utils/containers.py
index 3e2e362..9d8a960 100644
--- a/utils/containers.py
+++ b/utils/containers.py
@@ -22,16 +22,13 @@ class ContainerBuilder:
         # ensure we are able to build the container image
         log(f"building container image {self.image_id}...")
 
-        cmd = [CONTAINER_ENGINE, "build", "-t",
-               self.image_id, self.project_dir]
-        run_cmd_bg(cmd, verbose=True,
-                   error_prefix="failed to build container image")
+        cmd = [CONTAINER_ENGINE, "build", "-t", self.image_id, self.project_dir]
+        run_cmd_bg(cmd, verbose=True, error_prefix="failed to build container image")
 
         # ensure the image exists
         log(f"inspecting container image {self.image_id}...")
         cmd = [CONTAINER_ENGINE, "image", "inspect", self.image_id]
-        run_cmd_bg(cmd, verbose=True,
-                   error_prefix="failed to inspect container image")
+        run_cmd_bg(cmd, verbose=True, error_prefix="failed to inspect container image")
 
 
 @dataclass
@@ -56,11 +53,18 @@ class ClusterNode:
 
 
 class ClusterConductor:
-    def __init__(self, group_id: str, base_image: str, log: Logger, external_port_base: int = 8081):
+    def __init__(
+        self,
+        group_id: str,
+        base_image: str,
+        log: Logger,
+        external_port_base: int = 8081,
+    ):
         self.group_id = group_id
         self.base_image = base_image
         self.base_port = external_port_base
         self.nodes: List[ClusterNode] = []
+        self.shards: dict[str, List[ClusterNode]] = {}
 
         # naming patterns
         self.group_ctr_prefix = f"kvs_{group_id}_node"
@@ -115,22 +119,20 @@ class ClusterConductor:
         log_file = os.path.join(dir, f"{name}.log")
         self.log(f"Dumping logs for container {name} to file {log_file}")
 
-    # Construct the logs command. Docker and Podman both support the "logs" command.
+        # Construct the logs command. Docker and Podman both support the "logs" command.
         logs_cmd = [CONTAINER_ENGINE, "logs", name]
 
         try:
-            logs_output = subprocess.check_output(
-                logs_cmd, stderr=subprocess.STDOUT)
+            logs_output = subprocess.check_output(logs_cmd, stderr=subprocess.STDOUT)
             with open(log_file, "wb") as f:
                 f.write(logs_output)
-            self.log(f"Successfully wrote logs for container {
-                     name} to {log_file}")
+            self.log(f"Successfully wrote logs for container {name} to {log_file}")
         except subprocess.CalledProcessError as e:
-            self.log(f"Error dumping logs for container {
-                     name}: {e.output.decode().strip()}")
-        except Exception as e:
             self.log(
-                f"Unexpected error dumping logs for container {name}: {e}")
+                f"Error dumping logs for container {name}: {e.output.decode().strip()}"
+            )
+        except Exception as e:
+            self.log(f"Unexpected error dumping logs for container {name}: {e}")
 
     def _remove_container(self, name: str) -> None:
         # remove a single container
@@ -180,8 +182,7 @@ class ClusterConductor:
         network_regex = re.compile(network_pattern)
 
         # cleanup containers
-        self.log(f"  cleaning up {
-                 'group' if group_only else 'all'} containers")
+        self.log(f"  cleaning up {'group' if group_only else 'all'} containers")
         containers = self._list_containers()
         for container in containers:
             if container and container_regex.match(container):
@@ -214,8 +215,11 @@ class ClusterConductor:
         self.log(f"spawning cluster of {node_count} nodes")
 
         # delete base network if it exists
-        run_cmd_bg([CONTAINER_ENGINE, "network", "rm",
-                   self.base_net_name], check=False, log=self.log)
+        run_cmd_bg(
+            [CONTAINER_ENGINE, "network", "rm", self.base_net_name],
+            check=False,
+            log=self.log,
+        )
 
         # create base network
         run_cmd_bg(
@@ -232,8 +236,7 @@ class ClusterConductor:
             external_port = self.base_port + i
             port = 8081  # internal port
 
-            self.log(f"  starting container {
-                     node_name} (ext_port={external_port})")
+            self.log(f"  starting container {node_name} (ext_port={external_port})")
 
             # start container detached from networks
             run_cmd_bg(
@@ -257,11 +260,9 @@ class ClusterConductor:
             # attach container to base network
             self.log(f"    attaching container {node_name} to base network")
             run_cmd_bg(
-                [CONTAINER_ENGINE, "network", "connect",
-                    self.base_net_name, node_name],
+                [CONTAINER_ENGINE, "network", "connect", self.base_net_name, node_name],
                 verbose=True,
-                error_prefix=f"failed to attach container {
-                    node_name} to base network",
+                error_prefix=f"failed to attach container {node_name} to base network",
                 log=self.log,
             )
 
@@ -295,8 +296,7 @@ class ClusterConductor:
             )
             self.nodes.append(node)
 
-            self.log(f"    container {
-                     node_name} spawned, base_net_ip={container_ip}")
+            self.log(f"    container {node_name} spawned, base_net_ip={container_ip}")
 
         # wait for the nodes to come online (sequentially)
         self.log("waiting for nodes to come online...")
@@ -325,8 +325,9 @@ class ClusterConductor:
         self.log("nodes:")
         for node in self.nodes:
             self.log(
-                f"  {node.name}: {node.ip}:{
-                    node.port} <-> localhost:{node.external_port}"
+                f"  {node.name}: {node.ip}:{node.port} <-> localhost:{
+                    node.external_port
+                }"
             )
 
         # now log the partitions and the nodes they contain
@@ -339,11 +340,10 @@ class ClusterConductor:
 
         self.log("partitions:")
         for net, nodes in partitions.items():
-            part_name = net[len(self.group_net_prefix) + 1:]
+            part_name = net[len(self.group_net_prefix) + 1 :]
             self.log(f"  {part_name}: {nodes}")
 
     def my_partition(self, node_ids: List[int], partition_id: str) -> None:
-
         net_name = f"kvs_{self.group_id}_net_{partition_id}"
 
         self.log(f"creating partition {partition_id} with nodes {node_ids}")
@@ -359,14 +359,13 @@ class ClusterConductor:
             node = self.nodes[i]
             for network in node.networks:
                 if network != net_name:
-                    self.log(f"    disconnecting {
-                             node.name} from network {network}")
+                    self.log(f"    disconnecting {node.name} from network {network}")
                     run_cmd_bg(
-                        [CONTAINER_ENGINE, "network",
-                            "disconnect", network, node.name],
+                        [CONTAINER_ENGINE, "network", "disconnect", network, node.name],
                         verbose=True,
-                        error_prefix=f"failed to disconnect {
-                            node.name} from network {network}",
+                        error_prefix=f"failed to disconnect {node.name} from network {
+                            network
+                        }",
                     )
                     node.networks.remove(network)
 
@@ -384,8 +383,7 @@ class ClusterConductor:
             run_cmd_bg(
                 [CONTAINER_ENGINE, "network", "connect", net_name, node.name],
                 verbose=True,
-                error_prefix=f"failed to connect {
-                    node.name} to network {net_name}",
+                error_prefix=f"failed to connect {node.name} to network {net_name}",
             )
             node.networks.append(net_name)
 
@@ -398,8 +396,7 @@ class ClusterConductor:
             )
             info = json.loads(inspect.stdout)[0]
             container_ip = info["NetworkSettings"]["Networks"][net_name]["IPAddress"]
-            self.log(f"    node {node.name} ip in network {
-                     net_name}: {container_ip}")
+            self.log(f"    node {node.name} ip in network {net_name}: {container_ip}")
 
             # update node ip
             node.ip = container_ip
@@ -419,14 +416,13 @@ class ClusterConductor:
             node = self.nodes[i]
             for network in node.networks:
                 if network != net_name:
-                    self.log(f"    disconnecting {
-                             node.name} from network {network}")
+                    self.log(f"    disconnecting {node.name} from network {network}")
                     run_cmd_bg(
-                        [CONTAINER_ENGINE, "network",
-                            "disconnect", network, node.name],
+                        [CONTAINER_ENGINE, "network", "disconnect", network, node.name],
                         verbose=True,
-                        error_prefix=f"failed to disconnect {
-                            node.name} from network {network}",
+                        error_prefix=f"failed to disconnect {node.name} from network {
+                            network
+                        }",
                     )
                     node.networks.remove(network)
 
@@ -438,8 +434,7 @@ class ClusterConductor:
             run_cmd_bg(
                 [CONTAINER_ENGINE, "network", "connect", net_name, node.name],
                 verbose=True,
-                error_prefix=f"failed to connect {
-                    node.name} to network {net_name}",
+                error_prefix=f"failed to connect {node.name} to network {net_name}",
             )
             node.networks.append(net_name)
 
@@ -452,33 +447,58 @@ class ClusterConductor:
             )
             info = json.loads(inspect.stdout)[0]
             container_ip = info["NetworkSettings"]["Networks"][net_name]["IPAddress"]
-            self.log(f"    node {node.name} ip in network {
-                     net_name}: {container_ip}")
+            self.log(f"    node {node.name} ip in network {net_name}: {container_ip}")
 
             # update node ip
             node.ip = container_ip
 
     DeprecationWarning("View is in updated format")
+
     def get_full_view(self):
         view = []
         for node in self.nodes:
-            view.append({"address": f"{node.ip}:{
-                        node.port}", "id": node.index})
+            view.append({"address": f"{node.ip}:{node.port}", "id": node.index})
         return view
 
-    def get_nodes(self):
-        return self.nodes
-    
     def get_node(self, index):
         return self.nodes[index]
 
+    def get_nodes(self, start=0, end=None):
+        if start < 0:
+            raise ValueError("Start index cannot be negative.")
+        if end is not None:
+            if end < 0 or end > len(self.nodes):
+                raise ValueError(f"End index must be between 0 and {len(self.nodes)}.")
+            if start > end:
+                raise ValueError("Start index cannot be greater than end index.")
+
+        if end is None:
+            return self.nodes[start:]
+        return self.nodes[start:end]
+
+    def add_shard(self, shard_name: str, nodes: List[ClusterNode]):
+        # add the nodes
+        if shard_name not in self.shards:
+            self.shards[shard_name] = []
+        self.shards[shard_name] = nodes
+
+    def add_node_to_shard(self, shard_name: str, node: ClusterNode):
+        if shard_name not in self.shards:
+            self.shards[shard_name] = []
+        self.shards[shard_name].append(node)
+
+    def get_shard_view(self) -> dict:
+        return {
+            shard: [node.get_view() for node in nodes]
+            for shard, nodes in self.shards.items()
+        }
+
     def get_partition_view(self, partition_id: str):
         net_name = f"kvs_{self.group_id}_net_{partition_id}"
         view = []
         for node in self.nodes:
             if net_name in node.networks:
-                view.append({"address": f"{node.ip}:{ \
-                            node.port}", "id": node.index})
+                view.append({"address": f"{node.ip}:{node.port}", "id": node.index})
         return view
 
     def __enter__(self):
-- 
GitLab