Skip to content
Snippets Groups Projects
Commit 0757f92f authored by Swetha Lakshmana Murthy's avatar Swetha Lakshmana Murthy
Browse files

Update metrics in training pipeline to support automatic comparability

parent 0d1237a4
Branches
Tags
No related merge requests found
......@@ -9,7 +9,7 @@ shared_folder = '/home/slakshmana/dataset'
def run_classifier():
logging.info("Calling NewsClassifier_Stub...")
with grpc.insecure_channel('localhost:10061') as channel:
with grpc.insecure_channel('localhost:8061') as channel:
classifier_stub = news_classifier_pb2_grpc.NewsClassifierStub(channel)
......
......@@ -11,9 +11,19 @@ message TrainingConfig {
message TrainingStatus {
string type = 1;
double accuracy = 2;
double validation_loss = 3;
string status_text = 4;
message Lessisbetter{
double validation_loss = 1;
}
message Moreisbetter{
double accuracy = 1;
}
string status_text = 2;
Lessisbetter less_is_better = 3;
Moreisbetter more_is_better = 4;
}
message NewsText {
......
......@@ -20,10 +20,6 @@ from wtforms.validators import DataRequired, ValidationError
from wtforms.fields import StringField, IntegerField, FloatField, SubmitField
import datetime
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
port = 8061
shared_folder = os.getenv("SHARED_FOLDER_PATH")
......@@ -89,7 +85,7 @@ class NewsClassifier(news_classifier_pb2_grpc.NewsClassifierServicer):
def __init__(self):
self.has_metrics = True # Flag to indicate the presence of metrics in this node and print a message accordingly.
if self.has_metrics:
logging.info('MetricsAvailable')
print('MetricsAvailable')
self.create_model()
def create_model(self):
......@@ -135,8 +131,8 @@ class NewsClassifier(news_classifier_pb2_grpc.NewsClassifierServicer):
response = news_classifier_pb2.TrainingStatus()
response.type = 'classification-metrics/v1' # Type-in the type of metrics here
response.accuracy = history.history['accuracy'][-1]
response.validation_loss = history.history['val_loss'][-1]
response.more_is_better.accuracy = history.history['accuracy'][-1]
response.less_is_better.validation_loss = history.history['val_loss'][-1]
response.status_text = 'success'
# Call the gRPC routine immediately after the training process concludes and the metrics have been recorded
......@@ -172,13 +168,19 @@ class NewsClassifier(news_classifier_pb2_grpc.NewsClassifierServicer):
for field, value in request.ListFields():
field_name = field.name
final_metrics_dict['metrics'][field_name] = value
if field.message_type is not None:
nested_dict = {}
for nested_field, nested_value in value.ListFields():
nested_dict[nested_field.name] = nested_value
final_metrics_dict['metrics'][field_name] = nested_dict
else:
final_metrics_dict['metrics'][field_name] = value
print(final_metrics_dict)
logging.info(final_metrics_dict)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
classifier = NewsClassifier()
news_classifier_pb2_grpc.add_NewsClassifierServicer_to_server(classifier, server)
......
......@@ -3,6 +3,7 @@
# source: news_classifier.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
......@@ -13,214 +14,16 @@ _sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='news_classifier.proto',
package='',
syntax='proto3',
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\x15news_classifier.proto\"\xa8\x01\n\x0eTrainingConfig\x12\x1e\n\x16training_data_filename\x18\x01 \x01(\t\x12 \n\x18training_labels_filename\x18\x02 \x01(\t\x12\x0e\n\x06\x65pochs\x18\x03 \x01(\x05\x12\x12\n\nbatch_size\x18\x04 \x01(\x05\x12\x18\n\x10validation_ratio\x18\x05 \x01(\x01\x12\x16\n\x0emodel_filename\x18\x06 \x01(\t\"^\n\x0eTrainingStatus\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x10\n\x08\x61\x63\x63uracy\x18\x02 \x01(\x01\x12\x17\n\x0fvalidation_loss\x18\x03 \x01(\x01\x12\x13\n\x0bstatus_text\x18\x04 \x01(\t\"\x18\n\x08NewsText\x12\x0c\n\x04text\x18\x01 \x01(\t\"<\n\x0cNewsCategory\x12\x15\n\rcategory_code\x18\x01 \x01(\x05\x12\x15\n\rcategory_text\x18\x02 \x01(\t2\xa3\x01\n\x0eNewsClassifier\x12\x31\n\rstartTraining\x12\x0f.TrainingConfig\x1a\x0f.TrainingStatus\x12$\n\x08\x63lassify\x12\t.NewsText\x1a\r.NewsCategory\x12\x38\n\x14get_metrics_metadata\x12\x0f.TrainingStatus\x1a\x0f.TrainingStatusb\x06proto3'
)
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15news_classifier.proto\"\xa8\x01\n\x0eTrainingConfig\x12\x1e\n\x16training_data_filename\x18\x01 \x01(\t\x12 \n\x18training_labels_filename\x18\x02 \x01(\t\x12\x0e\n\x06\x65pochs\x18\x03 \x01(\x05\x12\x12\n\nbatch_size\x18\x04 \x01(\x05\x12\x18\n\x10validation_ratio\x18\x05 \x01(\x01\x12\x16\n\x0emodel_filename\x18\x06 \x01(\t\"\xea\x01\n\x0eTrainingStatus\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x13\n\x0bstatus_text\x18\x02 \x01(\t\x12\x34\n\x0eless_is_better\x18\x03 \x01(\x0b\x32\x1c.TrainingStatus.Lessisbetter\x12\x34\n\x0emore_is_better\x18\x04 \x01(\x0b\x32\x1c.TrainingStatus.Moreisbetter\x1a\'\n\x0cLessisbetter\x12\x17\n\x0fvalidation_loss\x18\x01 \x01(\x01\x1a \n\x0cMoreisbetter\x12\x10\n\x08\x61\x63\x63uracy\x18\x01 \x01(\x01\"\x18\n\x08NewsText\x12\x0c\n\x04text\x18\x01 \x01(\t\"<\n\x0cNewsCategory\x12\x15\n\rcategory_code\x18\x01 \x01(\x05\x12\x15\n\rcategory_text\x18\x02 \x01(\t2\xa3\x01\n\x0eNewsClassifier\x12\x31\n\rstartTraining\x12\x0f.TrainingConfig\x1a\x0f.TrainingStatus\x12$\n\x08\x63lassify\x12\t.NewsText\x1a\r.NewsCategory\x12\x38\n\x14get_metrics_metadata\x12\x0f.TrainingStatus\x1a\x0f.TrainingStatusb\x06proto3')
_TRAININGCONFIG = _descriptor.Descriptor(
name='TrainingConfig',
full_name='TrainingConfig',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='training_data_filename', full_name='TrainingConfig.training_data_filename', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='training_labels_filename', full_name='TrainingConfig.training_labels_filename', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='epochs', full_name='TrainingConfig.epochs', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='batch_size', full_name='TrainingConfig.batch_size', index=3,
number=4, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='validation_ratio', full_name='TrainingConfig.validation_ratio', index=4,
number=5, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='model_filename', full_name='TrainingConfig.model_filename', index=5,
number=6, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=26,
serialized_end=194,
)
_TRAININGSTATUS = _descriptor.Descriptor(
name='TrainingStatus',
full_name='TrainingStatus',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='TrainingStatus.type', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='accuracy', full_name='TrainingStatus.accuracy', index=1,
number=2, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='validation_loss', full_name='TrainingStatus.validation_loss', index=2,
number=3, type=1, cpp_type=5, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='status_text', full_name='TrainingStatus.status_text', index=3,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=196,
serialized_end=290,
)
_NEWSTEXT = _descriptor.Descriptor(
name='NewsText',
full_name='NewsText',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='text', full_name='NewsText.text', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=292,
serialized_end=316,
)
_NEWSCATEGORY = _descriptor.Descriptor(
name='NewsCategory',
full_name='NewsCategory',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='category_code', full_name='NewsCategory.category_code', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='category_text', full_name='NewsCategory.category_text', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=318,
serialized_end=378,
)
DESCRIPTOR.message_types_by_name['TrainingConfig'] = _TRAININGCONFIG
DESCRIPTOR.message_types_by_name['TrainingStatus'] = _TRAININGSTATUS
DESCRIPTOR.message_types_by_name['NewsText'] = _NEWSTEXT
DESCRIPTOR.message_types_by_name['NewsCategory'] = _NEWSCATEGORY
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_TRAININGCONFIG = DESCRIPTOR.message_types_by_name['TrainingConfig']
_TRAININGSTATUS = DESCRIPTOR.message_types_by_name['TrainingStatus']
_TRAININGSTATUS_LESSISBETTER = _TRAININGSTATUS.nested_types_by_name['Lessisbetter']
_TRAININGSTATUS_MOREISBETTER = _TRAININGSTATUS.nested_types_by_name['Moreisbetter']
_NEWSTEXT = DESCRIPTOR.message_types_by_name['NewsText']
_NEWSCATEGORY = DESCRIPTOR.message_types_by_name['NewsCategory']
TrainingConfig = _reflection.GeneratedProtocolMessageType('TrainingConfig', (_message.Message,), {
'DESCRIPTOR' : _TRAININGCONFIG,
'__module__' : 'news_classifier_pb2'
......@@ -229,11 +32,27 @@ TrainingConfig = _reflection.GeneratedProtocolMessageType('TrainingConfig', (_me
_sym_db.RegisterMessage(TrainingConfig)
TrainingStatus = _reflection.GeneratedProtocolMessageType('TrainingStatus', (_message.Message,), {
'Lessisbetter' : _reflection.GeneratedProtocolMessageType('Lessisbetter', (_message.Message,), {
'DESCRIPTOR' : _TRAININGSTATUS_LESSISBETTER,
'__module__' : 'news_classifier_pb2'
# @@protoc_insertion_point(class_scope:TrainingStatus.Lessisbetter)
})
,
'Moreisbetter' : _reflection.GeneratedProtocolMessageType('Moreisbetter', (_message.Message,), {
'DESCRIPTOR' : _TRAININGSTATUS_MOREISBETTER,
'__module__' : 'news_classifier_pb2'
# @@protoc_insertion_point(class_scope:TrainingStatus.Moreisbetter)
})
,
'DESCRIPTOR' : _TRAININGSTATUS,
'__module__' : 'news_classifier_pb2'
# @@protoc_insertion_point(class_scope:TrainingStatus)
})
_sym_db.RegisterMessage(TrainingStatus)
_sym_db.RegisterMessage(TrainingStatus.Lessisbetter)
_sym_db.RegisterMessage(TrainingStatus.Moreisbetter)
NewsText = _reflection.GeneratedProtocolMessageType('NewsText', (_message.Message,), {
'DESCRIPTOR' : _NEWSTEXT,
......@@ -249,51 +68,22 @@ NewsCategory = _reflection.GeneratedProtocolMessageType('NewsCategory', (_messag
})
_sym_db.RegisterMessage(NewsCategory)
_NEWSCLASSIFIER = _descriptor.ServiceDescriptor(
name='NewsClassifier',
full_name='NewsClassifier',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=381,
serialized_end=544,
methods=[
_descriptor.MethodDescriptor(
name='startTraining',
full_name='NewsClassifier.startTraining',
index=0,
containing_service=None,
input_type=_TRAININGCONFIG,
output_type=_TRAININGSTATUS,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
_descriptor.MethodDescriptor(
name='classify',
full_name='NewsClassifier.classify',
index=1,
containing_service=None,
input_type=_NEWSTEXT,
output_type=_NEWSCATEGORY,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
_descriptor.MethodDescriptor(
name='get_metrics_metadata',
full_name='NewsClassifier.get_metrics_metadata',
index=2,
containing_service=None,
input_type=_TRAININGSTATUS,
output_type=_TRAININGSTATUS,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_NEWSCLASSIFIER)
DESCRIPTOR.services_by_name['NewsClassifier'] = _NEWSCLASSIFIER
_NEWSCLASSIFIER = DESCRIPTOR.services_by_name['NewsClassifier']
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
_TRAININGCONFIG._serialized_start=26
_TRAININGCONFIG._serialized_end=194
_TRAININGSTATUS._serialized_start=197
_TRAININGSTATUS._serialized_end=431
_TRAININGSTATUS_LESSISBETTER._serialized_start=358
_TRAININGSTATUS_LESSISBETTER._serialized_end=397
_TRAININGSTATUS_MOREISBETTER._serialized_start=399
_TRAININGSTATUS_MOREISBETTER._serialized_end=431
_NEWSTEXT._serialized_start=433
_NEWSTEXT._serialized_end=457
_NEWSCATEGORY._serialized_start=459
_NEWSCATEGORY._serialized_end=519
_NEWSCLASSIFIER._serialized_start=522
_NEWSCLASSIFIER._serialized_end=685
# @@protoc_insertion_point(module_scope)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment