fix(model): checkpoints in h5 format

This commit is contained in:
Mahdi Dibaiee 2019-04-26 13:14:43 +04:30
parent b377c6dd5f
commit b192531a2a
5 changed files with 12 additions and 12 deletions

View File

@ -99,8 +99,8 @@ class Model():
def save(self, path): def save(self, path):
logger.debug('Saving model weights to path: %s', path) logger.debug('Saving model weights to path: %s', path)
self.model.save_weights(path + '/checkpoint.hd5') self.model.save_weights(path + '/checkpoint.h5', save_format='h5')
return path + '/checkpoint' return path + '/checkpoint.h5'
def evaluate(self): def evaluate(self):
return self.model.evaluate( return self.model.evaluate(

View File

@ -10,7 +10,7 @@ from model import Model
B_params = { B_params = {
'batch_size': tune.grid_search([256]), 'batch_size': tune.grid_search([256]),
'layers': tune.grid_search([[512, 512]]), 'layers': tune.grid_search([[512, 512]]),
'lr': tune.grid_search([1e-4]), 'lr': tune.grid_search([3e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]), 'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
} }
@ -46,8 +46,8 @@ class TuneB(tune.Trainable):
return self.model.restore(path) return self.model.restore(path)
A_params = { A_params = {
'batch_size': tune.grid_search([32]), 'batch_size': tune.grid_search([128]),
'layers': tune.grid_search([[32, 32]]), 'layers': tune.grid_search([[64, 64]]),
'lr': tune.grid_search([1e-4]), 'lr': tune.grid_search([1e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]), 'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
} }

View File

@ -163,8 +163,8 @@ def continent_agent(ground, position, size):
while True: while True:
# if trials > CONTINENT_MAX_TRIALS: # if trials > CONTINENT_MAX_TRIALS:
# print('couldnt proceed') # print('couldnt proceed')
# if size <= 0 or trials > CONTINENT_MAX_TRIALS: break if size <= 0 or trials > CONTINENT_MAX_TRIALS: break
if size <= 0: break # if size <= 0: break
dx = np.random.randint(2) or -1 dx = np.random.randint(2) or -1
dy = np.random.randint(2) or -1 dy = np.random.randint(2) or -1
@ -184,8 +184,8 @@ def continent_agent(ground, position, size):
trials = 0 trials = 0
size -= 1 size -= 1
ground[x, y] = np.random.randint(1, p['ground_noise']) ground[x, y] = np.random.randint(1, p['ground_noise'])
# else: else:
# trials += 1 trials += 1
def neighbours(ground, position, radius): def neighbours(ground, position, radius):
x, y = position x, y = position

View File

@ -12,7 +12,7 @@ function generate() {
const queryString = new URLSearchParams(formData).toString() const queryString = new URLSearchParams(formData).toString()
map.src = '/map?' + queryString; map.src = '/map?' + queryString;
map.classList.add('d-none'); map.classList.add('d-none');
map.width = formData.get('width');
} }
mapSettings.addEventListener('submit', (e) => { mapSettings.addEventListener('submit', (e) => {

View File

@ -21,7 +21,7 @@
<aside class='col-3 px-4 bg-dark text-light text-center py-3'> <aside class='col-3 px-4 bg-dark text-light text-center py-3'>
<h3>World Map Generator</h3> <h3>World Map Generator</h3>
<div class='panel px-4'> <div class='panel px-4 pb-3'>
<form class='mt-5' id='map-settings'> <form class='mt-5' id='map-settings'>
{% for k, v in parameters.items() %} {% for k, v in parameters.items() %}
<div class='form-group'> <div class='form-group'>