Skip to content
Snippets Groups Projects
Commit 57e400f0 authored by Mahyar Vahabi's avatar Mahyar Vahabi
Browse files

made a lot of changes

parent 6ef4dbf2
No related branches found
No related tags found
No related merge requests found
#
# REMEMBER TO INSTALL THESE
# pip install datasets
# pip install -U langchain-community
# pip install -U langchain-openai
# pip install transformers chromadb faiss-cpu torch
#
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from langchain_community.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings
from langchain_community.document_loaders import TextLoader
import torch
def load_model(vectorstore):
model_name = "facebook/rag-token-base"
tokenizer = RagTokenizer.from_pretrained(model_name)
# Initialize retriever with real embeddings
retriever = RagRetriever.from_pretrained(model_name)
retriever.set_ctx_encoder_model(vectorstore) # Set the FAISS vector store
model = RagSequenceForGeneration.from_pretrained(model_name)
return tokenizer, model, retriever # Return retriever too
def load_database():
global vectorstore # Make vectorstore accessible globally
documents = [
{"text": "Oblivious RAM (ORAM) prevents adversaries from learning access patterns."},
{"text": "RAG models improve LLMs by retrieving external information before generation."},
{"text": "Model inversion attacks can reconstruct training data from model outputs."}
]
Mahyars_key = "HEHE"
# Create embeddings and store them in FAISS
vectorstore = FAISS.from_texts(
texts=[doc["text"] for doc in documents],
embedding = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key= Mahyars_key)
)
def rag(vectorstore, tokenizer, model, retriever):
query = "How does ORAM help secure RAG models?"
# Retrieve documents using retriever
retrieved_docs = retriever.retrieve(query)
context = " ".join([doc.page_content for doc in retrieved_docs])
inputs = tokenizer(f"question: {query} context: {context}", return_tensors="pt")
with torch.no_grad():
output_ids = model.generate(**inputs)
response = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
print("RAG Model Response:", response)
def main():
load_database() # Ensure vectorstore is loaded
tokenizer, model, retriever = load_model(vectorstore) # Pass vectorstore
rag(vectorstore, tokenizer, model, retriever) # Pass retriever to rag()
if __name__ == "__main__":
main()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment