tensorflow2
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
from keras.callbacks import TensorBoard, ModelCheckpoint
|
||||
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
class CustomTensorBoard(TensorBoard):
|
||||
""" to log the loss after each batch
|
||||
"""
|
||||
"""
|
||||
def __init__(self, log_every=1, **kwargs):
|
||||
super(CustomTensorBoard, self).__init__(**kwargs)
|
||||
self.log_every = log_every
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def on_batch_end(self, batch, logs=None):
|
||||
self.counter+=1
|
||||
if self.counter%self.log_every==0:
|
||||
@@ -22,7 +23,7 @@ class CustomTensorBoard(TensorBoard):
|
||||
summary_value.tag = name
|
||||
self.writer.add_summary(summary, self.counter)
|
||||
self.writer.flush()
|
||||
|
||||
|
||||
super(CustomTensorBoard, self).on_batch_end(batch, logs)
|
||||
|
||||
class CustomModelCheckpoint(ModelCheckpoint):
|
||||
@@ -67,4 +68,4 @@ class CustomModelCheckpoint(ModelCheckpoint):
|
||||
else:
|
||||
self.model_to_save.save(filepath, overwrite=True)
|
||||
|
||||
super(CustomModelCheckpoint, self).on_batch_end(epoch, logs)
|
||||
super(CustomModelCheckpoint, self).on_batch_end(epoch, logs)
|
||||
|
||||
Reference in New Issue
Block a user