world-ecoregion/biomes/train.py

164 lines
4.5 KiB
Python
Raw Permalink Normal View History

import fire
import ray
import pandas as pd
import tensorflow as tf
from ray import tune
from tensorflow import keras
2019-04-22 07:57:20 +00:00
from utils import *
from model import Model
B_params = {
'batch_size': tune.grid_search([256]),
'layers': tune.grid_search([[512, 512]]),
'lr': tune.grid_search([1e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
}
df = pd.read_pickle('data.p')
class TuneB(tune.Trainable):
def _setup(self, config):
logger.debug('Ray Tune model configuration %s', config)
self.model = Model('b', epochs=1)
optimizer = config['optimizer']
optimizer = config['optimizer'](lr=config['lr'])
self.model.prepare_for_use(df=df, batch_size=config['batch_size'], layers=config['layers'], optimizer=optimizer)
def _train(self):
logs = self.model.train(self.config)
metrics = {
'mean_accuracy': logs.history['acc'][0],
'loss': logs.history['loss'][0],
'val_accuracy': logs.history['val_acc'][0],
'val_loss': logs.history['val_loss'][0],
}
return metrics
def _save(self, checkpoint_dir):
return self.model.save(checkpoint_dir)
def _restore(self, path):
return self.model.restore(path)
2019-04-22 07:57:20 +00:00
A_params = {
'batch_size': tune.grid_search([256]),
2019-04-26 08:44:43 +00:00
'layers': tune.grid_search([[64, 64]]),
'lr': tune.grid_search([3e-4]),
2019-04-22 07:57:20 +00:00
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
#'optimizer': tune.grid_search([tf.keras.optimizers.RMSprop])
2019-04-22 07:57:20 +00:00
}
class TuneTemp(tune.Trainable):
2019-04-22 07:57:20 +00:00
def _setup(self, config):
logger.debug('Ray Tune model configuration %s', config)
self.model = Model('temp', epochs=1)
2019-04-22 07:57:20 +00:00
optimizer = config['optimizer']
optimizer = config['optimizer'](lr=config['lr'])
self.model.prepare_for_use(
df=df,
batch_size=config['batch_size'],
layers=config['layers'],
optimizer=optimizer,
out_activation=None,
dataset_fn=dataframe_to_dataset_temp,
loss='mse',
metrics=['mae']
)
def _train(self):
logs = self.model.train(self.config)
print(logs.history)
metrics = {
'loss': logs.history['loss'][0],
'mae': logs.history['mean_absolute_error'][0],
'val_loss': logs.history['val_loss'][0],
'val_mae': logs.history['val_mean_absolute_error'][0],
}
return metrics
def _save(self, checkpoint_dir):
return self.model.save(checkpoint_dir)
def _restore(self, path):
return self.model.restore(path)
class TunePrecip(tune.Trainable):
def _setup(self, config):
logger.debug('Ray Tune model configuration %s', config)
self.model = Model('precip', epochs=1)
optimizer = config['optimizer']
optimizer = config['optimizer'](lr=config['lr'])
self.model.prepare_for_use(
df=df,
batch_size=config['batch_size'],
layers=config['layers'],
optimizer=optimizer,
out_activation=None,
dataset_fn=dataframe_to_dataset_precip,
loss='mse',
metrics=['mae']
2019-04-22 07:57:20 +00:00
)
def _train(self):
logs = self.model.train(self.config)
print(logs.history)
2019-04-22 07:57:20 +00:00
metrics = {
'loss': logs.history['loss'][0],
'mae': logs.history['mean_absolute_error'][0],
2019-04-22 07:57:20 +00:00
'val_loss': logs.history['val_loss'][0],
'val_mae': logs.history['val_mean_absolute_error'][0],
2019-04-22 07:57:20 +00:00
}
return metrics
def _save(self, checkpoint_dir):
return self.model.save(checkpoint_dir)
def _restore(self, path):
return self.model.restore(path)
def start_tuning(model, cpu=1, gpu=2, checkpoint_freq=1, checkpoint_at_end=True, resume=False, restore=None, stop=500):
ray.init()
if model == 'temp':
t = TuneTemp
params = A_params
elif model == 'precip':
t = TunePrecip
2019-04-22 07:57:20 +00:00
params = A_params
else:
t = TuneB
params = B_params
tune.run(t,
config=params,
resources_per_trial={
"cpu": cpu,
"gpu": gpu
},
resume=resume,
checkpoint_at_end=checkpoint_at_end,
checkpoint_freq=checkpoint_freq,
restore=restore,
max_failures=-1,
stop={
'training_iteration': stop
})
if __name__ == "__main__":
fire.Fire(start_tuning)