Skip to content
Snippets Groups Projects
Commit 3047aa07 authored by Mahyar Vahabi's avatar Mahyar Vahabi
Browse files

added LAST STUFF

parent 8c927e5c
No related branches found
No related tags found
No related merge requests found
Pipeline #617533 failed
......@@ -6,11 +6,15 @@ from langchain_openai import ChatOpenAI
import time
import random
import hashlib # For hashing keys
from cryptography.fernet import Fernet
# Generate a key for encryption (should be securely stored in production)
ENCRYPTION_KEY = Fernet.generate_key()
cipher = Fernet(ENCRYPTION_KEY)
def compile_oram():
"""Compiles the ORAM C++ code if not already compiled."""
print("Compiling ORAM C++ code...")
subprocess.run(["make", "rebuild"], check=True)
print("Compilation complete.")
......@@ -22,15 +26,26 @@ def hash_key(key):
return hashlib.sha256(key_str.encode()).hexdigest()
def encrypt_data(data):
"""Encrypts serialized data using Fernet symmetric encryption."""
return cipher.encrypt(data)
def decrypt_data(data):
"""Decrypts data using Fernet symmetric encryption."""
return cipher.decrypt(data)
def store(key, tensor):
"""Stores a PyTorch tensor securely using ORAM with a hashed key."""
"""Stores a PyTorch tensor securely using ORAM with a hashed key and encryption."""
hashed_key = hash_key(key) # Hash the key
tensor_bytes = pickle.dumps(tensor) # Serialize tensor
subprocess.run(["./oram", "store", hashed_key, tensor_bytes.hex()], check=True)
encrypted_data = encrypt_data(tensor_bytes) # Encrypt data
subprocess.run(["./oram", "store", hashed_key, encrypted_data.hex()], check=True)
def retrieve(key):
"""Retrieves a value securely from ORAM using a hashed key."""
"""Retrieves a value securely from ORAM using a hashed key and decryption."""
hashed_key = hash_key(key) # Hash the key
result = subprocess.run(["./oram", "retrieve", hashed_key], capture_output=True, text=True, check=True)
raw_output = result.stdout.strip()
......@@ -41,8 +56,9 @@ def retrieve(key):
raise ValueError(f"Key {key} not found in ORAM.")
try:
tensor_bytes = bytes.fromhex(raw_output)
return pickle.loads(tensor_bytes)
encrypted_bytes = bytes.fromhex(raw_output)
decrypted_bytes = decrypt_data(encrypted_bytes) # Decrypt the retrieved data
return pickle.loads(decrypted_bytes) # Deserialize back to tensor
except ValueError:
raise ValueError(f"Received invalid hex data: {raw_output}")
......@@ -75,14 +91,14 @@ def main():
api_key = "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"
query_key = random.randint(5, 10) # Assign a random key
print("Storing initial model parameters...")
store(query_key, user_query)
print("Storing initial model parameters securely...")
store(query_key, user_query) # Securely store user query
for key, tensor in parameters.items():
store(key, tensor)
time.sleep(10)
# time.sleep(10) # Simulate delay for ORAM processing
print("Retrieving stored parameters...")
print("Retrieving stored parameters securely...")
for key in parameters.keys():
print(f"Retrieving key {key}:", retrieve(key))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment