Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
T
tutorials
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Deploy
Releases
Model registry
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Eclipse Projects
Eclipse Graphene
tutorials
Commits
4b165da7
Commit
4b165da7
authored
11 months ago
by
Sangamithra Panneer Selvam
Browse files
Options
Downloads
Patches
Plain Diff
Conversation memory in rag chain
parent
3a57eb09
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
llm_docker_generator/templates/rag_chain.py.j2
+61
-16
61 additions, 16 deletions
llm_docker_generator/templates/rag_chain.py.j2
with
61 additions
and
16 deletions
llm_docker_generator/templates/rag_chain.py.j2
+
61
−
16
View file @
4b165da7
...
...
@@ -3,12 +3,19 @@ from langchain_core.output_parsers import StrOutputParser
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import create_history_aware_retriever, create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.history import RunnableWithMessageHistory
import os
# FAISS folder path
local_directory = os.getenv("SHARED_FOLDER_PATH")
faiss_folder = os.path.join(local_directory, "faiss_index")
# History management
store = {}
def faiss_db():
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
...
...
@@ -35,36 +42,74 @@ def get_context(retriever, question):
context = "\n".join([doc.page_content for doc in docs])
return context
def get_session_history(session_id: str) -> BaseChatMessageHistory:
global store
if session_id not in store:
store[session_id] = ChatMessageHistory()
print(f"Store : {store}")
return store[session_id]
def rag_chain(retriever, system_instruction, llm_service, question):
prompt_template_generic = """
### Contextualize question ###
contextualize_q_system_prompt = """Given a chat history and the latest user question \
which might reference context in the chat history, formulate a standalone question \
which can be understood without the chat history. Do NOT answer the question, \
just reformulate it if needed and otherwise return it as is."""
contextualize_q_prompt = ChatPromptTemplate.from_messages(
[
("system", contextualize_q_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
# History aware retreiever gives the reformualted question and retreieved context
history_aware_retriever = create_history_aware_retriever(
llm_service, retriever, contextualize_q_prompt
)
qa_system_prompt = """
<|start_header_id|>user<|end_header_id|>
You are an assistant for answering questions about the European AI Act.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.
Question: {question}
Context: {context}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
prompt = PromptTemplate(
input_variables=["context", "question"],
template=prompt_template_generic,
qa_prompt = ChatPromptTemplate.from_messages(
[
("system", qa_system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
]
)
question_answer_chain = create_stuff_documents_chain(llm_service, qa_prompt)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs
)
# create retrieval chain passes this reformulated question, retreieved context into the qa_prompt. cannot be seen explicitly
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain
)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm_service
| StrOutputParser()
conversational_rag_chain = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
# Invoke the chain with the provided question
rag_ans = rag_chain.invoke(question)
rag_ans = conversational_rag_chain.invoke(
{"input": question},
config={"configurable": {"session_id": "abc123"}},
)["answer"]
print(f"rag answer: {rag_ans}")
# Get relevant context
relevant_context= get_context(retriever, question)
return rag_ans, relevant_context
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment