diff --git a/__main__.py b/__main__.py
index f9e72a897a76a2b004c5d458ac01ec7cbc05d0d6..40d59b944498dd44f873c58e5c61ecbaa36c9784 100755
--- a/__main__.py
+++ b/__main__.py
@@ -81,6 +81,7 @@ from .tests.asgn3.causal_consistency.causal_basic import CAUSAL_TESTS
 from .tests.asgn3.eventual_consistency.convergence_basic import CONVERGENCE_TESTS
 # from .tests.asgn3.view_change.view_change_basic import VIEW_CHANGE_TESTS
 from .tests.proxy.basic_proxy import PROXY_TESTS
+from .tests.shuffle.basic_shuffle import SHUFFLE_TESTS
 
 TEST_SET = []
 TEST_SET.append(TestCase("hello_cluster", hello_cluster))
@@ -89,6 +90,7 @@ TEST_SET.extend(AVAILABILITY_TESTS)
 TEST_SET.extend(CAUSAL_TESTS)
 TEST_SET.extend(CONVERGENCE_TESTS)
 TEST_SET.extend(PROXY_TESTS)
+TEST_SET.extend(SHUFFLE_TESTS)
 # TEST_SET.extend(VIEW_CHANGE_TESTS)
 
 # set to True to stop at the first failing test
diff --git a/tests/helper.py b/tests/helper.py
index 18b31dbc6a88b3e062918ee66b7e887e45888bba..e8057e4db2aa4a900886ad5a1445815e674a7e37 100644
--- a/tests/helper.py
+++ b/tests/helper.py
@@ -7,6 +7,7 @@ from ..utils.util import Logger
 
 class KVSTestFixture:  
     def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int):
+        conductor._parent = self
         self.conductor = conductor
         self.dir = dir
         self.node_count = node_count
diff --git a/tests/proxy/basic_proxy.py b/tests/proxy/basic_proxy.py
index c3c97d79c4fafdf9730e3000511e46e86edec299..62e0bb3e03bb23709a50b4bf010d5513b411dc15 100644
--- a/tests/proxy/basic_proxy.py
+++ b/tests/proxy/basic_proxy.py
@@ -67,14 +67,30 @@ def basic_proxy_many_clients(conductor: ClusterConductor, dir, log: Logger):
 
         return True, "ok"
     
-def partitioned_shards(conductor: ClusterConductor, dir, log: Logger):
+def basic_proxy_partitioned_shards(conductor: ClusterConductor, dir, log: Logger, timeout= 5*DEFAULT_TIMEOUT):
     with KVSTestFixture(conductor, dir, log, node_count=4) as fx:
+        c = KVSMultiClient(fx.clients, "client", log)
+        conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
+        conductor.add_shard("shard2", conductor.get_nodes([2, 3]))
+        conductor.create_partition([2,3], "secondshard")
+        fx.broadcast_view(conductor.get_shard_view())
+        
+        helper(c, timeout=timeout)
+        return True, "ok"
+    
+def helper(c: KVSMultiClient, timeout= 5*DEFAULT_TIMEOUT):
         ###
         # test 2
         # partition the shards
         # put a bunch of keys
         # we MUST probablistically encounter some hanging there. 
         # have a time out where if it doesnt hang after like 50 keys, then its just wrong.
-        return True, "ok"
+        node_to_put = 0
+        base_key = "key"
+        for i in range(0, 50):
+            r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
+            assert r.ok or r.status_code == 408, f"expected ok for new key, got {r.status_code}"
+            node_to_put += 1
+            node_to_put = node_to_put % 4
 
-PROXY_TESTS = [TestCase("basic_proxy_one_client", basic_proxy_one_client), TestCase("basic_proxy_many_clients", basic_proxy_many_clients)]
+PROXY_TESTS = [TestCase("basic_proxy_one_client", basic_proxy_one_client), TestCase("basic_proxy_many_clients", basic_proxy_many_clients), TestCase("basic_proxy_partitioned_shards", basic_proxy_partitioned_shards)]
diff --git a/tests/shuffle/basic_shuffle.py b/tests/shuffle/basic_shuffle.py
new file mode 100644
index 0000000000000000000000000000000000000000..c23cda659d3788ff92a33921dfb0982722612120
--- /dev/null
+++ b/tests/shuffle/basic_shuffle.py
@@ -0,0 +1,113 @@
+from ...utils.containers import ClusterConductor
+from ...utils.testcase import TestCase
+from ...utils.util import Logger
+from ..helper import KVSMultiClient, KVSTestFixture
+from ...utils.kvs_api import DEFAULT_TIMEOUT
+
+def basic_shuffle(conductor: ClusterConductor, dir, log: Logger):
+    with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
+        c = KVSMultiClient(fx.clients, "client", log)
+        conductor.add_shard("shard1", conductor.get_nodes([0]))
+        conductor.add_shard("shard2", conductor.get_nodes([1]))
+        
+        fx.broadcast_view(conductor.get_shard_view())
+
+        node_to_put = 0
+        base_key = "key"
+        # Put 15 keys
+        for i in range(15):
+            log(f"Putting key {i}\n")
+            r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
+            assert r.ok, f"expected ok for new key, got {r.status_code}"
+            node_to_put += 1
+            node_to_put = node_to_put % 2
+
+        # Get all keys
+        r = c.get_all(0, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys = res
+
+        r = c.get_all(1, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard2_keys = res
+
+        log(f"Shard 1 keys: {shard1_keys}\n")
+        log(f"Shard 2 keys: {shard2_keys}\n")
+
+        # Total number of keys should matched number of keys put
+        assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
+
+        # Add a 3rd shard, causing a shuffle. There should still be 15 keys at the end.
+        log("Adding 3rd shard\n")
+        conductor.add_shard("shard3", conductor.get_nodes([2]))
+        fx.broadcast_view(conductor.get_shard_view())
+
+        # Get the keys on shard 1
+        r = c.get_all(0, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys = res
+
+        log(f"Shard 1 keys: {shard1_keys}\n")
+
+        # get the keys on shard 2
+        r = c.get_all(1, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard2_keys = res
+
+        log(f"Shard 2 keys: {shard2_keys}\n")
+
+        # get the keys on shard 3
+        r = c.get_all(2, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard3_keys = res
+    
+        log(f"Shard 3 keys: {shard3_keys}\n")
+
+        assert len(shard1_keys) + len(shard2_keys) + len(shard3_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys) + len(shard3_keys)}"
+
+        # Remove shard 3, causing a shuffle. Move Node 2 to shard 1 so the keys should still exist, and be shuffled
+        conductor.remove_shard("shard3")
+        conductor.add_node_to_shard("shard1", conductor.get_nodes([2]))
+
+        r = c.get_all(0, timeout=10) # should get all of shard 1's keys
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys = res
+
+        r = c.get_all(1, timeout=10) # should get all of shard 2's keys
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard2_keys = res
+
+        assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
+
+        # Remove shard 2. This loses keys.
+        conductor.remove_shard("shard2")
+        fx.broadcast_view(conductor.get_shard_view())
+
+        r = c.get_all(0, timeout=10)
+        assert r.ok, f"expected ok for get, got {r.status_code}"
+        res = r.json()["items"]
+        shard1_keys_after_delete = res
+        
+        assert len(shard1_keys) == len(shard1_keys_after_delete), f"expected {len(shard1_keys)} keys, got {len(shard1_keys_after_delete)}"
+
+
+        return True, "ok"
+    
+def partitioned_shards(conductor: ClusterConductor, dir, log: Logger):
+    with KVSTestFixture(conductor, dir, log, node_count=4) as fx:
+        ###
+        # test 2
+        # partition the shards
+        # put a bunch of keys
+        # we MUST probablistically encounter some hanging there. 
+        # have a time out where if it doesnt hang after like 50 keys, then its just wrong.
+        return True, "ok"
+
+SHUFFLE_TESTS = [TestCase("shuffle_basic", basic_shuffle)]
\ No newline at end of file
diff --git a/utils/containers.py b/utils/containers.py
index 894dbd281f25808892f6ba59dd3900f949b644ea..54a3bb4bd1260f42b2c4573f967a2c61ba6a7e98 100644
--- a/utils/containers.py
+++ b/utils/containers.py
@@ -384,6 +384,7 @@ class ClusterConductor:
 
         # connect nodes to partition network, and update node ip
         self.log(f"  connecting nodes to partition network {net_name}")
+        view_changed = False
         for i in node_ids:
             node = self.nodes[i]
 
@@ -412,7 +413,16 @@ class ClusterConductor:
             self.log(f"    node {node.name} ip in network {net_name}: {container_ip}")
 
             # update node ip
-            node.ip = container_ip
+
+            if container_ip != node.ip:
+                node.ip = container_ip
+                if hasattr(self, "_parent"):
+                    self._parent.clients[
+                        node.index
+                    ].base_url = self.node_external_endpoint(node.index)
+                view_changed = True
+        if view_changed and hasattr(self, "_parent"):
+            self._parent.broadcast_view(self.get_shard_view())
 
     def create_partition(self, node_ids: List[int], partition_id: str) -> None:
         net_name = f"kvs_{self.group_id}_net_{partition_id}"
@@ -441,6 +451,7 @@ class ClusterConductor:
 
         # connect nodes to partition network, and update node ip
         self.log(f"  connecting nodes to partition network {net_name}")
+        view_changed = False
         for i in node_ids:
             node = self.nodes[i]
             self.log(f"    connecting {node.name} to network {net_name}")
@@ -463,7 +474,15 @@ class ClusterConductor:
             self.log(f"    node {node.name} ip in network {net_name}: {container_ip}")
 
             # update node ip
-            node.ip = container_ip
+            if container_ip != node.ip:
+                node.ip = container_ip
+                if hasattr(self, "_parent"):
+                    self._parent.clients[
+                        node.index
+                    ].base_url = self.node_external_endpoint(node.index)
+                view_changed = True
+        if view_changed and hasattr(self, "_parent"):
+            self._parent.broadcast_view(self.get_shard_view())
 
     DeprecationWarning("View is in updated format")