world-ecoregion/biomes/predict.py

59 lines
1.6 KiB
Python
Raw Normal View History

import fire
2019-02-27 11:36:20 +00:00
import numpy as np
from utils import *
#from nn import compile_b
from constants import INPUTS
from model import Model
2019-02-27 11:36:20 +00:00
from draw import draw
def predicted_map(B, change=0, path=None):
2019-03-07 03:25:23 +00:00
year = MAX_YEAR - 1
2019-02-27 11:36:20 +00:00
2019-03-07 03:25:23 +00:00
df = pd.read_pickle('data.p')
2019-02-27 11:36:20 +00:00
logger.info('temperature change of %s', change)
2019-02-27 11:36:20 +00:00
inputs = list(INPUTS)
2019-02-27 11:36:20 +00:00
for season in SEASONS:
inputs += [
'temp_{}_{}'.format(season, year),
'precip_{}_{}'.format(season, year)
]
2019-03-05 11:53:29 +00:00
frame = df[inputs + ['longitude']]
frame_cp = df[inputs + ['longitude']]
2019-02-27 11:36:20 +00:00
2019-02-28 10:04:47 +00:00
for season in SEASONS:
frame.loc[:, 'temp_{}_{}'.format(season, year)] += change
2019-02-27 11:36:20 +00:00
columns = ['latitude', 'longitude', 'biome_num']
2019-02-27 11:36:20 +00:00
new_data = pd.DataFrame(columns=columns)
nframe = pd.DataFrame(columns=frame.columns, data=normalize_ndarray(frame.to_numpy(), frame_cp.to_numpy()))
for i, (chunk, chunk_original) in enumerate(zip(chunker(nframe, B.batch_size), chunker(frame, B.batch_size))):
if chunk.shape[0] < B.batch_size:
continue
input_data = chunk.loc[:, inputs].values
2019-02-27 11:36:20 +00:00
out = B.predict(input_data)
f = pd.DataFrame({
'longitude': chunk_original.loc[:, 'longitude'],
'latitude': chunk_original.loc[:, 'latitude'],
'biome_num': out
}, columns=columns)
new_data = new_data.append(f)
2019-02-27 11:36:20 +00:00
2019-03-07 03:25:23 +00:00
draw(new_data, path=path)
def predicted_map_cmd(checkpoint='checkpoints/save.h5', change=0, path=None):
B = Model('b', epochs=1)
B.prepare_for_use()
B.restore(checkpoint)
predicted_map(B, change=change, path=path)
2019-03-07 03:25:23 +00:00
if __name__ == "__main__":
fire.Fire(predicted_map_cmd)
2019-03-07 03:25:23 +00:00