Skip to content
Snippets Groups Projects
stoi.py 4.72 KiB
Newer Older
  • Learn to ignore specific revisions
  • Anna Neumann's avatar
    Anna Neumann committed
    import numpy as np
    import warnings
    import utils
    
    # Constant definition
    FS = 10000                          # Sampling frequency
    N_FRAME = 256                       # Window support
    NFFT = 512                          # FFT Size
    NUMBAND = 15                        # Number of 13 octave band
    MINFREQ = 150                       # Center frequency of 1st octave band (Hz)
    OBM, CF = utils.thirdoct(FS, NFFT, NUMBAND, MINFREQ)  # Get 1/3 octave band matrix
    N = 30                              # N. frames for intermediate intelligibility
    BETA = -15.                         # Lower SDR bound
    DYN_RANGE = 40                      # Speech dynamic range
    
    
    def stoi(x, y, fs_sig, extended=False):
        """ Short term objective intelligibility
        Computes the STOI (See [1][2]) of a denoised signal compared to a clean
        signal, The output is expected to have a monotonic relation with the
        subjective speech-intelligibility, where a higher score denotes better
        speech intelligibility.
    
        # Arguments
            x (np.ndarray): clean original speech
            y (np.ndarray): denoised speech
            fs_sig (int): sampling rate of x and y
            extended (bool): Boolean, whether to use the extended STOI described in [3]
    
        # Returns
            float: Short time objective intelligibility measure between clean and
            denoised speech
    
        # Raises
            AssertionError : if x and y have different lengths
    
        # Reference
            [1] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
                Objective Intelligibility Measure for Time-Frequency Weighted Noisy
                Speech', ICASSP 2010, Texas, Dallas.
            [2] C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
                Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
                IEEE Transactions on Audio, Speech, and Language Processing, 2011.
            [3] Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
                Intelligibility of Speech Masked by Modulated Noise Maskers',
                IEEE Transactions on Audio, Speech and Language Processing, 2016.
        """
        if x.shape != y.shape:
            raise Exception('x and y should have the same length,' +
                            'found {} and {}'.format(x.shape, y.shape))
    
        # Resample is fs_sig is different than fs
        if fs_sig != FS:
            x = utils.resample_oct(x, FS, fs_sig)
            y = utils.resample_oct(y, FS, fs_sig)
    
        # Remove silent frames
        x, y = utils.remove_silent_frames(x, y, DYN_RANGE, N_FRAME, int(N_FRAME/2))
    
        # Take STFT
        x_spec = utils.stft(x, N_FRAME, NFFT, overlap=2).transpose()
        y_spec = utils.stft(y, N_FRAME, NFFT, overlap=2).transpose()
    
        # Ensure at least 30 frames for intermediate intelligibility
        if x_spec.shape[-1] < N:
            warnings.warn('Not enough STFT frames to compute intermediate '
                          'intelligibility measure after removing silent '
                          'frames. Returning 1e-5. Please check you wav files',
                          RuntimeWarning)
            return 1e-5
    
        # Apply OB matrix to the spectrograms as in Eq. (1)
        x_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(x_spec))))
        y_tob = np.sqrt(np.matmul(OBM, np.square(np.abs(y_spec))))
    
        # Take segments of x_tob, y_tob
        x_segments = np.array(
            [x_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)])
        y_segments = np.array(
            [y_tob[:, m - N:m] for m in range(N, x_tob.shape[1] + 1)])
    
        if extended:
            x_n = utils.row_col_normalize(x_segments)
            y_n = utils.row_col_normalize(y_segments)
            return np.sum(x_n * y_n / N) / x_n.shape[0]
    
        else:
            # Find normalization constants and normalize
            normalization_consts = (
                np.linalg.norm(x_segments, axis=2, keepdims=True) /
                (np.linalg.norm(y_segments, axis=2, keepdims=True) + utils.EPS))
            y_segments_normalized = y_segments * normalization_consts
    
            # Clip as described in [1]
            clip_value = 10 ** (-BETA / 20)
            y_primes = np.minimum(
                y_segments_normalized, x_segments * (1 + clip_value))
    
            # Subtract mean vectors
            y_primes = y_primes - np.mean(y_primes, axis=2, keepdims=True)
            x_segments = x_segments - np.mean(x_segments, axis=2, keepdims=True)
    
            # Divide by their norms
            y_primes /= (np.linalg.norm(y_primes, axis=2, keepdims=True) + utils.EPS)
            x_segments /= (np.linalg.norm(x_segments, axis=2, keepdims=True) + utils.EPS)
            # Find a matrix with entries summing to sum of correlations of vectors
            correlations_components = y_primes * x_segments
    
            # J, M as in [1], eq.6
            J = x_segments.shape[0]
            M = x_segments.shape[1]
    
            # Find the mean of all correlations
            d = np.sum(correlations_components) / (J * M)
            return d