Skip to content
Snippets Groups Projects
Commit 2412dae7 authored by Mahyar Vahabi's avatar Mahyar Vahabi
Browse files

changes

parent fc588148
No related branches found
No related tags found
No related merge requests found
model.py 0 → 100644
import subprocess
import os
import pickle
import torch # For handling model parameters
from langchain_openai import ChatOpenAI
import random
def compile_oram():
"""Compiles the ORAM C++ code if not already compiled."""
if not os.path.exists("path_oram"): # Check if executable exists
print("Compiling ORAM C++ code...")
compile_command = ["g++", "path_oram.cpp", "-o", "path_oram"]
subprocess.run(compile_command, check=True)
print("Compilation complete.")
def store(key, tensor):
"""Stores a PyTorch tensor securely using ORAM."""
tensor_bytes = pickle.dumps(tensor) # Serialize tensor
subprocess.run(["./path_oram", "store", str(key), tensor_bytes.hex()], check=True)
def retrieve(key):
"""Retrieves a value securely from ORAM."""
result = subprocess.run(["./path_oram", "retrieve", str(key)], capture_output=True, text=True, check=True)
tensor_bytes = bytes.fromhex(result.stdout.strip())
return pickle.loads(tensor_bytes)
'''
def shuffle():
"""Calls the ORAM executable to shuffle access patterns."""
subprocess.run(["./path_oram", "shuffle"], check=True)
'''
def generate_response(query, api_key):
"""Generate a response using GPT-4o-mini."""
chat_model = ChatOpenAI(model="gpt-4o-mini", openai_api_key=api_key)
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query}
]
response = chat_model.invoke(messages)
return response
def main():
compile_oram() # Ensure ORAM is compiled
parameters = {
1: torch.randn(4, 4), # Simulated weight matrix
2: torch.randn(4), # Simulated bias vector
3: torch.randn(4, 4),
4: torch.randn(4),
}
# Store and retrieve user queries securely
user_query = input("Enter Prompt here: ")
api_key = "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"
query_key = random.randint(len(parameters), 10000) # Assign a random key
print("Storing initial model parameters...")
store(query_key, user_query)
for key, tensor in parameters.items():
store(key, tensor)
#shuffle() # Shuffle ORAM access patterns
print("Retrieving stored parameters...")
for key in parameters.keys():
print(f"Retrieving key {key}:", retrieve(key))
retrieved_query = retrieve(query_key)
print(f"Retrieved Query: {retrieved_query}")
response = generate_response(retrieved_query, api_key)
print("AI Response:", response.content)
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -17,91 +17,91 @@
class ORAM {
private:
unordered_map<int, string> storage; // Securely stores parameters & queries
vector<int> accessHistory; // Tracks access patterns
hash<string> str_hash;
const string filename = "oram_data.txt"; // Persistent storage file
void load_from_file() {
ifstream file(filename);
if (file.is_open()) {
int key;
string value;
while (file >> key) {
file.ignore();
getline(file, value);
storage[key] = value;
}
file.close();
}
}
unordered_map<int, string> storage; // Securely stores parameters & queries
vector<int> accessHistory; // Tracks access patterns
hash<string> str_hash;
const string filename = "oram_data.txt"; // Persistent storage file
void load_from_file() {
ifstream file(filename);
if (file.is_open()) {
int key;
string value;
while (file >> key) {
file.ignore();
getline(file, value);
storage[key] = value;
}
file.close();
}
}
void save_to_file() {
ofstream file(filename, ios::trunc);
if (file.is_open()) {
for (const auto& pair : storage) {
file << pair.first << " " << pair.second << "\n";
}
file << "\n";
file.close();
}
ofstream file(filename, ios::trunc);
if (file.is_open()) {
for (const auto& pair : storage) {
file << pair.first << " " << pair.second << "\n";
}
file << "\n";
file.close();
}
}
public:
ORAM() {
load_from_file(); // Load stored data at initialization
}
void store(int key, string value) {
storage[key] = value;
accessHistory.push_back(key);
ORAM() {
load_from_file(); // Load stored data at initialization
}
void store(int key, string value) {
storage[key] = value;
accessHistory.push_back(key);
size_t hashed_key = str_hash(to_string(key));
storage[hashed_key] = value;
accessHistory.push_back(hashed_key);
size_t hashed_key = str_hash(to_string(key));
storage[hashed_key] = value;
accessHistory.push_back(hashed_key);
save_to_file();
}
save_to_file();
}
/*
void perform_dummy_reads() {
random_device rd;
mt19937 gen(rd());
uniform_int_distribution<int> dist(1, 100);
for (int i = 0; i < rand() % 5 + 1; i++) {
int dummy_key = dist(gen);
storage.find(str_hash(to_string(dummy_key))); // Fake access
}
void perform_dummy_reads() {
random_device rd;
mt19937 gen(rd());
uniform_int_distribution<int> dist(1, 100);
for (int i = 0; i < rand() % 5 + 1; i++) {
int dummy_key = dist(gen);
storage.find(str_hash(to_string(dummy_key))); // Fake access
}
}
*/
void log_access(int key) {
ofstream log_file("oram_access_log.txt", ios::app);
if (log_file.is_open()) {
log_file << "Accessed Key: " << key << " at "
<< chrono::system_clock::to_time_t(chrono::system_clock::now()) << "\n";
log_file.close();
}
}
void log_access(int key) {
ofstream log_file("oram_access_log.txt", ios::app);
if (log_file.is_open()) {
log_file << "Accessed Key: " << key << " at "
<< chrono::system_clock::to_time_t(chrono::system_clock::now()) << "\n";
log_file.close();
}
}
string retrieve(int key) {
size_t hashed_key = str_hash(to_string(key));
//perform_dummy_reads(); // Add noise but its not working i don't think
log_access(key);
if (storage.find(hashed_key) != storage.end()) {
return storage[hashed_key];
}
return "Data not found";
}
string retrieve(int key) {
size_t hashed_key = str_hash(to_string(key));
//perform_dummy_reads(); // Add noise but its not working i don't think
log_access(key);
if (storage.find(hashed_key) != storage.end()) {
return storage[hashed_key];
}
return "Data not found";
}
void shuffle() {
random_device rd;
mt19937 g(rd());
std::shuffle(accessHistory.begin(), accessHistory.end(), g);
}
void shuffle() {
random_device rd;
mt19937 g(rd());
std::shuffle(accessHistory.begin(), accessHistory.end(), g);
}
void debug_display() {
void debug_display() {
ofstream shuffle_file("shuffling.txt", ios::trunc);
if (shuffle_file.is_open()) {
shuffle_file << "Stored Parameters:\n";
......
......@@ -25,10 +25,11 @@ def retrieve(key):
tensor_bytes = bytes.fromhex(result.stdout.strip())
return pickle.loads(tensor_bytes)
'''
def shuffle():
"""Calls the ORAM executable to shuffle access patterns."""
subprocess.run(["./test", "shuffle"], check=True)
'''
def generate_response(query, api_key):
"""Generate a response using GPT-4o-mini."""
chat_model = ChatOpenAI(model="gpt-4o-mini", openai_api_key=api_key)
......@@ -62,7 +63,7 @@ def main():
for key, tensor in parameters.items():
store(key, tensor)
shuffle() # Shuffle ORAM access patterns
#shuffle() # Shuffle ORAM access patterns
print("Retrieving stored parameters...")
for key in parameters.keys():
......
/*
* Path ORAM Secure Storage for Model Parameters & Queries
* Implements a tree-based ORAM with path eviction.
*/
#include <iostream>
#include <vector>
#include <unordered_map>
#include <random>
#include <fstream>
#include <cmath>
using namespace std;
const int BUCKET_SIZE = 2; // Max blocks per bucket
const int TREE_HEIGHT = 4; // Determines ORAM tree depth
const int NUM_BUCKETS = (1 << (TREE_HEIGHT + 1)) - 1; // Total nodes in binary tree
struct Block {
int key;
string value;
bool valid;
};
struct Bucket {
Block blocks[BUCKET_SIZE];
};
class PathORAM {
private:
vector<Bucket> tree;
unordered_map<int, int> position_map; // Maps key to leaf position
vector<Block> stash;
random_device rd;
mt19937 gen;
int get_random_leaf() {
return uniform_int_distribution<int>(1, (1 << TREE_HEIGHT) - 1)(gen);
}
void read_path(int leaf, vector<Block>& retrieved_blocks) {
for (int i = 0; i <= TREE_HEIGHT; i++) {
int node = leaf >> i;
for (auto& block : tree[node].blocks) {
if (block.valid) {
retrieved_blocks.push_back(block);
block.valid = false; // Remove from ORAM
}
}
}
}
void write_path(int leaf) {
for (int i = 0; i <= TREE_HEIGHT; i++) {
int node = leaf >> i;
for (auto& block : tree[node].blocks) {
if (!stash.empty()) {
block = stash.back();
stash.pop_back();
} else {
block.valid = false;
}
}
}
}
public:
PathORAM() : gen(rd()) {
tree.resize(NUM_BUCKETS);
}
void store(int key, string value) {
int leaf = get_random_leaf();
position_map[key] = leaf;
stash.push_back({key, value, true});
write_path(leaf);
}
string retrieve(int key) {
if (position_map.find(key) == position_map.end()) {
return "Data not found";
}
int leaf = position_map[key];
vector<Block> retrieved_blocks;
read_path(leaf, retrieved_blocks);
for (auto& block : retrieved_blocks) {
if (block.key == key) {
return block.value;
}
}
return "Data not found";
}
};
int main(int argc, char* argv[]) {
PathORAM oram;
if (argc < 2) {
cout << "Usage: ./path_oram [store/retrieve] [key] [value]" << endl;
return 1;
}
string command = argv[1];
if (command == "store" && argc == 4) {
int key = stoi(argv[2]);
oram.store(key, argv[3]);
} else if (command == "retrieve" && argc == 3) {
int key = stoi(argv[2]);
cout << oram.retrieve(key) << endl;
} else {
cout << "Invalid command." << endl;
return 1;
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment