Skip to content
Snippets Groups Projects
Mask_2Channel.py 9.1 KiB
Newer Older
Anna Neumann's avatar
Anna Neumann committed
def masktraining_skip_2chan(path_train, path_valid, batch_size, epochs, path_save_model, path_weights_model, plot_name, option, reduction_divisor):

    # initializing

    #!/usr/bin/env python3
    # -*- coding: utf8 -*-
    import ctypes
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
    os.environ['TF_ENABLE_XLA'] = '1'
    os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
    os.environ['TF_ENABLE_CUDNN_RNN_TENSOR_OP_MATH_FP32'] = '1'
    os.environ['TF_DISABLE_CUDNN_TENSOR_OP_MATH'] = '1'
    os.environ['TF_ENABLE_CUBLAS_TENSOR_OP_MATH_FP32'] = '1'

    import numpy as np
    import tensorflow as tf
    from tensorflow import keras
    from keras.layers import Add, Multiply, Input, Dense, Flatten, Dropout, Conv2D, MaxPooling2D, Conv2DTranspose, LeakyReLU, Reshape, Activation, BatchNormalization, UpSampling2D
    from keras.models import Model
    from keras.constraints import max_norm
    from keras.optimizers import Adam, SGD
    from keras.utils import to_categorical, normalize
    from keras.callbacks import ModelCheckpoint
    from keras import losses
    from keras import backend as K
    import matplotlib.pyplot as plt
    import random

    # checking for gpus and using it/them

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
        except RuntimeError as e:
            print(e)

    # custom generator import
    from Generators.MaskGenerator2 import DataGenerator

    # ERSETZEN!
    def count_files(dir):
        return len([1 for x in list(os.scandir(dir)) if x.is_file()])
    len_train = count_files(path_train)
    len_valid = count_files(path_valid)

    # generators for train and validation data

    train_generator = DataGenerator(path_train, option, reduction_divisor, len_train, batch_size, True)
    valid_generator = DataGenerator(path_valid, option, reduction_divisor, len_valid, batch_size, True)


    # building model

    def build_model():
        max_norm_value = 100.0

        # defining inputs and normalizing just noisy input

        input_noisy = Input(shape=(260,5,2))
        input_noise = Input(shape=(260,1,2))
        input_speech = Input(shape=(260,1,2))
        normalized_noisy = BatchNormalization()(input_noisy)
        # normalized_noise = BatchNormalization()(input_noise)
        # normalized_speech = BatchNormalization()(input_speech)

        # encoder of net

        conv_1 = Conv2D(filters=64, kernel_size=(1,3),strides=1,padding="valid",kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(normalized_noisy)
        leakyrelu_1 = LeakyReLU()(conv_1)
        conv_2 = Conv2D(filters=64, kernel_size=(1,3),strides=1,padding="valid",kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_1)
        leakyrelu_2 = LeakyReLU()(conv_2)
        conv_3 = Conv2D(filters=64, kernel_size=(16,1),strides=1,padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_2)
        leakyrelu_3 = LeakyReLU()(conv_3)
        maxpool_1 = MaxPooling2D(pool_size=(2,1))(leakyrelu_3)
        conv_4 = Conv2D(filters=64, kernel_size=(16,1),strides=1,padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(maxpool_1)
        leakyrelu_4 = LeakyReLU()(conv_4)
        maxpool_2 = MaxPooling2D(pool_size=(2,1))(leakyrelu_4)
        conv_5 = Conv2D(filters=64, kernel_size=(16,1),strides=1,padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(maxpool_2)
        leakyrelu_5 = LeakyReLU()(conv_5)
        conv_6 = Conv2D(filters=64, kernel_size=(16,1),strides=1,padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_5)
        leakyrelu_6 = LeakyReLU()(conv_6)

        # decoder of Net

        convtrans_1 = Conv2DTranspose(filters=64, kernel_size=(16,1),strides=1,padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_6)
        leakyrelu_7 = LeakyReLU()(convtrans_1)
        convtrans_2 = Conv2DTranspose(filters=64, kernel_size=(16,1),strides=1, padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_7)
        leakyrelu_8 = LeakyReLU()(convtrans_2)
        skip_1 = Add()([maxpool_2,leakyrelu_8])
        up_1 = UpSampling2D(size=(2,1))(skip_1)
        conv_7 = Conv2D(filters=64, kernel_size=(16,1), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(up_1)
        leakyrelu_9 = LeakyReLU()(conv_7)
        skip_2 = Add()([leakyrelu_4,leakyrelu_9])
        up_2 = UpSampling2D(size=(2,1))(skip_2)
        conv_8 = Conv2D(filters=64, kernel_size=(16,1), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(up_2)
        leakyrelu_10 = LeakyReLU()(conv_8)
        skip_3 = Add()([leakyrelu_3,leakyrelu_10])

        # mask from noisy input

        mask = Conv2D(filters=2, kernel_size=(16,1),strides=1, padding="same", activation='linear', kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(skip_3)

        # filtered speech and noise component

        n_tilde = Multiply()([mask,input_noise])
        s_tilde = Multiply()([mask,input_speech])

        # defining model

        model = Model(inputs=[input_noisy,input_noise,input_speech], outputs=[n_tilde,s_tilde])

        return model

    # if __name__ == "__main__":

    # build model and compile it

    model = build_model()
    model.compile(loss='mse', loss_weights=[0.5, 0.5], optimizer='adam', metrics=['mape','accuracy'])

    # making directory for weights

    if not os.path.exists(path_weights_model):
        try:
            os.makedirs(path_weights_model)
        except Exception as e:
            print("Konnte Ordner fuer Gewichte nicht erstellen" + str(e))

    # defining how weights are saved

    filepath = path_weights_model + "/weights-{epoch:04d}.hdf5"
    checkpoint = ModelCheckpoint(
        filepath,
        monitor='loss',
        verbose=0,
        save_best_only=True,
        mode='min'
    )
    model.summary()

    # train model

    history = model.fit(train_generator, steps_per_epoch = int(np.floor(len_train // reduction_divisor / batch_size)), epochs = epochs, validation_data = valid_generator, validation_steps = int(np.floor(len_valid // reduction_divisor / batch_size)), callbacks=[checkpoint], use_multiprocessing=True)

    model.summary()

    # save model after training

    model.save(path_save_model)

    index = path_save_model.split("/")
    path_save_nomodel = '/'.join(index[0:-1])


    loss = history.history["loss"]
    multiply_1_loss = history.history["multiply_1_loss"]
    multiply_2_loss = history.history["multiply_2_loss"]
    multiply_1_mape = history.history["multiply_1_mape"]
    multiply_2_mape = history.history["multiply_2_mape"]
    multiply_1_accuracy = history.history["multiply_1_accuracy"]
    multiply_2_accuracy = history.history["multiply_2_accuracy"]
    val_loss = history.history["val_loss"]
    val_multiply_1_loss = history.history["val_multiply_1_loss"]
    val_multiply_2_loss = history.history["val_multiply_2_loss"]
    val_multiply_1_mape = history.history["val_multiply_1_mape"]
    val_multiply_2_mape = history.history["val_multiply_2_mape"]
    val_multiply_1_accuracy = history.history["val_multiply_1_accuracy"]
    val_multiply_2_accuracy = history.history["val_multiply_2_accuracy"]
    dict_history = {'loss':loss, 'multiply_1_loss':multiply_1_loss, 'multiply_2_loss':multiply_2_loss, 'multiply_1_mape':multiply_1_mape, 'multiply_2_mape':multiply_2_mape, 'multiply_1_accuracy':multiply_1_accuracy, 'multiply_2_accuracy':multiply_2_accuracy, 'val_loss':val_loss, 'val_multiply_1_loss':val_multiply_1_loss, 'val_multiply_2_loss':val_multiply_2_loss, 'val_multiply_1_mape':val_multiply_1_mape, 'val_multiply_2_mape':val_multiply_2_mape, 'val_multiply_1_accuracy':val_multiply_1_accuracy, 'val_multiply_2_accuracy':val_multiply_2_accuracy}
    # loss_history = history.history["loss"]
    # mape_history = history.history["mape"]
    # accuracy_history = history.history["accuracy"]
    # val_loss_history = history.history["val_loss"]
    # val_mape_history = history.history["val_mape"]
    # val_accuracy_history = history.history["val_accuracy"]
    #
    # dict_history = {'loss':loss_history, 'mape':mape_history, 'accuracy':accuracy_history, 'val_loss':val_loss_history, 'val_mape':val_mape_history, 'val_accuracy':val_accuracy_history}
    np.save(path_save_nomodel + '/' + plot_name, dict_history)

    # plot train loss and validation loss after training

    plt.style.use("ggplot")
    plt.figure()
    plt.plot(history.history["loss"], label="train_loss")
    plt.plot(history.history["val_loss"], label="val_loss")
    plt.title("Training Loss and Accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Loss/Accuracy")
    plt.legend(loc="lower left")
    index = path_save_model.split("/")
    path_save_nomodel = '/'.join(index[0:-1])
    plt.savefig(path_save_nomodel + '/' + plot_name + '.png', dpi=400)

    plt.show()

    return