Skip to content
Snippets Groups Projects
Commit 7f238543 authored by Thomas Dillow's avatar Thomas Dillow
Browse files

merge main

parents f3a50fcf d7a0e6a1
No related branches found
No related tags found
No related merge requests found
......@@ -50,4 +50,5 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
return True, "ok"
BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)]
\ No newline at end of file
BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)]
......@@ -7,6 +7,7 @@ from ..utils.util import Logger
import asyncio
class KVSTestFixture:
def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int):
conductor._parent = self
......@@ -42,9 +43,7 @@ class KVSTestFixture:
async def send_view(client: KVSClient, i: int):
r = await client.async_send_view(view)
assert (
r.status == 200
), f"expected 200 to ack view, got {r.status}"
assert r.status == 200, f"expected 200 to ack view, got {r.status}"
self.log(f"view sent to node {i}: {r.status} {r.text}")
tasks = [send_view(client, i) for i, client in enumerate(self.clients)]
......@@ -53,6 +52,8 @@ class KVSTestFixture:
def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]):
for i, client in enumerate(self.clients):
r = client.resend_last_view_with_ips_from_new_view(new_view)
if r is None:
return
assert r.status_code == 200, (
f"expected 200 to ack view, got {r.status_code}"
)
......
......@@ -102,7 +102,9 @@ def basic_shuffle_add_remove(conductor: ClusterConductor, dir, log: Logger):
res = r.json()["items"]
shard2_keys = res
assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Remove shard 2. This loses keys.
conductor.remove_shard("shard2")
......@@ -230,7 +232,7 @@ def basic_shuffle_1(conductor: ClusterConductor, dir, log: Logger):
def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
c = KVSMultiClient(fx.clients, "client", log)
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
fx.broadcast_view(conductor.get_shard_view())
......@@ -249,32 +251,38 @@ def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 3
node_to_put = node_to_put % 2
r = c.get_all(0, timeout=10) # should get all of shard 1's keys
r = c.get_all(0, timeout=30) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
original_get_all = r.json()["items"]
assert len(original_get_all.keys()) == 300, (
f"original_get_all doesn't have 300 keys, instead has {len(original_get_all.keys())} keys"
)
conductor.add_shard("shard2", conductor.get_nodes([2]))
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(2, timeout=10) # should get all of shard 2's keys
r = c.get_all(2, timeout=30) # should get all of shard 2's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
get_all_1 = r.json()["items"]
keys1 = get_all_1.keys()
keys1 = set(get_all_1.keys())
r = c.get_all(1, timeout=10) # should get all of shard 1's keys
r = c.get_all(1, timeout=30) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
get_all_2 = r.json()["items"]
keys2 = get_all_2.keys()
keys2 = set(get_all_2.keys())
for key in keys1:
assert not (key in keys2)
assert key not in keys2, "key not in keys2"
for key in keys2:
assert not (key in keys1)
assert key not in keys1, "key not in keys2"
assert original_get_all.keys() == keys1 | keys2
assert len(original_get_all) == len(keys1) + len(keys2)
assert original_get_all.keys() == keys1.union(keys2), (
f"get all keys does not equal key1 joined with keys2. diff one way: \n{keys1.union(keys2).difference(original_get_all.keys())}\n diff other way:\n{set(original_get_all.keys()).difference(keys1.union(keys2))}"
)
assert len(original_get_all) == len(keys1) + len(keys2), "lengths do not match"
return True, "ok"
......
......@@ -152,6 +152,7 @@ class KVSClient:
if not isinstance(current_view, dict):
raise ValueError("view must be a dict")
if not hasattr(self, "last_view"):
return
raise LookupError("Must have sent at least one view before calling resend.")
flattened_current_view = {}
for shard_key in current_view:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment