From f3a50fcfa9c34839aaa95d1ca2c74aff6f2973bf Mon Sep 17 00:00:00 2001
From: Thomas Dillow <tdillow@ucsc.edu>
Date: Fri, 14 Mar 2025 21:20:47 -0700
Subject: [PATCH] can now change bench size stuff easily

---
 tests/bench/benchmark.py | 12 +++++++-----
 1 file changed, 7 insertions(+), 5 deletions(-)

diff --git a/tests/bench/benchmark.py b/tests/bench/benchmark.py
index b5f664e..5b86587 100644
--- a/tests/bench/benchmark.py
+++ b/tests/bench/benchmark.py
@@ -7,14 +7,16 @@ import time
 import asyncio
 import matplotlib.pyplot as plt
 
+NUM_SHARDS = 16
+NUM_KEYS = 10000
 def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
-    with KVSTestFixture(conductor, dir, log, node_count=16) as fx:
+    with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS) as fx:
         conductor.add_shard("shard1", conductor.get_nodes([0]))
         fx.broadcast_view(conductor.get_shard_view())
 
         log("putting 100 keys\n")
         put_times = []
-        for i in range(1000):
+        for i in range(NUM_SHARDS):
             c = KVSMultiClient(fx.clients, "client", log)
             start_time = time.time()
             r = c.put(0, f"key{i}", f"value{i}", timeout=10)
@@ -24,7 +26,7 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
 
         log("Starting benchmark\n")
         reshard_times = []
-        for shard in range(2, 17):
+        for shard in range(2, NUM_SHARDS+1):
             start_time = time.time()
             log(f"adding shard{shard}\n")
             conductor.add_shard(f"shard{shard}", conductor.get_nodes([shard - 1]))
@@ -38,8 +40,8 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
         log(f"shard count: {shard}, reshard time: {time_taken}")
 
     # Generate plot
-    plt.figure(figsize=(10, 6))
-    plt.plot(range(2, 17), reshard_times, marker='o')
+    plt.figure(figsize=(NUM_SHARDS, 10))
+    plt.plot(range(2, NUM_SHARDS+1), reshard_times, marker='o')
     plt.title('Reshard Times')
     plt.xlabel('Number of Shards')
     plt.ylabel('Time (seconds)')
-- 
GitLab