Skip to content
Snippets Groups Projects
Commit 26128cd4 authored by Danial Hezarkhani's avatar Danial Hezarkhani
Browse files

fixed detr databroker for hpc

parent 4cee84cd
No related branches found
No related tags found
1 merge request!12Detr hpc merge to main
Showing
with 12311 additions and 0 deletions
FROM python:3.8
RUN apt-get update -y
RUN pip install --no-cache-dir -U pip \
&& python -m pip install --upgrade build
COPY requirements.txt ./requirements.txt
RUN pip install --no-cache-dir -r requirements.txt
ENV PRODUCTION true
ENV CUSTOM_MODEL_DIR custom_model
ENV DATASET_DIR dataset
ENV ANNOTATION_FILE_NAME _annotations.coco.json
WORKDIR /home/databroker
COPY main.py app.py config.py logger.py s3_func.py model.proto model_pb2.py model_pb2_grpc.py detr_databroker_server.py ./
COPY static ./static
COPY templates ./templates
ENTRYPOINT [ "python","main.py" ]
from flask import Flask,abort,render_template,request,redirect,url_for, Blueprint
#from flask_bootstrap import Bootstrap5
from flask_bootstrap import Bootstrap
from flask_wtf import FlaskForm
from wtforms.validators import DataRequired, ValidationError
from wtforms.fields import StringField, IntegerField, FloatField, SubmitField
import os
from logger import Logger
import zipfile
import s3_func
import shutil
logger = Logger(__name__)
config = None
app = Flask(__name__)
class S3_dataset_config(FlaskForm):
data_endpoint_url = StringField('Endpoint URL', validators=[DataRequired()])
data_s3_bucket = StringField('S3 Bucket', validators=[DataRequired()])
data_s3_key = StringField('S3 Key', validators=[DataRequired()])
data_access_key = StringField('S3 Access Key', validators=[DataRequired()])
data_secret_key = StringField('S3 Secret Key', validators=[DataRequired()])
data_save = SubmitField('Download from S3', render_kw={"onclick": "loading()"})
class S3_model_config(FlaskForm):
endpoint_url = StringField('Endpoint URL', validators=[DataRequired()])
s3_bucket = StringField('S3 Bucket', validators=[DataRequired()])
s3_key = StringField('S3 Key', validators=[DataRequired()])
access_key = StringField('S3 Access Key', validators=[DataRequired()])
secret_key = StringField('S3 Secret Key', validators=[DataRequired()])
save = SubmitField('Download from S3', render_kw={"onclick": "loading()"})
@app.route('/',methods = ['GET', 'POST'])
def view_home():
annotation_file_name = config.annotation_file_name
return render_template('index.html',annotation_file_name=annotation_file_name)
@app.route('/train',methods = ['GET', 'POST'])
def view_train():
s3_dataset_config = S3_dataset_config()
s3_model_config = S3_model_config()
if request.method == 'POST':
if s3_dataset_config.data_save.data and s3_dataset_config.validate_on_submit():
endpoint_url = s3_dataset_config.data_endpoint_url.data
s3_bucket = s3_dataset_config.data_s3_bucket.data
s3_key = s3_dataset_config.data_s3_key.data
access_key = s3_dataset_config.data_access_key.data
secret_key = s3_dataset_config.data_secret_key.data
clear_dir(config.dataset_path)
s3_func.download_dir(endpoint_url,
s3_bucket,
s3_key,
access_key,
secret_key,
config.dataset_path)
change_new_train_flag(True)
if s3_model_config.save.data and s3_model_config.validate_on_submit():
endpoint_url = s3_model_config.endpoint_url.data
s3_bucket = s3_model_config.s3_bucket.data
s3_key = s3_model_config.s3_key.data
access_key = s3_model_config.access_key.data
secret_key = s3_model_config.secret_key.data
clear_dir(config.custom_model_path)
dl_number = s3_func.download_dir(endpoint_url,
s3_bucket,
s3_key,
access_key,
secret_key,
config.custom_model_path)
change_new_train_flag(True)
return render_template('index_train.html',s3_dataset_config=s3_dataset_config, s3_model_config=s3_model_config)
@app.route('/inference',methods = ['GET', 'POST'])
def view_inference():
s3_model_config = S3_model_config()
if request.method == 'POST':
if s3_model_config.save.data and s3_model_config.validate_on_submit():
endpoint_url = s3_model_config.endpoint_url.data
s3_bucket = s3_model_config.s3_bucket.data
s3_key = s3_model_config.s3_key.data
access_key = s3_model_config.access_key.data
secret_key = s3_model_config.secret_key.data
dl_number = s3_func.download_dir(endpoint_url,
s3_bucket,
s3_key,
access_key,
secret_key,
config.custom_model_path)
change_new_inference_flag(True)
return render_template('index_inference.html', s3_model_config=s3_model_config)
@app.route('/uploadmodel',methods = ['POST'])
def upload_model():
try:
if request.files:
logger.debug("A file being uploaded")
clear_dir(config.custom_model_path)
unzip_file(request.files["directory"], config.custom_model_path)
change_new_inference_flag(True)
change_new_train_flag(True)
else:
logger.debug("request was empty. returning to home")
except Exception as e:
logger.error("Upload failed")
return redirect('/')
@app.route('/uploaddataset',methods = ['POST'])
def upload_dataset():
try:
if request.files:
logger.debug("A file being uploaded")
clear_dir(config.dataset_path)
unzip_file(request.files["directory"], config.dataset_path)
change_new_train_flag(True)
else:
logger.debug("request was empty. returning to home")
except Exception as e:
logger.error("Upload failed")
return redirect('/')
@app.route('/uploadimage',methods = ['POST'])
def upload_image():
try:
if request.files:
file = request.files['file']
if file.filename == '':
return "No selected file"
parts = file.filename.split('.')
# Take the first part (substring until the first dot)
file_format = parts[-1]
inference_image_path = f"{config.input_path}/image.{file_format}"
with open(config.status_folder+"/image_path.txt", mode="w") as f:
f.write(inference_image_path)
f.close()
file.save(inference_image_path)
logger.info("Image was uploaded.")
change_new_inference_flag(True)
else:
logger.info("request was empty. returning to home")
except Exception as e:
logger.error(e, exc_info=True)
logger.error("Upload failed")
return redirect('/')
def get_image_path():
file_path = config.status_folder+"/image_path.txt"
path_line = None
if os.path.exists(file_path):
with open(file_path, mode="r") as f:
for line in f.readlines():
print(line)
path_line = line
f.close()
return path_line
def change_new_inference_flag(flag:bool):
path_file = config.status_folder+"/inference_startable.txt"
if flag:
print("adding inference flag")
open(path_file, 'w').close()
else:
print("removing inference flag")
if os.path.exists(path_file):
os.remove(path_file)
def change_new_train_flag(flag:bool):
path_file = config.status_folder+"/train_startable.txt"
if flag:
print("adding train flag")
open(path_file, 'w').close()
else:
print("removing train flag")
if os.path.exists(path_file):
os.remove(path_file)
def is_training_startable():
if os.path.exists(config.status_folder+"/train_startable.txt"):
return True
return False
def is_inference_startable():
if os.path.exists(config.status_folder+"/inference_startable.txt"):
return True
return False
def unzip_file(file_name, path):
"""Unzips a file into a directory."""
archive = zipfile.ZipFile(file_name, "r")
for member in archive.infolist():
archive.extract(member, path)
archive.close()
def clear_dir(directory_path):
try:
# Remove the directory and its contents
shutil.rmtree(directory_path)
# Recreate the directory
os.makedirs(directory_path)
logger.info(f"Directory {directory_path} removed and recreated successfully.")
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
def web_app(config_main, webui_port):
global config
config = config_main
logger.info("starting webUI")
app.secret_key = "detr"
bootstrap = Bootstrap(app)
app.run(host="0.0.0.0", port=webui_port)
import os
class Config:
def __init__(self, PROD_FLAG, SHARED_FOLDER):
self.PROD_FLAG = PROD_FLAG
self.SHARED_FOLDER = SHARED_FOLDER
# get path to the models saving location
if PROD_FLAG:
self.custom_model_path = os.path.join(SHARED_FOLDER, os.environ['CUSTOM_MODEL_DIR'])
self.dataset_path = os.path.join(SHARED_FOLDER, os.environ['DATASET_DIR'])
self.input_path = os.path.join(SHARED_FOLDER, "inputimage")
self.output_path = os.path.join(SHARED_FOLDER, "outputimage")
self.status_folder = os.path.join(SHARED_FOLDER, "status")
self.annotation_file_name = os.getenv("ANNOTATION_FILE_NAME")
else:
self.custom_model_path = "./dev/custom_model/"
self.dataset_path = "./dev/dataset"
self.input_path = "./dev/inputimage"
self.output_path = "./dev/outputimage"
self.status_folder = "./dev/status"
self.annotation_file_name = "_annotations.coco.json"
self.create_not_existing_dirs([SHARED_FOLDER,self.custom_model_path,self.dataset_path,self.input_path,self.output_path,self.status_folder])
def create_not_existing_dirs(self, list_of_paths):
for dir in list_of_paths:
if not os.path.exists(dir):
# If not, create the folder
os.makedirs(dir)
print(f"Folder '{dir}' created successfully.")
import grpc
from timeit import default_timer as timer
import logging
# import the generated classes
import model_pb2
import model_pb2_grpc
port = 8061
def run():
print("Calling HPP_Stub..")
with grpc.insecure_channel('172.17.0.2:{}'.format(port)) as channel:
stub = model_pb2_grpc.DetrDatabrokerStub(channel)
ui_request = model_pb2.Empty()
# inference
#response = stub.send_detect_signal(ui_request)
# training
response = stub.send_training_signal(ui_request)
print("Greeter client received: ")
print(response)
if __name__ == '__main__':
logging.basicConfig()
run()
import grpc
import model_pb2
import model_pb2_grpc
from concurrent import futures
import threading
import os
import app as webui
from logger import Logger
logger = Logger(__name__)
class DetrDataBrokerServicer(model_pb2_grpc.DetrDatabroker):
def __init__(self,config):
self.config = config
logger.info("server started")
def send_training_signal(self, request, context):
logger.info("Preparing training signal to be sent to the model.")
out_message = model_pb2.TrainingConfig()
# checking if training message is already being sent
if not webui.is_training_startable():
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("No new data for training has been detected.")
logger.info("Training not startable. Not sending training signal to the model.")
return out_message
training_status = "starting"
out_message.config=training_status
logger.info("Sending training signal to the model.")
webui.change_new_train_flag(False)
return out_message
def send_detect_signal(self, request, context):
logger.info("Preparing inference signal to be sent to the model.")
image_path = webui.get_image_path()
out_message = model_pb2.ObjectDetectionInputFile()
# check if the request is already sent
if not webui.is_inference_startable():
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("No new data for inference has been detected.")
logger.info("Inference not startable. Not sending inference signal to the model.")
return out_message
if not image_path:
logger.error("No image to be found under the path")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("No image to be found under the path")
return out_message
if os.path.exists(image_path):
out_message.path = image_path
logger.info("Sending inference signal to the model.")
webui.change_new_inference_flag(False)
return out_message
else:
logger.error("Inference image is not saved and cannot be found. No request was sent")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details("Inference image is not saved.")
def serve_grpc(config,port):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
model_pb2_grpc.add_DetrDatabrokerServicer_to_server(DetrDataBrokerServicer(config), server)
server.add_insecure_port('[::]:{}'.format(port))
server.start()
return server
image
\ No newline at end of file
%% Cell type:code id: tags:
```
import boto3
import os
SHARED_FOLDER = os.getenv("SHARED_FOLDER_PATH", "./local_shared_folder")
dataset_path = os.path.join(SHARED_FOLDER, os.environ['DATASET_DIR'])
custom_model_path = os.path.join(SHARED_FOLDER, os.environ['CUSTOM_MODEL_DIR'])
endpoint_url = 'http://nm-baios-gpu-01.nm-baios.iais.fraunhofer.de:9000'
bucket = "selma-h2020"
key = "speech_recognition/HF_data/dataset_info.json"
access_key= ''
secret_key= ''
s3config = {
"endpoint_url": endpoint_url,
"aws_access_key_id": access_key,
"aws_secret_access_key": secret_key
}
# Initializing S3.ServiceResource object - http://boto3.readthedocs.io/en/latest/reference/services/s3.html#service-resource
s3resource = boto3.resource("s3", **s3config)
# Initializing S3.Client object - http://boto3.readthedocs.io/en/latest/reference/services/s3.html#client
s3client = boto3.client("s3", **s3config)
bucket_resource = s3resource.Bucket(bucket)
for my_bucket_object in bucket_resource.objects.all():
print(my_bucket_object.key)
objs = list(bucket_resource.objects.filter(Prefix=key))
print(objs)
for obj in objs:
#print(obj.key)
out_name = obj.key.split('/')[-1]
bucket_resource.download_file(obj.key, out_name)
def download_files(s3_dir_prefix, download_dir):
paginator = s3client.get_paginator('list_objects_v2')
# List all items that start with the prefix
for page in paginator.paginate(Bucket=bucket, Prefix=s3_dir_prefix):
if page["Contents"]:
for contents in page["Contents"]:
file_key = contents["Key"]
download_dir = os.path.join(download_dir,os.path.dirname(file_key))
try:
os.makedirs(download_dir)
except FileExistsError:
print("ignoring, dir already exists")
save_path = os.path.join(download_dir,file_key)
s3client.download_file(Filename=save_path,Bucket=bucket,Key=file_key)
```
{
"$schema": "https://raw.githubusercontent.com/acumos/license-manager/master/license-manager-client-library/src/main/resources/schema/1.0.0/license-profile.json",
"keyword": "Apache-2.0",
"licenseName": "Apache License 2.0",
"copyright": {
"year": 2019,
"company": "Company A",
"suffix": "All Rights Reserved"
},
"softwareType": "Machine Learning Model",
"companyName": "Company A",
"contact": {
"name": "Company A Team Member",
"URL": "http://companya.com",
"email": "support@companya.com"
},
"rtuRequired": false
}
\ No newline at end of file
import logging
class Logger(logging.Logger):
def __init__(self, name):
super().__init__(name)
formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
console_handler.setFormatter(formatter)
self.addHandler(console_handler)
\ No newline at end of file
import os
from config import Config
import detr_databroker_server
from app import web_app
import argparse
import threading
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('-gp','--grpc-port', type=int , help='grpc port',required=True)
parser.add_argument('-hp','--http-port', type=int , help='http port',required=True)
parser.add_argument('-f','--folder', type=str , help='Folder path',required=True)
args = parser.parse_args()
# run_inference(args.input, args.output)
prod_flag = os.getenv("PRODUCTION", False)
config = Config(prod_flag, args.folder)
grpc_server = detr_databroker_server.serve_grpc(config, args.grpc_port)
threading.Thread(target=web_app, args=(config, args.http_port)).start()
grpc_server.wait_for_termination()
\ No newline at end of file
syntax = "proto3";
message Empty {
}
message DatasetFeatues{
string datasetname = 1;
string description = 2;
string size = 3;
string DOI_ID = 4;
}
message ObjectDetectionInputFile {
string path = 1;
}
message TrainingConfig {
string config = 1;
}
service DetrDatabroker {
rpc send_detect_signal(Empty) returns (ObjectDetectionInputFile);
rpc send_training_signal(Empty) returns (TrainingConfig);
}
\ No newline at end of file
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: model.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0bmodel.proto\"\x07\n\x05\x45mpty\"X\n\x0e\x44\x61tasetFeatues\x12\x13\n\x0b\x64\x61tasetname\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65scription\x18\x02 \x01(\t\x12\x0c\n\x04size\x18\x03 \x01(\t\x12\x0e\n\x06\x44OI_ID\x18\x04 \x01(\t\"(\n\x18ObjectDetectionInputFile\x12\x0c\n\x04path\x18\x01 \x01(\t\" \n\x0eTrainingConfig\x12\x0e\n\x06\x63onfig\x18\x01 \x01(\t2\xab\x01\n\x0e\x44\x65trDatabroker\x12\x37\n\x12send_detect_signal\x12\x06.Empty\x1a\x19.ObjectDetectionInputFile\x12/\n\x14get_dataset_metadata\x12\x06.Empty\x1a\x0f.DatasetFeatues\x12/\n\x14send_training_signal\x12\x06.Empty\x1a\x0f.TrainingConfigb\x06proto3')
_EMPTY = DESCRIPTOR.message_types_by_name['Empty']
_DATASETFEATUES = DESCRIPTOR.message_types_by_name['DatasetFeatues']
_OBJECTDETECTIONINPUTFILE = DESCRIPTOR.message_types_by_name['ObjectDetectionInputFile']
_TRAININGCONFIG = DESCRIPTOR.message_types_by_name['TrainingConfig']
Empty = _reflection.GeneratedProtocolMessageType('Empty', (_message.Message,), {
'DESCRIPTOR' : _EMPTY,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:Empty)
})
_sym_db.RegisterMessage(Empty)
DatasetFeatues = _reflection.GeneratedProtocolMessageType('DatasetFeatues', (_message.Message,), {
'DESCRIPTOR' : _DATASETFEATUES,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:DatasetFeatues)
})
_sym_db.RegisterMessage(DatasetFeatues)
ObjectDetectionInputFile = _reflection.GeneratedProtocolMessageType('ObjectDetectionInputFile', (_message.Message,), {
'DESCRIPTOR' : _OBJECTDETECTIONINPUTFILE,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:ObjectDetectionInputFile)
})
_sym_db.RegisterMessage(ObjectDetectionInputFile)
TrainingConfig = _reflection.GeneratedProtocolMessageType('TrainingConfig', (_message.Message,), {
'DESCRIPTOR' : _TRAININGCONFIG,
'__module__' : 'model_pb2'
# @@protoc_insertion_point(class_scope:TrainingConfig)
})
_sym_db.RegisterMessage(TrainingConfig)
_DETRDATABROKER = DESCRIPTOR.services_by_name['DetrDatabroker']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_EMPTY._serialized_start=15
_EMPTY._serialized_end=22
_DATASETFEATUES._serialized_start=24
_DATASETFEATUES._serialized_end=112
_OBJECTDETECTIONINPUTFILE._serialized_start=114
_OBJECTDETECTIONINPUTFILE._serialized_end=154
_TRAININGCONFIG._serialized_start=156
_TRAININGCONFIG._serialized_end=188
_DETRDATABROKER._serialized_start=191
_DETRDATABROKER._serialized_end=362
# @@protoc_insertion_point(module_scope)
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc
import model_pb2 as model__pb2
class DetrDatabrokerStub(object):
"""Missing associated documentation comment in .proto file."""
def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.send_detect_signal = channel.unary_unary(
'/DetrDatabroker/send_detect_signal',
request_serializer=model__pb2.Empty.SerializeToString,
response_deserializer=model__pb2.ObjectDetectionInputFile.FromString,
)
self.get_dataset_metadata = channel.unary_unary(
'/DetrDatabroker/get_dataset_metadata',
request_serializer=model__pb2.Empty.SerializeToString,
response_deserializer=model__pb2.DatasetFeatues.FromString,
)
self.send_training_signal = channel.unary_unary(
'/DetrDatabroker/send_training_signal',
request_serializer=model__pb2.Empty.SerializeToString,
response_deserializer=model__pb2.TrainingConfig.FromString,
)
class DetrDatabrokerServicer(object):
"""Missing associated documentation comment in .proto file."""
def send_detect_signal(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def get_dataset_metadata(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def send_training_signal(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')
def add_DetrDatabrokerServicer_to_server(servicer, server):
rpc_method_handlers = {
'send_detect_signal': grpc.unary_unary_rpc_method_handler(
servicer.send_detect_signal,
request_deserializer=model__pb2.Empty.FromString,
response_serializer=model__pb2.ObjectDetectionInputFile.SerializeToString,
),
'get_dataset_metadata': grpc.unary_unary_rpc_method_handler(
servicer.get_dataset_metadata,
request_deserializer=model__pb2.Empty.FromString,
response_serializer=model__pb2.DatasetFeatues.SerializeToString,
),
'send_training_signal': grpc.unary_unary_rpc_method_handler(
servicer.send_training_signal,
request_deserializer=model__pb2.Empty.FromString,
response_serializer=model__pb2.TrainingConfig.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'DetrDatabroker', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
# This class is part of an EXPERIMENTAL API.
class DetrDatabroker(object):
"""Missing associated documentation comment in .proto file."""
@staticmethod
def send_detect_signal(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/DetrDatabroker/send_detect_signal',
model__pb2.Empty.SerializeToString,
model__pb2.ObjectDetectionInputFile.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def get_dataset_metadata(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/DetrDatabroker/get_dataset_metadata',
model__pb2.Empty.SerializeToString,
model__pb2.DatasetFeatues.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
@staticmethod
def send_training_signal(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/DetrDatabroker/send_training_signal',
model__pb2.Empty.SerializeToString,
model__pb2.TrainingConfig.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
# DETR Object detection
This model uses the detr to detect objects in a picture. The model was trained on COCO dataset and is now able to detect ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck",
"boat","traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat",
"dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack",
"umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard",
"tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl",
"banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
"tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven",
"toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear",
"hair drier", "toothbrush"]
It is also possible to train this model on a custom dataset. More on how the dataset is structured are explained in the s3-databroker webui.
# Changing the code
## Step 1: Apply changes
Apply your changes into the source code.
## Step 2: Generate gRPC classes for Python:
Open the terminal, change the directory to be in the same folder that the proto file is
in.
To generate the gRPC classes we have to install the needed libraries first:
* Install gRPC :
```cmd
python -m pip install grpcio
```
* To install gRPC tools, run:
```commandline
python -m pip install grpcio-tools googleapis-common-protos
```
* Now, run this command:
```commandline
python -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. model.proto
```
This command used model.proto file to generate the needed stubs to create the
client/server.
The files generated will be as follows:
model_pb2.py — contains message classes
* model_pb2.Features for the input features
* model_pb2.Prediction for the prediction review
model_pb2_grpc.py — contains server and client classes
* model_pb2_grpc.PredictServicer will be used by the server
* model_pb2_grpc.PredictStub the client will use it
Tested in python 3.8
## Step 3: Push to docker
docker build -t cicd.ai4eu-dev.eu:7444/tutorials/detr_objectdetection/detr_hpc_s3_databroker:latest .
docker push cicd.ai4eu-dev.eu:7444/tutorials/detr_objectdetection/detr_hpc_s3_databroker:latest
docker run cicd.ai4eu-dev.eu:7444/tutorials/detr_objectdetection/detr_hpc_s3_databroker:latest
Testing in image:
docker run -p 8061:8061 -p 8062:8062 cicd.ai4eu-dev.eu:7444/tutorials/detr_objectdetection/detr_hpc_s3_databroker:latest -gp 8061 -hp 8062 -f /home/run1
docker run -p 8061:8061 -p 8062:8062 -it --entrypoint /bin/bash cicd.ai4eu-dev.eu:7444/tutorials/detr_objectdetection/detr_hpc_s3_databroker:latest
\ No newline at end of file
bootstrap-flask~=2.3.2
flask~=3.0.0
flask-wtf~=1.2.1
grpcio~=1.48.1
grpcio-tools~=1.48.1
googleapis-common-protos~=1.61.0
boto3~=1.28.84
\ No newline at end of file
import boto3
import os
def download_dir(endpoint_url, bucket, s3_dir_prefix, access_key, secret_key, download_dir):
print("download dir is: {}".format(download_dir))
s3config = {
"endpoint_url": endpoint_url,
"aws_access_key_id": access_key,
"aws_secret_access_key": secret_key
}
s3client = boto3.client("s3", **s3config)
paginator = s3client.get_paginator('list_objects_v2')
dl_number = 0
# List all items that start with the prefix
for page in paginator.paginate(Bucket=bucket, Prefix=s3_dir_prefix):
if page["Contents"]:
for contents in page["Contents"]:
file_key = contents["Key"]
saving_rel_folder = file_key
if file_key.startswith(s3_dir_prefix):
saving_rel_folder = file_key.replace(s3_dir_prefix, '.')
download_dir_current = os.path.normpath(os.path.join(download_dir,os.path.dirname(saving_rel_folder)))
if not os.path.exists(download_dir_current):
# If not, create the folder
os.makedirs(download_dir_current)
save_path = os.path.join(download_dir_current,os.path.basename(saving_rel_folder))
s3client.download_file(Filename=save_path,Bucket=bucket,Key=file_key)
dl_number+=1
return dl_number
\ No newline at end of file
.grid-container {
display: grid;
grid-column-gap: 50px;
grid-row-gap: 50px;
grid-template-columns: auto auto;
padding: 10px;
}
.grid-item {
padding: 20px;
font-size: 30px;
text-align: left;
}
\ No newline at end of file
This diff is collapsed.
detr_object_detection/hpc/detr_hpc_s3_databroker/static/img/news-trainer.png

16.7 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment