world-ecoregion/train.py

68 lines
1.8 KiB
Python

import fire
import ray
import pandas as pd
import tensorflow as tf
from ray import tune
from tensorflow import keras
from utils import logger
from model import Model
B_params = {
'batch_size': tune.grid_search([256]),
'layers': tune.grid_search([[512, 512]]),
'lr': tune.grid_search([1e-4]),
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
}
df = pd.read_pickle('data.p')
class TuneB(tune.Trainable):
def _setup(self, config):
logger.debug('Ray Tune model configuration %s', config)
self.model = Model('b', epochs=1)
optimizer = config['optimizer']
optimizer = config['optimizer'](lr=config['lr'])
self.model.prepare_for_use(df=df, batch_size=config['batch_size'], layers=config['layers'], optimizer=optimizer)
def _train(self):
logs = self.model.train(self.config)
metrics = {
'mean_accuracy': logs.history['acc'][0],
'loss': logs.history['loss'][0],
'val_accuracy': logs.history['val_acc'][0],
'val_loss': logs.history['val_loss'][0],
}
return metrics
def _save(self, checkpoint_dir):
return self.model.save(checkpoint_dir)
def _restore(self, path):
return self.model.restore(path)
def start_tuning(cpu=1, gpu=2, checkpoint_freq=1, checkpoint_at_end=True, resume=False, restore=None, stop=500):
ray.init()
tune.run(TuneB,
config=B_params,
resources_per_trial={
"cpu": cpu,
"gpu": gpu
},
resume=resume,
checkpoint_at_end=checkpoint_at_end,
checkpoint_freq=checkpoint_freq,
restore=restore,
stop={
'training_iteration': stop
})
if __name__ == "__main__":
fire.Fire(start_tuning)