Skip to content
Snippets Groups Projects
Commit d7a0e6a1 authored by zphrs's avatar zphrs
Browse files

fix to basic_shuffle

parent c9790dd3
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import time
import asyncio
import matplotlib.pyplot as plt
def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=16) as fx:
conductor.add_shard("shard1", conductor.get_nodes([0]))
......@@ -14,8 +15,8 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
log("putting 100 keys\n")
put_times = []
for i in range(1000):
c = KVSMultiClient(fx.clients, "client", log)
c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
for i in range(20000):
start_time = time.time()
r = c.put(0, f"key{i}", f"value{i}", timeout=10)
end_time = time.time()
......@@ -39,13 +40,14 @@ def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
# Generate plot
plt.figure(figsize=(10, 6))
plt.plot(range(2, 17), reshard_times, marker='o')
plt.title('Reshard Times')
plt.xlabel('Number of Shards')
plt.ylabel('Time (seconds)')
plt.plot(range(2, 17), reshard_times, marker="o")
plt.title("Reshard Times")
plt.xlabel("Number of Shards")
plt.ylabel("Time (seconds)")
plt.grid(True)
plt.savefig(f"{dir}/reshard_times.png")
return True, "ok"
BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)]
\ No newline at end of file
BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard)]
......@@ -7,6 +7,7 @@ from ..utils.util import Logger
import asyncio
class KVSTestFixture:
def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int):
conductor._parent = self
......@@ -42,9 +43,7 @@ class KVSTestFixture:
async def send_view(client: KVSClient, i: int):
r = await client.async_send_view(view)
assert (
r.status == 200
), f"expected 200 to ack view, got {r.status}"
assert r.status == 200, f"expected 200 to ack view, got {r.status}"
self.log(f"view sent to node {i}: {r.status} {r.text}")
tasks = [send_view(client, i) for i, client in enumerate(self.clients)]
......@@ -53,6 +52,8 @@ class KVSTestFixture:
def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]):
for i, client in enumerate(self.clients):
r = client.resend_last_view_with_ips_from_new_view(new_view)
if r is None:
return
assert r.status_code == 200, (
f"expected 200 to ack view, got {r.status_code}"
)
......
......@@ -102,7 +102,9 @@ def basic_shuffle_add_remove(conductor: ClusterConductor, dir, log: Logger):
res = r.json()["items"]
shard2_keys = res
assert len(shard1_keys) + len(shard2_keys) == 15, f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Remove shard 2. This loses keys.
conductor.remove_shard("shard2")
......@@ -230,7 +232,7 @@ def basic_shuffle_1(conductor: ClusterConductor, dir, log: Logger):
def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
c = KVSMultiClient(fx.clients, "client", log)
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
fx.broadcast_view(conductor.get_shard_view())
......@@ -249,32 +251,38 @@ def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 3
node_to_put = node_to_put % 2
r = c.get_all(0, timeout=10) # should get all of shard 1's keys
r = c.get_all(0, timeout=30) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
original_get_all = r.json()["items"]
assert len(original_get_all.keys()) == 300, (
f"original_get_all doesn't have 300 keys, instead has {len(original_get_all.keys())} keys"
)
conductor.add_shard("shard2", conductor.get_nodes([2]))
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(2, timeout=10) # should get all of shard 2's keys
r = c.get_all(2, timeout=30) # should get all of shard 2's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
get_all_1 = r.json()["items"]
keys1 = get_all_1.keys()
keys1 = set(get_all_1.keys())
r = c.get_all(1, timeout=10) # should get all of shard 1's keys
r = c.get_all(1, timeout=30) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
get_all_2 = r.json()["items"]
keys2 = get_all_2.keys()
keys2 = set(get_all_2.keys())
for key in keys1:
assert not (key in keys2)
assert key not in keys2, "key not in keys2"
for key in keys2:
assert not (key in keys1)
assert key not in keys1, "key not in keys2"
assert original_get_all.keys() == keys1 | keys2
assert len(original_get_all) == len(keys1) + len(keys2)
assert original_get_all.keys() == keys1.union(keys2), (
f"get all keys does not equal key1 joined with keys2. diff one way: \n{keys1.union(keys2).difference(original_get_all.keys())}\n diff other way:\n{set(original_get_all.keys()).difference(keys1.union(keys2))}"
)
assert len(original_get_all) == len(keys1) + len(keys2), "lengths do not match"
return True, "ok"
......
......@@ -152,6 +152,7 @@ class KVSClient:
if not isinstance(current_view, dict):
raise ValueError("view must be a dict")
if not hasattr(self, "last_view"):
return
raise LookupError("Must have sent at least one view before calling resend.")
flattened_current_view = {}
for shard_key in current_view:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment