From d7a0e6a1e2504c6a40c640828d6be9e184b7a3c5 Mon Sep 17 00:00:00 2001
From: zphrs <z@zephiris.dev>
Date: Fri, 14 Mar 2025 20:28:12 -0700
Subject: [PATCH] fix to basic_shuffle

---
 tests/bench/benchmark.py       | 16 +++++++++-------
 tests/helper.py                |  7 ++++---
 tests/shuffle/basic_shuffle.py | 32 ++++++++++++++++++++------------
 utils/kvs_api.py               |  1 +
 4 files changed, 34 insertions(+), 22 deletions(-)

diff --git a/tests/bench/benchmark.py b/tests/bench/benchmark.py
index b5f664e..3ec2f1e 100644
--- a/tests/bench/benchmark.py
+++ b/tests/bench/benchmark.py
@@ -7,6 +7,7 @@ import time
 import asyncio
 import matplotlib.pyplot as plt
 
+
 def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
     with KVSTestFixture(conductor, dir, log, node_count=16) as fx:
         conductor.add_shard("shard1", conductor.get_nodes([0]))
@@ -14,8 +15,8 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
 
         log("putting 100 keys\n")
         put_times = []
-        for i in range(1000):
-            c = KVSMultiClient(fx.clients, "client", log)
+        c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
+        for i in range(20000):
             start_time = time.time()
             r = c.put(0, f"key{i}", f"value{i}", timeout=10)
             end_time = time.time()
@@ -39,13 +40,14 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
 
     # Generate plot
     plt.figure(figsize=(10, 6))
-    plt.plot(range(2, 17), reshard_times, marker='o')
-    plt.title('Reshard Times')
-    plt.xlabel('Number of Shards')
-    plt.ylabel('Time (seconds)')
+    plt.plot(range(2, 17), reshard_times, marker="o")
+    plt.title("Reshard Times")
+    plt.xlabel("Number of Shards")
+    plt.ylabel("Time (seconds)")
     plt.grid(True)
     plt.savefig(f"{dir}/reshard_times.png")
 
     return True, "ok"
 
-BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)]
\ No newline at end of file
+
+BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)]
diff --git a/tests/helper.py b/tests/helper.py
index e33a803..957c6ca 100644
--- a/tests/helper.py
+++ b/tests/helper.py
@@ -7,6 +7,7 @@ from ..utils.util import Logger
 
 import asyncio
 
+
 class KVSTestFixture:
     def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int):
         conductor._parent = self
@@ -42,9 +43,7 @@ class KVSTestFixture:
 
         async def send_view(client: KVSClient, i: int):
             r = await client.async_send_view(view)
-            assert (
-                r.status == 200
-            ), f"expected 200 to ack view, got {r.status}"
+            assert r.status == 200, f"expected 200 to ack view, got {r.status}"
             self.log(f"view sent to node {i}: {r.status} {r.text}")
 
         tasks = [send_view(client, i) for i, client in enumerate(self.clients)]
@@ -53,6 +52,8 @@ class KVSTestFixture:
     def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]):
         for i, client in enumerate(self.clients):
             r = client.resend_last_view_with_ips_from_new_view(new_view)
+            if r is None:
+                return
             assert r.status_code == 200, (
                 f"expected 200 to ack view, got {r.status_code}"
             )
diff --git a/tests/shuffle/basic_shuffle.py b/tests/shuffle/basic_shuffle.py
index a99859c..af00040 100644
--- a/tests/shuffle/basic_shuffle.py
+++ b/tests/shuffle/basic_shuffle.py
@@ -102,7 +102,9 @@ def basic_shuffle_add_remove(conductor: ClusterConductor, dir, log: Logger):
         res = r.json()["items"]
         shard2_keys = res
 
-        assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
+        assert len(shard1_keys) + len(shard2_keys) == 15, (
+            f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
+        )
 
         # Remove shard 2. This loses keys.
         conductor.remove_shard("shard2")
@@ -230,7 +232,7 @@ def basic_shuffle_1(conductor: ClusterConductor, dir, log: Logger):
 
 def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
     with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
-        c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
+        c = KVSMultiClient(fx.clients, "client", log)
         conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
         fx.broadcast_view(conductor.get_shard_view())
 
@@ -249,32 +251,38 @@ def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
             r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
             assert r.ok, f"expected ok for new key, got {r.status_code}"
             node_to_put += 1
-            node_to_put = node_to_put % 3
+            node_to_put = node_to_put % 2
 
-        r = c.get_all(0, timeout=10)  # should get all of shard 1's keys
+        r = c.get_all(0, timeout=30)  # should get all of shard 1's keys
         assert r.ok, f"expected ok for get, got {r.status_code}"
         original_get_all = r.json()["items"]
 
+        assert len(original_get_all.keys()) == 300, (
+            f"original_get_all doesn't have 300 keys, instead has {len(original_get_all.keys())} keys"
+        )
+
         conductor.add_shard("shard2", conductor.get_nodes([2]))
         fx.broadcast_view(conductor.get_shard_view())
 
-        r = c.get_all(2, timeout=10)  # should get all of shard 2's keys
+        r = c.get_all(2, timeout=30)  # should get all of shard 2's keys
         assert r.ok, f"expected ok for get, got {r.status_code}"
         get_all_1 = r.json()["items"]
-        keys1 = get_all_1.keys()
+        keys1 = set(get_all_1.keys())
 
-        r = c.get_all(1, timeout=10)  # should get all of shard 1's keys
+        r = c.get_all(1, timeout=30)  # should get all of shard 1's keys
         assert r.ok, f"expected ok for get, got {r.status_code}"
         get_all_2 = r.json()["items"]
-        keys2 = get_all_2.keys()
+        keys2 = set(get_all_2.keys())
 
         for key in keys1:
-            assert not (key in keys2)
+            assert key not in keys2, "key not in keys2"
         for key in keys2:
-            assert not (key in keys1)
+            assert key not in keys1, "key not in keys2"
 
-        assert original_get_all.keys() == keys1 | keys2
-        assert len(original_get_all) == len(keys1) + len(keys2)
+        assert original_get_all.keys() == keys1.union(keys2), (
+            f"get all keys does not equal key1 joined with keys2. diff one way: \n{keys1.union(keys2).difference(original_get_all.keys())}\n diff other way:\n{set(original_get_all.keys()).difference(keys1.union(keys2))}"
+        )
+        assert len(original_get_all) == len(keys1) + len(keys2), "lengths do not match"
 
         return True, "ok"
 
diff --git a/utils/kvs_api.py b/utils/kvs_api.py
index ab6eb04..c630a50 100644
--- a/utils/kvs_api.py
+++ b/utils/kvs_api.py
@@ -152,6 +152,7 @@ class KVSClient:
         if not isinstance(current_view, dict):
             raise ValueError("view must be a dict")
         if not hasattr(self, "last_view"):
+            return
             raise LookupError("Must have sent at least one view before calling resend.")
         flattened_current_view = {}
         for shard_key in current_view:
-- 
GitLab