From 302ebc36b858fa4666a0c23f7d126273862cc1cd Mon Sep 17 00:00:00 2001 From: zphrs <z@zephiris.dev> Date: Fri, 14 Mar 2025 00:41:19 -0700 Subject: [PATCH] Made a rebroadcast_view fn to allow for changing the last sent view to only update IP addresses --- tests/helper.py | 11 ++++++++++- utils/containers.py | 4 ++-- utils/kvs_api.py | 19 +++++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/tests/helper.py b/tests/helper.py index e8057e4..a1b77c2 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -11,9 +11,10 @@ class KVSTestFixture: self.conductor = conductor self.dir = dir self.node_count = node_count - self.clients = [] + self.clients: list[KVSClient] = [] self.log = log + def spawn_cluster(self): self.log("\n> SPAWN CLUSTER") self.conductor.spawn_cluster(node_count=self.node_count) @@ -36,6 +37,14 @@ class KVSTestFixture: ), f"expected 200 to ack view, got {r.status_code}" self.log(f"view sent to node {i}: {r.status_code} {r.text}") + def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]): + for i, client in enumerate(self.clients): + r = client.resend_last_view_with_ips_from_new_view(new_view) + assert ( + r.status_code == 200 + ), f"expected 200 to ack view, got {r.status_code}" + self.log(f"view resent to node {i}: {r.status_code} {r.text}") + def send_view(self, node_id: int, view: Dict[str, List[Dict[str, Any]]]): r = self.clients[node_id].send_view(view) assert r.status_code == 200, f"expected 200 to ack view, got { diff --git a/utils/containers.py b/utils/containers.py index 54a3bb4..7041cd3 100644 --- a/utils/containers.py +++ b/utils/containers.py @@ -422,7 +422,7 @@ class ClusterConductor: ].base_url = self.node_external_endpoint(node.index) view_changed = True if view_changed and hasattr(self, "_parent"): - self._parent.broadcast_view(self.get_shard_view()) + self._parent.rebroadcast_view(self.get_shard_view()) def create_partition(self, node_ids: List[int], partition_id: str) -> None: net_name = f"kvs_{self.group_id}_net_{partition_id}" @@ -482,7 +482,7 @@ class ClusterConductor: ].base_url = self.node_external_endpoint(node.index) view_changed = True if view_changed and hasattr(self, "_parent"): - self._parent.broadcast_view(self.get_shard_view()) + self._parent.rebroadcast_view(self.get_shard_view()) DeprecationWarning("View is in updated format") diff --git a/utils/kvs_api.py b/utils/kvs_api.py index 6f6de12..0baf91d 100644 --- a/utils/kvs_api.py +++ b/utils/kvs_api.py @@ -108,5 +108,24 @@ class KVSClient: if not isinstance(view, dict): raise ValueError("view must be a dict") + self.last_view = view request_body = {"view": view} return requests.put(f"{self.base_url}/view", json=request_body, timeout=timeout) + + def resend_last_view_with_ips_from_new_view(self, current_view: dict[str, List[Dict[str, Any]]], timeout: float = DEFAULT_TIMEOUT) -> requests.Response: + if not isinstance(current_view, dict): + raise ValueError("view must be a dict") + if not hasattr(self, "last_view"): + raise LookupError("Must have sent at least one view before calling resend.") + flattened_current_view = {} + for shard_key in current_view: + for node in current_view[shard_key]: + flattened_current_view[node["id"]] = node["address"] + + for shard_key in self.last_view: + for node in self.last_view[shard_key]: + node['address'] = flattened_current_view[node["id"]] + + request_body = {"view": self.last_view} + return requests.put(f"{self.base_url}/view", json=request_body, timeout=timeout) + \ No newline at end of file -- GitLab