diff --git a/tests/bench/benchmark.py b/tests/bench/benchmark.py index 97ef0d14da2849c94caa7a22e6d053538e7e0fd6..61b5cfa66629cdfc321f183abd5e6d1ca0a009f1 100644 --- a/tests/bench/benchmark.py +++ b/tests/bench/benchmark.py @@ -8,14 +8,15 @@ import asyncio import matplotlib.pyplot as plt NUM_SHARDS = 16 -NUM_KEYS = 10000 +NUM_KEYS = 20000 +NUM_NODES = 8 def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger): with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS) as fx: conductor.add_shard("shard1", conductor.get_nodes([0])) fx.broadcast_view(conductor.get_shard_view()) - log("putting 100 keys\n") + log(f"putting {NUM_KEYS} keys\n") put_times = [] for i in range(NUM_KEYS): c = KVSMultiClient(fx.clients, "client", log) @@ -51,5 +52,46 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger): return True, "ok" +def benchmark_add_shard_two_nodes(conductor: ClusterConductor, dir, log: Logger): + with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS*NUM_NODES) as fx: + conductor.add_shard("shard1", conductor.get_nodes([0, 1])) + fx.broadcast_view(conductor.get_shard_view()) + + log(f"putting {NUM_KEYS} keys\n") + put_times = [] + for i in range(NUM_KEYS): + c = KVSMultiClient(fx.clients, "client", log) + start_time = time.time() + r = c.put(i%NUM_NODES, f"key{i}", f"value{i}", timeout=10) + end_time = time.time() + assert r.ok, f"expected ok for new key, got {r.status_code}" + put_times.append(end_time - start_time) + + log("Starting benchmark\n") + reshard_times = [] + for shard in range(2, NUM_SHARDS+1): + start_time = time.time() + log(f"adding shard{shard}\n") + conductor.add_shard(f"shard{shard}", conductor.get_nodes([2*(shard - 1), 2*(shard - 1) + 1])) + asyncio.run(fx.parallel_broadcast_view(conductor.get_shard_view())) + end_time = time.time() + reshard_times.append(end_time - start_time) + log(f"reshard time with {shard} shards: {reshard_times[-1]}\n") + + log("Average put time: ", sum(put_times) / len(put_times)) + for shard, time_taken in enumerate(reshard_times, start=2): + log(f"shard count: {shard}, reshard time: {time_taken}") + + # Generate plot + plt.figure(figsize=(NUM_SHARDS, 10)) + plt.plot(range(2, NUM_SHARDS+1), reshard_times, marker='o') + plt.title('Reshard Times') + plt.xlabel('Number of Shards') + plt.ylabel('Time (seconds)') + plt.grid(True) + plt.savefig(f"{dir}/reshard_times.png") + + return True, "ok" -BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)] +BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard), + TestCase("benchmark_add_shard_two_nodes", benchmark_add_shard_two_nodes)]