Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • dakshah/cse138-assignment-4-test-suite
  • ketompki/cse138-assignment-4-test-suite
  • evkjones/cse138-assignment-4-test-suite
  • ranadkar/cse138-assignment-4-test-suite
  • ctknab/cse138-assignment-4-test-suite
  • awaghili/cse138-assignment-4-test-suite
6 results
Show changes
Commits on Source (71)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import os import os
import sys import sys
import re
import json import json
import subprocess import subprocess
import time import time
...@@ -20,8 +21,15 @@ import logging ...@@ -20,8 +21,15 @@ import logging
CONTAINER_IMAGE_ID = "kvstore-hw3-test" CONTAINER_IMAGE_ID = "kvstore-hw3-test"
TEST_GROUP_ID = "hw3" TEST_GROUP_ID = "hw3"
class TestRunner: class TestRunner:
def __init__(self, project_dir: str, debug_output_dir: str): def __init__(
self,
project_dir: str,
debug_output_dir: str,
group_id=TEST_GROUP_ID,
thread_id="0",
):
self.project_dir = project_dir self.project_dir = project_dir
self.debug_output_dir = debug_output_dir self.debug_output_dir = debug_output_dir
# builder to build container image # builder to build container image
...@@ -30,7 +38,8 @@ class TestRunner: ...@@ -30,7 +38,8 @@ class TestRunner:
) )
# network manager to mess with container networking # network manager to mess with container networking
self.conductor = ClusterConductor( self.conductor = ClusterConductor(
group_id=TEST_GROUP_ID, group_id=group_id,
thread_id=thread_id,
base_image=CONTAINER_IMAGE_ID, base_image=CONTAINER_IMAGE_ID,
external_port_base=9000, external_port_base=9000,
log=global_logger(), log=global_logger(),
...@@ -46,7 +55,7 @@ class TestRunner: ...@@ -46,7 +55,7 @@ class TestRunner:
# aggressively clean up anything kvs-related # aggressively clean up anything kvs-related
# NOTE: this disallows parallel run processes, so turn it off for that # NOTE: this disallows parallel run processes, so turn it off for that
self.conductor.cleanup_hanging(group_only=False) self.conductor.cleanup_hanging(group_only=True)
def cleanup_environment(self) -> None: def cleanup_environment(self) -> None:
log("\n-- cleanup_environment --") log("\n-- cleanup_environment --")
...@@ -80,12 +89,27 @@ from .tests.asgn3.availability.availability_basic import AVAILABILITY_TESTS ...@@ -80,12 +89,27 @@ from .tests.asgn3.availability.availability_basic import AVAILABILITY_TESTS
from .tests.asgn3.causal_consistency.causal_basic import CAUSAL_TESTS from .tests.asgn3.causal_consistency.causal_basic import CAUSAL_TESTS
from .tests.asgn3.eventual_consistency.convergence_basic import CONVERGENCE_TESTS from .tests.asgn3.eventual_consistency.convergence_basic import CONVERGENCE_TESTS
# from .tests.asgn3.view_change.view_change_basic import VIEW_CHANGE_TESTS
from .tests.proxy.basic_proxy import PROXY_TESTS
from .tests.shuffle.basic_shuffle import SHUFFLE_TESTS
from .tests.bench.benchmark import BENCHMARKS
from .tests.stress.stress_tests import STRESS_TESTS
TEST_SET = [] TEST_SET = []
# tests from here...
TEST_SET.append(TestCase("hello_cluster", hello_cluster)) TEST_SET.append(TestCase("hello_cluster", hello_cluster))
TEST_SET.extend(BASIC_TESTS) TEST_SET.extend(BASIC_TESTS)
TEST_SET.extend(AVAILABILITY_TESTS) TEST_SET.extend(AVAILABILITY_TESTS)
TEST_SET.extend(CAUSAL_TESTS) TEST_SET.extend(CAUSAL_TESTS)
TEST_SET.extend(CONVERGENCE_TESTS) TEST_SET.extend(CONVERGENCE_TESTS)
# ... to here can be run in parallel.
# Below here should be run synchronously:
TEST_SET.extend(PROXY_TESTS)
TEST_SET.extend(SHUFFLE_TESTS)
TEST_SET.extend(STRESS_TESTS)
# This one is less of a test and more a graphing program
# (absolutely CANNOT be run in parallel thanks to matplotlib):
TEST_SET.extend(BENCHMARKS)
# set to True to stop at the first failing test # set to True to stop at the first failing test
FAIL_FAST = True FAIL_FAST = True
...@@ -107,6 +131,11 @@ def main(): ...@@ -107,6 +131,11 @@ def main():
parser.add_argument( parser.add_argument(
"--num-threads", type=int, default=1, help="number of threads to run tests in" "--num-threads", type=int, default=1, help="number of threads to run tests in"
) )
parser.add_argument(
"--group-id",
default=TEST_GROUP_ID,
help="Group Id (prepended to docker containers & networks) (useful for running two versions of the test suite in parallel)",
)
parser.add_argument( parser.add_argument(
"--port-offset", type=int, default=1000, help="port offset for each test" "--port-offset", type=int, default=1000, help="port offset for each test"
) )
...@@ -114,14 +143,19 @@ def main(): ...@@ -114,14 +143,19 @@ def main():
args = parser.parse_args() args = parser.parse_args()
project_dir = os.getcwd() project_dir = os.getcwd()
runner = TestRunner(project_dir=project_dir, debug_output_dir=DEBUG_OUTPUT_DIR) runner = TestRunner(
project_dir=project_dir,
debug_output_dir=DEBUG_OUTPUT_DIR,
group_id=args.group_id,
thread_id="0",
)
runner.prepare_environment(build=args.build) runner.prepare_environment(build=args.build)
if args.filter is not None: if args.filter is not None:
test_filter = args.filter test_filter = args.filter
log(f"filtering tests by: {test_filter}") log(f"filtering tests by: {test_filter}")
global TEST_SET global TEST_SET
TEST_SET = [t for t in TEST_SET if test_filter in t.name] TEST_SET = [t for t in TEST_SET if re.compile(test_filter).match(t.name)]
if args.run_all: if args.run_all:
global FAIL_FAST global FAIL_FAST
...@@ -130,7 +164,7 @@ def main(): ...@@ -130,7 +164,7 @@ def main():
log("\n== RUNNING TESTS ==") log("\n== RUNNING TESTS ==")
run_tests = [] run_tests = []
def run_test(test: TestCase, gid: str, port_offset: int): def run_test(test: TestCase, gid: str, thread_id: str, port_offset: int):
log(f"\n== TEST: [{test.name}] ==\n") log(f"\n== TEST: [{test.name}] ==\n")
test_set_name = test.name.lower().split("_")[0] test_set_name = test.name.lower().split("_")[0]
test_dir = create_test_dir(DEBUG_OUTPUT_DIR, test_set_name, test.name) test_dir = create_test_dir(DEBUG_OUTPUT_DIR, test_set_name, test.name)
...@@ -142,6 +176,7 @@ def main(): ...@@ -142,6 +176,7 @@ def main():
logger = Logger(files=(log_file, sys.stderr)) logger = Logger(files=(log_file, sys.stderr))
conductor = ClusterConductor( conductor = ClusterConductor(
group_id=gid, group_id=gid,
thread_id=f"{thread_id}",
base_image=CONTAINER_IMAGE_ID, base_image=CONTAINER_IMAGE_ID,
external_port_base=9000 + port_offset, external_port_base=9000 + port_offset,
log=logger, log=logger,
...@@ -160,7 +195,7 @@ def main(): ...@@ -160,7 +195,7 @@ def main():
if args.num_threads == 1: if args.num_threads == 1:
print("Running tests sequentially") print("Running tests sequentially")
for test in TEST_SET: for test in TEST_SET:
if not run_test(test, gid="0", port_offset=0): if not run_test(test, gid=args.group_id, thread_id="0", port_offset=0):
if not args.run_all: if not args.run_all:
print("--run-all not set, stopping at first failure") print("--run-all not set, stopping at first failure")
break break
...@@ -169,7 +204,10 @@ def main(): ...@@ -169,7 +204,10 @@ def main():
pool = ThreadPool(processes=args.num_threads) pool = ThreadPool(processes=args.num_threads)
pool.map( pool.map(
lambda a: run_test( lambda a: run_test(
a[1], gid=f"{a[0]}", port_offset=a[0] * args.port_offset a[1],
gid=args.group_id,
thread_id=f"{a[0]}",
port_offset=a[0] * args.port_offset,
), ),
enumerate(TEST_SET), enumerate(TEST_SET),
) )
......
[project] [project]
name = "cse138-asgn3-tests" name = "cse138-asgn4-tests"
version = "0.1.0" version = "0.1.0"
description = "Add your description here" description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"
dependencies = [ dependencies = [
"requests>=2.32.3", "requests>=2.32.3",
"matplotlib>=3.4.3",
"aiohttp>=3.8.1",
] ]
[project.scripts] [project.scripts]
cse138-asgn3-tests = "cse138_asgn3_tests.__main__:main" cse138-asgn4-tests = "cse138_asgn4_tests.__main__:main"
...@@ -451,7 +451,7 @@ def causal_basic_tiebreak(conductor, dir, log): ...@@ -451,7 +451,7 @@ def causal_basic_tiebreak(conductor, dir, log):
b = david.get(1, "x") b = david.get(1, "x")
assert b.ok, f"expected ok for get, got {b.status_code}" assert b.ok, f"expected ok for get, got {b.status_code}"
assert a.json()["value"] == b.json()["value"], f"expected {a} == {b}" assert a.json()["value"] == b.json()["value"], f"expected {a.json()['value']} == {b.json()['value']}"
if a.json()["value"] == "1": if a.json()["value"] == "1":
assert alice_hang, f"expected alice to hang" assert alice_hang, f"expected alice to hang"
......
...@@ -66,5 +66,70 @@ def basic_kv_1(conductor: ClusterConductor, dir, log: Logger): ...@@ -66,5 +66,70 @@ def basic_kv_1(conductor: ClusterConductor, dir, log: Logger):
return True, 0 return True, 0
def basic_kv_verify_proxy(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=4) as fx:
c1 = KVSMultiClient(fx.clients, "c1", log)
c2 = KVSMultiClient(fx.clients, "c2", log)
c3 = KVSMultiClient(fx.clients, "c3", log)
c4 = KVSMultiClient(fx.clients, "c4", log)
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
conductor.add_shard("shard2", conductor.get_nodes([2, 3]))
fx.broadcast_view(conductor.get_shard_view())
keys = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k"]
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
for i in range(len(keys)):
key = keys[i]
value = str(values[i])
r = c1.put(0, key, value)
assert r.ok, f"expected ok for new key, got {r.status_code}"
conductor.create_partition([0,1], "p0")
conductor.create_partition([2,3], "p1")
r = c2.get_all(0)
assert r.ok, f"expected ok for new key, got {r.status_code}"
shard1_keys = r.json()["items"]
r = c3.get_all(2)
assert r.ok, f"expected ok for new key, got {r.status_code}"
shard2_keys = r.json()["items"]
print(shard1_keys)
print(shard2_keys)
assert ((len(shard1_keys) > 0) and (len(shard2_keys) > 0)), "One of the shards has no keys, this is extremely unlikely (1/2^11) and probably means something is wrong"
rk1 = list(shard1_keys.keys())[0]
rk2 = list(shard2_keys.keys())[0]
r = c4.put(0, rk2, "This should fail")
assert r.status_code == 408, f"expected 408 for new key, got {r.status_code}"
r = c4.put(2, rk1, "This should also fail")
assert r.status_code == 408, f"expected 408 for new key, got {r.status_code}"
conductor.create_partition([0, 1, 2, 3], "base")
r = c4.put(0, rk2, "This should work")
assert r.ok, f"expected ok for new key, got {r.status_code}"
r = c4.put(2, rk1, "This should also work")
assert r.ok, f"expected ok for new key, got {r.status_code}"
r = c2.get_all(0)
assert r.ok, f"expected ok for new key, got {r.status_code}"
shard1_keys = r.json()["items"]
r = c3.get_all(2)
assert r.ok, f"expected ok for new key, got {r.status_code}"
shard2_keys = r.json()["items"]
print(shard1_keys)
print(shard2_keys)
assert (len(shard1_keys) > 0 and len(shard2_keys) > 0), "One of the shards has no keys, this is extremely unlikely (1/2^11) and probably means something is wrong"
return True, 0
BASIC_TESTS = [TestCase("basic_kv_1", basic_kv_1), TestCase("basic_kv_view_accept", basic_kv_view_accept)] BASIC_TESTS = [TestCase("basic_kv_1", basic_kv_1), TestCase("basic_kv_view_accept", basic_kv_view_accept), TestCase("basic_kv_verify_proxy", basic_kv_verify_proxy)]
from ...utils.containers import ClusterConductor
from ...utils.testcase import TestCase
from ...utils.util import Logger
from ..helper import KVSMultiClient, KVSTestFixture
from ...utils.kvs_api import DEFAULT_TIMEOUT
import time
import asyncio
import matplotlib.pyplot as plt
NUM_SHARDS = 16
NUM_KEYS = 5000
NUM_NODES = 8
def benchmark_add_shard(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS) as fx:
conductor.add_shard("shard1", conductor.get_nodes([0]))
fx.broadcast_view(conductor.get_shard_view())
log(f"putting {NUM_KEYS} keys\n")
put_times = []
for i in range(NUM_KEYS):
c = KVSMultiClient(fx.clients, "client", log)
start_time = time.time()
r = c.put(0, f"key{i}", f"value{i}", timeout=10)
end_time = time.time()
assert r.ok, f"expected ok for new key, got {r.status_code}"
put_times.append(end_time - start_time)
log("Starting benchmark\n")
reshard_times = []
for shard in range(2, NUM_SHARDS+1):
start_time = time.time()
log(f"adding shard{shard}\n")
conductor.add_shard(f"shard{shard}", conductor.get_nodes([shard - 1]))
asyncio.run(fx.parallel_broadcast_view(conductor.get_shard_view()))
end_time = time.time()
reshard_times.append(end_time - start_time)
log(f"reshard time with {shard} shards: {reshard_times[-1]}\n")
log("Average put time: ", sum(put_times) / len(put_times))
for shard, time_taken in enumerate(reshard_times, start=2):
log(f"shard count: {shard}, reshard time: {time_taken}")
# Generate plot
plt.figure(figsize=(NUM_SHARDS, 10))
plt.plot(range(2, NUM_SHARDS+1), reshard_times, marker='o')
plt.title('Reshard Times')
plt.xlabel('Number of Shards')
plt.ylabel('Time (seconds)')
plt.grid(True)
plt.savefig(f"{dir}/reshard_times.png")
return True, "ok"
def benchmark_add_shard_two_nodes(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=NUM_SHARDS*NUM_NODES) as fx:
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
fx.broadcast_view(conductor.get_shard_view())
log(f"putting {NUM_KEYS} keys\n")
put_times = []
for i in range(NUM_KEYS):
c = KVSMultiClient(fx.clients, "client", log)
start_time = time.time()
r = c.put(i%NUM_NODES, f"key{i}", f"value{i}", timeout=10)
end_time = time.time()
assert r.ok, f"expected ok for new key, got {r.status_code}"
put_times.append(end_time - start_time)
log("Starting benchmark\n")
reshard_times = []
for shard in range(2, NUM_SHARDS+1):
start_time = time.time()
log(f"adding shard{shard}\n")
conductor.add_shard(f"shard{shard}", conductor.get_nodes([2*(shard - 1), 2*(shard - 1) + 1]))
asyncio.run(fx.parallel_broadcast_view(conductor.get_shard_view()))
end_time = time.time()
reshard_times.append(end_time - start_time)
log(f"reshard time with {shard} shards: {reshard_times[-1]}\n")
log("Average put time: ", sum(put_times) / len(put_times))
for shard, time_taken in enumerate(reshard_times, start=2):
log(f"shard count: {shard}, reshard time: {time_taken}")
# Generate plot
plt.figure(figsize=(NUM_SHARDS, 10))
plt.plot(range(2, NUM_SHARDS+1), reshard_times, marker='o')
plt.title('Reshard Times')
plt.xlabel('Number of Shards')
plt.ylabel('Time (seconds)')
plt.grid(True)
plt.savefig(f"{dir}/reshard_times.png")
return True, "ok"
BENCHMARKS = [TestCase("benchmark_add_shard", benchmark_add_shard),
TestCase("benchmark_add_shard_two_nodes", benchmark_add_shard_two_nodes)]
...@@ -5,6 +5,8 @@ from ..utils.containers import ClusterConductor ...@@ -5,6 +5,8 @@ from ..utils.containers import ClusterConductor
from ..utils.kvs_api import KVSClient from ..utils.kvs_api import KVSClient
from ..utils.util import Logger from ..utils.util import Logger
import asyncio
class KVSTestFixture: class KVSTestFixture:
def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int): def __init__(self, conductor: ClusterConductor, dir, log: Logger, node_count: int):
...@@ -36,6 +38,28 @@ class KVSTestFixture: ...@@ -36,6 +38,28 @@ class KVSTestFixture:
) )
self.log(f"view sent to node {i}: {r.status_code} {r.text}") self.log(f"view sent to node {i}: {r.status_code} {r.text}")
async def parallel_broadcast_view(self, view: Dict[str, List[Dict[str, Any]]]):
self.log(f"\n> SEND VIEW: {view}")
async def send_view(client: KVSClient, i: int):
r = await client.async_send_view(view)
assert r.status == 200, f"expected 200 to ack view, got {r.status}"
self.log(f"view sent to node {i}: {r.status} {r.text}")
tasks = [send_view(client, i) for i, client in enumerate(self.clients)]
await asyncio.gather(*tasks)
def rebroadcast_view(self, new_view: Dict[str, List[Dict[str, Any]]]):
for i, client in enumerate(self.clients):
self.log(f"rebroadcasting view for node {i}")
r = client.resend_last_view_with_ips_from_new_view(new_view, self.log)
if r is None:
return
assert r.status_code == 200, (
f"expected 200 to ack view, got {r.status_code}"
)
self.log(f"view resent to node {i}: {r.status_code} {r.text}")
def send_view(self, node_id: int, view: Dict[str, List[Dict[str, Any]]]): def send_view(self, node_id: int, view: Dict[str, List[Dict[str, Any]]]):
r = self.clients[node_id].send_view(view) r = self.clients[node_id].send_view(view)
assert r.status_code == 200, f"expected 200 to ack view, got {r.status_code}" assert r.status_code == 200, f"expected 200 to ack view, got {r.status_code}"
...@@ -55,16 +79,28 @@ class KVSTestFixture: ...@@ -55,16 +79,28 @@ class KVSTestFixture:
class KVSMultiClient: class KVSMultiClient:
def __init__(self, clients: List[KVSClient], name: str, log: Logger): def __init__(
self, clients: List[KVSClient], name: str, log: Logger, persist_metadata=True
):
self.clients = clients self.clients = clients
self.metadata = None self._metadata = None
self.name = name self.name = name
self.req = 0 self.req = 0
self.log = log self.log = log
self.persist_metadata = persist_metadata
# internal model of kvs # internal model of kvs
self._kvs_model = {} self._kvs_model = {}
@property
def metadata(self):
"""I'm the 'x' property."""
return self._metadata
@metadata.setter
def metadata(self, value):
self._metadata = value if self.persist_metadata else None
def reset_model(self): def reset_model(self):
self._kvs_model = {} self._kvs_model = {}
self.metadata = None self.metadata = None
......
from ...utils.containers import ClusterConductor
from ...utils.testcase import TestCase
from ...utils.util import Logger
from ..helper import KVSMultiClient, KVSTestFixture
from ...utils.kvs_api import DEFAULT_TIMEOUT
def basic_proxy_one_client(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=4) as fx:
c = KVSMultiClient(fx.clients, "client", log)
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
conductor.add_shard("shard2", conductor.get_nodes([2, 3]))
fx.broadcast_view(conductor.get_shard_view())
# test 1
# put 50 keys (at least one proxy expected here)
# get_all() on one shard
# then ask the other shard for that key (proxy MUST happen here)
node_to_put = 0
base_key = "key"
for i in range(0, 300):
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 4
r = c.get_all(0, timeout=20) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
items = r.json()["items"]
keys = items.keys()
for key in keys:
r = c.get(2, key, timeout=30)
assert r.ok, f"expected ok for get, got {r.status_code}"
assert r.json()["value"] == items[key], f"wrong value returned: {r.json()}"
return True, "ok"
def basic_proxy_many_clients(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=7) as fx:
c = KVSMultiClient(fx.clients, "client", log)
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
conductor.add_shard("shard2", conductor.get_nodes([2, 3]))
conductor.add_shard("shard3", conductor.get_nodes([4, 5, 6]))
fx.broadcast_view(conductor.get_shard_view())
# test 1
# put 50 keys (at least one proxy expected here)
# get_all() on one shard
# then ask the other shard for that key (proxy MUST happen here)
node_to_put = 0
base_key = "key"
for i in range(0, 10000):
c1 = KVSMultiClient(fx.clients, "client", log)
r = c1.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 7
r = c.get_all(0, timeout=20) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
items = r.json()["items"]
keys = items.keys()
for key in keys:
r = c.get(2, key, timeout=30)
assert r.ok, f"expected ok for get, got {r.status_code}"
assert r.json()["value"] == items[key], f"wrong value returned: {r.json()}"
return True, "ok"
def basic_proxy_partitioned_shards(
conductor: ClusterConductor, dir, log: Logger, timeout=5 * DEFAULT_TIMEOUT
):
with KVSTestFixture(conductor, dir, log, node_count=4) as fx:
c = KVSMultiClient(fx.clients, "client", log)
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
conductor.add_shard("shard2", conductor.get_nodes([2, 3]))
fx.broadcast_view(conductor.get_shard_view())
conductor.create_partition([2, 3], "secondshard")
helper(c, timeout=timeout)
return True, "ok"
def helper(c: KVSMultiClient, timeout=5 * DEFAULT_TIMEOUT):
###
# test 2
# partition the shards
# put a bunch of keys
# we MUST probablistically encounter some hanging there.
# have a time out where if it doesnt hang after like 50 keys, then its just wrong.
node_to_put = 0
base_key = "key"
for i in range(0, 50):
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok or r.status_code == 408, (
f"expected ok for new key, got {r.status_code}"
)
node_to_put += 1
node_to_put = node_to_put % 4
PROXY_TESTS = [
TestCase("basic_proxy_one_client", basic_proxy_one_client),
TestCase("basic_proxy_many_clients", basic_proxy_many_clients),
TestCase("basic_proxy_partitioned_shards", basic_proxy_partitioned_shards),
]
from ...utils.containers import ClusterConductor
from ...utils.testcase import TestCase
from ...utils.util import Logger
from ..helper import KVSMultiClient, KVSTestFixture
from ...utils.kvs_api import DEFAULT_TIMEOUT
def basic_shuffle_add_remove(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
conductor.add_shard("shard1", conductor.get_nodes([0]))
conductor.add_shard("shard2", conductor.get_nodes([1]))
fx.broadcast_view(conductor.get_shard_view())
node_to_put = 0
base_key = "key"
# Put 15 keys
for i in range(15):
log(f"Putting key {i}\n")
c.reset_model()
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 2
c.reset_model()
# Get all keys
r = c.get_all(0, timeout=10)
c.reset_model()
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
c.reset_model()
r = c.get_all(1, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
c.reset_model()
log(f"Shard 1 keys: {shard1_keys}\n")
log(f"Shard 2 keys: {shard2_keys}\n")
# Total number of keys should matched number of keys put
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Add a 3rd shard, causing a shuffle. There should still be 15 keys at the end.
log("Adding 3rd shard\n")
conductor.add_shard("shard3", conductor.get_nodes([2]))
fx.broadcast_view(conductor.get_shard_view())
# Get the keys on shard 1
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
c.reset_model()
log(f"Shard 1 keys: {shard1_keys}\n")
# get the keys on shard 2
r = c.get_all(1, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
c.reset_model()
log(f"Shard 2 keys: {shard2_keys}\n")
# get the keys on shard 3
r = c.get_all(2, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard3_keys = res
c.reset_model()
log(f"Shard 3 keys: {shard3_keys}\n")
assert len(shard1_keys) + len(shard2_keys) + len(shard3_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys) + len(shard3_keys)}"
)
# Remove shard 3, causing a shuffle. Move Node 2 to shard 1 so the keys should still exist, and be shuffled
conductor.remove_shard("shard3")
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(0, timeout=10) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
c.reset_model()
r = c.get_all(1, timeout=10) # should get all of shard 2's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Remove shard 2. This loses keys.
conductor.remove_shard("shard2")
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys_after_delete = res
c.reset_model()
assert len(shard1_keys_after_delete) == 15, (
f"expected 15 keys, got {len(shard1_keys_after_delete)}"
)
return True, "ok"
def basic_shuffle_1(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
conductor.add_shard("shard1", conductor.get_nodes([0]))
conductor.add_shard("shard2", conductor.get_nodes([1]))
fx.broadcast_view(conductor.get_shard_view())
node_to_put = 0
base_key = "key"
# Put 15 keys
for i in range(15):
log(f"Putting key {i}\n")
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 2
# Get all keys
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
r = c.get_all(1, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
log(f"Shard 1 keys: {shard1_keys}\n")
log(f"Shard 2 keys: {shard2_keys}\n")
# Total number of keys should matched number of keys put
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Add a 3rd shard, causing a shuffle. There should still be 15 keys at the end.
log("Adding 3rd shard\n")
conductor.add_shard("shard3", conductor.get_nodes([2]))
fx.broadcast_view(conductor.get_shard_view())
# Get the keys on shard 1
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
log(f"Shard 1 keys: {shard1_keys}\n")
# get the keys on shard 2
r = c.get_all(1, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
log(f"Shard 2 keys: {shard2_keys}\n")
# get the keys on shard 3
r = c.get_all(2, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard3_keys = res
log(f"Shard 3 keys: {shard3_keys}\n")
assert len(shard1_keys) + len(shard2_keys) + len(shard3_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys) + len(shard3_keys)}"
)
# Remove shard 3, causing a shuffle. Move Node 2 to shard 1 so the keys should still exist, and be shuffled
conductor.remove_shard("shard3")
conductor.add_node_to_shard("shard1", conductor.get_node(2))
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(0, timeout=10) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
r = c.get_all(1, timeout=10) # should get all of shard 2's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Remove shard 2. This loses keys.
conductor.remove_shard("shard2")
conductor.remove_node_from_shard("shard1", conductor.get_node(2))
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys_after_delete = res
assert len(shard1_keys_after_delete) == 15, (
f"expected 15 keys, got {len(shard1_keys_after_delete)}"
)
return True, "ok"
def basic_shuffle_2(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
c = KVSMultiClient(fx.clients, "client", log)
conductor.add_shard("shard1", conductor.get_nodes([0, 1]))
fx.broadcast_view(conductor.get_shard_view())
## basic shuffle 2
# view= 1 shard with 2 nodes
# put 50 keys
# get_all keys from shard 1
# add shard with 1 node
# get_all keys from shard 2
# get_all keys from shard 1
# check both returned sets are disjoint and that their union makes the original get_all results
node_to_put = 0
base_key = "key"
for i in range(0, 300):
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 2
r = c.get_all(0, timeout=30) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
original_get_all = r.json()["items"]
assert len(original_get_all.keys()) == 300, (
f"original_get_all doesn't have 300 keys, instead has {len(original_get_all.keys())} keys"
)
conductor.add_shard("shard2", conductor.get_nodes([2]))
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(2, timeout=30) # should get all of shard 2's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
get_all_1 = r.json()["items"]
keys1 = set(get_all_1.keys())
r = c.get_all(1, timeout=30) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
get_all_2 = r.json()["items"]
keys2 = set(get_all_2.keys())
for key in keys1:
assert key not in keys2, "key not in keys2"
for key in keys2:
assert key not in keys1, "key not in keys2"
assert original_get_all.keys() == keys1.union(keys2), (
f"get all keys does not equal key1 joined with keys2. diff one way: \n{keys1.union(keys2).difference(original_get_all.keys())}\n diff other way:\n{set(original_get_all.keys()).difference(keys1.union(keys2))}"
)
assert len(original_get_all) == len(keys1) + len(keys2), "lengths do not match"
return True, "ok"
def basic_shuffle_3(conductor: ClusterConductor, dir, log: Logger):
with KVSTestFixture(conductor, dir, log, node_count=3) as fx:
c = KVSMultiClient(fx.clients, "client", log, persist_metadata=False)
conductor.add_shard("shard1", conductor.get_nodes([0]))
conductor.add_shard("shard2", conductor.get_nodes([1]))
fx.broadcast_view(conductor.get_shard_view())
node_to_put = 0
base_key = "key"
# Put 15 keys
for i in range(15):
log(f"Putting key {i}\n")
r = c.put(node_to_put, f"{base_key}{i}", f"{i}", timeout=10)
assert r.ok, f"expected ok for new key, got {r.status_code}"
node_to_put += 1
node_to_put = node_to_put % 2
# Get all keys
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
r = c.get_all(1, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
log(f"Shard 1 keys: {shard1_keys}\n")
log(f"Shard 2 keys: {shard2_keys}\n")
# Total number of keys should matched number of keys put
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Add a 3rd shard, causing a shuffle. There should still be 15 keys at the end.
log("Adding 3rd shard\n")
conductor.add_shard("shard3", conductor.get_nodes([2]))
fx.broadcast_view(conductor.get_shard_view())
# Get the keys on shard 1
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
log(f"Shard 1 keys: {shard1_keys}\n")
# get the keys on shard 2
r = c.get_all(1, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
log(f"Shard 2 keys: {shard2_keys}\n")
# get the keys on shard 3
r = c.get_all(2, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard3_keys = res
log(f"Shard 3 keys: {shard3_keys}\n")
assert len(shard1_keys) + len(shard2_keys) + len(shard3_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys) + len(shard3_keys)}"
)
# Remove shard 3, causing a shuffle. Move Node 2 to shard 1 so the keys should still exist, and be shuffled
conductor.remove_shard("shard3")
conductor.add_node_to_shard("shard1", conductor.get_node(2))
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(0, timeout=10) # should get all of shard 1's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys = res
r = c.get_all(1, timeout=10) # should get all of shard 2's keys
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard2_keys = res
assert len(shard1_keys) + len(shard2_keys) == 15, (
f"expected 15 keys, got {len(shard1_keys) + len(shard2_keys)}"
)
# Remove shard 2. This loses keys.
conductor.remove_shard("shard2")
fx.broadcast_view(conductor.get_shard_view())
r = c.get_all(0, timeout=10)
assert r.ok, f"expected ok for get, got {r.status_code}"
res = r.json()["items"]
shard1_keys_after_delete = res
assert len(shard1_keys_after_delete) == 15, (
f"expected 15 keys, got {len(shard1_keys_after_delete)}"
)
return True, "ok"
SHUFFLE_TESTS = [
TestCase("basic_shuffle_add_remove", basic_shuffle_add_remove),
TestCase("basic_shuffle_1", basic_shuffle_1),
TestCase("basic_shuffle_2", basic_shuffle_2),
TestCase("basic_shuffle_3", basic_shuffle_3),
]
import time
import random
import string
from ...utils.containers import ClusterConductor
from ...utils.testcase import TestCase
from ...utils.util import Logger
from ..helper import KVSMultiClient, KVSTestFixture
from ...utils.kvs_api import DEFAULT_TIMEOUT
def make_random_key(prefix="key", length=8):
chars = string.ascii_lowercase + string.digits
random_part = "".join(random.choice(chars) for _ in range(length))
return f"{prefix}{random_part}"
def shard_key_distribution(conductor: ClusterConductor, dir, log: Logger):
NUM_KEYS = 5000
NODE_COUNT = 8
SHARD_COUNT = 4
with KVSTestFixture(conductor, dir, log, node_count=NODE_COUNT) as fx:
c = KVSMultiClient(fx.clients, "client", log)
for i in range(SHARD_COUNT):
conductor.add_shard(f"shard{i}", conductor.get_nodes([i * 2, i * 2 + 1]))
fx.broadcast_view(conductor.get_shard_view())
log(f"\n> ADDING {NUM_KEYS} RANDOM KEYS")
keys = []
for i in range(NUM_KEYS):
key = make_random_key()
value = key
node = random.randint(0, NODE_COUNT - 1)
c.metadata = None
r = c.put(node, key, value)
assert r.ok, f"expected ok for {key}, got {r.status_code}"
keys.append(key)
if (i + 1) % 50 == 0:
log(f"Added {i + 1} keys")
log("\n> CHECKING KEY DISTRIBUTION")
shard_key_counts = {}
for shard_idx in range(SHARD_COUNT):
node_id = shard_idx * 2
c.metadata = None
r = c.get_all(node_id)
assert r.ok, (
f"expected ok for get_all from shard {shard_idx}, got {r.status_code}"
)
shard_keys = r.json().get("items", {})
shard_key_counts[shard_idx] = len(shard_keys)
log(f"Shard {shard_idx} has {len(shard_keys)} keys")
# randomly sample keys to verify
for key in random.sample(list(shard_keys.keys()), min(10, len(shard_keys))):
c.metadata = None
r = c.get(node_id, key)
assert r.ok, f"expected ok for get {key}, got {r.status_code}"
assert r.json()["value"] == shard_keys[key], (
f"wrong value returned for {key}"
)
if shard_key_counts:
avg_keys = sum(shard_key_counts.values()) / len(shard_key_counts)
min_keys = min(shard_key_counts.values())
max_keys = max(shard_key_counts.values())
deviation = max_keys - min_keys
deviation_percent = (deviation / avg_keys) * 100 if avg_keys > 0 else 0
log(f"Key distribution: min={min_keys}, max={max_keys}, avg={avg_keys:.1f}")
log(f"Max deviation: {deviation} keys ({deviation_percent:.1f}%)")
is_good_distribution = deviation_percent < 40
return (
is_good_distribution,
f"Key distribution test completed. Deviation: {deviation_percent:.1f}%",
)
return False, "Could not collect key distribution data"
def shard_addition_performance(conductor: ClusterConductor, dir, log: Logger):
NUM_KEYS = 5000
NODE_COUNT = 22
INITIAL_SHARDS = 10
with KVSTestFixture(conductor, dir, log, node_count=NODE_COUNT) as fx:
c = KVSMultiClient(fx.clients, "client", log)
for i in range(INITIAL_SHARDS):
conductor.add_shard(f"shard{i}", conductor.get_nodes([i * 2, i * 2 + 1]))
fx.broadcast_view(conductor.get_shard_view())
log(f"\n> ADDING {NUM_KEYS} RANDOM KEYS")
key_values = {}
for i in range(NUM_KEYS):
key = make_random_key()
value = key
# node = random.randint(0, NODE_COUNT-3)
node = random.randint(0, NODE_COUNT - 1)
c.metadata = None
r = c.put(node, key, value)
assert r.ok, f"expected ok for {key}, got {r.status_code}"
key_values[key] = value
if (i + 1) % 50 == 0:
log(f"Added {i + 1} keys")
log("\n> CHECKING INITIAL KEY DISTRIBUTION")
initial_distribution = {}
for i in range(INITIAL_SHARDS):
node_id = i * 2
c.metadata = None
r = c.get_all(node_id)
assert r.ok, f"expected ok for get_all from shard {i}, got {r.status_code}"
shard_keys = r.json().get("items", {})
initial_distribution[i] = set(shard_keys.keys())
log(f"Shard {i} initially has {len(shard_keys)} keys")
log("\n> ADDING NEW SHARD")
conductor.shards = dict(conductor.shards)
conductor.add_shard("newShard", conductor.get_nodes([20, 21]))
fx.broadcast_view(conductor.get_shard_view())
log("Waiting for resharding to complete...")
time.sleep(DEFAULT_TIMEOUT)
log("\n> CHECKING KEY DISTRIBUTION AFTER RESHARDING")
final_distribution = {}
total_keys_after = 0
for i in range(INITIAL_SHARDS):
node_id = i * 2
c.metadata = None
r = c.get_all(node_id)
assert r.ok, f"expected ok for get_all from shard {i}, got {r.status_code}"
shard_keys = r.json().get("items", {})
final_distribution[i] = set(shard_keys.keys())
total_keys_after += len(shard_keys)
log(f"Shard {i} now has {len(shard_keys)} keys")
c.metadata = None
r = c.get_all(20)
assert r.ok, f"expected ok for get_all from new shard, got {r.status_code}"
new_shard_keys = r.json().get("items", {})
final_distribution["new"] = set(new_shard_keys.keys())
total_keys_after += len(new_shard_keys)
log(f"New shard has {len(new_shard_keys)} keys")
keys_moved = 0
for shard_idx, initial_keys in initial_distribution.items():
final_keys = final_distribution[shard_idx]
moved_from_this_shard = len(initial_keys - final_keys)
keys_moved += moved_from_this_shard
log(f"Shard {shard_idx} lost {moved_from_this_shard} keys")
# should move roughly 1/N+1 keys
expected_keys_moved = NUM_KEYS / (INITIAL_SHARDS + 1)
actual_moved_ratio = keys_moved / NUM_KEYS
expected_moved_ratio = 1 / (INITIAL_SHARDS + 1)
log(
f"Expected keys to move: ~{expected_keys_moved:.1f} ({expected_moved_ratio * 100:.1f}%)"
)
log(f"Actual keys moved: {keys_moved} ({actual_moved_ratio * 100:.1f}%)")
is_efficient = actual_moved_ratio <= expected_moved_ratio * 1.5
return (
is_efficient,
f"Shard addition test completed. Keys moved: {keys_moved}/{NUM_KEYS} ({actual_moved_ratio * 100:.1f}%), "
+ f"Efficiency: {'good' if is_efficient else 'poor'}",
)
def shard_removal_performance(conductor: ClusterConductor, dir, log: Logger):
NUM_KEYS = 5000
NODE_COUNT = 22
INITIAL_SHARDS = 11
with KVSTestFixture(conductor, dir, log, node_count=NODE_COUNT) as fx:
c = KVSMultiClient(fx.clients, "client", log)
for i in range(INITIAL_SHARDS):
conductor.add_shard(f"shard{i}", conductor.get_nodes([i * 2, i * 2 + 1]))
fx.broadcast_view(conductor.get_shard_view())
log(f"\n> ADDING {NUM_KEYS} RANDOM KEYS")
key_values = {}
for i in range(NUM_KEYS):
key = make_random_key()
value = key
node = random.randint(0, NODE_COUNT - 1)
c.metadata = None
r = c.put(node, key, value)
assert r.ok, f"expected ok for {key}, got {r.status_code}"
key_values[key] = value
if (i + 1) % 50 == 0:
log(f"Added {i + 1} keys")
log("\n> CHECKING INITIAL KEY DISTRIBUTION")
initial_distribution = {}
total_keys_before = 0
for i in range(INITIAL_SHARDS):
node_id = i * 2
c.metadata = None
r = c.get_all(node_id)
assert r.ok, f"expected ok for get_all from shard {i}, got {r.status_code}"
shard_keys = r.json().get("items", {})
initial_distribution[i] = set(shard_keys.keys())
total_keys_before += len(shard_keys)
log(f"Shard {i} initially has {len(shard_keys)} keys")
shard_to_remove = "shard5"
shard_idx_to_remove = 5
log(f"\n> REMOVING SHARD: {shard_to_remove}")
removed_shard_keys = initial_distribution[shard_idx_to_remove]
log(f"The removed shard had {len(removed_shard_keys)} keys")
nodes_to_reassign = conductor.shards[shard_to_remove]
log(f"Moving node {nodes_to_reassign[0]} to shard0")
conductor.shards = dict(conductor.shards)
conductor.shards["shard0"] = conductor.shards["shard0"] + [nodes_to_reassign[0]]
del conductor.shards[shard_to_remove]
fx.broadcast_view(conductor.get_shard_view())
log("Waiting for resharding to complete...")
time.sleep(DEFAULT_TIMEOUT)
log("\n> CHECKING KEY DISTRIBUTION AFTER RESHARDING")
final_distribution = {}
total_keys_after = 0
for i in range(INITIAL_SHARDS):
if i == shard_idx_to_remove:
continue
node_id = i * 2
c.metadata = None
r = c.get_all(node_id)
assert r.ok, f"expected ok for get_all from shard {i}, got {r.status_code}"
shard_keys = r.json().get("items", {})
final_distribution[i] = set(shard_keys.keys())
total_keys_after += len(shard_keys)
log(f"Shard {i} now has {len(shard_keys)} keys")
keys_redistributed = len(removed_shard_keys)
if total_keys_after < total_keys_before:
log(
f"WARNING: Some keys may have been lost. Before: {total_keys_before}, After: {total_keys_after}"
)
redistributed_keys_per_shard = {}
for i in final_distribution:
if i in initial_distribution:
new_keys = len(final_distribution[i] - initial_distribution[i])
redistributed_keys_per_shard[i] = new_keys
log(f"Shard {i} received {new_keys} new keys")
values = list(redistributed_keys_per_shard.values())
if values:
avg_keys_received = sum(values) / len(values)
max_deviation = max(abs(v - avg_keys_received) for v in values)
deviation_percent = (
(max_deviation / avg_keys_received) * 100
if avg_keys_received > 0
else 0
)
log(
f"Expected keys redistributed per shard: ~{keys_redistributed / (INITIAL_SHARDS - 1):.1f}"
)
log(
f"Actual redistribution: max deviation {deviation_percent:.1f}% from average"
)
is_efficient = deviation_percent < 150
return (
is_efficient,
f"Shard removal test completed. {keys_redistributed} keys redistributed, "
+ f"Efficiency: {'good' if is_efficient else 'poor'}",
)
return (False, "Could not properly analyze key redistribution")
STRESS_TESTS = [
TestCase("shard_key_distribution", shard_key_distribution),
TestCase("shard_addition_performance", shard_addition_performance),
TestCase("shard_removal_performance", shard_removal_performance),
]
...@@ -13,6 +13,7 @@ import requests ...@@ -13,6 +13,7 @@ import requests
from .util import run_cmd_bg, Logger from .util import run_cmd_bg, Logger
CONTAINER_ENGINE = os.getenv("ENGINE", "docker") CONTAINER_ENGINE = os.getenv("ENGINE", "docker")
REBROADCAST_VIEW = os.getenv("REBROADCAST_VIEW", "false")
class ContainerBuilder: class ContainerBuilder:
...@@ -44,6 +45,9 @@ class ClusterNode: ...@@ -44,6 +45,9 @@ class ClusterNode:
) )
networks: List[str] # networks the container is attached to networks: List[str] # networks the container is attached to
def get_view(self) -> str:
return {"address": f"{self.ip}:{self.port}", "id": self.index}
def internal_endpoint(self) -> str: def internal_endpoint(self) -> str:
return f"http://{self.ip}:{self.port}" return f"http://{self.ip}:{self.port}"
...@@ -57,19 +61,21 @@ class ClusterConductor: ...@@ -57,19 +61,21 @@ class ClusterConductor:
def __init__( def __init__(
self, self,
group_id: str, group_id: str,
thread_id: str,
base_image: str, base_image: str,
log: Logger, log: Logger,
external_port_base: int = 8081, external_port_base: int = 8081,
): ):
self.group_id = group_id self.group_id = group_id
self.thread_id = thread_id
self.base_image = base_image self.base_image = base_image
self.base_port = external_port_base self.base_port = external_port_base
self.nodes: List[ClusterNode] = [] self.nodes: List[ClusterNode] = []
self.shards: dict[str, List[ClusterNode]] = {} self.shards: dict[str, List[ClusterNode]] = {}
# naming patterns # naming patterns
self.group_ctr_prefix = f"kvs_{group_id}_node" self.group_ctr_prefix = f"kvs_{group_id}_{thread_id}_node"
self.group_net_prefix = f"kvs_{group_id}_net" self.group_net_prefix = f"kvs_{group_id}_{group_id}_net"
# base network # base network
self.base_net_name = f"{self.group_net_prefix}_base" self.base_net_name = f"{self.group_net_prefix}_base"
...@@ -108,15 +114,17 @@ class ClusterConductor: ...@@ -108,15 +114,17 @@ class ClusterConductor:
def dump_all_container_logs(self, dir): def dump_all_container_logs(self, dir):
self.log("dumping logs of kvs containers") self.log("dumping logs of kvs containers")
container_pattern = f"^kvs_{self.group_id}.*" container_pattern = f"^kvs_{self.group_id}_{self.thread_id}_.*"
container_regex = re.compile(container_pattern) container_regex = re.compile(container_pattern)
containers = self._list_containers() containers = self._list_containers()
for container in containers: for container in containers:
if container and container_regex.match(container): if container and container_regex.match(container):
self._dump_container_logs(dir, container) self._dump_container_logs(dir, container)
def get_view(self) -> str: def get_view(self) -> str:
return {"address": f"{self.ip}:{self.port}", "id": self.index} return {"address": f"{self.ip}:{self.port}", "id": self.index}
def _dump_container_logs(self, dir, name: str) -> None: def _dump_container_logs(self, dir, name: str) -> None:
log_file = os.path.join(dir, f"{name}.log") log_file = os.path.join(dir, f"{name}.log")
self.log(f"Dumping logs for container {name} to file {log_file}") self.log(f"Dumping logs for container {name} to file {log_file}")
...@@ -182,8 +190,8 @@ class ClusterConductor: ...@@ -182,8 +190,8 @@ class ClusterConductor:
# otherwise clean up anything kvs related # otherwise clean up anything kvs related
if group_only: if group_only:
self.log(f"cleaning up group {self.group_id}") self.log(f"cleaning up group {self.group_id}")
container_pattern = f"^kvs_{self.group_id}_.*" container_pattern = f"^kvs_{self.group_id}_{self.thread_id}_.*"
network_pattern = f"^kvs_{self.group_id}_net_.*" network_pattern = f"^kvs_{self.group_id}_{self.thread_id}_net_.*"
else: else:
self.log("cleaning up all kvs containers and networks") self.log("cleaning up all kvs containers and networks")
container_pattern = "^kvs_.*" container_pattern = "^kvs_.*"
...@@ -202,9 +210,6 @@ class ClusterConductor: ...@@ -202,9 +210,6 @@ class ClusterConductor:
if container and container_regex.match(container) if container and container_regex.match(container)
] ]
self._remove_containers(containers_to_remove) self._remove_containers(containers_to_remove)
# for container in containers:
# if container and container_regex.match(container):
# self._remove_container(container)
# cleanup networks # cleanup networks
self.log(f" cleaning up {'group' if group_only else 'all'} networks") self.log(f" cleaning up {'group' if group_only else 'all'} networks")
...@@ -223,7 +228,7 @@ class ClusterConductor: ...@@ -223,7 +228,7 @@ class ClusterConductor:
return False return False
def _node_name(self, index: int) -> str: def _node_name(self, index: int) -> str:
return f"kvs_{self.group_id}_node_{index}" return f"kvs_{self.group_id}_{self.thread_id}_node_{index}"
def node_external_endpoint(self, index: int) -> str: def node_external_endpoint(self, index: int) -> str:
return self.nodes[index].external_endpoint() return self.nodes[index].external_endpoint()
...@@ -257,19 +262,22 @@ class ClusterConductor: ...@@ -257,19 +262,22 @@ class ClusterConductor:
self.log(f" starting container {node_name} (ext_port={external_port})") self.log(f" starting container {node_name} (ext_port={external_port})")
# start container detached from networks # start container detached from networks
run_cmd = [
CONTAINER_ENGINE,
"run",
"-d",
"--name",
node_name,
"--env",
f"NODE_IDENTIFIER={i}",
"-p",
f"{external_port}:{port}",
self.base_image,
]
if CONTAINER_ENGINE == "podman":
run_cmd.insert(2, "--log-driver=k8s-file")
run_cmd_bg( run_cmd_bg(
[ run_cmd,
CONTAINER_ENGINE,
"run",
"-d",
"--name",
node_name,
"--env",
f"NODE_IDENTIFIER={i}",
"-p",
f"{external_port}:{port}",
self.base_image,
],
verbose=True, verbose=True,
error_prefix=f"failed to start container {node_name}", error_prefix=f"failed to start container {node_name}",
log=self.log, log=self.log,
...@@ -368,7 +376,7 @@ class ClusterConductor: ...@@ -368,7 +376,7 @@ class ClusterConductor:
self.log(f" {part_name}: {nodes}") self.log(f" {part_name}: {nodes}")
def my_partition(self, node_ids: List[int], partition_id: str) -> None: def my_partition(self, node_ids: List[int], partition_id: str) -> None:
net_name = f"kvs_{self.group_id}_net_{partition_id}" net_name = f"kvs_{self.group_id}_{self.thread_id}_net_{partition_id}"
self.log(f"creating partition {partition_id} with nodes {node_ids}") self.log(f"creating partition {partition_id} with nodes {node_ids}")
# create partition network if it doesn't exist # create partition network if it doesn't exist
...@@ -430,17 +438,21 @@ class ClusterConductor: ...@@ -430,17 +438,21 @@ class ClusterConductor:
# update node ip # update node ip
if container_ip != node.ip: if container_ip != node.ip:
self.log(
f"Warning: Node {i} IP addr changed from {node.ip} to {container_ip}"
)
node.ip = container_ip node.ip = container_ip
if hasattr(self, "_parent"): if CONTAINER_ENGINE == "podman":
self._parent.clients[ if hasattr(self, "_parent"):
node.index self._parent.clients[
].base_url = self.node_external_endpoint(node.index) node.index
view_changed = True ].base_url = self.node_external_endpoint(node.index)
if view_changed and hasattr(self, "_parent"): view_changed = True
self._parent.broadcast_view(self.get_full_view()) if CONTAINER_ENGINE == "podman" and view_changed and hasattr(self, "_parent"):
self._parent.rebroadcast_view(self.get_shard_view())
def create_partition(self, node_ids: List[int], partition_id: str) -> None: def create_partition(self, node_ids: List[int], partition_id: str) -> None:
net_name = f"kvs_{self.group_id}_net_{partition_id}" net_name = f"kvs_{self.group_id}_{self.thread_id}_net_{partition_id}"
self.log(f"creating partition {partition_id} with nodes {node_ids}") self.log(f"creating partition {partition_id} with nodes {node_ids}")
...@@ -496,16 +508,25 @@ class ClusterConductor: ...@@ -496,16 +508,25 @@ class ClusterConductor:
# update node ip # update node ip
if container_ip != node.ip: if container_ip != node.ip:
self.log(
f"Warning: Node {i} IP addr changed from {node.ip} to {container_ip}"
)
node.ip = container_ip node.ip = container_ip
if hasattr(self, "_parent"): if CONTAINER_ENGINE == "podman" or REBROADCAST_VIEW == "true":
self._parent.clients[ if hasattr(self, "_parent"):
node.index self._parent.clients[
].base_url = self.node_external_endpoint(node.index) node.index
view_changed = True ].base_url = self.node_external_endpoint(node.index)
if view_changed and hasattr(self, "_parent"): view_changed = True
self._parent.broadcast_view(self.get_full_view()) if (
def get_node(self, index): (CONTAINER_ENGINE == "podman" or REBROADCAST_VIEW == "true")
return self.nodes[index] and view_changed
and hasattr(self, "_parent")
):
self._parent.rebroadcast_view(self.get_shard_view())
DeprecationWarning("View is in updated format")
def get_full_view(self): def get_full_view(self):
view = [] view = []
for node in self.nodes: for node in self.nodes:
...@@ -560,7 +581,7 @@ class ClusterConductor: ...@@ -560,7 +581,7 @@ class ClusterConductor:
} }
def get_partition_view(self, partition_id: str): def get_partition_view(self, partition_id: str):
net_name = f"kvs_{self.group_id}_net_{partition_id}" net_name = f"kvs_{self.group_id}_{self.thread_id}_net_{partition_id}"
view = [] view = []
for node in self.nodes: for node in self.nodes:
if net_name in node.networks: if net_name in node.networks:
......
from logging import Logger
import requests import requests
from typing import Dict, Any, List from typing import Dict, Any, List
import aiohttp
""" """
Request Timeout status code. Request Timeout status code.
...@@ -109,7 +111,7 @@ class KVSClient: ...@@ -109,7 +111,7 @@ class KVSClient:
r.status_code = REQUEST_TIMEOUT_STATUS_CODE r.status_code = REQUEST_TIMEOUT_STATUS_CODE
return r return r
else: else:
return requests.get(f"{self.base_url}/data") return requests.get(f"{self.base_url}/data", json=create_json(metadata))
def clear(self, timeout: float = DEFAULT_TIMEOUT) -> None: def clear(self, timeout: float = DEFAULT_TIMEOUT) -> None:
response = self.get_all(timeout=timeout) response = self.get_all(timeout=timeout)
...@@ -129,5 +131,47 @@ class KVSClient: ...@@ -129,5 +131,47 @@ class KVSClient:
if not isinstance(view, dict): if not isinstance(view, dict):
raise ValueError("view must be a dict") raise ValueError("view must be a dict")
self.last_view = view
request_body = {"view": view} request_body = {"view": view}
return requests.put(f"{self.base_url}/view", json=request_body, timeout=timeout) return requests.put(f"{self.base_url}/view", json=request_body, timeout=timeout)
async def async_send_view(
self, view: dict[str, List[Dict[str, Any]]], timeout: float = None
) -> aiohttp.ClientResponse:
if not isinstance(view, dict):
raise ValueError("view must be a dict")
self.last_view = view
request_body = {"view": view}
async with aiohttp.ClientSession() as session:
async with session.put(
f"{self.base_url}/view", json=request_body, timeout=timeout
) as response:
return response
if response.status != 200:
raise RuntimeError(f"failed to send view: {response.status}")
def resend_last_view_with_ips_from_new_view(
self,
current_view: dict[str, List[Dict[str, Any]]],
log: Logger,
timeout: float = DEFAULT_TIMEOUT,
) -> requests.Response:
if not isinstance(current_view, dict):
raise ValueError("view must be a dict")
if not hasattr(self, "last_view"):
return
raise LookupError("Must have sent at least one view before calling resend.")
flattened_current_view = {}
for shard_key in current_view:
for node in current_view[shard_key]:
flattened_current_view[node["id"]] = node["address"]
for shard_key in self.last_view:
for node in self.last_view[shard_key]:
node["address"] = flattened_current_view[node["id"]]
request_body = {"view": self.last_view}
log(f"Sending new view: {self.last_view}")
return requests.put(f"{self.base_url}/view", json=request_body, timeout=timeout)