tensorflow2
This commit is contained in:
44
keras-yolo3-master/train.py
Executable file → Normal file
44
keras-yolo3-master/train.py
Executable file → Normal file
@@ -8,13 +8,16 @@ from voc import parse_voc_annotation
|
||||
from yolo import create_yolov3_model, dummy_loss
|
||||
from generator import BatchGenerator
|
||||
from utils.utils import normalize, evaluate, makedirs
|
||||
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
|
||||
from keras.optimizers import Adam
|
||||
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
|
||||
from tensorflow.keras.optimizers import Adam
|
||||
from callbacks import CustomModelCheckpoint, CustomTensorBoard
|
||||
from utils.multi_gpu_model import multi_gpu_model
|
||||
import tensorflow as tf
|
||||
import keras
|
||||
from keras.models import load_model
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras.models import load_model
|
||||
|
||||
tf.keras.backend.clear_session()
|
||||
tf.config.experimental_run_functions_eagerly(True)
|
||||
|
||||
def create_training_instances(
|
||||
train_annot_folder,
|
||||
@@ -66,28 +69,34 @@ def create_callbacks(saved_weights_name, tensorboard_logs, model_to_save):
|
||||
makedirs(tensorboard_logs)
|
||||
|
||||
early_stop = EarlyStopping(
|
||||
monitor = 'loss',
|
||||
monitor = 'val_loss',
|
||||
min_delta = 0.01,
|
||||
patience = 25,
|
||||
mode = 'min',
|
||||
verbose = 1
|
||||
)
|
||||
checkpoint = CustomModelCheckpoint(
|
||||
"""checkpoint = CustomModelCheckpoint(
|
||||
model_to_save = model_to_save,
|
||||
filepath = saved_weights_name,# + '{epoch:02d}.h5',
|
||||
monitor = 'loss',
|
||||
verbose = 1,
|
||||
save_best_only = True,
|
||||
mode = 'min',
|
||||
period = 1
|
||||
)
|
||||
save_freq = 1
|
||||
)"""
|
||||
checkpoint = ModelCheckpoint(filepath=saved_weights_name,
|
||||
monitor='val_loss',
|
||||
save_best_only=True,
|
||||
save_weights_only=True,
|
||||
verbose=1)
|
||||
|
||||
reduce_on_plateau = ReduceLROnPlateau(
|
||||
monitor = 'loss',
|
||||
monitor = 'val_loss',
|
||||
factor = 0.5,
|
||||
patience = 15,
|
||||
verbose = 1,
|
||||
mode = 'min',
|
||||
epsilon = 0.01,
|
||||
min_delta = 0.01,
|
||||
cooldown = 0,
|
||||
min_lr = 0
|
||||
)
|
||||
@@ -96,7 +105,7 @@ def create_callbacks(saved_weights_name, tensorboard_logs, model_to_save):
|
||||
write_graph = True,
|
||||
write_images = True,
|
||||
)
|
||||
return [early_stop, checkpoint, reduce_on_plateau, tensorboard]
|
||||
return [early_stop, checkpoint, reduce_on_plateau]
|
||||
|
||||
def create_model(
|
||||
nb_class,
|
||||
@@ -245,21 +254,24 @@ def _main_(args):
|
||||
backend = config['model']['backend']
|
||||
)
|
||||
|
||||
|
||||
###############################
|
||||
# Kick off the training
|
||||
###############################
|
||||
callbacks = create_callbacks(config['train']['saved_weights_name'], config['train']['tensorboard_dir'], infer_model)
|
||||
|
||||
train_model.fit_generator(
|
||||
generator = train_generator,
|
||||
train_model.fit(
|
||||
x = train_generator,
|
||||
validation_data = valid_generator,
|
||||
steps_per_epoch = len(train_generator) * config['train']['train_times'],
|
||||
epochs = config['train']['nb_epochs'] + config['train']['warmup_epochs'],
|
||||
verbose = 2 if config['train']['debug'] else 1,
|
||||
callbacks = callbacks,
|
||||
workers = 4,
|
||||
max_queue_size = 8
|
||||
max_queue_size = 8,
|
||||
callbacks = callbacks
|
||||
)
|
||||
|
||||
|
||||
# make a GPU version of infer_model for evaluation
|
||||
if multi_gpu > 1:
|
||||
infer_model = load_model(config['train']['saved_weights_name'])
|
||||
@@ -284,7 +296,7 @@ def _main_(args):
|
||||
return
|
||||
|
||||
print('mAP using the weighted average of precisions among classes: {:.4f}'.format(sum([a * b for a, b in zip(total_instances, precisions)]) / sum(total_instances)))
|
||||
print('mAP: {:.4f}'.format(sum(precisions) / sum(x > 0 for x in total_instances)))
|
||||
print('mAP: {:.4f}'.format(sum(precisions) / sum(x > 0 for x in total_instances)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
argparser = argparse.ArgumentParser(description='train and evaluate YOLO_v3 model on any dataset')
|
||||
|
||||
Reference in New Issue
Block a user