72 lines
3.1 KiB
Python
Executable File
72 lines
3.1 KiB
Python
Executable File
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
import warnings
|
|
|
|
class CustomTensorBoard(TensorBoard):
|
|
""" to log the loss after each batch
|
|
"""
|
|
def __init__(self, log_every=1, **kwargs):
|
|
super(CustomTensorBoard, self).__init__(**kwargs)
|
|
self.log_every = log_every
|
|
self.counter = 0
|
|
|
|
def on_batch_end(self, batch, logs=None):
|
|
self.counter+=1
|
|
if self.counter%self.log_every==0:
|
|
for name, value in logs.items():
|
|
if name in ['batch', 'size']:
|
|
continue
|
|
summary = tf.Summary()
|
|
summary_value = summary.value.add()
|
|
summary_value.simple_value = value.item()
|
|
summary_value.tag = name
|
|
self.writer.add_summary(summary, self.counter)
|
|
self.writer.flush()
|
|
|
|
super(CustomTensorBoard, self).on_batch_end(batch, logs)
|
|
|
|
class CustomModelCheckpoint(ModelCheckpoint):
|
|
""" to save the template model, not the multi-GPU model
|
|
"""
|
|
def __init__(self, model_to_save, **kwargs):
|
|
super(CustomModelCheckpoint, self).__init__(**kwargs)
|
|
self.model_to_save = model_to_save
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
logs = logs or {}
|
|
self.epochs_since_last_save += 1
|
|
if self.epochs_since_last_save >= self.period:
|
|
self.epochs_since_last_save = 0
|
|
filepath = self.filepath.format(epoch=epoch + 1, **logs)
|
|
if self.save_best_only:
|
|
current = logs.get(self.monitor)
|
|
if current is None:
|
|
warnings.warn('Can save best model only with %s available, '
|
|
'skipping.' % (self.monitor), RuntimeWarning)
|
|
else:
|
|
if self.monitor_op(current, self.best):
|
|
if self.verbose > 0:
|
|
print('\nEpoch %05d: %s improved from %0.5f to %0.5f,'
|
|
' saving model to %s'
|
|
% (epoch + 1, self.monitor, self.best,
|
|
current, filepath))
|
|
self.best = current
|
|
if self.save_weights_only:
|
|
self.model_to_save.save_weights(filepath, overwrite=True)
|
|
else:
|
|
self.model_to_save.save(filepath, overwrite=True)
|
|
else:
|
|
if self.verbose > 0:
|
|
print('\nEpoch %05d: %s did not improve from %0.5f' %
|
|
(epoch + 1, self.monitor, self.best))
|
|
else:
|
|
if self.verbose > 0:
|
|
print('\nEpoch %05d: saving model to %s' % (epoch + 1, filepath))
|
|
if self.save_weights_only:
|
|
self.model_to_save.save_weights(filepath, overwrite=True)
|
|
else:
|
|
self.model_to_save.save(filepath, overwrite=True)
|
|
|
|
super(CustomModelCheckpoint, self).on_batch_end(epoch, logs)
|