fix(model): checkpoints in h5 format
This commit is contained in:
parent
b377c6dd5f
commit
b192531a2a
@ -99,8 +99,8 @@ class Model():
|
||||
|
||||
def save(self, path):
|
||||
logger.debug('Saving model weights to path: %s', path)
|
||||
self.model.save_weights(path + '/checkpoint.hd5')
|
||||
return path + '/checkpoint'
|
||||
self.model.save_weights(path + '/checkpoint.h5', save_format='h5')
|
||||
return path + '/checkpoint.h5'
|
||||
|
||||
def evaluate(self):
|
||||
return self.model.evaluate(
|
||||
|
@ -10,7 +10,7 @@ from model import Model
|
||||
B_params = {
|
||||
'batch_size': tune.grid_search([256]),
|
||||
'layers': tune.grid_search([[512, 512]]),
|
||||
'lr': tune.grid_search([1e-4]),
|
||||
'lr': tune.grid_search([3e-4]),
|
||||
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
||||
}
|
||||
|
||||
@ -46,8 +46,8 @@ class TuneB(tune.Trainable):
|
||||
return self.model.restore(path)
|
||||
|
||||
A_params = {
|
||||
'batch_size': tune.grid_search([32]),
|
||||
'layers': tune.grid_search([[32, 32]]),
|
||||
'batch_size': tune.grid_search([128]),
|
||||
'layers': tune.grid_search([[64, 64]]),
|
||||
'lr': tune.grid_search([1e-4]),
|
||||
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
||||
}
|
||||
|
@ -163,8 +163,8 @@ def continent_agent(ground, position, size):
|
||||
while True:
|
||||
# if trials > CONTINENT_MAX_TRIALS:
|
||||
# print('couldnt proceed')
|
||||
# if size <= 0 or trials > CONTINENT_MAX_TRIALS: break
|
||||
if size <= 0: break
|
||||
if size <= 0 or trials > CONTINENT_MAX_TRIALS: break
|
||||
# if size <= 0: break
|
||||
|
||||
dx = np.random.randint(2) or -1
|
||||
dy = np.random.randint(2) or -1
|
||||
@ -184,8 +184,8 @@ def continent_agent(ground, position, size):
|
||||
trials = 0
|
||||
size -= 1
|
||||
ground[x, y] = np.random.randint(1, p['ground_noise'])
|
||||
# else:
|
||||
# trials += 1
|
||||
else:
|
||||
trials += 1
|
||||
|
||||
def neighbours(ground, position, radius):
|
||||
x, y = position
|
||||
@ -271,7 +271,7 @@ def generate_map(**kwargs):
|
||||
|
||||
ym = ym * -1
|
||||
|
||||
random_size = ground_size / continents
|
||||
random_size = ground_size / continents
|
||||
continent_agent(ground, position, size=random_size)
|
||||
|
||||
ground = ndimage.gaussian_filter(ground, sigma=(1 - p['sharpness']) * 20)
|
||||
|
@ -12,7 +12,7 @@ function generate() {
|
||||
const queryString = new URLSearchParams(formData).toString()
|
||||
map.src = '/map?' + queryString;
|
||||
map.classList.add('d-none');
|
||||
|
||||
map.width = formData.get('width');
|
||||
}
|
||||
|
||||
mapSettings.addEventListener('submit', (e) => {
|
||||
|
@ -21,7 +21,7 @@
|
||||
<aside class='col-3 px-4 bg-dark text-light text-center py-3'>
|
||||
<h3>World Map Generator</h3>
|
||||
|
||||
<div class='panel px-4'>
|
||||
<div class='panel px-4 pb-3'>
|
||||
<form class='mt-5' id='map-settings'>
|
||||
{% for k, v in parameters.items() %}
|
||||
<div class='form-group'>
|
||||
|
Loading…
x
Reference in New Issue
Block a user