tensorflow2

This commit is contained in:
Daniel Saavedra
2020-03-25 18:23:00 -03:00
parent 7010af8a58
commit 7cf0c577a1
25 changed files with 1016 additions and 309 deletions

View File

@@ -1,15 +1,16 @@
from keras.callbacks import TensorBoard, ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint
import tensorflow as tf
import numpy as np
import warnings
class CustomTensorBoard(TensorBoard):
""" to log the loss after each batch
"""
"""
def __init__(self, log_every=1, **kwargs):
super(CustomTensorBoard, self).__init__(**kwargs)
self.log_every = log_every
self.counter = 0
def on_batch_end(self, batch, logs=None):
self.counter+=1
if self.counter%self.log_every==0:
@@ -22,7 +23,7 @@ class CustomTensorBoard(TensorBoard):
summary_value.tag = name
self.writer.add_summary(summary, self.counter)
self.writer.flush()
super(CustomTensorBoard, self).on_batch_end(batch, logs)
class CustomModelCheckpoint(ModelCheckpoint):
@@ -67,4 +68,4 @@ class CustomModelCheckpoint(ModelCheckpoint):
else:
self.model_to_save.save(filepath, overwrite=True)
super(CustomModelCheckpoint, self).on_batch_end(epoch, logs)
super(CustomModelCheckpoint, self).on_batch_end(epoch, logs)