Skip to content
Snippets Groups Projects
Commit 6ec48b84 authored by Swetha Lakshmana Murthy's avatar Swetha Lakshmana Murthy
Browse files

Restructure: agentic workflow code

parent aa3b6f12
No related branches found
No related tags found
No related merge requests found
Showing
with 116 additions and 8 deletions
# Use a slim Python base image # Use a slim Python base image
FROM python:3.10-slim FROM python:3.12-slim
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
......
...@@ -18,7 +18,6 @@ enum AgentType { ...@@ -18,7 +18,6 @@ enum AgentType {
WEATHER = 1; WEATHER = 1;
PLACES = 2; PLACES = 2;
FOOD = 3; FOOD = 3;
MAP = 4;
} }
message AgentResponse { message AgentResponse {
......
grpcio==1.38.0 grpcio
grpcio-tools==1.38.0 grpcio-tools
grpc-interceptor grpc-interceptor
protobuf==3.16.0 protobuf
multithreading multithreading
openai openai
requests requests
......
# Use a slim Python base image # Use a slim Python base image
FROM python:3.10-slim FROM python:3.12-slim
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
......
...@@ -18,7 +18,6 @@ enum AgentType { ...@@ -18,7 +18,6 @@ enum AgentType {
WEATHER = 1; WEATHER = 1;
PLACES = 2; PLACES = 2;
FOOD = 3; FOOD = 3;
MAP = 4;
} }
message AgentResponse { message AgentResponse {
......
grpcio==1.38.0 grpcio
grpcio-tools==1.38.0 grpcio-tools
grpc-interceptor grpc-interceptor
protobuf==3.16.0 protobuf
multithreading multithreading
Flask Flask
torch torch
transformers transformers
\ No newline at end of file hf_xet
\ No newline at end of file
import grpc
from concurrent import futures
import logging
import time
import planner_pb2
import planner_pb2_grpc
from transformers import pipeline
# ------------------ Logger Setup ------------------
def setup_logger():
logger = logging.getLogger("PlannerService")
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
return logger
logger = setup_logger()
# ------------------ Classifier Setup ------------------
def initialize_classifier():
return pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
revision="d7645e1"
)
AGENT_LABELS = ["weather", "places", "food", "map"]
LABEL_TO_AGENT_TYPE = {
"weather": planner_pb2.WEATHER,
"places": planner_pb2.PLACES,
"food": planner_pb2.FOOD
}
def match_agent(text):
try:
result = classifier(text, AGENT_LABELS)
label = result["labels"][0]
logger.info(f"Classifier result: {label}")
return LABEL_TO_AGENT_TYPE.get(label, planner_pb2.UNKNOWN)
except Exception as e:
logger.error(f"Classifier error: {e}")
return planner_pb2.UNKNOWN
# ------------------ gRPC Service Implementation ------------------
class PlannerService(planner_pb2_grpc.PlannerServicer):
def __init__(self):
logger.info("PlannerService initialized.")
self.agent_type_stored = None
def evaluateUserQueriesFromChatbot(self, request_iterator, context):
logger.info("Evaluating user queries from chatbot...")
for user_query in request_iterator:
logger.info(f"User query: {user_query.text}")
agent_type = match_agent(user_query.text)
response = planner_pb2.AgentRequest(
text=user_query.text,
agent_type=agent_type
)
# Store agent type as a readable string
self.agent_type_stored = planner_pb2.AgentType.Name(agent_type)
logger.info(f"Matched agent type: {self.agent_type_stored}")
# Handle unknown queries
if agent_type == planner_pb2.UNKNOWN:
response.text = "Sorry, I couldn't understand the query."
yield response
def processAgentResponsesFromAgent(self, request_iterator, context):
logger.info("Processing responses from agent...")
for response in request_iterator:
logger.info(f"Agent response: {response.text}")
formatted_response = planner_pb2.AgentResponse(
text=f"[**Answering from {self.agent_type_stored} Agent**] \n{response.text}"
)
yield formatted_response
logger.info("Finished processing all agent responses.")
# ------------------ Server Setup ------------------
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
planner_pb2_grpc.add_PlannerServicer_to_server(PlannerService(), server)
port = 8061
server.add_insecure_port(f"[::]:{port}")
server.start()
logger.info(f"Planner Service is running on port {port}")
server.wait_for_termination()
if __name__ == "__main__":
classifier = initialize_classifier()
serve()
# Use a slim Python base image # Use a slim Python base image
FROM python:3.10-slim FROM python:3.12-slim
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment