diff --git a/tests/shuffle/basic_shuffle.py b/tests/shuffle/basic_shuffle.py index 68c98ed5669d85a1a19e8b886f3cd77ea016216d..3272be2739ab33eeb1780c7246b20bc5b25f926b 100644 --- a/tests/shuffle/basic_shuffle.py +++ b/tests/shuffle/basic_shuffle.py @@ -101,14 +101,54 @@ def basic_shuffle(conductor: ClusterConductor, dir, log: Logger): return True, "ok" -def partitioned_shards(conductor: ClusterConductor, dir, log: Logger): - with KVSTestFixture(conductor, dir, log, node_count=4) as fx: - ### - # test 2 - # partition the shards - # put a bunch of keys - # we MUST probablistically encounter some hanging there. - # have a time out where if it doesnt hang after like 50 keys, then its just wrong. +def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger): + with KVSTestFixture(conductor, dir, log, node_count=3) as fx: + c = KVSMultiClient(fx.clients, "client", log) + conductor.add_shard("shard1", conductor.get_nodes([0, 1])) + fx.broadcast_view(conductor.get_shard_view()) + + ## basic shuffle 2 + # view= 1 shard with 2 nodes + # put 50 keys + # get_all keys from shard 1 + # add shard with 1 node + # get_all keys from shard 2 + # get_all keys from shard 1 + # check both returned sets are disjoint and that their union makes the original get_all results + + node_to_put = 0 + base_key = "key" + for i in range(0, 50): + r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10) + assert r.ok, f"expected ok for new key, got {r.status_code}" + node_to_put += 1 + node_to_put = node_to_put % 3 + + r = c.get_all(0, timeout=10) # should get all of shard 1's keys + assert r.ok, f"expected ok for get, got {r.status_code}" + original_get_all = r.json()["items"] + + conductor.add_shard("shard2", conductor.get_nodes([2])) + fx.broadcast_view(conductor.get_shard_view()) + + r = c.get_all(2, timeout=10) # should get all of shard 2's keys + assert r.ok, f"expected ok for get, got {r.status_code}" + get_all_1 = r.json()["items"] + keys1 = get_all_1.keys() + + r = c.get_all(1, timeout=10) # should get all of shard 1's keys + assert r.ok, f"expected ok for get, got {r.status_code}" + get_all_2 = r.json()["items"] + keys2 = get_all_2.keys() + + for key in keys1: + assert not (key in keys2) + for key in keys2: + assert not (key in keys1) + + assert original_get_all.keys() == keys1 + keys2 + assert len(original_get_all) == len(keys1) + len(keys2) + return True, "ok" -SHUFFLE_TESTS = [TestCase("shuffle_basic", basic_shuffle)] \ No newline at end of file +SHUFFLE_TESTS = [TestCase("basic_shuffle", basic_shuffle), TestCase("basic_shuffle_2", basic_shuffle_2)] \ No newline at end of file