From 7504efee864b0063e7e5b06ef06c8a080ab898bd Mon Sep 17 00:00:00 2001
From: Alexander Aghili <alexander.w.aghili@gmail.com>
Date: Fri, 14 Mar 2025 11:36:15 -0700
Subject: [PATCH] Basic shuffle 3 + fix to container

---
 tests/proxy/basic_proxy.py     |   2 +-
 tests/shuffle/basic_shuffle.py | 101 ++++++++++++++++++++++++++++++++-
 utils/containers.py            |   4 +-
 3 files changed, 103 insertions(+), 4 deletions(-)

diff --git a/tests/proxy/basic_proxy.py b/tests/proxy/basic_proxy.py
index 62e0bb3..6366192 100644
--- a/tests/proxy/basic_proxy.py
+++ b/tests/proxy/basic_proxy.py
@@ -77,7 +77,7 @@ def basic_proxy_partitioned_shards(conductor: ClusterConductor, dir, log: Logger
         
         helper(c, timeout=timeout)
         return True, "ok"
-    
+
 def helper(c: KVSMultiClient, timeout= 5*DEFAULT_TIMEOUT):
         ###
         # test 2
diff --git a/tests/shuffle/basic_shuffle.py b/tests/shuffle/basic_shuffle.py
index 6ba0e9d..de404cc 100644
--- a/tests/shuffle/basic_shuffle.py
+++ b/tests/shuffle/basic_shuffle.py
@@ -248,4 +248,103 @@ def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
 
         return True, "ok"
 
-SHUFFLE_TESTS = [TestCase("basic_shuffle_add_remove", basic_shuffle_add_remove), TestCase("basic_shuffle", basic_shuffle), TestCase("basic_shuffle_2", basic_shuffle_2)]
+
+def basic_shuffle_3(conductor: ClusterConductor, dir, log: Logger):
+    with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
+        c = KVSMultiClient(fx.clients, "client", log)
+        conductor.add_shard("shard1", conductor.get_nodes([0]))
+        conductor.add_shard("shard2", conductor.get_nodes([1]))
+        
+        fx.broadcast_view(conductor.get_shard_view())
+
+        node_to_put = 0
+        base_key = "key"
+        # Put 15 keys
+        for i in range(15):
+            log(f"Putting key {i}\n")
+            r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
+            assert r.ok, f"expected ok for new key, got {r.status_code}"
+            node_to_put += 1
+            node_to_put = node_to_put % 2
+
+        # Get all keys
+        r = c.get_all(0, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys = res
+
+        r = c.get_all(1, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard2_keys = res
+
+        log(f"Shard 1 keys: {shard1_keys}\n")
+        log(f"Shard 2 keys: {shard2_keys}\n")
+
+        # Total number of keys should matched number of keys put
+        assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
+
+        # Add a 3rd shard, causing a shuffle. There should still be 15 keys at the end.
+        log("Adding 3rd shard\n")
+        conductor.add_shard("shard3", conductor.get_nodes([2]))
+        fx.broadcast_view(conductor.get_shard_view())
+
+        # Get the keys on shard 1
+        r = c.get_all(0, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys = res
+
+        log(f"Shard 1 keys: {shard1_keys}\n")
+
+        # get the keys on shard 2
+        r = c.get_all(1, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard2_keys = res
+
+        log(f"Shard 2 keys: {shard2_keys}\n")
+
+        # get the keys on shard 3
+        r = c.get_all(2, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard3_keys = res
+    
+        log(f"Shard 3 keys: {shard3_keys}\n")
+
+        assert len(shard1_keys) + len(shard2_keys) + len(shard3_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys) + len(shard3_keys)}"
+
+        # Remove shard 3, causing a shuffle. Move Node 2 to shard 1 so the keys should still exist, and be shuffled
+        conductor.remove_shard("shard3")
+        conductor.add_node_to_shard("shard1", conductor.get_node(2))
+        fx.broadcast_view(conductor.get_shard_view())
+
+        r = c.get_all(0, timeout=10) # should get all of shard 1's keys
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys = res
+
+        r = c.get_all(1, timeout=10) # should get all of shard 2's keys
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard2_keys = res
+
+        assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
+
+        # Remove shard 2. This loses keys.
+        conductor.remove_shard("shard2")
+        fx.broadcast_view(conductor.get_shard_view())
+
+        r = c.get_all(0, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys_after_delete = res
+        
+        assert len(shard1_keys_after_delete) == 15, f"expected 15 keys, got {len(shard1_keys_after_delete)}"
+
+
+        return True, "ok"
+
+
+SHUFFLE_TESTS = [TestCase("basic_shuffle_add_remove", basic_shuffle_add_remove), TestCase("basic_shuffle", basic_shuffle), TestCase("basic_shuffle_2", basic_shuffle_2), TestCase("basic_shuffle_3", basic_shuffle_3)]
diff --git a/utils/containers.py b/utils/containers.py
index 7041cd3..d62efd6 100644
--- a/utils/containers.py
+++ b/utils/containers.py
@@ -481,8 +481,8 @@ class ClusterConductor:
                         node.index
                     ].base_url = self.node_external_endpoint(node.index)
                 view_changed = True
-        if view_changed and hasattr(self, "_parent"):
-            self._parent.rebroadcast_view(self.get_shard_view())
+        #if view_changed and hasattr(self, "_parent"):
+        #    self._parent.rebroadcast_view(self.get_shard_view())
 
     DeprecationWarning("View is in updated format")
 
-- 
GitLab