From d7a0e6a1e2504c6a40c640828d6be9e184b7a3c5 Mon Sep 17 00:00:00 2001 From: zphrs <z@zephiris.dev> Date: Fri, 14 Mar 2025 20:28:12 -0700 Subject: [PATCH] fix to basic_shuffle --- tests/bench/benchmark.py | 16 +++++++++------- tests/helper.py | 7 ++++--- tests/shuffle/basic_shuffle.py | 32 ++++++++++++++++++++------------ utils/kvs_api.py | 1 + 4 files changed, 34 insertions(+), 22 deletions(-) diff --git a/tests/bench/benchmark.py b/tests/bench/benchmark.py index b5f664e..3ec2f1e 100644 --- a/tests/bench/benchmark.py +++ b/tests/bench/benchmark.py @@ -7,6 +7,7 @@ import time import asyncio import matplotlib.pyplot as plt + def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger): with KVSTestFixture(conductor, dir, log, node_count=16) as fx: conductor.add_shard("shard1", conductor.get_nodes([0])) @@ -14,8 +15,8 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger): log("putting 100 keys\n") put_times = [] - for i in range(1000): - c = KVSMultiClient(fx.clients, "client", log) + c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False) + for i in range(20000): start_time = time.time() r = c.put(0, f"key{i}", f"value{i}", timeout=10) end_time = time.time() @@ -39,13 +40,14 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger): # Generate plot plt.figure(figsize=(10, 6)) - plt.plot(range(2, 17), reshard_times, marker='o') - plt.title('Reshard Times') - plt.xlabel('Number of Shards') - plt.ylabel('Time (seconds)') + plt.plot(range(2, 17), reshard_times, marker="o") + plt.title("Reshard Times") + plt.xlabel("Number of Shards") + plt.ylabel("Time (seconds)") plt.grid(True) plt.savefig(f"{dir}/reshard_times.png") return True, "ok" -BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)] \ No newline at end of file + +BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)] diff --git a/tests/helper.py b/tests/helper.py index e33a803..957c6ca 100644 --- a/tests/helper.py +++ b/tests/helper.py @@ -7,6 +7,7 @@ from ..utils.util import Logger import asyncio + class KVSTestFixture: def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int): conductor._parent = self @@ -42,9 +43,7 @@ class KVSTestFixture: async def send_view(client: KVSClient, i: int): r = await client.async_send_view(view) - assert ( - r.status == 200 - ), f"expected 200 to ack view, got {r.status}" + assert r.status == 200, f"expected 200 to ack view, got {r.status}" self.log(f"view sent to node {i}: {r.status} {r.text}") tasks = [send_view(client, i) for i, client in enumerate(self.clients)] @@ -53,6 +52,8 @@ class KVSTestFixture: def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]): for i, client in enumerate(self.clients): r = client.resend_last_view_with_ips_from_new_view(new_view) + if r is None: + return assert r.status_code == 200, ( f"expected 200 to ack view, got {r.status_code}" ) diff --git a/tests/shuffle/basic_shuffle.py b/tests/shuffle/basic_shuffle.py index a99859c..af00040 100644 --- a/tests/shuffle/basic_shuffle.py +++ b/tests/shuffle/basic_shuffle.py @@ -102,7 +102,9 @@ def basic_shuffle_add_remove(conductor: ClusterConductor, dir, log: Logger): res = r.json()["items"] shard2_keys = res - assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}" + assert len(shard1_keys) + len(shard2_keys) == 15, ( + f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}" + ) # Remove shard 2. This loses keys. conductor.remove_shard("shard2") @@ -230,7 +232,7 @@ def basic_shuffle_1(conductor: ClusterConductor, dir, log: Logger): def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger): with KVSTestFixture(conductor, dir, log, node_count=3) as fx: - c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False) + c = KVSMultiClient(fx.clients, "client", log) conductor.add_shard("shard1", conductor.get_nodes([0, 1])) fx.broadcast_view(conductor.get_shard_view()) @@ -249,32 +251,38 @@ def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger): r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10) assert r.ok, f"expected ok for new key, got {r.status_code}" node_to_put += 1 - node_to_put = node_to_put % 3 + node_to_put = node_to_put % 2 - r = c.get_all(0, timeout=10) # should get all of shard 1's keys + r = c.get_all(0, timeout=30) # should get all of shard 1's keys assert r.ok, f"expected ok for get, got {r.status_code}" original_get_all = r.json()["items"] + assert len(original_get_all.keys()) == 300, ( + f"original_get_all doesn't have 300 keys, instead has {len(original_get_all.keys())} keys" + ) + conductor.add_shard("shard2", conductor.get_nodes([2])) fx.broadcast_view(conductor.get_shard_view()) - r = c.get_all(2, timeout=10) # should get all of shard 2's keys + r = c.get_all(2, timeout=30) # should get all of shard 2's keys assert r.ok, f"expected ok for get, got {r.status_code}" get_all_1 = r.json()["items"] - keys1 = get_all_1.keys() + keys1 = set(get_all_1.keys()) - r = c.get_all(1, timeout=10) # should get all of shard 1's keys + r = c.get_all(1, timeout=30) # should get all of shard 1's keys assert r.ok, f"expected ok for get, got {r.status_code}" get_all_2 = r.json()["items"] - keys2 = get_all_2.keys() + keys2 = set(get_all_2.keys()) for key in keys1: - assert not (key in keys2) + assert key not in keys2, "key not in keys2" for key in keys2: - assert not (key in keys1) + assert key not in keys1, "key not in keys2" - assert original_get_all.keys() == keys1 | keys2 - assert len(original_get_all) == len(keys1) + len(keys2) + assert original_get_all.keys() == keys1.union(keys2), ( + f"get all keys does not equal key1 joined with keys2. diff one way: \n{keys1.union(keys2).difference(original_get_all.keys())}\n diff other way:\n{set(original_get_all.keys()).difference(keys1.union(keys2))}" + ) + assert len(original_get_all) == len(keys1) + len(keys2), "lengths do not match" return True, "ok" diff --git a/utils/kvs_api.py b/utils/kvs_api.py index ab6eb04..c630a50 100644 --- a/utils/kvs_api.py +++ b/utils/kvs_api.py @@ -152,6 +152,7 @@ class KVSClient: if not isinstance(current_view, dict): raise ValueError("view must be a dict") if not hasattr(self, "last_view"): + return raise LookupError("Must have sent at least one view before calling resend.") flattened_current_view = {} for shard_key in current_view: -- GitLab