Skip to content
Snippets Groups Projects
Mask_test.py 6.31 KiB
Newer Older
  • Learn to ignore specific revisions
  • def masktraining_skip_auto(path_train, path_valid, batch_size, epochs, path_save_model, path_weights_model, option):
    
        #!/usr/bin/env python3
        # -*- coding: utf8 -*-
        import ctypes
        import os
        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
        os.environ['TF_ENABLE_XLA'] = '1'
        os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
        os.environ['TF_ENABLE_CUDNN_RNN_TENSOR_OP_MATH_FP32'] = '1'
        os.environ['TF_DISABLE_CUDNN_TENSOR_OP_MATH'] = '1'
        os.environ['TF_ENABLE_CUBLAS_TENSOR_OP_MATH_FP32'] = '1'
    
        import numpy as np
        import tensorflow as tf
        from tensorflow import keras
        from keras.layers import Add, Multiply, Input, Dense, Flatten, Dropout, Conv2D, MaxPooling2D, Conv2DTranspose, LeakyReLU, Reshape, Activation, BatchNormalization, UpSampling2D
        from keras.models import Model
        from keras.constraints import max_norm
        from keras.optimizers import Adam, SGD
        from keras.utils import to_categorical, normalize
        from keras.callbacks import ModelCheckpoint
        from keras import losses
        from keras import backend as K
        import matplotlib.pyplot as plt
        import random
    
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                tf.config.experimental.set_virtual_device_configuration(gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=4096)])
            except RuntimeError as e:
                print(e)
    
        def build_model(self):
            max_norm_value = 100.0
            # x = keras.Input((self.seq_len, len(self.alphabet)))
            input_noisy = Input(shape=(260,5,1))
            input_noise = Input(shape=(260,5,1))
            input_speech = Input(shape=(260,5,1))
            normalized = BatchNormalization()(input)
            conv_1 = Conv2D(filters=32, kernel_size=(1,3),strides=1,padding="valid",kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(input_noisy)
            leakyrelu_1 = LeakyReLU()(conv_1)
            conv_2 = Conv2D(filters=32, kernel_size=(1,3),strides=1,padding="valid",kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_1)
            leakyrelu_2 = LeakyReLU()(conv_2)
            conv_3 = Conv2D(filters=32, kernel_size=(16,1),strides=1,padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_2)
            leakyrelu_3 = LeakyReLU()(conv_3)
            maxpool_1 = MaxPooling2D(pool_size=(2,1))(leakyrelu_3)
            conv_4 = Conv2D(filters=32, kernel_size=(16,1),strides=1,padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(maxpool_1)
            leakyrelu_4 = LeakyReLU()(conv_4)
            maxpool_2 = MaxPooling2D(pool_size=(2,1))(leakyrelu_4)
            conv_5 = Conv2D(filters=32, kernel_size=(16,1),strides=1,padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(maxpool_2)
            leakyrelu_5 = LeakyReLU()(conv_5)
            conv_6 = Conv2D(filters=32, kernel_size=(16,1),strides=1,padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_5)
            leakyrelu_6 = LeakyReLU()(conv_6)
            convtrans_1 = Conv2DTranspose(filters=32, kernel_size=(16,1),strides=1,padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_6)
            leakyrelu_7 = LeakyReLU()(convtrans_1)
            convtrans_2 = Conv2DTranspose(filters=32, kernel_size=(16,1),strides=1, padding="valid", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(leakyrelu_7)
            leakyrelu_8 = LeakyReLU()(convtrans_2)
            skip_1 = Add()([maxpool_2,leakyrelu_8])
            up_1 = UpSampling2D(size=(2,1))(skip_1)
            conv_7 = Conv2D(filters=32, kernel_size=(16,1), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(up_1)
            leakyrelu_9 = LeakyReLU()(conv_7)
            skip_2 = Add()([leakyrelu_4,leakyrelu_9])
            up_2 = UpSampling2D(size=(2,1))(skip_2)
            conv_8 = Conv2D(filters=32, kernel_size=(16,1), strides=1, padding="same", kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(up_2)
            leakyrelu_10 = LeakyReLU()(conv_8)
            skip_3 = Add()([leakyrelu_3,leakyrelu_10])
            # mask from noisy input
            mask = Conv2D(filters=1, kernel_size=(16,1),strides=1, padding="same", activation='linear', kernel_constraint=max_norm(max_norm_value), kernel_initializer="he_normal")(skip_3)
            # filtered speech and noise component
            n_tilde = Multiply()([mask,input_noise])
            s_tilde = Multiply()([mask,input_speech])
            model = Model(inputs=[input_noisy,input_noise,input_speech], outputs=[n_tilde,s_tilde])
            return model
    
        if __name__ == "__main__":
            model = build_model()
            model.compile(loss='mse', loss_weights=[0.5, 0.5], optimizer='adam', metrics=['metrics','accuracy'])
            if not os.path.exists(path_weights_model):
                try:
                    os.mkdir(path_weights_model)
                except Exception as e:
                    print("Konnte Ordner für Gewichte nicht erstellen" + str(e))
            filepath = path_weights_model + "/weights-{epoch:02d}-{loss:.4f}.hdf5"
            checkpoint = ModelCheckpoint(
                filepath,
                monitor='loss',
                verbose=0,
                save_best_only=True,
                mode='min'
            )
            model.summary()
            history = model.fit(train_generator, steps_per_epoch = int(np.floor(len(train) / batch_size)), epochs = epochs, validation_data = valid_generator, validation_steps = int(np.floor(len(valid) / batch_size)), callbacks=[checkpoint], use_multiprocessing=True)
            # what generator must Returns
            # model.fit([train_input_noisy, train_input_noise, train_input_speech], [train_residual_noise_power_target, train_clean_speech], ..., validation_data = [[val_input_noisy, val_input_noise, val_input_speech], [val_residual_noise_power_target, vali_clean_speech]])
            model.summary()
            model.save(path_save_model)
    
            plt.style.use("ggplot")
            plt.figure()
            plt.plot(history.history["loss"], label="train_loss")
            plt.plot(history.history["val_loss"], label="val_loss")
            plt.title("Training Loss and Accuracy")
            plt.xlabel("Epochs")
            plt.ylabel("Loss/Accuracy")
            plt.legend(loc="lower left")
            plt.savefig(path_save_model + '/testrun.png', dpi=400)
    
            plt.show()
    
            return