path_model = '/home/aneumann/Schreibtisch/Net Models/my_model_db_smalllabels1.h5'
path_data = '/media/aneumann/Harddisk/Datenbanken/PythonTest'
path_data_save = '/media/aneumann/Harddisk/Datenbanken/PythonTest_2'
path_save = '/media/aneumann/Harddisk/Datenbanken/PythonTest_5'
skipcount = 2
framelength = 4
option = 'createdata'
option2 = 'music'

import ctypes
import os
import numpy as np
import scipy.io
import math
from tqdm import tqdm
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_ENABLE_XLA'] = '1'
os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
os.environ['TF_ENABLE_CUDNN_RNN_TENSOR_OP_MATH_FP32'] = '1'
os.environ['TF_DISABLE_CUDNN_TENSOR_OP_MATH'] = '1'
os.environ['TF_ENABLE_CUBLAS_TENSOR_OP_MATH_FP32'] = '1'

from keras.models import load_model
import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
    except RuntimeError as e:
        print(e)

model = tf.keras.models.load_model(path_model)

# create data
# -------------------------------------------------------------------------------------------------

if option == 'createdata':
    print('Creating files: 1weiter ...')
    from Preprocessing_flatten_1weiter import prep_flatten_1weiter
    diffcount = prep_flatten_1weiter(path_data, path_data_save, skipcount, framelength, option2)
    files_all = os.listdir(path_data_save)
    files_all = sorted(files_all)
    path_load = path_data_save
elif option == 'getdata':
    print('Getting files ...')
    files_all = os.listdir(path_data)
    files_all = sorted(files_all)
    path_load = path_data
else:
    print('Option nicht zulaessig!')

# path = '/media/aneumann/Harddisk/Datenbanken/PythonTest_2/Processed_0_2_flatten_1weiter.npz'
# data = np.load(path)
#
# x, y = data['x'], data['y']
# print(x.shape)
# print(y.shape)

# get data
# -------------------------------------------------------------------------------------------------

print('Predicting ...')

for file in files_all:

    path_loadcontents = path_load + '/' + file
    contents = np.load(path_loadcontents)
    X_raw, Y_raw = contents['x'], contents['y']

    X_predicted_matrix = []
    predictionError_matrix = []
    dbmixed_matrix = []
    dblabel_matrix = []

    for m in tqdm(list(range(0,int(X_raw.shape[0]/diffcount)))):
        X, Y = X_raw[m*diffcount:m*diffcount+diffcount,:], Y_raw[m*diffcount:m*diffcount+diffcount,:]

        dbmixed = []
        dblabel = []
        X_predicted_whole = []
        predictionError_whole = []

        for n in list(range(0,diffcount)):
            X_predict, Y_predict = X[n,:], Y[n,:]
            X_predict, Y_predict = np.reshape(X_predict,(1,X_predict.shape[0])), np.reshape(Y_predict,(1,Y_predict.shape[0]))

            dbmixed.extend(X_predict[:,-(int(X_predict.shape[1]/framelength)):])
            dblabel.extend(Y_predict)
            X_predicted = model.predict(X_predict)
            X_predicted_whole.extend(X_predicted)

            predictionError = Y_predict - X_predicted
            predictionError_whole.extend(predictionError)

        X_predicted_whole = np.transpose(X_predicted_whole)
        predictionError_whole = np.transpose(predictionError_whole)
        dbmixed = np.transpose(dbmixed)
        dblabel = np.transpose(dblabel)

        if m == 0:
            X_predicted_matrix = np.expand_dims(X_predicted_whole, axis = 0)
            predictionError_matrix = np.expand_dims(predictionError_whole, axis = 0)
            dbmixed_matrix = np.expand_dims(dbmixed, axis = 0)
            dblabel_matrix = np.expand_dims(dblabel, axis = 0)
        else:
            X_predicted_matrix = np.concatenate((X_predicted_matrix, np.expand_dims(X_predicted_whole, axis = 0)), axis=0)
            predictionError_matrix = np.concatenate((predictionError_matrix, np.expand_dims(predictionError_whole, axis = 0)), axis=0)
            dbmixed_matrix = np.concatenate((dbmixed_matrix, np.expand_dims(dbmixed, axis = 0)), axis=0)
            dblabel_matrix = np.concatenate((dblabel_matrix, np.expand_dims(dblabel, axis = 0)), axis=0)

    # saving
    # -------------------------------------------------------------------------------------------------

    index = file.split("_",2)
    path_savecontents = path_save + '/Predicted_' + index[1] + '_' + str(X_predicted_matrix.shape[0]) + '_Dense_1weiter.npz'
    np.savez(path_savecontents, X_predicted_matrix=X_predicted_matrix, predictionError_matrix=predictionError, dbmixed_matrix=dbmixed_matrix, dblabel_matrix=dblabel_matrix)
    print(X_predicted_matrix.shape)