2019-03-31 05:22:00 +00:00
|
|
|
from __future__ import absolute_import, division, print_function
|
|
|
|
|
|
|
|
# TensorFlow and tf.keras
|
|
|
|
import tensorflow as tf
|
|
|
|
from tensorflow import keras
|
|
|
|
|
|
|
|
# Helper libraries
|
|
|
|
import numpy as np
|
|
|
|
import pandas as pd
|
|
|
|
|
|
|
|
from utils import *
|
|
|
|
|
|
|
|
RANDOM_SEED = 1
|
|
|
|
|
|
|
|
logger.debug('Tensorflow version: %s', tf.__version__)
|
|
|
|
logger.debug('Random Seed: %s', RANDOM_SEED)
|
|
|
|
|
|
|
|
tf.set_random_seed(RANDOM_SEED)
|
|
|
|
np.random.seed(RANDOM_SEED)
|
|
|
|
|
|
|
|
DEFAULT_BATCH_SIZE=256
|
|
|
|
DEFAULT_LAYERS = [512, 512]
|
|
|
|
DEFAULT_BUFFER_SIZE=500
|
|
|
|
DEFAULT_OUT_ACTIVATION = tf.nn.softmax
|
|
|
|
DEFAULT_LOSS = 'sparse_categorical_crossentropy'
|
|
|
|
DEFAULT_OPTIMIZER = tf.keras.optimizers.Adam(lr=0.001)
|
2019-05-02 06:21:29 +00:00
|
|
|
DEFAULT_METRICS = ['accuracy']
|
2019-03-31 05:22:00 +00:00
|
|
|
|
|
|
|
class Model():
|
|
|
|
def __init__(self, name, epochs=1):
|
|
|
|
self.name = name
|
|
|
|
self.path = "checkpoints/{}.hdf5".format(name)
|
|
|
|
|
|
|
|
self.epochs = epochs
|
|
|
|
|
|
|
|
def prepare_dataset(self, df, fn, **kwargs):
|
|
|
|
self.dataset_fn = fn
|
|
|
|
|
|
|
|
self.set_dataset(*fn(df), **kwargs)
|
|
|
|
|
|
|
|
def set_dataset(self, dataset_size, features, output_size, class_weight, dataset, shuffle_buffer_size=DEFAULT_BUFFER_SIZE, batch_size=DEFAULT_BATCH_SIZE):
|
|
|
|
self.shuffle_buffer_size = shuffle_buffer_size
|
|
|
|
|
|
|
|
self.class_weight = class_weight
|
|
|
|
self.dataset = dataset.shuffle(self.shuffle_buffer_size)
|
|
|
|
self.TRAIN_SIZE = int(dataset_size * 0.85)
|
|
|
|
self.TEST_SIZE = dataset_size - self.TRAIN_SIZE
|
|
|
|
(training, test) = (self.dataset.take(self.TRAIN_SIZE),
|
|
|
|
self.dataset.skip(self.TRAIN_SIZE))
|
|
|
|
|
|
|
|
logger.debug('Model dataset info: size=%s, train=%s, test=%s', dataset_size, self.TRAIN_SIZE, self.TEST_SIZE)
|
|
|
|
|
|
|
|
self.dataset_size = dataset_size
|
|
|
|
self.features = features
|
|
|
|
self.output_size = output_size
|
|
|
|
self.training = training
|
|
|
|
self.test = test
|
|
|
|
|
|
|
|
logger.debug('Model input size: %s', self.features)
|
|
|
|
logger.debug('Model output size: %s', self.output_size)
|
|
|
|
|
|
|
|
self.batch_size = batch_size
|
|
|
|
self.training_batched = self.training.batch(self.batch_size).repeat()
|
|
|
|
self.test_batched = self.test.batch(self.batch_size).repeat()
|
|
|
|
|
|
|
|
def create_model(self, layers=DEFAULT_LAYERS, out_activation=DEFAULT_OUT_ACTIVATION):
|
|
|
|
params = {
|
|
|
|
'kernel_initializer': 'lecun_uniform',
|
|
|
|
'bias_initializer': 'zeros',
|
|
|
|
# 'kernel_regularizer': keras.regularizers.l2(l=0.01)
|
|
|
|
'input_shape': [self.features]
|
|
|
|
}
|
|
|
|
|
|
|
|
activation = tf.nn.elu
|
|
|
|
|
|
|
|
logger.debug('Model layer parameters: %s', params)
|
|
|
|
logger.debug('Model layer sizes: %s', layers)
|
|
|
|
logger.debug('Model layer activation function: %s', activation)
|
|
|
|
logger.debug('Model out activation function: %s', out_activation)
|
|
|
|
|
|
|
|
|
|
|
|
self.model = keras.Sequential([
|
|
|
|
keras.layers.Dense(n, activation=activation, **params) for n in layers
|
|
|
|
] + [
|
|
|
|
keras.layers.Dense(self.output_size, activation=out_activation, **params)
|
|
|
|
])
|
|
|
|
|
2019-05-02 06:21:29 +00:00
|
|
|
def compile(self, loss=DEFAULT_LOSS, metrics=DEFAULT_METRICS, optimizer=DEFAULT_OPTIMIZER):
|
2019-03-31 05:22:00 +00:00
|
|
|
logger.debug('Model loss function: %s', loss)
|
|
|
|
logger.debug('Model optimizer: %s', optimizer)
|
|
|
|
logger.debug('Model metrics: %s', metrics)
|
|
|
|
|
|
|
|
self.model.compile(loss=loss,
|
|
|
|
optimizer=optimizer,
|
|
|
|
metrics=metrics)
|
|
|
|
|
|
|
|
def restore(self, path):
|
|
|
|
logger.debug('Restoring model weights from path: %s', path)
|
|
|
|
return self.model.load_weights(path)
|
|
|
|
|
|
|
|
def save(self, path):
|
|
|
|
logger.debug('Saving model weights to path: %s', path)
|
2019-04-26 08:44:43 +00:00
|
|
|
self.model.save_weights(path + '/checkpoint.h5', save_format='h5')
|
|
|
|
return path + '/checkpoint.h5'
|
2019-03-31 05:22:00 +00:00
|
|
|
|
|
|
|
def evaluate(self):
|
|
|
|
return self.model.evaluate(
|
|
|
|
self.test,
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
steps=int(self.dataset_size / self.batch_size),
|
|
|
|
verbose=1
|
|
|
|
)
|
|
|
|
|
|
|
|
def evaluate_print(self):
|
|
|
|
loss, accuracy = self.evaluate()
|
|
|
|
print('Test evaluation: loss: {}, accuracy: {}'.format(loss, accuracy))
|
|
|
|
|
|
|
|
def train(self, config):
|
|
|
|
self.model.summary()
|
|
|
|
|
|
|
|
# map_callback = MapHistory()
|
|
|
|
|
2019-04-22 07:57:20 +00:00
|
|
|
extra_params = {}
|
|
|
|
if self.class_weight:
|
|
|
|
extra_params['class_weight'] = self.class_weight
|
|
|
|
|
2019-03-31 05:22:00 +00:00
|
|
|
out = self.model.fit(
|
|
|
|
self.training_batched,
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
epochs=self.epochs,
|
|
|
|
steps_per_epoch=int(self.TRAIN_SIZE / self.batch_size),
|
|
|
|
validation_data=self.test_batched,
|
|
|
|
validation_steps=int(self.TEST_SIZE / self.batch_size),
|
2019-04-22 07:57:20 +00:00
|
|
|
verbose=1,
|
|
|
|
**extra_params
|
2019-03-31 05:22:00 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
2019-05-02 06:21:29 +00:00
|
|
|
def predict_class(self, a):
|
2019-03-31 05:22:00 +00:00
|
|
|
return np.argmax(self.model.predict(a), axis=1)
|
|
|
|
|
2019-05-02 06:21:29 +00:00
|
|
|
def predict(self, a):
|
|
|
|
return self.model.predict(a)
|
|
|
|
|
|
|
|
def prepare_for_use(self, df=None, batch_size=DEFAULT_BUFFER_SIZE, layers=DEFAULT_LAYERS, out_activation=DEFAULT_OUT_ACTIVATION, loss=DEFAULT_LOSS, optimizer=DEFAULT_OPTIMIZER, dataset_fn=dataframe_to_dataset_biomes, metrics=DEFAULT_METRICS):
|
2019-03-31 05:22:00 +00:00
|
|
|
if df is None:
|
|
|
|
df = pd.read_pickle('data.p')
|
2019-04-22 05:19:31 +00:00
|
|
|
self.prepare_dataset(df, dataset_fn, batch_size=batch_size)
|
2019-03-31 05:22:00 +00:00
|
|
|
self.create_model(layers=layers, out_activation=out_activation)
|
2019-05-02 06:21:29 +00:00
|
|
|
self.compile(loss=loss, optimizer=optimizer, metrics=metrics)
|
2019-03-31 05:22:00 +00:00
|
|
|
|