diff --git a/tests/bench/benchmark.py b/tests/bench/benchmark.py
index 3ec2f1ec3ed98b3ba2c18c786b0f0991bb3d5dcc..ed17ee77fd46b7201874f6f66a78cf07150b006d 100644
--- a/tests/bench/benchmark.py
+++ b/tests/bench/benchmark.py
@@ -7,16 +7,17 @@ import time
 import asyncio
 import matplotlib.pyplot as plt
 
-
+NUM_SHARDS = 16
+NUM_KEYS = 10000
 def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
-    with KVSTestFixture(conductor, dir, log, node_count=16) as fx:
+    with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS) as fx:
         conductor.add_shard("shard1", conductor.get_nodes([0]))
         fx.broadcast_view(conductor.get_shard_view())
 
         log("putting 100 keys\n")
         put_times = []
-        c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
-        for i in range(20000):
+        for i in range(NUM_SHARDS):
+            c = KVSMultiClient(fx.clients, "client", log)
             start_time = time.time()
             r = c.put(0, f"key{i}", f"value{i}", timeout=10)
             end_time = time.time()
@@ -25,7 +26,7 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
 
         log("Starting benchmark\n")
         reshard_times = []
-        for shard in range(2, 17):
+        for shard in range(2, NUM_SHARDS+1):
             start_time = time.time()
             log(f"adding shard{shard}\n")
             conductor.add_shard(f"shard{shard}", conductor.get_nodes([shard - 1]))
@@ -39,11 +40,11 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
         log(f"shard count: {shard}, reshard time: {time_taken}")
 
     # Generate plot
-    plt.figure(figsize=(10, 6))
-    plt.plot(range(2, 17), reshard_times, marker="o")
-    plt.title("Reshard Times")
-    plt.xlabel("Number of Shards")
-    plt.ylabel("Time (seconds)")
+    plt.figure(figsize=(NUM_SHARDS, 10))
+    plt.plot(range(2, NUM_SHARDS+1), reshard_times, marker='o')
+    plt.title('Reshard Times')
+    plt.xlabel('Number of Shards')
+    plt.ylabel('Time (seconds)')
     plt.grid(True)
     plt.savefig(f"{dir}/reshard_times.png")