Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
M
ML_Security_ORAM
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Mahyar Vahabi
ML_Security_ORAM
Commits
2412dae7
Commit
2412dae7
authored
2 months ago
by
Mahyar Vahabi
Browse files
Options
Downloads
Patches
Plain Diff
changes
parent
fc588148
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
model.py
+80
-0
80 additions, 0 deletions
model.py
old_code/test.cpp
+143
-0
143 additions, 0 deletions
old_code/test.cpp
old_code/test.py
+3
-2
3 additions, 2 deletions
old_code/test.py
path_oram.cpp
+113
-0
113 additions, 0 deletions
path_oram.cpp
with
339 additions
and
2 deletions
model.py
0 → 100644
+
80
−
0
View file @
2412dae7
import
subprocess
import
os
import
pickle
import
torch
# For handling model parameters
from
langchain_openai
import
ChatOpenAI
import
random
def
compile_oram
():
"""
Compiles the ORAM C++ code if not already compiled.
"""
if
not
os
.
path
.
exists
(
"
path_oram
"
):
# Check if executable exists
print
(
"
Compiling ORAM C++ code...
"
)
compile_command
=
[
"
g++
"
,
"
path_oram.cpp
"
,
"
-o
"
,
"
path_oram
"
]
subprocess
.
run
(
compile_command
,
check
=
True
)
print
(
"
Compilation complete.
"
)
def
store
(
key
,
tensor
):
"""
Stores a PyTorch tensor securely using ORAM.
"""
tensor_bytes
=
pickle
.
dumps
(
tensor
)
# Serialize tensor
subprocess
.
run
([
"
./path_oram
"
,
"
store
"
,
str
(
key
),
tensor_bytes
.
hex
()],
check
=
True
)
def
retrieve
(
key
):
"""
Retrieves a value securely from ORAM.
"""
result
=
subprocess
.
run
([
"
./path_oram
"
,
"
retrieve
"
,
str
(
key
)],
capture_output
=
True
,
text
=
True
,
check
=
True
)
tensor_bytes
=
bytes
.
fromhex
(
result
.
stdout
.
strip
())
return
pickle
.
loads
(
tensor_bytes
)
'''
def shuffle():
"""
Calls the ORAM executable to shuffle access patterns.
"""
subprocess.run([
"
./path_oram
"
,
"
shuffle
"
], check=True)
'''
def
generate_response
(
query
,
api_key
):
"""
Generate a response using GPT-4o-mini.
"""
chat_model
=
ChatOpenAI
(
model
=
"
gpt-4o-mini
"
,
openai_api_key
=
api_key
)
messages
=
[
{
"
role
"
:
"
system
"
,
"
content
"
:
"
You are a helpful assistant.
"
},
{
"
role
"
:
"
user
"
,
"
content
"
:
query
}
]
response
=
chat_model
.
invoke
(
messages
)
return
response
def
main
():
compile_oram
()
# Ensure ORAM is compiled
parameters
=
{
1
:
torch
.
randn
(
4
,
4
),
# Simulated weight matrix
2
:
torch
.
randn
(
4
),
# Simulated bias vector
3
:
torch
.
randn
(
4
,
4
),
4
:
torch
.
randn
(
4
),
}
# Store and retrieve user queries securely
user_query
=
input
(
"
Enter Prompt here:
"
)
api_key
=
"
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
"
query_key
=
random
.
randint
(
len
(
parameters
),
10000
)
# Assign a random key
print
(
"
Storing initial model parameters...
"
)
store
(
query_key
,
user_query
)
for
key
,
tensor
in
parameters
.
items
():
store
(
key
,
tensor
)
#shuffle() # Shuffle ORAM access patterns
print
(
"
Retrieving stored parameters...
"
)
for
key
in
parameters
.
keys
():
print
(
f
"
Retrieving key
{
key
}
:
"
,
retrieve
(
key
))
retrieved_query
=
retrieve
(
query_key
)
print
(
f
"
Retrieved Query:
{
retrieved_query
}
"
)
response
=
generate_response
(
retrieved_query
,
api_key
)
print
(
"
AI Response:
"
,
response
.
content
)
if
__name__
==
"
__main__
"
:
main
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
test.cpp
→
old_code/
test.cpp
+
143
−
0
View file @
2412dae7
...
...
@@ -17,91 +17,91 @@
class
ORAM
{
private:
unordered_map
<
int
,
string
>
storage
;
// Securely stores parameters & queries
vector
<
int
>
accessHistory
;
// Tracks access patterns
hash
<
string
>
str_hash
;
const
string
filename
=
"oram_data.txt"
;
// Persistent storage file
void
load_from_file
()
{
ifstream
file
(
filename
);
if
(
file
.
is_open
())
{
int
key
;
string
value
;
while
(
file
>>
key
)
{
file
.
ignore
();
getline
(
file
,
value
);
storage
[
key
]
=
value
;
}
file
.
close
();
}
}
unordered_map
<
int
,
string
>
storage
;
// Securely stores parameters & queries
vector
<
int
>
accessHistory
;
// Tracks access patterns
hash
<
string
>
str_hash
;
const
string
filename
=
"oram_data.txt"
;
// Persistent storage file
void
load_from_file
()
{
ifstream
file
(
filename
);
if
(
file
.
is_open
())
{
int
key
;
string
value
;
while
(
file
>>
key
)
{
file
.
ignore
();
getline
(
file
,
value
);
storage
[
key
]
=
value
;
}
file
.
close
();
}
}
void
save_to_file
()
{
ofstream
file
(
filename
,
ios
::
trunc
);
if
(
file
.
is_open
())
{
for
(
const
auto
&
pair
:
storage
)
{
file
<<
pair
.
first
<<
" "
<<
pair
.
second
<<
"
\n
"
;
}
file
<<
"
\n
"
;
file
.
close
();
}
ofstream
file
(
filename
,
ios
::
trunc
);
if
(
file
.
is_open
())
{
for
(
const
auto
&
pair
:
storage
)
{
file
<<
pair
.
first
<<
" "
<<
pair
.
second
<<
"
\n
"
;
}
file
<<
"
\n
"
;
file
.
close
();
}
}
public
:
ORAM
()
{
load_from_file
();
// Load stored data at initialization
}
void
store
(
int
key
,
string
value
)
{
storage
[
key
]
=
value
;
accessHistory
.
push_back
(
key
);
ORAM
()
{
load_from_file
();
// Load stored data at initialization
}
void
store
(
int
key
,
string
value
)
{
storage
[
key
]
=
value
;
accessHistory
.
push_back
(
key
);
size_t
hashed_key
=
str_hash
(
to_string
(
key
));
storage
[
hashed_key
]
=
value
;
accessHistory
.
push_back
(
hashed_key
);
size_t
hashed_key
=
str_hash
(
to_string
(
key
));
storage
[
hashed_key
]
=
value
;
accessHistory
.
push_back
(
hashed_key
);
save_to_file
();
}
save_to_file
();
}
/*
void perform_dummy_reads() {
random_device rd;
mt19937 gen(rd());
uniform_int_distribution<int> dist(1, 100);
for (int i = 0; i < rand() % 5 + 1; i++) {
int dummy_key = dist(gen);
storage.find(str_hash(to_string(dummy_key))); // Fake access
}
void perform_dummy_reads() {
random_device rd;
mt19937 gen(rd());
uniform_int_distribution<int> dist(1, 100);
for (int i = 0; i < rand() % 5 + 1; i++) {
int dummy_key = dist(gen);
storage.find(str_hash(to_string(dummy_key))); // Fake access
}
}
*/
void
log_access
(
int
key
)
{
ofstream
log_file
(
"oram_access_log.txt"
,
ios
::
app
);
if
(
log_file
.
is_open
())
{
log_file
<<
"Accessed Key: "
<<
key
<<
" at "
<<
chrono
::
system_clock
::
to_time_t
(
chrono
::
system_clock
::
now
())
<<
"
\n
"
;
log_file
.
close
();
}
}
void
log_access
(
int
key
)
{
ofstream
log_file
(
"oram_access_log.txt"
,
ios
::
app
);
if
(
log_file
.
is_open
())
{
log_file
<<
"Accessed Key: "
<<
key
<<
" at "
<<
chrono
::
system_clock
::
to_time_t
(
chrono
::
system_clock
::
now
())
<<
"
\n
"
;
log_file
.
close
();
}
}
string
retrieve
(
int
key
)
{
size_t
hashed_key
=
str_hash
(
to_string
(
key
));
//perform_dummy_reads(); // Add noise but its not working i don't think
log_access
(
key
);
if
(
storage
.
find
(
hashed_key
)
!=
storage
.
end
())
{
return
storage
[
hashed_key
];
}
return
"Data not found"
;
}
string
retrieve
(
int
key
)
{
size_t
hashed_key
=
str_hash
(
to_string
(
key
));
//perform_dummy_reads(); // Add noise but its not working i don't think
log_access
(
key
);
if
(
storage
.
find
(
hashed_key
)
!=
storage
.
end
())
{
return
storage
[
hashed_key
];
}
return
"Data not found"
;
}
void
shuffle
()
{
random_device
rd
;
mt19937
g
(
rd
());
std
::
shuffle
(
accessHistory
.
begin
(),
accessHistory
.
end
(),
g
);
}
void
shuffle
()
{
random_device
rd
;
mt19937
g
(
rd
());
std
::
shuffle
(
accessHistory
.
begin
(),
accessHistory
.
end
(),
g
);
}
void
debug_display
()
{
void
debug_display
()
{
ofstream
shuffle_file
(
"shuffling.txt"
,
ios
::
trunc
);
if
(
shuffle_file
.
is_open
())
{
shuffle_file
<<
"Stored Parameters:
\n
"
;
...
...
This diff is collapsed.
Click to expand it.
test.py
→
old_code/
test.py
+
3
−
2
View file @
2412dae7
...
...
@@ -25,10 +25,11 @@ def retrieve(key):
tensor_bytes
=
bytes
.
fromhex
(
result
.
stdout
.
strip
())
return
pickle
.
loads
(
tensor_bytes
)
'''
def shuffle():
"""
Calls the ORAM executable to shuffle access patterns.
"""
subprocess.run([
"
./test
"
,
"
shuffle
"
], check=True)
'''
def
generate_response
(
query
,
api_key
):
"""
Generate a response using GPT-4o-mini.
"""
chat_model
=
ChatOpenAI
(
model
=
"
gpt-4o-mini
"
,
openai_api_key
=
api_key
)
...
...
@@ -62,7 +63,7 @@ def main():
for
key
,
tensor
in
parameters
.
items
():
store
(
key
,
tensor
)
shuffle
()
# Shuffle ORAM access patterns
#
shuffle() # Shuffle ORAM access patterns
print
(
"
Retrieving stored parameters...
"
)
for
key
in
parameters
.
keys
():
...
...
This diff is collapsed.
Click to expand it.
path_oram.cpp
0 → 100644
+
113
−
0
View file @
2412dae7
/*
* Path ORAM Secure Storage for Model Parameters & Queries
* Implements a tree-based ORAM with path eviction.
*/
#include
<iostream>
#include
<vector>
#include
<unordered_map>
#include
<random>
#include
<fstream>
#include
<cmath>
using
namespace
std
;
const
int
BUCKET_SIZE
=
2
;
// Max blocks per bucket
const
int
TREE_HEIGHT
=
4
;
// Determines ORAM tree depth
const
int
NUM_BUCKETS
=
(
1
<<
(
TREE_HEIGHT
+
1
))
-
1
;
// Total nodes in binary tree
struct
Block
{
int
key
;
string
value
;
bool
valid
;
};
struct
Bucket
{
Block
blocks
[
BUCKET_SIZE
];
};
class
PathORAM
{
private:
vector
<
Bucket
>
tree
;
unordered_map
<
int
,
int
>
position_map
;
// Maps key to leaf position
vector
<
Block
>
stash
;
random_device
rd
;
mt19937
gen
;
int
get_random_leaf
()
{
return
uniform_int_distribution
<
int
>
(
1
,
(
1
<<
TREE_HEIGHT
)
-
1
)(
gen
);
}
void
read_path
(
int
leaf
,
vector
<
Block
>&
retrieved_blocks
)
{
for
(
int
i
=
0
;
i
<=
TREE_HEIGHT
;
i
++
)
{
int
node
=
leaf
>>
i
;
for
(
auto
&
block
:
tree
[
node
].
blocks
)
{
if
(
block
.
valid
)
{
retrieved_blocks
.
push_back
(
block
);
block
.
valid
=
false
;
// Remove from ORAM
}
}
}
}
void
write_path
(
int
leaf
)
{
for
(
int
i
=
0
;
i
<=
TREE_HEIGHT
;
i
++
)
{
int
node
=
leaf
>>
i
;
for
(
auto
&
block
:
tree
[
node
].
blocks
)
{
if
(
!
stash
.
empty
())
{
block
=
stash
.
back
();
stash
.
pop_back
();
}
else
{
block
.
valid
=
false
;
}
}
}
}
public
:
PathORAM
()
:
gen
(
rd
())
{
tree
.
resize
(
NUM_BUCKETS
);
}
void
store
(
int
key
,
string
value
)
{
int
leaf
=
get_random_leaf
();
position_map
[
key
]
=
leaf
;
stash
.
push_back
({
key
,
value
,
true
});
write_path
(
leaf
);
}
string
retrieve
(
int
key
)
{
if
(
position_map
.
find
(
key
)
==
position_map
.
end
())
{
return
"Data not found"
;
}
int
leaf
=
position_map
[
key
];
vector
<
Block
>
retrieved_blocks
;
read_path
(
leaf
,
retrieved_blocks
);
for
(
auto
&
block
:
retrieved_blocks
)
{
if
(
block
.
key
==
key
)
{
return
block
.
value
;
}
}
return
"Data not found"
;
}
};
int
main
(
int
argc
,
char
*
argv
[])
{
PathORAM
oram
;
if
(
argc
<
2
)
{
cout
<<
"Usage: ./path_oram [store/retrieve] [key] [value]"
<<
endl
;
return
1
;
}
string
command
=
argv
[
1
];
if
(
command
==
"store"
&&
argc
==
4
)
{
int
key
=
stoi
(
argv
[
2
]);
oram
.
store
(
key
,
argv
[
3
]);
}
else
if
(
command
==
"retrieve"
&&
argc
==
3
)
{
int
key
=
stoi
(
argv
[
2
]);
cout
<<
oram
.
retrieve
(
key
)
<<
endl
;
}
else
{
cout
<<
"Invalid command."
<<
endl
;
return
1
;
}
}
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment