Skip to content
Snippets Groups Projects
plot_net.py 1.46 KiB
Newer Older
  • Learn to ignore specific revisions
  • Anna Neumann's avatar
    Anna Neumann committed
    def plot_net_performance(X_predict, Y_predict, X_predicted, predictionError):
    
        import matplotlib.pyplot as plt
    
        plt.subplot(2,2,1)
        plt.title('dbmixed')
        plt.imshow(X_predict, interpolation='nearest', aspect='auto')
        plt.colorbar()
    
        plt.subplot(2,2,2)
        plt.title('dbmusic')
        plt.imshow(Y_predict, interpolation='nearest', aspect='auto')
        plt.colorbar()
    
        plt.subplot(2,2,3)
        plt.title('prediction')
        plt.imshow(X_predicted, interpolation='nearest', aspect='auto')
        plt.colorbar()
    
        plt.subplot(2,2,4)
        plt.title('predictionerror')
        plt.imshow(predictionError, interpolation='nearest', aspect='auto')
        plt.colorbar()
    
        plt.show()
    
    def plot_net_logerr(logerr_mean_whole, logerr_over_whole, logerr_under_whole, logerr_var_whole):
        import matplotlib
        import matplotlib.pyplot as plt
        import numpy as np
    
        labels = ['mean','over','under','var']
        logerr_mean = np.mean(logerr_mean_whole)
        logerr_over = np.mean(logerr_over_whole)
        logerr_under = np.abs(np.mean(logerr_under_whole))
        logerr_var = np.mean(logerr_var_whole)
        means = [logerr_mean, logerr_over, logerr_under, logerr_var]
        x = np.arange(len(labels))
        width = 0.35
        fig, ax = plt.subplots()
        rects1 = ax.bar(x - width/2, means, width, label='CNN_Noise')
        ax.set_ylabel('Mittelwerte')
        ax.set_title('Mittelwerte logerr-Auswertung')
        ax.set_xticks(x)
        ax.set_xticklabels(labels)
        ax.legend()
    
        fig.tight_layout()
    
        plt.show()