68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
|
import fire
|
||
|
import ray
|
||
|
import pandas as pd
|
||
|
import tensorflow as tf
|
||
|
from ray import tune
|
||
|
from tensorflow import keras
|
||
|
from utils import logger
|
||
|
from model import Model
|
||
|
|
||
|
B_params = {
|
||
|
'batch_size': tune.grid_search([256]),
|
||
|
'layers': tune.grid_search([[512, 512]]),
|
||
|
'lr': tune.grid_search([1e-4]),
|
||
|
'optimizer': tune.grid_search([tf.keras.optimizers.Adam]),
|
||
|
}
|
||
|
|
||
|
df = pd.read_pickle('data.p')
|
||
|
|
||
|
class TuneB(tune.Trainable):
|
||
|
def _setup(self, config):
|
||
|
logger.debug('Ray Tune model configuration %s', config)
|
||
|
|
||
|
self.model = Model('b', epochs=1)
|
||
|
|
||
|
optimizer = config['optimizer']
|
||
|
optimizer = config['optimizer'](lr=config['lr'])
|
||
|
|
||
|
self.model.prepare_for_use(df=df, batch_size=config['batch_size'], layers=config['layers'], optimizer=optimizer)
|
||
|
|
||
|
def _train(self):
|
||
|
logs = self.model.train(self.config)
|
||
|
|
||
|
metrics = {
|
||
|
'mean_accuracy': logs.history['acc'][0],
|
||
|
'loss': logs.history['loss'][0],
|
||
|
'val_accuracy': logs.history['val_acc'][0],
|
||
|
'val_loss': logs.history['val_loss'][0],
|
||
|
}
|
||
|
|
||
|
return metrics
|
||
|
|
||
|
def _save(self, checkpoint_dir):
|
||
|
return self.model.save(checkpoint_dir)
|
||
|
|
||
|
def _restore(self, path):
|
||
|
return self.model.restore(path)
|
||
|
|
||
|
def start_tuning(cpu=1, gpu=2, checkpoint_freq=1, checkpoint_at_end=True, resume=False, restore=None, stop=500):
|
||
|
ray.init()
|
||
|
|
||
|
tune.run(TuneB,
|
||
|
config=B_params,
|
||
|
resources_per_trial={
|
||
|
"cpu": cpu,
|
||
|
"gpu": gpu
|
||
|
},
|
||
|
resume=resume,
|
||
|
checkpoint_at_end=checkpoint_at_end,
|
||
|
checkpoint_freq=checkpoint_freq,
|
||
|
restore=restore,
|
||
|
stop={
|
||
|
'training_iteration': stop
|
||
|
})
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
fire.Fire(start_tuning)
|