fix(model): checkpoints in h5 format

This commit is contained in:
Mahdi Dibaiee
2019-04-26 13:14:43 +04:30
parent b377c6dd5f
commit b192531a2a
5 changed files with 12 additions and 12 deletions

View File

@ -99,8 +99,8 @@ class Model():
def save(self, path):
logger.debug('Saving model weights to path: %s', path)
self.model.save_weights(path + '/checkpoint.hd5')
return path + '/checkpoint'
self.model.save_weights(path + '/checkpoint.h5', save_format='h5')
return path + '/checkpoint.h5'
def evaluate(self):
return self.model.evaluate(

View File

@ -10,7 +10,7 @@ from model import Model
B_params = {
'batch_size': tune.grid_search([256]),
'layers': tune.grid_search([[512, 512]]),
'lr': tune.grid_search([1e-4]),
'lr': tune.grid_search([3e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
}
@ -46,8 +46,8 @@ class TuneB(tune.Trainable):
return self.model.restore(path)
A_params = {
'batch_size': tune.grid_search([32]),
'layers': tune.grid_search([[32, 32]]),
'batch_size': tune.grid_search([128]),
'layers': tune.grid_search([[64, 64]]),
'lr': tune.grid_search([1e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
}