Skip to content
Snippets Groups Projects
Commit 3a57eb09 authored by Sangamithra Panneer Selvam's avatar Sangamithra Panneer Selvam
Browse files

Rag-node folder name change

parent 060f7928
No related branches found
No related tags found
No related merge requests found
Showing
with 1503 additions and 0 deletions
cicd.ai4eu-dev.eu:7444/tutorials/llms/test-ragpipeline
\ No newline at end of file
FROM python:3.10-slim
RUN apt-get update -y
RUN apt-get install -y python3-pip python3-dev
RUN pip3 install --upgrade pip
COPY requirements.txt .
RUN pip3 install -r requirements.txt
ARG TIMEZONE=Europe/Berlin
RUN ln -sf /usr/share/zoneinfo/$TIMEZONE /etc/localtime && \
echo "$TIMEZONE" > /etc/timezone
RUN mkdir /ragp
COPY . /ragp
WORKDIR /ragp
RUN python3 -m grpc_tools.protoc --python_out=. --proto_path=. --grpc_python_out=. rag.proto
ENTRYPOINT python3 -u server.py
\ No newline at end of file
### Current pipeline architecture:
The refactored RAG pipeline (single-node) is designed to handle the connection to the unified LLM interface.
![alt text](image.png)
Key features
- Bi-directional streaming is enabled
- Loading and ingestion - dual possibility (PDF Upload and FAISS embeddings)
- Save workspace for future use
This pipeline is currently in developmental stages and will undergo frequent changes. Contributions, suggestions, and feedback are welcome to help improve this project.
Please refer to the following ticket to better understand the pipeline structure, eclipse/graphene/tutorials#45.
from flask import Flask, request, jsonify, render_template, send_file
import os
from rag_utils import process_pdf, start_df, update_df
import pandas as pd
from sklearn.cluster import KMeans
import seaborn as sns
import matplotlib.pyplot as plt
from io import BytesIO
from visualization import plot_embeddings
from langchain.embeddings import HuggingFaceEmbeddings
import json
from datetime import datetime
import logging
app = Flask(__name__)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
user_question = []
user_answer = []
readme_ratings = {}
ratings_list = []
faithfulness = None
relevancy= None
local_directory = os.getenv("SHARED_FOLDER_PATH")
def save_uploaded_file(file):
file_path = os.path.join("/tmp", file.filename)
with open(file_path, "wb") as f:
f.write(file.read())
return file_path
@app.route("/")
def index():
global faithfulness
global relevancy
with open("metrics.txt", mode="r") as f:
content = ""
for line in f.readlines():
content += line.strip('') # Concatenate lines with hyphen
# Split the content into two items based on '|'
metrics_list = content.split('|')
metrics_list = [item.strip('') for item in metrics_list] # Remove trailing hyphen
try:
faithfulness=metrics_list[-2]
relevancy=metrics_list[-1].rstrip('\n')
except:
faithfulness=metrics_list[:-1]
relevancy=metrics_list[-1].rstrip('\n')
return render_template("index.html", faithfulness= faithfulness , relevancy = relevancy)
@app.route("/handle_request", methods=["POST"])
def handle_request():
action = request.args.get("action")
# Upload the PDF
if action == "upload":
file = request.files.get("file")
if not file:
return jsonify({"response": "No file provided"}), 400
chunk_size = int(request.form.get("chunk_size", 500))
chunk_overlap = int(request.form.get("chunk_overlap", 50))
file_path = save_uploaded_file(file)
return jsonify(
{
"response": f"File uploaded successfully as {file_path}. Ready to process."
}
)
# Other PDF parameters such as chunk size and chunk overlap and create embeddings
elif action == "process":
file_name = request.form.get("file_name")
if not file_name:
return jsonify({"response": "No file specified for processing"}), 400
chunk_size = int(request.form.get("chunk_size", 500))
chunk_overlap = int(request.form.get("chunk_overlap", 50))
file_path = os.path.join("/tmp", file_name)
if not os.path.exists(file_path):
return jsonify({"response": f"File {file_name} does not exist"}), 400
response = process_pdf(
file_path, chunk_size, chunk_overlap
) # Create embeddings here
return jsonify({"response": response})
# If user would like to upload existing embeddings, it can be done here
elif action == "upload_index":
file = request.files.get("file")
if not file:
return jsonify({"response": "No file provided"}), 400
response = "Should be implemented"
return jsonify({"response": response})
# History of Q&A
elif action == "chat":
request_data = request.get_json()
question = request_data.get("question", "")
if not question:
return jsonify({"response": "No question provided"}), 400
user_question.clear()
user_question.append(question)
# Check results file for the answer
question, answer = check_results_file()
if answer:
return jsonify({"response": [{"question": question, "answer": answer}]})
else:
return jsonify({"response": []})
else:
return jsonify({"response": "Invalid action"}), 400
def check_results_file():
# Read the stored contents (q_id, question and answer) created by the server
file_path = os.path.join(local_directory, "results.txt")
if not os.path.exists(file_path):
print("Results file not found.")
return "", ""
with open(file_path, "r") as file:
lines = file.readlines()
question, answer = None, None
for line in lines:
line = line.strip()
if line.startswith("Question:"):
question = line.split(":", 1)[1].strip()
elif line.startswith("Answer:"):
answer = line.split(":", 1)[1].strip()
if not question or not answer:
print("No question or answer found in the results file.")
return "", ""
return question, answer
def get_user_question():
# Send the user question to the server for answer retrieval
return user_question
# Plotting begins here.............
@app.route("/umap_plot")
def umap_plot():
embeddings_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
df = start_df()
r_question, r_answer = check_results_file()
question_df = pd.DataFrame(
{
"id": "question",
"question": r_question,
"embedding": [embeddings_model.embed_query(r_question)],
}
)
answer_df = pd.DataFrame(
{
"id": "answer",
"answer": r_answer,
"embedding": [embeddings_model.embed_query(r_answer)],
}
)
df = pd.concat([question_df, answer_df, df], ignore_index=True)
df = update_df(df)
fig = plot_embeddings(df)
return send_plot_as_image(fig)
@app.route("/cluster_plot")
def cluster_plot():
n_clusters = 10
df = start_df()
df = update_df(df)
X = df[["UMAP_x", "UMAP_y"]]
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
df["cluster"] = kmeans.fit_predict(X)
plt.figure(figsize=(10, 8))
sns.scatterplot(
x="UMAP_x",
y="UMAP_y",
hue="cluster",
data=df,
palette="viridis",
s=100,
alpha=0.8,
)
plt.title("Clusters based on UMAP embeddings")
img_bytes = BytesIO()
plt.savefig(img_bytes, format="png")
img_bytes.seek(0)
plt.close()
return send_file(img_bytes, mimetype="image/png", as_attachment=False)
def send_plot_as_image(fig):
img_bytes = fig.to_image(format="png")
return send_file(BytesIO(img_bytes), mimetype="image/png")
@app.route('/rate_readme', methods=['POST'])
def rate_readme():
global faithfulness
global relevancy
faithfulness_score = float(faithfulness)
relevancy_score = float(relevancy)
try:
data = request.json
rating = data['rating']
feedback = data.get('feedback', '') # Get the feedback string, default to empty string if not provided
readme_ratings.setdefault(local_directory, []).append({'rating': rating, 'feedback': feedback, 'faithfulness' : faithfulness_score, 'answer_Relevancy' : relevancy_score})
logger.info(readme_ratings)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
ratings_filename = os.path.join(local_directory, "ratings.json")
rating_dict={"Rating": rating, "Feedback": feedback, "Faithfulness": faithfulness_score, "Answer_Relevancy": relevancy_score, "Timestamp": timestamp}
ratings_list.append(rating_dict)
with open(ratings_filename, 'w') as file:
json.dump(ratings_list, file, indent=4)
logger.info(f"Rating: {rating}, Feedback: {feedback}, Faithfulness: {faithfulness_score}, Answer_Relevancy: {relevancy_score}, Timestamp: {timestamp}\n")
except Exception as e:
logger.exception("An error occurred")
return jsonify({'success': True})
except Exception as e:
logger.info("Exception:", e)
return jsonify({'success': False, 'error': str(e)})
def app_run():
app.run(host="0.0.0.0", port=8062, debug=False)
import grpc
import rag_pb2
import rag_pb2_grpc
def run():
# Dummy implementation
with grpc.insecure_channel("localhost:8061") as channel:
stub = rag_pb2_grpc.RAGServiceStub(channel)
responses = stub.instruct_llm_query(rag_pb2.Empty())
for response in responses:
print(f"Received LLM Query: {response.input.user}")
print(f"Question: {response.qa.question}")
if __name__ == "__main__":
run()
RAG-pipelines/RAG-Node/image.png

28.5 KiB

syntax = "proto3";
message Empty {
}
message LLMConfig {
double temp = 1;
}
message PromptInput {
string system = 1;
string user = 2;
string context = 3;
string prompt = 4;
}
message UserQuestion {
string question = 1;
}
message ConvoID {
string q_id = 1;
}
message LLMQuery {
LLMConfig config = 1;
PromptInput input = 2;
UserQuestion qa = 3;
ConvoID id = 4;
}
message LLMAnswer {
string text = 1;
ConvoID id = 2;
string relevant_context = 3;
}
message Status {
string message = 1;
}
service RAGService {
rpc instruct_llm_query(Empty) returns( stream LLMQuery);
rpc retrieve_llm_response( stream LLMAnswer) returns(Status);
}
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: rag.proto
# Protobuf Python Version: 5.26.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n\trag.proto"\x07\n\x05\x45mpty"\x19\n\tLLMConfig\x12\x0c\n\x04temp\x18\x01 \x01(\x01"L\n\x0bPromptInput\x12\x0e\n\x06system\x18\x01 \x01(\t\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x0f\n\x07\x63ontext\x18\x03 \x01(\t\x12\x0e\n\x06prompt\x18\x04 \x01(\t" \n\x0cUserQuestion\x12\x10\n\x08question\x18\x01 \x01(\t"\x17\n\x07\x43onvoID\x12\x0c\n\x04q_id\x18\x01 \x01(\t"t\n\x08LLMQuery\x12\x1a\n\x06\x63onfig\x18\x01 \x01(\x0b\x32\n.LLMConfig\x12\x1b\n\x05input\x18\x02 \x01(\x0b\x32\x0c.PromptInput\x12\x19\n\x02qa\x18\x03 \x01(\x0b\x32\r.UserQuestion\x12\x14\n\x02id\x18\x04 \x01(\x0b\x32\x08.ConvoID"/\n\tLLMAnswer\x12\x0c\n\x04text\x18\x01 \x01(\t\x12\x14\n\x02id\x18\x02 \x01(\x0b\x32\x08.ConvoID"\x19\n\x06Status\x12\x0f\n\x07message\x18\x01 \x01(\t2g\n\nRAGService\x12)\n\x12instruct_llm_query\x12\x06.Empty\x1a\t.LLMQuery0\x01\x12.\n\x15retrieve_llm_response\x12\n.LLMAnswer\x1a\x07.Status(\x01\x62\x06proto3'
)
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "rag_pb2", _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals["_EMPTY"]._serialized_start = 13
_globals["_EMPTY"]._serialized_end = 20
_globals["_LLMCONFIG"]._serialized_start = 22
_globals["_LLMCONFIG"]._serialized_end = 47
_globals["_PROMPTINPUT"]._serialized_start = 49
_globals["_PROMPTINPUT"]._serialized_end = 125
_globals["_USERQUESTION"]._serialized_start = 127
_globals["_USERQUESTION"]._serialized_end = 159
_globals["_CONVOID"]._serialized_start = 161
_globals["_CONVOID"]._serialized_end = 184
_globals["_LLMQUERY"]._serialized_start = 186
_globals["_LLMQUERY"]._serialized_end = 302
_globals["_LLMANSWER"]._serialized_start = 304
_globals["_LLMANSWER"]._serialized_end = 351
_globals["_STATUS"]._serialized_start = 353
_globals["_STATUS"]._serialized_end = 378
_globals["_RAGSERVICE"]._serialized_start = 380
_globals["_RAGSERVICE"]._serialized_end = 483
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import warnings
import rag_pb2 as rag__pb2
GRPC_GENERATED_VERSION = "1.65.5"
GRPC_VERSION = grpc.__version__
EXPECTED_ERROR_RELEASE = "1.66.0"
SCHEDULED_RELEASE_DATE = "August 6, 2024"
_version_not_supported = False
try:
from grpc._utilities import first_version_is_lower
_version_not_supported = first_version_is_lower(
GRPC_VERSION, GRPC_GENERATED_VERSION
)
except ImportError:
_version_not_supported = True
if _version_not_supported:
warnings.warn(
f"The grpc package installed is at version {GRPC_VERSION},"
+ f" but the generated code in rag_pb2_grpc.py depends on"
+ f" grpcio>={GRPC_GENERATED_VERSION}."
+ f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
+ f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
+ f" This warning will become an error in {EXPECTED_ERROR_RELEASE},"
+ f" scheduled for release on {SCHEDULED_RELEASE_DATE}.",
RuntimeWarning,
)
class RAGServiceStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.instruct_llm_query = channel.unary_stream(
"/RAGService/instruct_llm_query",
request_serializer=rag__pb2.Empty.SerializeToString,
response_deserializer=rag__pb2.LLMQuery.FromString,
_registered_method=True,
)
self.retrieve_llm_response = channel.stream_unary(
"/RAGService/retrieve_llm_response",
request_serializer=rag__pb2.LLMAnswer.SerializeToString,
response_deserializer=rag__pb2.Status.FromString,
_registered_method=True,
)
class RAGServiceServicer(object):
"""Missing associated documentation comment in .proto file."""
def instruct_llm_query(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
def retrieve_llm_response(self, request_iterator, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
def add_RAGServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
"instruct_llm_query": grpc.unary_stream_rpc_method_handler(
servicer.instruct_llm_query,
request_deserializer=rag__pb2.Empty.FromString,
response_serializer=rag__pb2.LLMQuery.SerializeToString,
),
"retrieve_llm_response": grpc.stream_unary_rpc_method_handler(
servicer.retrieve_llm_response,
request_deserializer=rag__pb2.LLMAnswer.FromString,
response_serializer=rag__pb2.Status.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
"RAGService", rpc_method_handlers
)
server.add_generic_rpc_handlers((generic_handler,))
server.add_registered_method_handlers("RAGService", rpc_method_handlers)
# This class is part of an EXPERIMENTAL API.
class RAGService(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def instruct_llm_query(
request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.unary_stream(
request,
target,
"/RAGService/instruct_llm_query",
rag__pb2.Empty.SerializeToString,
rag__pb2.LLMQuery.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True,
)
@staticmethod
def retrieve_llm_response(
request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None,
):
return grpc.experimental.stream_unary(
request_iterator,
target,
"/RAGService/retrieve_llm_response",
rag__pb2.LLMAnswer.SerializeToString,
rag__pb2.Status.FromString,
options,
channel_credentials,
insecure,
call_credentials,
compression,
wait_for_ready,
timeout,
metadata,
_registered_method=True,
)
import os
import umap
import numpy as np
from langchain.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
# FAISS folder path
local_directory = os.getenv("SHARED_FOLDER_PATH")
faiss_folder = os.path.join(local_directory, "faiss_index")
def process_pdf(file_name, chunk_size=500, chunk_overlap=50):
# Load and process the PDF document
loader = PyPDFLoader(file_name)
raw_document = loader.load()
# Split the document into chunks
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
chunks = splitter.split_documents(raw_document)
# Create and save FAISS index
embedding_model = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
faiss_db = FAISS.from_documents(chunks, embedding=embedding_model)
faiss_db.save_local(faiss_folder)
return "FAISS index processed and saved successfully"
def faiss_db():
# Retriever contents can be obtained here
embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-base-en-v1.5")
try:
saved_db = FAISS.load_local(
faiss_folder, embeddings, allow_dangerous_deserialization=True
)
retriever = saved_db.as_retriever(
search_type="similarity", search_kwargs={"k": 4}
)
print(retriever)
print(type(retriever))
return saved_db, retriever
except Exception as e:
print(f"Error loading FAISS database: {e}")
return None
# Start creating the dataframe for visualization
def start_df():
saved_db, retriever = faiss_db()
vs = saved_db.__dict__.get("docstore")
index_list = saved_db.__dict__.get("index_to_docstore_id").values()
doc_cnt = saved_db.index.ntotal
embeddings_vec = saved_db.index.reconstruct_n()
doc_list = list()
for i, docid in enumerate(index_list):
a_doc = vs.search(docid)
doc_list.append(
[docid, a_doc.metadata.get("source"), a_doc.page_content, embeddings_vec[i]]
)
df = pd.DataFrame(doc_list, columns=["id", "metadata", "document", "embedding"])
return df
def update_df(df):
embeddings = np.vstack(df["embedding"].values)
umap_model = umap.UMAP(n_neighbors=15, min_dist=0.1, metric="cosine")
umap_embeddings = umap_model.fit_transform(embeddings)
df["UMAP_x"] = umap_embeddings[:, 0]
df["UMAP_y"] = umap_embeddings[:, 1]
return df
grpcio==1.38.0
grpcio-tools==1.38.0
grpc-interceptor
protobuf==3.16.0
multithreading
Bootstrap-Flask
Flask
Flask-WTF
flask_session
WTForms
requests
python-dotenv
langchain
langchain_community
umap-learn
pandas
matplotlib
sentence-transformers
faiss-cpu
pypdf
numpy
seaborn
umap-learn
scikit-learn
plotly
kaleido
ragas==0.1.7
\ No newline at end of file
import grpc
from concurrent import futures
import threading
import time
import rag_pb2
import rag_pb2_grpc
from app import app_run, get_user_question
import json
import os
import uuid
from datasets import Dataset
from ragas.metrics import faithfulness, answer_relevancy
from ragas import evaluate
system_instruction = """
user
You are an assistant for answering questions.
You are given the extracted parts of a long document and a question. Provide a conversational answer.
If you don't know the answer, just say "I do not know." Don't make up an answer.
"""
local_directory = os.getenv("SHARED_FOLDER_PATH")
os.environ["OPENAI_API_KEY"] = "sk-ai-builder-rhzhJNo29KLDuN2gT4JsT3BlbkFJ7Vw7lOks4OqAm4wRw1YZ"
os.environ['OPENAI_ORGANIZATION'] = "org-XSbDbl5S9lfwrFTncFOBVOMx"
class RAGService(rag_pb2_grpc.RAGServiceServicer):
def __init__(self):
super().__init__()
self.responses = {}
self.current_question = None # Store the last processed question
def instruct_llm_query(self, request, context):
while True:
params = get_user_question()
if not params:
time.sleep(0.5)
continue
self.new_question = params[0]
if self.new_question == self.current_question:
time.sleep(0.5)
continue
self.current_question = self.new_question
question_id = str(uuid.uuid4())
input_text = f"You asked: {self.new_question}"
Pinput = rag_pb2.PromptInput(
system="System instruction here",
user=input_text,
context="",
prompt=input_text,
)
QAinput = rag_pb2.UserQuestion(question=self.new_question)
llm_query = rag_pb2.LLMQuery(
id=rag_pb2.ConvoID(q_id=question_id), input=Pinput, qa=QAinput
)
print(llm_query)
yield llm_query
while question_id not in self.responses:
time.sleep(0.1)
def retrieve_llm_response(self, request_iterator, context):
qa_pairs = []
for answer in request_iterator:
question_id = answer.id.q_id
print(f"Received LLMAnswer for ID {question_id}: {answer.text}")
self.responses[question_id] = answer.text
faithfulness, relevancy = self.calculate_metrics(answer.relevant_context, answer.text)
print(f'faithfulness, relevancy: {faithfulness}, {relevancy}')
qa_pairs.append(
{
"question_id": question_id,
"question": self.new_question,
"answer": answer.text,
}
)
with open(os.path.join(local_directory, "results.txt"), "w") as f:
f.write(f"Question ID: {question_id}\n")
f.write(f"Question: {self.new_question}\n")
f.write(f"Answer: {answer.text}\n\n")
with open("metrics.txt", mode="a+") as f:
# for e0, e1, e2, e3, e4, e5 in result:
f.write("|" + str(round(faithfulness,3)) + "|" + str(round(relevancy,3)) + "\n")
f.close()
return rag_pb2.Status(message="All answers received successfully.")
def calculate_metrics(self, relevant_info, rag_output):
data_samples = {
'question': [str(self.new_question)],
'answer': [str(rag_output)],
'contexts' : [[str(relevant_info)]]
}
dataset = Dataset.from_dict(data_samples)
faithfulness_score = evaluate(dataset,metrics=[faithfulness])
answer_relevancy_score = evaluate(dataset,metrics=[answer_relevancy])
return faithfulness_score["faithfulness"], answer_relevancy_score["answer_relevancy"]
def serve(port):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
rag_pb2_grpc.add_RAGServiceServicer_to_server(RAGService(), server)
server.add_insecure_port(f"[::]:{port}")
print(f"Starting server. Listening on port: {port}")
server.start()
threading.Thread(target=app_run).start()
server.wait_for_termination()
if __name__ == "__main__":
port = 8061
open('metrics.txt', 'w').close()
serve(port)
This diff is collapsed.
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Visualizations</title>
<link rel="stylesheet" href="styles.css"> <!-- Optional: Link to your CSS file -->
</head>
<body>
<div class="container">
<h1>RAG Visualizations</h1>
<div id="visualizationContent">
<!-- Insert your visualizations here, e.g., charts, graphs, here, the UMAP and cluster plot here -->
<p>Visualization content will be displayed here.</p>
</div>
<a href="index.html" class="button">Back to Main Page</a>
</div>
</body>
</html>
import umap
import plotly.express as px
from scipy.spatial.distance import cdist, euclidean
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
def plot_embeddings(df):
question_coords = df[df["id"] == "question"][["UMAP_x", "UMAP_y"]]
answer_coords = df[df["id"] == "answer"][["UMAP_x", "UMAP_y"]]
if not question_coords.empty and not answer_coords.empty:
question_point = question_coords[["UMAP_x", "UMAP_y"]].values[0]
answer_point = answer_coords[["UMAP_x", "UMAP_y"]].values[0]
df["distance_to_question"] = cdist(df[["UMAP_x", "UMAP_y"]], [question_point])
df["distance_to_answer"] = cdist(df[["UMAP_x", "UMAP_y"]], [answer_point])
max_distance = 0.5
df_filtered = df[
(df["distance_to_question"] < max_distance)
| (df["distance_to_answer"] < max_distance)
]
df_filtered["document"] = df_filtered["document"].fillna("No content available")
df_filtered["hover_text"] = (
df_filtered["document"]
.astype(str)
.apply(lambda x: "<br>".join(x.split("\n")))
)
fig = px.scatter(
df_filtered,
x="UMAP_x",
y="UMAP_y",
color="id",
hover_name="id",
hover_data={"document": False, "hover_text": True},
title="UMAP Projection of Closely Related Documents, Question, and Answer",
labels={"UMAP_x": "UMAP Dimension 1", "UMAP_y": "UMAP Dimension 2"},
color_discrete_map={"question": "red", "answer": "blue"},
)
fig.add_scatter(
x=[question_point[0]],
y=[question_point[1]],
mode="markers+text",
text=["Question"],
marker=dict(color="red", size=12),
textposition="top center",
showlegend=False,
)
fig.add_scatter(
x=[answer_point[0]],
y=[answer_point[1]],
mode="markers+text",
text=["Answer"],
marker=dict(color="blue", size=12),
textposition="top center",
showlegend=False,
)
distance = euclidean(question_point, answer_point)
mid_x = (question_point[0] + answer_point[0]) / 2
mid_y = (question_point[1] + answer_point[1]) / 2
fig.add_annotation(
x=mid_x,
y=mid_y,
text=f"Distance: {distance:.2f}",
showarrow=True,
arrowhead=2,
ax=0,
ay=-40,
font=dict(size=12, color="black"),
align="center",
bgcolor="white",
borderpad=4,
)
fig.update_layout(
legend_title="Legend", legend=dict(x=0.8, y=0.1, traceorder="normal")
)
fig.update_layout(
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
title="Legend",
traceorder="normal",
font=dict(size=12),
bgcolor="rgba(255, 255, 255, 0)",
bordercolor="rgba(0, 0, 0, 0)",
borderwidth=0,
visible=False, # Hide the legend
)
)
fig.show()
return fig
def cluster(df):
n_clusters = 10
X = df[["UMAP_x", "UMAP_y"]]
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
df["cluster"] = kmeans.fit_predict(X)
plt.figure(figsize=(10, 8))
sns.scatterplot(
x="UMAP_x",
y="UMAP_y",
hue="cluster",
data=df,
palette="viridis",
s=100,
alpha=0.8,
)
plt.title("Clusters based on UMAP embeddings")
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment