164 lines
4.5 KiB
Python
164 lines
4.5 KiB
Python
import fire
|
|
import ray
|
|
import pandas as pd
|
|
import tensorflow as tf
|
|
from ray import tune
|
|
from tensorflow import keras
|
|
from utils import *
|
|
from model import Model
|
|
|
|
B_params = {
|
|
'batch_size': tune.grid_search([256]),
|
|
'layers': tune.grid_search([[512, 512]]),
|
|
'lr': tune.grid_search([1e-4]),
|
|
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
|
}
|
|
|
|
df = pd.read_pickle('data.p')
|
|
|
|
class TuneB(tune.Trainable):
|
|
def _setup(self, config):
|
|
logger.debug('Ray Tune model configuration %s', config)
|
|
|
|
self.model = Model('b', epochs=1)
|
|
|
|
optimizer = config['optimizer']
|
|
optimizer = config['optimizer'](lr=config['lr'])
|
|
|
|
self.model.prepare_for_use(df=df, batch_size=config['batch_size'], layers=config['layers'], optimizer=optimizer)
|
|
|
|
def _train(self):
|
|
logs = self.model.train(self.config)
|
|
|
|
metrics = {
|
|
'mean_accuracy': logs.history['acc'][0],
|
|
'loss': logs.history['loss'][0],
|
|
'val_accuracy': logs.history['val_acc'][0],
|
|
'val_loss': logs.history['val_loss'][0],
|
|
}
|
|
|
|
return metrics
|
|
|
|
def _save(self, checkpoint_dir):
|
|
return self.model.save(checkpoint_dir)
|
|
|
|
def _restore(self, path):
|
|
return self.model.restore(path)
|
|
|
|
A_params = {
|
|
'batch_size': tune.grid_search([256]),
|
|
'layers': tune.grid_search([[64, 64]]),
|
|
'lr': tune.grid_search([3e-4]),
|
|
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
|
#'optimizer': tune.grid_search([tf.keras.optimizers.RMSprop])
|
|
}
|
|
|
|
class TuneTemp(tune.Trainable):
|
|
def _setup(self, config):
|
|
logger.debug('Ray Tune model configuration %s', config)
|
|
|
|
self.model = Model('temp', epochs=1)
|
|
|
|
optimizer = config['optimizer']
|
|
optimizer = config['optimizer'](lr=config['lr'])
|
|
|
|
self.model.prepare_for_use(
|
|
df=df,
|
|
batch_size=config['batch_size'],
|
|
layers=config['layers'],
|
|
optimizer=optimizer,
|
|
out_activation=None,
|
|
dataset_fn=dataframe_to_dataset_temp,
|
|
loss='mse',
|
|
metrics=['mae']
|
|
)
|
|
|
|
def _train(self):
|
|
logs = self.model.train(self.config)
|
|
|
|
print(logs.history)
|
|
metrics = {
|
|
'loss': logs.history['loss'][0],
|
|
'mae': logs.history['mean_absolute_error'][0],
|
|
'val_loss': logs.history['val_loss'][0],
|
|
'val_mae': logs.history['val_mean_absolute_error'][0],
|
|
}
|
|
|
|
return metrics
|
|
|
|
def _save(self, checkpoint_dir):
|
|
return self.model.save(checkpoint_dir)
|
|
|
|
def _restore(self, path):
|
|
return self.model.restore(path)
|
|
|
|
class TunePrecip(tune.Trainable):
|
|
def _setup(self, config):
|
|
logger.debug('Ray Tune model configuration %s', config)
|
|
|
|
self.model = Model('precip', epochs=1)
|
|
|
|
optimizer = config['optimizer']
|
|
optimizer = config['optimizer'](lr=config['lr'])
|
|
|
|
self.model.prepare_for_use(
|
|
df=df,
|
|
batch_size=config['batch_size'],
|
|
layers=config['layers'],
|
|
optimizer=optimizer,
|
|
out_activation=None,
|
|
dataset_fn=dataframe_to_dataset_precip,
|
|
loss='mse',
|
|
metrics=['mae']
|
|
)
|
|
|
|
def _train(self):
|
|
logs = self.model.train(self.config)
|
|
|
|
print(logs.history)
|
|
metrics = {
|
|
'loss': logs.history['loss'][0],
|
|
'mae': logs.history['mean_absolute_error'][0],
|
|
'val_loss': logs.history['val_loss'][0],
|
|
'val_mae': logs.history['val_mean_absolute_error'][0],
|
|
}
|
|
|
|
return metrics
|
|
|
|
def _save(self, checkpoint_dir):
|
|
return self.model.save(checkpoint_dir)
|
|
|
|
def _restore(self, path):
|
|
return self.model.restore(path)
|
|
|
|
def start_tuning(model, cpu=1, gpu=2, checkpoint_freq=1, checkpoint_at_end=True, resume=False, restore=None, stop=500):
|
|
ray.init()
|
|
|
|
if model == 'temp':
|
|
t = TuneTemp
|
|
params = A_params
|
|
elif model == 'precip':
|
|
t = TunePrecip
|
|
params = A_params
|
|
else:
|
|
t = TuneB
|
|
params = B_params
|
|
|
|
tune.run(t,
|
|
config=params,
|
|
resources_per_trial={
|
|
"cpu": cpu,
|
|
"gpu": gpu
|
|
},
|
|
resume=resume,
|
|
checkpoint_at_end=checkpoint_at_end,
|
|
checkpoint_freq=checkpoint_freq,
|
|
restore=restore,
|
|
max_failures=-1,
|
|
stop={
|
|
'training_iteration': stop
|
|
})
|
|
|
|
if __name__ == "__main__":
|
|
fire.Fire(start_tuning)
|