Skip to content
Snippets Groups Projects
Commit 302ebc36 authored by zphrs's avatar zphrs
Browse files

Made a rebroadcast_view fn to allow for changing the last sent view to only update IP addresses

parent f7708e72
No related branches found
No related tags found
No related merge requests found
......@@ -11,9 +11,10 @@ class KVSTestFixture:
self.conductor = conductor
self.dir = dir
self.node_count = node_count
self.clients = []
self.clients: list[KVSClient] = []
self.log = log
def spawn_cluster(self):
self.log("\n> SPAWN CLUSTER")
self.conductor.spawn_cluster(node_count=self.node_count)
......@@ -36,6 +37,14 @@ class KVSTestFixture:
), f"expected 200 to ack view, got {r.status_code}"
self.log(f"view sent to node {i}: {r.status_code} {r.text}")
def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]):
for i, client in enumerate(self.clients):
r = client.resend_last_view_with_ips_from_new_view(new_view)
assert (
r.status_code == 200
), f"expected 200 to ack view, got {r.status_code}"
self.log(f"view resent to node {i}: {r.status_code} {r.text}")
def send_view(self, node_id: int, view: Dict[str, List[Dict[str, Any]]]):
r = self.clients[node_id].send_view(view)
assert r.status_code == 200, f"expected 200 to ack view, got {
......
......@@ -422,7 +422,7 @@ class ClusterConductor:
].base_url = self.node_external_endpoint(node.index)
view_changed = True
if view_changed and hasattr(self, "_parent"):
self._parent.broadcast_view(self.get_shard_view())
self._parent.rebroadcast_view(self.get_shard_view())
def create_partition(self, node_ids: List[int], partition_id: str) -> None:
net_name = f"kvs_{self.group_id}_net_{partition_id}"
......@@ -482,7 +482,7 @@ class ClusterConductor:
].base_url = self.node_external_endpoint(node.index)
view_changed = True
if view_changed and hasattr(self, "_parent"):
self._parent.broadcast_view(self.get_shard_view())
self._parent.rebroadcast_view(self.get_shard_view())
DeprecationWarning("View is in updated format")
......
......@@ -108,5 +108,24 @@ class KVSClient:
if not isinstance(view, dict):
raise ValueError("view must be a dict")
self.last_view = view
request_body = {"view": view}
return requests.put(f"{self.base_url}/view", json=request_body, timeout=timeout)
def resend_last_view_with_ips_from_new_view(self, current_view: dict[str, List[Dict[str, Any]]], timeout: float = DEFAULT_TIMEOUT) -> requests.Response:
if not isinstance(current_view, dict):
raise ValueError("view must be a dict")
if not hasattr(self, "last_view"):
raise LookupError("Must have sent at least one view before calling resend.")
flattened_current_view = {}
for shard_key in current_view:
for node in current_view[shard_key]:
flattened_current_view[node["id"]] = node["address"]
for shard_key in self.last_view:
for node in self.last_view[shard_key]:
node['address'] = flattened_current_view[node["id"]]
request_body = {"view": self.last_view}
return requests.put(f"{self.base_url}/view", json=request_body, timeout=timeout)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment