fix(model): checkpoints in h5 format
This commit is contained in:
@ -99,8 +99,8 @@ class Model():
|
||||
|
||||
def save(self, path):
|
||||
logger.debug('Saving model weights to path: %s', path)
|
||||
self.model.save_weights(path + '/checkpoint.hd5')
|
||||
return path + '/checkpoint'
|
||||
self.model.save_weights(path + '/checkpoint.h5', save_format='h5')
|
||||
return path + '/checkpoint.h5'
|
||||
|
||||
def evaluate(self):
|
||||
return self.model.evaluate(
|
||||
|
@ -10,7 +10,7 @@ from model import Model
|
||||
B_params = {
|
||||
'batch_size': tune.grid_search([256]),
|
||||
'layers': tune.grid_search([[512, 512]]),
|
||||
'lr': tune.grid_search([1e-4]),
|
||||
'lr': tune.grid_search([3e-4]),
|
||||
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
||||
}
|
||||
|
||||
@ -46,8 +46,8 @@ class TuneB(tune.Trainable):
|
||||
return self.model.restore(path)
|
||||
|
||||
A_params = {
|
||||
'batch_size': tune.grid_search([32]),
|
||||
'layers': tune.grid_search([[32, 32]]),
|
||||
'batch_size': tune.grid_search([128]),
|
||||
'layers': tune.grid_search([[64, 64]]),
|
||||
'lr': tune.grid_search([1e-4]),
|
||||
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
||||
}
|
||||
|
Reference in New Issue
Block a user