Skip to content
Snippets Groups Projects
Commit 2a95b16d authored by Mahyar Vahabi's avatar Mahyar Vahabi
Browse files

made changes to how the cpp main should work with the python wrapper

parent 9e079760
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,7 @@
*
* Author: Lily Faris
*/
/*
#include <iostream>
#include <unordered_map>
#include <vector>
......@@ -58,6 +58,7 @@
};
// Test cases for ORAM functionality
int main() {
ORAM oram;
......@@ -80,4 +81,116 @@
return 0;
}
\ No newline at end of file
*/
#include <iostream>
#include <unordered_map>
#include <vector>
#include <random>
#include <fstream>
#include <algorithm>
#include <chrono>
using namespace std;
class ORAM {
private:
unordered_map<int, string> storage;
vector<int> accessHistory;
hash<string> str_hash;
const string filename = "oram_data.txt"; // Persistent storage file
void load_from_file() {
ifstream file(filename);
if (file.is_open()) {
int key;
string value;
while (file >> key) {
file.ignore();
getline(file, value);
storage[key] = value;
}
file.close();
}
}
void save_to_file() {
ofstream file(filename, ios::trunc);
if (file.is_open()) {
for (const auto& pair : storage) {
file << pair.first << " " << pair.second << "\n";
}
file.close();
}
}
public:
ORAM() {
load_from_file(); // Load stored parameters at initialization
}
void store(int key, string value) {
size_t hashed_key = str_hash(to_string(key));
storage[hashed_key] = value;
accessHistory.push_back(hashed_key);
save_to_file(); // Save changes
}
string retrieve(int key) {
size_t hashed_key = str_hash(to_string(key));
if (storage.find(hashed_key) != storage.end()) {
accessHistory.push_back(hashed_key);
return storage[hashed_key];
}
return "Parameter not found";
}
// Simulate ORAM shuffle to obfuscate access patterns
void shuffle() {
random_device rd;
mt19937 g(rd());
std::shuffle(accessHistory.begin(), accessHistory.end(), g);
}
// Display stored parameters (for testing only, not secure in real use)
void debug_display() {
cout << "Stored Parameters (hashed keys):\n";
for (const auto& pair : storage) {
cout << "Key: " << pair.first << " | Value: " << pair.second << endl;
}
}
};
int main(int argc, char* argv[]) {
ORAM oram;
if (argc < 2) {
cout << "Usage: ./oram [store/retrieve/shuffle] [key] [value]" << endl;
return 1;
}
string command = argv[1];
if (command == "store" && argc == 4) {
int key = stoi(argv[2]);
oram.store(key, argv[3]);
cout << "Stored successfully!" << endl;
}
else if (command == "retrieve" && argc == 3) {
int key = stoi(argv[2]);
cout << oram.retrieve(key) << endl;
}
else if (command == "shuffle") {
cout << "Shuffling access patterns..." << endl;
oram.shuffle();
cout << "Displaying stored data (hashed keys)..." << endl;
oram.debug_display();
}
else {
cout << "Invalid command." << endl;
return 1;
}
return 0;
}
......@@ -13,12 +13,20 @@ def compile_oram():
def store_parameter(key, value):
"""Stores a model parameter securely using ORAM."""
subprocess.run(["./oram", "store", str(key), value], check=True)
'''
def retrieve_parameter(key):
"""Retrieves a model parameter securely using ORAM."""
result = subprocess.run(["./oram", "retrieve", str(key)], capture_output=True, text=True, check=True)
return result.stdout.strip()
'''
def retrieve_parameter(key):
result = subprocess.run(["./oram", "retrieve", str(key)], capture_output=True, text=True)
return result.stdout.strip() if result.returncode == 0 else "Error retrieving parameter"
def perform_shuffle():
"""Calls the ORAM executable to shuffle access patterns."""
subprocess.run(["./oram", "shuffle"], check=True)
def generate_response(query, api_key):
"""Generate a response using GPT-4o-mini without retrieval."""
chat_model = ChatOpenAI(model="gpt-4o-mini", openai_api_key=api_key)
......@@ -40,12 +48,16 @@ def main():
store_parameter(1, "weight_1: 0.345")
store_parameter(2, "weight_2: 0.678")
store_parameter(3, "bias_1: -0.123")
store_parameter(4, "weight_3: 0.597")
print("Retrieving stored parameters...")
print("Retrieved:", retrieve_parameter(1))
print("Retrieved:", retrieve_parameter(2))
print("Retrieved:", retrieve_parameter(3))
print("Retrieved:", retrieve_parameter(4)) # This should return "Parameter not found"
print("Retrieved:", retrieve_parameter(4))
print("Retrieved:", retrieve_parameter(5)) # This should return "Parameter not found"
perform_shuffle() # Perform Oram Shuffle to randomize the access pattern
query = "How does ORAM help secure models?"
response = generate_response(query, api_key)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment