diff --git a/tests/bench/benchmark.py b/tests/bench/benchmark.py
index 97ef0d14da2849c94caa7a22e6d053538e7e0fd6..61b5cfa66629cdfc321f183abd5e6d1ca0a009f1 100644
--- a/tests/bench/benchmark.py
+++ b/tests/bench/benchmark.py
@@ -8,14 +8,15 @@ import asyncio
 import matplotlib.pyplot as plt
 
 NUM_SHARDS = 16
-NUM_KEYS = 10000
+NUM_KEYS = 20000
+NUM_NODES = 8
 
 def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
     with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS) as fx:
         conductor.add_shard("shard1", conductor.get_nodes([0]))
         fx.broadcast_view(conductor.get_shard_view())
 
-        log("putting 100 keys\n")
+        log(f"putting {NUM_KEYS} keys\n")
         put_times = []
         for i in range(NUM_KEYS):
             c = KVSMultiClient(fx.clients, "client", log)
@@ -51,5 +52,46 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
 
     return True, "ok"
 
+def benchmark_add_shard_two_nodes(conductor: ClusterConductor, dir, log: Logger):
+    with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS*NUM_NODES) as fx:
+        conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
+        fx.broadcast_view(conductor.get_shard_view())
+
+        log(f"putting {NUM_KEYS} keys\n")
+        put_times = []
+        for i in range(NUM_KEYS):
+            c = KVSMultiClient(fx.clients, "client", log)
+            start_time = time.time()
+            r = c.put(i%NUM_NODES, f"key{i}", f"value{i}", timeout=10)
+            end_time = time.time()
+            assert r.ok, f"expected ok for new key, got {r.status_code}"
+            put_times.append(end_time - start_time)
+
+        log("Starting benchmark\n")
+        reshard_times = []
+        for shard in range(2, NUM_SHARDS+1):
+            start_time = time.time()
+            log(f"adding shard{shard}\n")
+            conductor.add_shard(f"shard{shard}", conductor.get_nodes([2*(shard - 1), 2*(shard - 1) + 1]))
+            asyncio.run(fx.parallel_broadcast_view(conductor.get_shard_view()))
+            end_time = time.time()
+            reshard_times.append(end_time - start_time)
+            log(f"reshard time with {shard} shards: {reshard_times[-1]}\n")
+
+    log("Average put time: ", sum(put_times) / len(put_times))
+    for shard, time_taken in enumerate(reshard_times, start=2):
+        log(f"shard count: {shard}, reshard time: {time_taken}")
+
+    # Generate plot
+    plt.figure(figsize=(NUM_SHARDS, 10))
+    plt.plot(range(2, NUM_SHARDS+1), reshard_times, marker='o')
+    plt.title('Reshard Times')
+    plt.xlabel('Number of Shards')
+    plt.ylabel('Time (seconds)')
+    plt.grid(True)
+    plt.savefig(f"{dir}/reshard_times.png")
+
+    return True, "ok"
 
-BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)]
+BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard),
+              TestCase("benchmark_add_shard_two_nodes", benchmark_add_shard_two_nodes)]