fix(model): checkpoints in h5 format
This commit is contained in:
parent
b377c6dd5f
commit
b192531a2a
@ -99,8 +99,8 @@ class Model():
|
|||||||
|
|
||||||
def save(self, path):
|
def save(self, path):
|
||||||
logger.debug('Saving model weights to path: %s', path)
|
logger.debug('Saving model weights to path: %s', path)
|
||||||
self.model.save_weights(path + '/checkpoint.hd5')
|
self.model.save_weights(path + '/checkpoint.h5', save_format='h5')
|
||||||
return path + '/checkpoint'
|
return path + '/checkpoint.h5'
|
||||||
|
|
||||||
def evaluate(self):
|
def evaluate(self):
|
||||||
return self.model.evaluate(
|
return self.model.evaluate(
|
||||||
|
@ -10,7 +10,7 @@ from model import Model
|
|||||||
B_params = {
|
B_params = {
|
||||||
'batch_size': tune.grid_search([256]),
|
'batch_size': tune.grid_search([256]),
|
||||||
'layers': tune.grid_search([[512, 512]]),
|
'layers': tune.grid_search([[512, 512]]),
|
||||||
'lr': tune.grid_search([1e-4]),
|
'lr': tune.grid_search([3e-4]),
|
||||||
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,8 +46,8 @@ class TuneB(tune.Trainable):
|
|||||||
return self.model.restore(path)
|
return self.model.restore(path)
|
||||||
|
|
||||||
A_params = {
|
A_params = {
|
||||||
'batch_size': tune.grid_search([32]),
|
'batch_size': tune.grid_search([128]),
|
||||||
'layers': tune.grid_search([[32, 32]]),
|
'layers': tune.grid_search([[64, 64]]),
|
||||||
'lr': tune.grid_search([1e-4]),
|
'lr': tune.grid_search([1e-4]),
|
||||||
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
||||||
}
|
}
|
||||||
|
@ -163,8 +163,8 @@ def continent_agent(ground, position, size):
|
|||||||
while True:
|
while True:
|
||||||
# if trials > CONTINENT_MAX_TRIALS:
|
# if trials > CONTINENT_MAX_TRIALS:
|
||||||
# print('couldnt proceed')
|
# print('couldnt proceed')
|
||||||
# if size <= 0 or trials > CONTINENT_MAX_TRIALS: break
|
if size <= 0 or trials > CONTINENT_MAX_TRIALS: break
|
||||||
if size <= 0: break
|
# if size <= 0: break
|
||||||
|
|
||||||
dx = np.random.randint(2) or -1
|
dx = np.random.randint(2) or -1
|
||||||
dy = np.random.randint(2) or -1
|
dy = np.random.randint(2) or -1
|
||||||
@ -184,8 +184,8 @@ def continent_agent(ground, position, size):
|
|||||||
trials = 0
|
trials = 0
|
||||||
size -= 1
|
size -= 1
|
||||||
ground[x, y] = np.random.randint(1, p['ground_noise'])
|
ground[x, y] = np.random.randint(1, p['ground_noise'])
|
||||||
# else:
|
else:
|
||||||
# trials += 1
|
trials += 1
|
||||||
|
|
||||||
def neighbours(ground, position, radius):
|
def neighbours(ground, position, radius):
|
||||||
x, y = position
|
x, y = position
|
||||||
|
@ -12,7 +12,7 @@ function generate() {
|
|||||||
const queryString = new URLSearchParams(formData).toString()
|
const queryString = new URLSearchParams(formData).toString()
|
||||||
map.src = '/map?' + queryString;
|
map.src = '/map?' + queryString;
|
||||||
map.classList.add('d-none');
|
map.classList.add('d-none');
|
||||||
|
map.width = formData.get('width');
|
||||||
}
|
}
|
||||||
|
|
||||||
mapSettings.addEventListener('submit', (e) => {
|
mapSettings.addEventListener('submit', (e) => {
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
<aside class='col-3 px-4 bg-dark text-light text-center py-3'>
|
<aside class='col-3 px-4 bg-dark text-light text-center py-3'>
|
||||||
<h3>World Map Generator</h3>
|
<h3>World Map Generator</h3>
|
||||||
|
|
||||||
<div class='panel px-4'>
|
<div class='panel px-4 pb-3'>
|
||||||
<form class='mt-5' id='map-settings'>
|
<form class='mt-5' id='map-settings'>
|
||||||
{% for k, v in parameters.items() %}
|
{% for k, v in parameters.items() %}
|
||||||
<div class='form-group'>
|
<div class='form-group'>
|
||||||
|
Loading…
Reference in New Issue
Block a user