Skip to content
Snippets Groups Projects
Commit fe035a77 authored by Danial Hezarkhani's avatar Danial Hezarkhani
Browse files

added bugfix for detr hpc. removed inference method from it.

parent 1be959e3
No related branches found
No related tags found
1 merge request!12Detr hpc merge to main
Showing
with 27 additions and 223 deletions
......@@ -10,6 +10,7 @@ RUN pip install --no-cache-dir -r requirements.docker.txt
ENV PRODUCTION true
ENV CUSTOM_MODEL_DIR custom_model
ENV DEFAULT_MODEL_DIR default_model
ENV DATASET_DIR dataset
ENV TENSOR_BOARD_LOG_DIR tensorboard_logs/
ENV ANNOTATION_FILE_NAME _annotations.coco.json
......@@ -18,6 +19,7 @@ ENV MAX_NU_METRIC_IMAGES 10
ENV DETECTION_THRESHOLD 0.8
WORKDIR /home/objectdetection
RUN ls
#RUN pip install --no-cache-dir jupyter
......
......@@ -8,6 +8,7 @@ class Config:
# get path to the models saving location
if PROD_FLAG:
self.default_model_path = os.path.join(SHARED_FOLDER, os.environ['DEFAULT_MODEL_DIR'])
self.custom_model_path = os.path.join(SHARED_FOLDER, os.environ['CUSTOM_MODEL_DIR'])
self.dataset_path = os.path.join(SHARED_FOLDER, os.environ['DATASET_DIR'])
self.lightning_logs_dir = os.path.join(SHARED_FOLDER, os.environ['TENSOR_BOARD_LOG_DIR'])
......@@ -34,7 +35,7 @@ class Config:
self.TENSORBOARD_FOLDER = os.getenv(self.lightning_logs_dir, "./detr/dev/tensorboard_logs")
self.create_not_existing_dirs([SHARED_FOLDER,self.custom_model_path,self.dataset_path,self.lightning_logs_dir,self.input_path,self.output_path,self.status_folder])
self.create_not_existing_dirs([SHARED_FOLDER,self.custom_model_path,self.default_model_path,self.dataset_path,self.lightning_logs_dir,self.input_path,self.output_path,self.status_folder])
def create_not_existing_dirs(self, list_of_paths):
......
......@@ -26,10 +26,6 @@ def run_inference(input_path, output_path):
ui_request.output = output_path
response = stub.detect_objects(ui_request)
# training
# ui_request = model_pb2.TrainingConfig(config="test")
# response = stub.startTraining(ui_request)
logger.debug("Inference finished")
print(response)
......
......@@ -14,17 +14,7 @@ message TrainStatus {
int32 status = 1 ;
}
message ObjectDetectionInputFile {
string input = 1;
string output = 2;
}
message ObjectDetectionOutputFile {
string path = 1;
}
service DetrModelServicer {
rpc detect_objects(ObjectDetectionInputFile) returns (ObjectDetectionOutputFile);
rpc startTraining(TrainingConfig) returns (TrainStatus);
}
......@@ -14,15 +14,13 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bmodel.proto\"\x07\n\x05\x45mpty\" \n\x0eTrainingConfig\x12\x0e\n\x06\x63onfig\x18\x01 \x01(\t\"\x1d\n\x0bTrainStatus\x12\x0e\n\x06status\x18\x01 \x01(\x05\"9\n\x18ObjectDetectionInputFile\x12\r\n\x05input\x18\x01 \x01(\t\x12\x0e\n\x06output\x18\x02 \x01(\t\")\n\x19ObjectDetectionOutputFile\x12\x0c\n\x04path\x18\x01 \x01(\t2\x8c\x01\n\x11\x44\x65trModelServicer\x12G\n\x0e\x64\x65tect_objects\x12\x19.ObjectDetectionInputFile\x1a\x1a.ObjectDetectionOutputFile\x12.\n\rstartTraining\x12\x0f.TrainingConfig\x1a\x0c.TrainStatusb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bmodel.proto\"\x07\n\x05\x45mpty\" \n\x0eTrainingConfig\x12\x0e\n\x06\x63onfig\x18\x01 \x01(\t\"\x1d\n\x0bTrainStatus\x12\x0e\n\x06status\x18\x01 \x01(\x05\x32\x43\n\x11\x44\x65trModelServicer\x12.\n\rstartTraining\x12\x0f.TrainingConfig\x1a\x0c.TrainStatusb\x06proto3')
_EMPTY = DESCRIPTOR.message_types_by_name['Empty']
_TRAININGCONFIG = DESCRIPTOR.message_types_by_name['TrainingConfig']
_TRAINSTATUS = DESCRIPTOR.message_types_by_name['TrainStatus']
_OBJECTDETECTIONINPUTFILE = DESCRIPTOR.message_types_by_name['ObjectDetectionInputFile']
_OBJECTDETECTIONOUTPUTFILE = DESCRIPTOR.message_types_by_name['ObjectDetectionOutputFile']
Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), {
'DESCRIPTOR' : _EMPTY,
'__module__' : 'model_pb2'
......@@ -44,20 +42,6 @@ TrainStatus = _reflection.GeneratedProtocolMessageType('TrainStatus', (_message.
})
_sym_db.RegisterMessage(TrainStatus)
ObjectDetectionInputFile = _reflection.GeneratedProtocolMessageType('ObjectDetectionInputFile', (_message.Message,), {
'DESCRIPTOR' : _OBJECTDETECTIONINPUTFILE,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:ObjectDetectionInputFile)
})
_sym_db.RegisterMessage(ObjectDetectionInputFile)
ObjectDetectionOutputFile = _reflection.GeneratedProtocolMessageType('ObjectDetectionOutputFile', (_message.Message,), {
'DESCRIPTOR' : _OBJECTDETECTIONOUTPUTFILE,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:ObjectDetectionOutputFile)
})
_sym_db.RegisterMessage(ObjectDetectionOutputFile)
_DETRMODELSERVICER = DESCRIPTOR.services_by_name['DetrModelServicer']
if _descriptor._USE_C_DESCRIPTORS == False:
......@@ -68,10 +52,6 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_TRAININGCONFIG._serialized_end=56
_TRAINSTATUS._serialized_start=58
_TRAINSTATUS._serialized_end=87
_OBJECTDETECTIONINPUTFILE._serialized_start=89
_OBJECTDETECTIONINPUTFILE._serialized_end=146
_OBJECTDETECTIONOUTPUTFILE._serialized_start=148
_OBJECTDETECTIONOUTPUTFILE._serialized_end=189
_DETRMODELSERVICER._serialized_start=192
_DETRMODELSERVICER._serialized_end=332
_DETRMODELSERVICER._serialized_start=89
_DETRMODELSERVICER._serialized_end=156
# @@protoc_insertion_point(module_scope)
......@@ -14,11 +14,6 @@ class DetrModelServicerStub(object):
Args:
channel: A grpc.Channel.
"""
self.detect_objects = channel.unary_unary(
'/DetrModelServicer/detect_objects',
request_serializer=model__pb2.ObjectDetectionInputFile.SerializeToString,
response_deserializer=model__pb2.ObjectDetectionOutputFile.FromString,
)
self.startTraining = channel.unary_unary(
'/DetrModelServicer/startTraining',
request_serializer=model__pb2.TrainingConfig.SerializeToString,
......@@ -29,12 +24,6 @@ class DetrModelServicerStub(object):
class DetrModelServicerServicer(object):
"""Missing associated documentation comment in .proto file."""
def detect_objects(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def startTraining(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
......@@ -44,11 +33,6 @@ class DetrModelServicerServicer(object):
def add_DetrModelServicerServicer_to_server(servicer, server):
rpc_method_handlers = {
'detect_objects': grpc.unary_unary_rpc_method_handler(
servicer.detect_objects,
request_deserializer=model__pb2.ObjectDetectionInputFile.FromString,
response_serializer=model__pb2.ObjectDetectionOutputFile.SerializeToString,
),
'startTraining': grpc.unary_unary_rpc_method_handler(
servicer.startTraining,
request_deserializer=model__pb2.TrainingConfig.FromString,
......@@ -64,23 +48,6 @@ def add_DetrModelServicerServicer_to_server(servicer, server):
class DetrModelServicer(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def detect_objects(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/DetrModelServicer/detect_objects',
model__pb2.ObjectDetectionInputFile.SerializeToString,
model__pb2.ObjectDetectionOutputFile.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def startTraining(request,
target,
......
......@@ -23,14 +23,15 @@ class DetrOD():
self.config = config
# get path to folder with scripts
self.current_dir_path = os.path.dirname(os.path.abspath(__file__))
self.default_model_path = os.path.join(self.current_dir_path, "default_model")
#self.default_model_path = os.path.join(self.current_dir_path, "default_model")
self.default_model_path = config.default_model_path
model_path = self.default_model_path
extractor_path = self.default_model_path
try:
# loading default model & extractor on the first stop always
self.model = DetrForObjectDetection.from_pretrained(model_path, local_files_only=config.PROD_FLAG)
self.extractor = DetrImageProcessor.from_pretrained(extractor_path, local_files_only=config.PROD_FLAG)
self.model = DetrForObjectDetection.from_pretrained(model_path)
self.extractor = DetrImageProcessor.from_pretrained(extractor_path)
except Exception:
logger.info("Could not find the default model. Downloading it now")
self.download_default_model()
......@@ -38,10 +39,11 @@ class DetrOD():
def download_default_model(self):
checkpoint = "ciasimbaya/ObjectDetection"
model = DetrForObjectDetection.from_pretrained(checkpoint)
extractor = DetrImageProcessor.from_pretrained(checkpoint)
extractor.save_pretrained(self.default_model_path)
model.model.save_pretrained(self.default_model_path)
self.model = DetrForObjectDetection.from_pretrained(checkpoint)
self.extractor = DetrImageProcessor.from_pretrained(checkpoint)
# In HPC files cannot be saved under apptainer folders.
self.extractor.save_pretrained(self.default_model_path)
self.model.model.save_pretrained(self.default_model_path)
def load_custom_model(self):
......
......@@ -15,6 +15,8 @@ ENV ANNOTATION_FILE_NAME _annotations.coco.json
WORKDIR /home/databroker
RUN ls
COPY main.py app.py config.py logger.py s3_func.py model.proto model_pb2.py model_pb2_grpc.py detr_databroker_server.py ./
COPY static ./static
COPY templates ./templates
......
......@@ -4,7 +4,6 @@ import model_pb2_grpc
from concurrent import futures
import threading
import os
import app as webui
from logger import Logger
logger = Logger(__name__)
......@@ -19,48 +18,11 @@ class DetrDataBrokerServicer(model_pb2_grpc.DetrDatabroker):
def send_training_signal(self, request, context):
logger.info("Preparing training signal to be sent to the model.")
out_message = model_pb2.TrainingConfig()
# checking if training message is already being sent
if not webui.is_training_startable():
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("No new data for training has been detected.")
logger.info("Training not startable. Not sending training signal to the model.")
return out_message
training_status = "starting"
out_message.config=training_status
logger.info("Sending training signal to the model.")
webui.change_new_train_flag(False)
return out_message
def send_detect_signal(self, request, context):
logger.info("Preparing inference signal to be sent to the model.")
image_path = webui.get_image_path()
out_message = model_pb2.ObjectDetectionInputFile()
# check if the request is already sent
if not webui.is_inference_startable():
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("No new data for inference has been detected.")
logger.info("Inference not startable. Not sending inference signal to the model.")
return out_message
if not image_path:
logger.error("No image to be found under the path")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("No image to be found under the path")
return out_message
if os.path.exists(image_path):
out_message.path = image_path
logger.info("Sending inference signal to the model.")
webui.change_new_inference_flag(False)
return out_message
else:
logger.error("Inference image is not saved and cannot be found. No request was sent")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("Inference image is not saved.")
def serve_grpc(config,port):
......
......@@ -3,22 +3,10 @@ syntax = "proto3";
message Empty {
}
message DatasetFeatues{
string datasetname = 1;
string description = 2;
string size = 3;
string DOI_ID = 4;
}
message ObjectDetectionInputFile {
string path = 1;
}
message TrainingConfig {
string config = 1;
}
service DetrDatabroker {
rpc send_detect_signal(Empty) returns (ObjectDetectionInputFile);
rpc send_training_signal(Empty) returns (TrainingConfig);
}
\ No newline at end of file
......@@ -14,13 +14,11 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bmodel.proto\"\x07\n\x05\x45mpty\"X\n\x0e\x44\x61tasetFeatues\x12\x13\n\x0b\x64\x61tasetname\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0c\n\x04size\x18\x03 \x01(\t\x12\x0e\n\x06\x44OI_ID\x18\x04 \x01(\t\"(\n\x18ObjectDetectionInputFile\x12\x0c\n\x04path\x18\x01 \x01(\t\" \n\x0eTrainingConfig\x12\x0e\n\x06\x63onfig\x18\x01 \x01(\t2\xab\x01\n\x0e\x44\x65trDatabroker\x12\x37\n\x12send_detect_signal\x12\x06.Empty\x1a\x19.ObjectDetectionInputFile\x12/\n\x14get_dataset_metadata\x12\x06.Empty\x1a\x0f.DatasetFeatues\x12/\n\x14send_training_signal\x12\x06.Empty\x1a\x0f.TrainingConfigb\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bmodel.proto\"\x07\n\x05\x45mpty\" \n\x0eTrainingConfig\x12\x0e\n\x06\x63onfig\x18\x01 \x01(\t2A\n\x0e\x44\x65trDatabroker\x12/\n\x14send_training_signal\x12\x06.Empty\x1a\x0f.TrainingConfigb\x06proto3')
_EMPTY = DESCRIPTOR.message_types_by_name['Empty']
_DATASETFEATUES = DESCRIPTOR.message_types_by_name['DatasetFeatues']
_OBJECTDETECTIONINPUTFILE = DESCRIPTOR.message_types_by_name['ObjectDetectionInputFile']
_TRAININGCONFIG = DESCRIPTOR.message_types_by_name['TrainingConfig']
Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), {
'DESCRIPTOR' : _EMPTY,
......@@ -29,20 +27,6 @@ Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), {
})
_sym_db.RegisterMessage(Empty)
DatasetFeatues = _reflection.GeneratedProtocolMessageType('DatasetFeatues', (_message.Message,), {
'DESCRIPTOR' : _DATASETFEATUES,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:DatasetFeatues)
})
_sym_db.RegisterMessage(DatasetFeatues)
ObjectDetectionInputFile = _reflection.GeneratedProtocolMessageType('ObjectDetectionInputFile', (_message.Message,), {
'DESCRIPTOR' : _OBJECTDETECTIONINPUTFILE,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:ObjectDetectionInputFile)
})
_sym_db.RegisterMessage(ObjectDetectionInputFile)
TrainingConfig = _reflection.GeneratedProtocolMessageType('TrainingConfig', (_message.Message,), {
'DESCRIPTOR' : _TRAININGCONFIG,
'__module__' : 'model_pb2'
......@@ -56,12 +40,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_EMPTY._serialized_start=15
_EMPTY._serialized_end=22
_DATASETFEATUES._serialized_start=24
_DATASETFEATUES._serialized_end=112
_OBJECTDETECTIONINPUTFILE._serialized_start=114
_OBJECTDETECTIONINPUTFILE._serialized_end=154
_TRAININGCONFIG._serialized_start=156
_TRAININGCONFIG._serialized_end=188
_DETRDATABROKER._serialized_start=191
_DETRDATABROKER._serialized_end=362
_TRAININGCONFIG._serialized_start=24
_TRAININGCONFIG._serialized_end=56
_DETRDATABROKER._serialized_start=58
_DETRDATABROKER._serialized_end=123
# @@protoc_insertion_point(module_scope)
......@@ -14,16 +14,6 @@ class DetrDatabrokerStub(object):
Args:
channel: A grpc.Channel.
"""
self.send_detect_signal = channel.unary_unary(
'/DetrDatabroker/send_detect_signal',
request_serializer=model__pb2.Empty.SerializeToString,
response_deserializer=model__pb2.ObjectDetectionInputFile.FromString,
)
self.get_dataset_metadata = channel.unary_unary(
'/DetrDatabroker/get_dataset_metadata',
request_serializer=model__pb2.Empty.SerializeToString,
response_deserializer=model__pb2.DatasetFeatues.FromString,
)
self.send_training_signal = channel.unary_unary(
'/DetrDatabroker/send_training_signal',
request_serializer=model__pb2.Empty.SerializeToString,
......@@ -34,18 +24,6 @@ class DetrDatabrokerStub(object):
class DetrDatabrokerServicer(object):
"""Missing associated documentation comment in .proto file."""
def send_detect_signal(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def get_dataset_metadata(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def send_training_signal(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
......@@ -55,16 +33,6 @@ class DetrDatabrokerServicer(object):
def add_DetrDatabrokerServicer_to_server(servicer, server):
rpc_method_handlers = {
'send_detect_signal': grpc.unary_unary_rpc_method_handler(
servicer.send_detect_signal,
request_deserializer=model__pb2.Empty.FromString,
response_serializer=model__pb2.ObjectDetectionInputFile.SerializeToString,
),
'get_dataset_metadata': grpc.unary_unary_rpc_method_handler(
servicer.get_dataset_metadata,
request_deserializer=model__pb2.Empty.FromString,
response_serializer=model__pb2.DatasetFeatues.SerializeToString,
),
'send_training_signal': grpc.unary_unary_rpc_method_handler(
servicer.send_training_signal,
request_deserializer=model__pb2.Empty.FromString,
......@@ -80,40 +48,6 @@ def add_DetrDatabrokerServicer_to_server(servicer, server):
class DetrDatabroker(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def send_detect_signal(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/DetrDatabroker/send_detect_signal',
model__pb2.Empty.SerializeToString,
model__pb2.ObjectDetectionInputFile.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def get_dataset_metadata(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/DetrDatabroker/get_dataset_metadata',
model__pb2.Empty.SerializeToString,
model__pb2.DatasetFeatues.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def send_training_signal(request,
target,
......
......@@ -38,10 +38,10 @@ class DetrOD():
def download_default_model(self):
checkpoint = "ciasimbaya/ObjectDetection"
model = DetrForObjectDetection.from_pretrained(checkpoint)
extractor = DetrImageProcessor.from_pretrained(checkpoint)
extractor.save_pretrained(self.default_model_path)
model.model.save_pretrained(self.default_model_path)
self.model = DetrForObjectDetection.from_pretrained(checkpoint)
self.extractor = DetrImageProcessor.from_pretrained(checkpoint)
self.extractor.save_pretrained(self.default_model_path)
self.model.model.save_pretrained(self.default_model_path)
def load_custom_model(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment