world-ecoregion/biomes/train_temp.py

91 lines
2.5 KiB
Python
Raw Normal View History

import fire
import ray
import pandas as pd
import tensorflow as tf
import numpy as np
from tensorflow import keras
from utils import *
from model import Model
from constants import *
CHECKPOINT = 'checkpoints/temp.h5'
SEED = 1
np.random.seed(SEED)
df = pd.read_pickle('data.p')
dataset_size, x_columns, y_columns, dataset = dataframe_to_dataset_temp_precip(df)
batch_size = 5
epochs = 500
def baseline_model():
model = keras.models.Sequential()
params = {
'kernel_initializer': 'lecun_uniform',
'bias_initializer': 'zeros',
}
model.add(keras.layers.Dense(x_columns, input_dim=x_columns, **params, activation='elu'))
model.add(keras.layers.Dense(6, **params, activation='relu'))
model.add(keras.layers.Dense(y_columns, **params))
model.compile(loss='mse', optimizer='adam', metrics=['mae'])
return model
model = baseline_model()
model.summary()
dataset = dataset.shuffle(500)
TRAIN_SIZE = int(dataset_size * 0.85)
TEST_SIZE = dataset_size - TRAIN_SIZE
(training, test) = (dataset.take(TRAIN_SIZE),
dataset.skip(TRAIN_SIZE))
training_batched = training.batch(batch_size).repeat()
test_batched = test.batch(batch_size).repeat()
logger.debug('Model dataset info: size=%s, train=%s, test=%s', dataset_size, TRAIN_SIZE, TEST_SIZE)
# model.load_weights(CHECKPOINT)
def predict():
columns = INPUTS
YEAR = 2000
print(columns)
print(df[0:batch_size])
inputs = df[columns].to_numpy()
inputs = normalize_ndarray(inputs, df[columns].to_numpy())
print(inputs[0:batch_size])
out_columns = []
for season in SEASONS:
out_columns += ['temp_{}_{}'.format(season, YEAR), 'precip_{}_{}'.format(season, YEAR)]
print(out_columns)
out = model.predict(inputs)
print(out)
print(df[out_columns][0:batch_size])
print(denormalize(out, df[out_columns].to_numpy()))
def train():
tfb_callback = tf.keras.callbacks.TensorBoard(batch_size=batch_size, log_dir='temp_logs')
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=CHECKPOINT, monitor='val_loss')
model.fit(training_batched,
batch_size=batch_size,
epochs=epochs,
steps_per_epoch=int(TRAIN_SIZE / batch_size),
validation_data=test_batched,
validation_steps=int(TEST_SIZE / batch_size),
callbacks=[tfb_callback, checkpoint_callback],
verbose=1)
model.save_weights(CHECKPOINT)
# train()
if __name__ == "__main__":
fire.Fire({ 'predict': predict, 'train': train })