world-ecoregion/biomes/train.py
2019-05-02 10:55:14 +04:30

121 lines
3.3 KiB
Python

import fire
import ray
import pandas as pd
import tensorflow as tf
from ray import tune
from tensorflow import keras
from utils import *
from model import Model
B_params = {
'batch_size': tune.grid_search([256]),
'layers': tune.grid_search([[512, 512]]),
'lr': tune.grid_search([3e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
}
df = pd.read_pickle('data.p')
class TuneB(tune.Trainable):
def _setup(self, config):
logger.debug('Ray Tune model configuration %s', config)
self.model = Model('b', epochs=1)
optimizer = config['optimizer']
optimizer = config['optimizer'](lr=config['lr'])
self.model.prepare_for_use(df=df, batch_size=config['batch_size'], layers=config['layers'], optimizer=optimizer)
def _train(self):
logs = self.model.train(self.config)
metrics = {
'mean_accuracy': logs.history['acc'][0],
'loss': logs.history['loss'][0],
'val_accuracy': logs.history['val_acc'][0],
'val_loss': logs.history['val_loss'][0],
}
return metrics
def _save(self, checkpoint_dir):
return self.model.save(checkpoint_dir)
def _restore(self, path):
return self.model.restore(path)
A_params = {
'batch_size': tune.grid_search([256]),
'layers': tune.grid_search([[64, 64]]),
'lr': tune.grid_search([3e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
#'optimizer': tune.grid_search([tf.keras.optimizers.RMSprop])
}
class TuneA(tune.Trainable):
def _setup(self, config):
logger.debug('Ray Tune model configuration %s', config)
self.model = Model('a', epochs=1)
optimizer = config['optimizer']
optimizer = config['optimizer'](lr=config['lr'])
self.model.prepare_for_use(
df=df,
batch_size=config['batch_size'],
layers=config['layers'],
optimizer=optimizer,
out_activation=None,
dataset_fn=dataframe_to_dataset_temp_precip,
loss='mse',
metrics=['mae']
)
def _train(self):
logs = self.model.train(self.config)
print(logs.history)
metrics = {
'loss': logs.history['loss'][0],
'mae': logs.history['mean_absolute_error'][0],
'val_loss': logs.history['val_loss'][0],
'val_mae': logs.history['val_mean_absolute_error'][0],
}
return metrics
def _save(self, checkpoint_dir):
return self.model.save(checkpoint_dir)
def _restore(self, path):
return self.model.restore(path)
def start_tuning(model, cpu=1, gpu=2, checkpoint_freq=1, checkpoint_at_end=True, resume=False, restore=None, stop=500):
ray.init()
if model == 'a':
t = TuneA
params = A_params
else:
t = TuneB
params = B_params
tune.run(t,
config=params,
resources_per_trial={
"cpu": cpu,
"gpu": gpu
},
resume=resume,
checkpoint_at_end=checkpoint_at_end,
checkpoint_freq=checkpoint_freq,
restore=restore,
stop={
'training_iteration': stop
})
if __name__ == "__main__":
fire.Fire(start_tuning)