#!/usr/bin/env python3
# -*- coding: utf8 -*-
import ctypes
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
os.environ['TF_ENABLE_XLA'] = '1'
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
os.environ['TF_ENABLE_CUDNN_RNN_TENSOR_OP_MATH_FP32'] = '1'
os.environ['TF_DISABLE_CUDNN_TENSOR_OP_MATH'] = '1'
os.environ['TF_ENABLE_CUBLAS_TENSOR_OP_MATH_FP32'] = '1'

import numpy as np
from tensorflow.keras.models import Sequential
import tensorflow.keras
import os

import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
    except RuntimeError as e:
        print(e)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout, Conv2D, MaxPooling2D, Conv2DTranspose, LeakyReLU, Reshape, Activation, BatchNormalization, UpSampling2D
from tensorflow.keras.activations import selu
from tensorflow.keras.constraints import max_norm
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.utils import to_categorical, normalize
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.datasets import mnist
from tensorflow.keras import losses
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
import numpy as np


# GET DATA
# def getData():
# Datasets
path_train = '/media/aneumann/Harddisk/Datenbanken/PythonTest/test/Preprocessing/Train'
path_valid = '/media/aneumann/Harddisk/Datenbanken/PythonTest/test/Preprocessing/Validation'

train = os.listdir(path_train)
train = sorted(train)
print(len(train))
valid = os.listdir(path_valid)
valid = sorted(valid)
print(len(valid))

# partition = {'train': train,
#              'validation': valid}
# labels = []
# for file in train:
#     np.load(path_train + '/' + file)

# print(partition)

# train_file = np.load(path_train + '/' + train[0])
# ids_train = np.arange(np.size(train_file['x'],0))
# print(ids_train)
batch_size = 10

def batch_generator(files, path, batch_size):
    i = 0
    # while True:
    # for file in files:
    file = files[i]
    print(file)
    loadfile = np.load(path + '/' + file)
    ids = np.arange(np.size(loadfile['x'],0))

    batch = []
    batch_counter = 0
    np.random.shuffle(ids)
    # while batch_counter < np.floor(len(ids)/batch_size):
    while True:
        for i in ids:
            batch.append(i)
            if len(batch) == batch_size:
                yield load_data(loadfile,batch)
                batch = []
                batch_counter += batch_size
        # i += 1

# def batch_valid_generator(ids, batch_size):


def load_data(loadfile,ids):
    X = []
    Y = []

    for i in ids:

        x = loadfile['x'][i,:,:,:]
        y = loadfile['y'][i,:,:,:]

        X.append(x)
        Y.append(x)

    return np.array(X), np.array(Y)

# train_generator = batch_generator(ids, batch_size = 10)
train_generator = batch_generator(train, path_train, batch_size)
# print(train_generator.get_next().shape)
# print(train_generator.get_next().shape)
# for X,Y in train_generator:
#     print(X.shape,Y.shape)
# valid_generator = batch_valid_generator(ids_valid, batch_size)
valid_generator = batch_generator(valid, path_valid, batch_size)

    # return train_generator, valid_generator

# BUILD MODEL

max_norm_value = 100.0
model = Sequential([
    Conv2D(filters=32, kernel_size=(3,3), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal", input_shape=(256, 128, 1), name="enc_conv1"),
    LeakyReLU(),
    Dropout(0.3, name="enc_drop1"),
    MaxPooling2D(pool_size=(2,2), name="enc_pool1"),
    Conv2D(filters=32, kernel_size=(3,3), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal", name="enc_conv2"),
    LeakyReLU(),
    MaxPooling2D(pool_size=(2,2), name="enc_pool2"),
    Conv2D(filters=32, kernel_size=(3,3), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal", name="enc_conv3"),
    LeakyReLU(),
    MaxPooling2D(pool_size=(2,2), name="enc_pool3"),
    Conv2D(filters=32, kernel_size=(3,3), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal", name="enc_conv4"),
    LeakyReLU(),
    Flatten(),
    Dense(units = 8192),
    Dense(units = 8192),
    Reshape((32,16,16)),
    UpSampling2D(size=(2,2), name="dec_up1"),
    Conv2D(filters=32, kernel_size=(3,3), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal", name="dec_conv1"),
    LeakyReLU(),
    UpSampling2D(size=(2,2), name="dec_up2"),
    Conv2D(filters=32, kernel_size=(3,3), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal", name="dec_conv2"),
    LeakyReLU(),
    Dropout(0.3, name="dec_drop1"),
    UpSampling2D(size=(2,2), name="dec_up3"),
    Conv2D(filters=1, kernel_size=(3,3), strides=1, padding="same", activation='linear', kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal", name="dec_conv3")
])
# opt = Adam(lr=0.000001)
model.compile(loss='mse', optimizer='adam', metrics=['mape', 'accuracy'])

# MODEL Train

if not os.path.exists("weights_cnn"):
        try:
            os.mkdir("weights_cnn")
        except Exception as e:
            print("Konnte Ordner für Gewichte nicht erstellen" + str(e))

filepath = "weights_cnn/weights-{epoch:02d}-{loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(
    filepath,
    monitor='loss',
    verbose=0,
    save_best_only=True,
    mode='min'
)
model.summary()
history = model.fit_generator(train_generator, steps_per_epoch = 1, epochs = 5, validation_data=valid_generator, validation_steps = 1, use_multiprocessing=True)
model.summary()
model.save('cnn_autoencoder_model.h5')