!22893 [assistant][ops] Add new data operator CreateFbMatrix

Merge pull request !22893 from TR-nbu/CreateFbMatrix
This commit is contained in:
i-robot 2022-01-17 08:33:48 +00:00 committed by Gitee
commit 19a9ad6f38
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 506 additions and 1 deletions

View File

@ -464,6 +464,30 @@ Magphase::Magphase(float power) : data_(std::make_shared<Data>(power)) {}
std::shared_ptr<TensorOperation> Magphase::Parse() { return std::make_shared<MagphaseOperation>(data_->power_); }
// MelscaleFbanks Function.
Status MelscaleFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels, int32_t sample_rate,
NormType norm, MelType mel_type) {
RETURN_UNEXPECTED_IF_NULL(output);
CHECK_FAIL_RETURN_UNEXPECTED(n_freqs > 0,
"MelscaleFbanks: n_freqs must be greater than 0, got: " + std::to_string(n_freqs));
CHECK_FAIL_RETURN_UNEXPECTED(f_min >= 0, "MelscaleFbanks: f_min must be non negative, got: " + std::to_string(f_min));
CHECK_FAIL_RETURN_UNEXPECTED(f_max > 0,
"MelscaleFbanks: f_max must be greater than 0, got: " + std::to_string(f_max));
CHECK_FAIL_RETURN_UNEXPECTED(n_mels > 0,
"MelscaleFbanks: n_mels must be greater than 0, got: " + std::to_string(n_mels));
CHECK_FAIL_RETURN_UNEXPECTED(
sample_rate > 0, "MelscaleFbanks: sample_rate must be greater than 0, got: " + std::to_string(sample_rate));
CHECK_FAIL_RETURN_UNEXPECTED(f_max > f_min, "MelscaleFbanks: f_max must be greater than f_min, got: f_min = " +
std::to_string(f_min) + ", while f_max = " + std::to_string(f_max));
std::shared_ptr<dataset::Tensor> fb;
RETURN_IF_NOT_OK(CreateFbanks(&fb, n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type));
CHECK_FAIL_RETURN_UNEXPECTED(fb->HasData(),
"MelscaleFbanks: get an empty tensor with shape " + fb->shape().ToString());
*output = mindspore::MSTensor(std::make_shared<DETensor>(fb));
return Status::OK();
}
// MuLawDecoding Transform Operation.
struct MuLawDecoding::Data {
explicit Data(int32_t quantization_channels) : quantization_channels_(quantization_channels) {}

View File

@ -30,6 +30,30 @@ PYBIND_REGISTER(CreateDct, 1, ([](py::module *m) {
}));
}));
PYBIND_REGISTER(MelscaleFbanks, 1, ([](py::module *m) {
(void)m->def(
"MelscaleFbanks", ([](int32_t n_freqs, float f_min, float f_max, int32_t n_mels,
int32_t sample_rate, NormType norm, MelType mel_type) {
std::shared_ptr<Tensor> fb;
THROW_IF_ERROR(CreateFbanks(&fb, n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type));
return fb;
}));
}));
PYBIND_REGISTER(MelType, 0, ([](const py::module *m) {
(void)py::enum_<MelType>(*m, "MelType", py::arithmetic())
.value("DE_MELTYPE_HTK", MelType::kHtk)
.value("DE_MELTYPE_SLANEY", MelType::kSlaney)
.export_values();
}));
PYBIND_REGISTER(NormType, 0, ([](const py::module *m) {
(void)py::enum_<NormType>(*m, "NormType", py::arithmetic())
.value("DE_NORMTYPE_NONE", NormType::kNone)
.value("DE_NORMTYPE_SLANEY", NormType::kSlaney)
.export_values();
}));
PYBIND_REGISTER(NormMode, 0, ([](const py::module *m) {
(void)py::enum_<NormMode>(*m, "NormMode", py::arithmetic())
.value("DE_NORMMODE_NONE", NormMode::kNone)

View File

@ -222,6 +222,170 @@ Status Phase(const std::shared_ptr<Tensor> &angle_0, const std::shared_ptr<Tenso
return Status::OK();
}
template <typename T>
Status CreateTriangularFilterbank(std::shared_ptr<Tensor> *output, const std::shared_ptr<Tensor> &all_freqs,
const std::shared_ptr<Tensor> &f_pts) {
// calculate the difference between each mel point and each stft freq point in hertz.
std::vector<T> f_diff;
auto iter_fpts1 = f_pts->begin<T>();
auto iter_fpts2 = f_pts->begin<T>();
++iter_fpts2;
for (size_t i = 1; i < f_pts->Size(); i++) {
f_diff.push_back(*iter_fpts2 - *iter_fpts1);
++iter_fpts2;
++iter_fpts1;
}
std::vector<T> slopes;
TensorShape slopes_shape({all_freqs->Size(), f_pts->Size()});
auto iter_all_freq = all_freqs->begin<T>();
for (; iter_all_freq != all_freqs->end<T>(); ++iter_all_freq) {
auto iter_f_pts = f_pts->begin<T>();
for (; iter_f_pts != f_pts->end<T>(); ++iter_f_pts) {
slopes.push_back(*iter_f_pts - *iter_all_freq);
}
}
// calculate up and down slopes for creating overlapping triangles.
std::vector<T> down_slopes;
TensorShape down_slopes_shape({all_freqs->Size(), f_pts->Size() - 2});
for (size_t row = 0; row < down_slopes_shape[0]; row++)
for (size_t col = 0; col < down_slopes_shape[1]; col++) {
down_slopes.push_back(-slopes[col + row * f_pts->Size()] / f_diff[col]);
}
std::vector<T> up_slopes;
TensorShape up_slopes_shape({all_freqs->Size(), f_pts->Size() - 2});
for (size_t row = 0; row < up_slopes_shape[0]; row++)
for (size_t col = 2; col < f_pts->Size(); col++) {
up_slopes.push_back(slopes[col + row * f_pts->Size()] / f_diff[col - 1]);
}
// clip the value of triangles and save into fb.
std::vector<T> fb;
TensorShape fb_shape({all_freqs->Size(), f_pts->Size() - 2});
for (size_t i = 0; i < down_slopes.size(); i++) {
fb.push_back(std::max(0.0f, std::min(down_slopes[i], up_slopes[i])));
}
std::shared_ptr<Tensor> fb_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(fb, fb_shape, &fb_tensor));
*output = fb_tensor;
return Status::OK();
}
Status CreateFbanks(std::shared_ptr<Tensor> *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels,
int32_t sample_rate, NormType norm, MelType mel_type) {
// min_log_hz, min_log_mel, logstep and f_sp are the const of the mel value equation.
const double min_log_hz = 1000.0;
const double min_log_mel = 1000 / (200.0 / 3);
const double logstep = log2(6.4) / 27.0;
const double f_sp = 200.0 / 3;
// hez_to_mel_c and mel_to_hz_c are the const coefficient of mel frequency cepstrum.
const double hz_to_mel_c = 2595.0;
const double mel_to_hz_c = 700.0;
// all_freqs is equivalent filterbank construction.
std::shared_ptr<Tensor> all_freqs;
// the sampling frequency is at least twice the highest frequency of the signal.
const double signal_times = 2;
RETURN_IF_NOT_OK(Linspace<float>(&all_freqs, 0, sample_rate / signal_times, n_freqs));
// calculate mel value by f_min and f_max.
double m_min = 0.0;
double m_max = 0.0;
if (mel_type == MelType::kHtk) {
m_min = hz_to_mel_c * log10(1.0 + (f_min / mel_to_hz_c));
m_max = hz_to_mel_c * log10(1.0 + (f_max / mel_to_hz_c));
} else {
m_min = (f_min - 0.0) / f_sp;
m_max = (f_max - 0.0) / f_sp;
if (m_min >= min_log_hz) {
m_min = min_log_mel + log2(f_min / min_log_hz) / logstep;
}
if (m_max >= min_log_hz) {
m_max = min_log_mel + log2(f_max / min_log_hz) / logstep;
}
}
// m_pts is mel value sequence in linspace of (m_min, m_max).
std::shared_ptr<Tensor> m_pts;
const int32_t bias = 2;
RETURN_IF_NOT_OK(Linspace<float>(&m_pts, m_min, m_max, n_mels + bias));
// f_pts saves hertz(mel) though 700.0 * (10.0 **(mel/ 2595.0) - 1.).
std::shared_ptr<Tensor> f_pts;
const double htk_mel_c = 10.0;
RETURN_IF_NOT_OK(Tensor::CreateEmpty(m_pts->shape(), DataType(DataType::DE_FLOAT32), &f_pts));
if (mel_type == MelType::kHtk) {
auto iter_f = f_pts->begin<float>();
auto iter_m = m_pts->begin<float>();
for (; iter_m != m_pts->end<float>(); ++iter_m) {
*iter_f = mel_to_hz_c * (pow(htk_mel_c, *iter_m / hz_to_mel_c) - 1.0);
++iter_f;
}
} else {
auto iter_f = f_pts->begin<float>();
auto iter_m = m_pts->begin<float>();
for (; iter_m != m_pts->end<float>(); iter_m++, iter_f++) {
*iter_f = f_sp * (*iter_m);
}
iter_f = f_pts->begin<float>();
iter_m = m_pts->begin<float>();
for (; iter_m != m_pts->end<float>(); iter_m++, iter_f++) {
if (*iter_m >= min_log_mel) {
*iter_f = min_log_hz * exp(logstep * (*iter_m - min_log_mel));
}
}
}
// create filterbank
TensorShape fb_shape({all_freqs->Size(), f_pts->Size() - 2});
std::shared_ptr<Tensor> fb;
RETURN_IF_NOT_OK(CreateTriangularFilterbank<float>(&fb, all_freqs, f_pts));
// normalize with Slaney
std::vector<float> enorm;
if (norm == NormType::kSlaney) {
auto iter_f_pts_0 = f_pts->begin<float>();
auto iter_f_pts_2 = f_pts->begin<float>();
iter_f_pts_2++;
iter_f_pts_2++;
for (; iter_f_pts_2 != f_pts->end<float>(); iter_f_pts_0++, iter_f_pts_2++) {
enorm.push_back(2.0f / (*iter_f_pts_2 - *iter_f_pts_0));
}
auto iter_fb = fb->begin<float>();
for (size_t row = 0; row < fb_shape[0]; row++) {
for (size_t col = 0; col < fb_shape[1]; col++) {
*iter_fb = (*iter_fb) * enorm[col];
iter_fb++;
}
}
enorm.clear();
}
// anomaly detection.
auto iter_fb = fb->begin<float>();
std::vector<float> max_val(fb_shape[1], 0);
for (size_t row = 0; row < fb_shape[0]; row++) {
for (size_t col = 0; col < fb_shape[1]; col++) {
max_val[col] = std::max(max_val[col], *iter_fb);
iter_fb++;
}
}
for (size_t col = 0; col < fb_shape[1]; col++) {
CHECK_FAIL_RETURN_UNEXPECTED(
max_val[col] >= 1e-8,
"MelscaleFbanks: at least one mel filterbank is all zeros, check if the value for 'n_mels' " +
std::to_string(n_mels) + " is set too high or the value for 'n_freqs' " + std::to_string(n_freqs) +
" is set too low.");
}
*output = fb;
return Status::OK();
}
/// \brief Calculate magnitude.
/// \param[in] alphas - The alphas.
/// \param[in] abs_0 - The norm.

View File

@ -380,6 +380,19 @@ Status RandomMaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr
Status MaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t mask_width,
int32_t mask_start, float mask_value, int32_t axis);
/// \brief Create a frequency transformation matrix with shape (n_freqs, n_mels).
/// \param output Tensor of the frequency transformation matrix.
/// \param n_freqs: Number of frequency.
/// \param f_min: Minimum of frequency in Hz.
/// \param f_max: Maximum of frequency in Hz.
/// \param n_mels: Number of mel filterbanks.
/// \param sample_rate: Sample rate.
/// \param norm: Norm to use, can be NormTyppe::kSlaney or NormTyppe::kNone.
/// \param mel_type: Scale to use, can be MelTyppe::kSlaney or MelTyppe::kHtk.
/// \return Status code.
Status CreateFbanks(std::shared_ptr<Tensor> *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels,
int32_t sample_rate, NormType norm, MelType mel_type);
/// \brief Create a DCT transformation matrix with shape (n_mels, n_mfcc), normalized depending on norm.
/// \param n_mfcc: Number of mfc coefficients to retain, the value must be greater than 0.
/// \param n_mels: Number of mel filterbanks, the value must be greater than 0.

View File

@ -616,6 +616,19 @@ class MS_API Magphase final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Create a frequency transformation matrix with shape (n_freqs, n_mels).
/// \param[in] output Tensor of the frequency transformation matrix.
/// \param[in] n_freqs Number of frequencies to highlight/apply.
/// \param[in] f_min Minimum frequency (Hz).
/// \param[in] f_max Maximum frequency (Hz).
/// \param[in] n_mels Number of mel filterbanks.
/// \param[in] sample_rate Sample rate of the audio waveform.
/// \param[in] norm Norm to use, can be NormType::kNone or NormType::kSlaney (Default: NormType::kNone).
/// \param[in] mel_type Scale to use, can be MelType::kHtk or MelType::kSlaney (Default: MelType::kHtz).
/// \return Status code.
Status MS_API MelscaleFbanks(MSTensor *output, int32_t n_freqs, float f_min, float f_max, int32_t n_mels,
int32_t sample_rate, NormType norm = NormType::kNone, MelType mel_type = MelType::kHtk);
/// \brief MuLawDecoding TensorTransform.
/// \note Decode mu-law encoded signal.
class MS_API MuLawDecoding final : public TensorTransform {

View File

@ -84,6 +84,12 @@ enum class MS_API NormMode {
kOrtho = 1 ///< Ortho type norm.
};
/// \brief Possible options for norm in MelscaleFbanks.
enum class MS_API NormType {
kNone = 0, ///< None type norm.
kSlaney = 1, ///< Slaney type norm.
};
/// \brief The mode for manual offload.
enum class MS_API ManualOffloadMode {
kUnspecified, ///< Not set, will use auto_offload setting instead.
@ -98,6 +104,12 @@ enum class MS_API MapTargetDevice {
kAscend310 ///< Ascend310 Device.
};
/// \brief Possible options for mel_type in MelscaleFbanks.
enum class MS_API MelType {
kHtk = 0, ///< Htk scale type.
kSlaney = 1, ///< Slaney scale type.
};
/// \brief The initial type of tensor implementation.
enum class MS_API TensorImpl {
kNone, ///< None type tensor.

View File

@ -18,7 +18,8 @@ Enum for audio ops.
from enum import Enum
import mindspore._c_dataengine as cde
from mindspore.dataset.core.validator_helpers import check_non_negative_float32, check_non_negative_int32, check_pos_float32, check_pos_int32, \
type_check
class DensityFunction(str, Enum):
"""
@ -110,6 +111,84 @@ class ScaleType(str, Enum):
MAGNITUDE: str = "magnitude"
class NormType(str, Enum):
"""
Norm Types.
Possible enumeration values are: NormType.NONE, NormType.SLANEY.
- NormType.NONE: norm the input data with none.
- NormType.SLANEY: norm the input data with slaney.
"""
NONE: str = "none"
SLANEY: str = "slaney"
DE_C_NORMTYPE_TYPE = {NormType.NONE: cde.NormType.DE_NORMTYPE_NONE,
NormType.SLANEY: cde.NormType.DE_NORMTYPE_SLANEY}
class MelType(str, Enum):
"""
Mel Types.
Possible enumeration values are: MelType.HTK, MelType.SLANEY.
- MelType.NONE: scale the input data with htk.
- MelType.ORTHO: scale the input data with slaney.
"""
HTK: str = "htk"
SLANEY: str = "slaney"
DE_C_MELTYPE_TYPE = {MelType.HTK: cde.MelType.DE_MELTYPE_HTK,
MelType.SLANEY: cde.MelType.DE_MELTYPE_SLANEY}
def melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm=NormType.NONE, mel_type=MelType.HTK):
"""
Create a frequency transformation matrix with shape (n_freqs, n_mels).
Args:
n_freqs (int): Number of frequency.
f_min (float): Minimum of frequency in Hz.
f_max (float): Maximum of frequency in Hz.
n_mels (int): Number of mel filterbanks.
sample_rate (int): Sample rate.
norm (NormType, optional): Norm to use, can be NormType.NONE or NormType.SLANEY (Default: NormType.NONE).
mel_type (MelType, optional): Scale to use, can be MelType.HTK or MelType.SLANEY (Default: NormType.SLANEY).
Returns:
numpy.ndarray, the frequency transformation matrix.
Examples:
>>> melscale_fbanks = audio.melscale_fbanks(n_freqs=4096, f_min=0, f_max=8000, n_mels=40, sample_rate=16000)
"""
type_check(n_freqs, (int,), "n_freqs")
check_non_negative_int32(n_freqs, "n_freqs")
type_check(f_min, (int, float,), "f_min")
check_non_negative_float32(f_min, "f_min")
type_check(f_max, (int, float,), "f_max")
check_pos_float32(f_max, "f_max")
if f_min > f_max:
raise ValueError(
"Input f_min should be no more than f_max, but got f_min: {0} and f_max: {1}.".format(f_min, f_max))
type_check(n_mels, (int,), "n_mels")
check_pos_int32(n_mels, "n_mels")
type_check(sample_rate, (int,), "sample_rate")
check_pos_int32(sample_rate, "sample_rate")
type_check(norm, (NormType,), "norm")
type_check(mel_type, (MelType,), "mel_type")
return cde.MelscaleFbanks(n_freqs, f_min, f_max, n_mels, sample_rate, DE_C_NORMTYPE_TYPE[norm],
DE_C_MELTYPE_TYPE[mel_type]).as_array()
class NormMode(str, Enum):
"""
Norm Types.

View File

@ -936,6 +936,34 @@ TEST_F(MindDataTestPipeline, TestHighpassBiquadWrongArgs) {
EXPECT_EQ(iter02, nullptr);
}
/// Feature: MelscaleFbanks.
/// Description: Test normal operation.
/// Expectation: As expected.
TEST_F(MindDataTestPipeline, TestMelscaleFbanksNormal) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-MelscaleFbanksNormal.";
mindspore::MSTensor output;
NormType norm = NormType::kSlaney;
MelType mel_type = MelType::kHtk;
Status s01 = audio::MelscaleFbanks(&output, 1024, 0, 1000, 40, 16000, norm, mel_type);
EXPECT_TRUE(s01.IsOk());
}
/// Feature: MelscaleFbanks.
/// Description: Test operation with invalid input.
/// Expectation: Throw exception as expected.
TEST_F(MindDataTestPipeline, TestMelscaleFbanksWithInvalidInput) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMelscaleFbanksWithInvalidInput.";
mindspore::MSTensor output;
MS_LOG(INFO) << "n_freqs is too low.";
NormType norm = NormType::kNone;
MelType mel_type = MelType::kHtk;
Status s01 = audio::MelscaleFbanks(&output, 1, 50, 1000, 20, 16000, norm, mel_type);
EXPECT_FALSE(s01.IsOk());
MS_LOG(INFO) << "n_mels is too high.";
Status s02 = audio::MelscaleFbanks(&output, 100, 50, 1000, 40, 16000, norm, mel_type);
EXPECT_FALSE(s02.IsOk());
}
TEST_F(MindDataTestPipeline, TestMuLawDecodingBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMuLawDecodingBasic.";

View File

@ -0,0 +1,148 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset.audio.utils as audio
from mindspore import log as logger
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def test_melscale_fbanks_normal():
"""
Feature: melscale_fbanks.
Description: Test normal operation with NormType.NONE and MelType.HTK.
Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
"""
expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000],
[0.5502, 0.0000, 0.0000, 0.0000],
[0.6898, 0.3102, 0.0000, 0.0000],
[0.0000, 0.9366, 0.0634, 0.0000],
[0.0000, 0.1924, 0.8076, 0.0000],
[0.0000, 0.0000, 0.4555, 0.5445],
[0.0000, 0.0000, 0.0000, 0.7247],
[0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
output = audio.melscale_fbanks(8, 2, 50, 4, 100, audio.NormType.NONE, audio.MelType.HTK)
count_unequal_element(expect, output, 0.0001, 0.0001)
def test_melscale_fbanks_none_slaney():
"""
Feature: melscale_fbanks.
Description: Test normal operation with NormType.NONE and MelType.SLANEY.
Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
"""
expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000],
[0.5357, 0.0000, 0.0000, 0.0000],
[0.7202, 0.2798, 0.0000, 0.0000],
[0.0000, 0.9762, 0.0238, 0.0000],
[0.0000, 0.2321, 0.7679, 0.0000],
[0.0000, 0.0000, 0.4881, 0.5119],
[0.0000, 0.0000, 0.0000, 0.7440],
[0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
output = audio.melscale_fbanks(8, 2, 50, 4, 100, audio.NormType.NONE, audio.MelType.SLANEY)
count_unequal_element(expect, output, 0.0001, 0.0001)
def test_melscale_fbanks_with_slaney_htk():
"""
Feature: melscale_fbanks.
Description: Test normal operation with NormType.SLANEY and MelType.HTK.
Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
"""
output = audio.melscale_fbanks(10, 0, 50, 5, 100, audio.NormType.SLANEY, audio.MelType.HTK)
expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0843, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0776, 0.0447, 0.0000, 0.0000, 0.0000],
[0.0000, 0.1158, 0.0055, 0.0000, 0.0000],
[0.0000, 0.0344, 0.0860, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0741, 0.0454, 0.0000],
[0.0000, 0.0000, 0.0000, 0.1133, 0.0053],
[0.0000, 0.0000, 0.0000, 0.0355, 0.0822],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0760],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
count_unequal_element(expect, output, 0.0001, 0.0001)
def test_melscale_fbanks_with_slaney_slaney():
"""
Feature: melscale_fbanks.
Description: Test normal operation with NormType.SLANEY and MelType.SLANEY.
Expectation: The output data is the same as the result of torchaudio.functional.melscale_fbanks.
"""
output = audio.melscale_fbanks(8, 2, 50, 4, 100, audio.NormType.SLANEY, audio.MelType.SLANEY)
expect = np.array([[0.0000, 0.0000, 0.0000, 0.0000],
[0.0558, 0.0000, 0.0000, 0.0000],
[0.0750, 0.0291, 0.0000, 0.0000],
[0.0000, 0.1017, 0.0025, 0.0000],
[0.0000, 0.0242, 0.0800, 0.0000],
[0.0000, 0.0000, 0.0508, 0.0533],
[0.0000, 0.0000, 0.0000, 0.0775],
[0.0000, 0.0000, 0.0000, 0.0000]], dtype=np.float64)
count_unequal_element(expect, output, 0.0001, 0.0001)
def test_melscale_fbanks_invalid_input():
"""
Feature: melscale_fbanks.
Description: Test operation with invalid input.
Expectation: Throw exception as expected.
"""
def test_invalid_input(test_name, n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type, error, error_msg):
logger.info("Test melscale_fbanks with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
audio.melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type)
print(error_info)
assert error_msg in str(error_info.value)
test_invalid_input("invalid n_freqs parameter Value", 99999999999, 0, 50, 5, 100, audio.NormType.NONE,
audio.MelType.HTK, ValueError, "n_freqs")
test_invalid_input("invalid n_freqs parameter type", 10.5, 0, 50, 5, 100, audio.NormType.NONE, audio.MelType.HTK,
TypeError, "n_freqs")
test_invalid_input("invalid f_min parameter type", 10, None, 50, 5, 100, audio.NormType.NONE, audio.MelType.HTK,
TypeError, "f_min")
test_invalid_input("invalid f_max parameter type", 10, 0, None, 5, 100, audio.NormType.NONE, audio.MelType.HTK,
TypeError, "f_max")
test_invalid_input("invalid n_mels parameter type", 10, 0, 50, 10.1, 100, audio.NormType.NONE, audio.MelType.HTK,
TypeError, "n_mels")
test_invalid_input("invalid n_mels parameter Value", 20, 0, 50, 999999999999, 100, audio.NormType.NONE,
audio.MelType.HTK, ValueError, "n_mels")
test_invalid_input("invalid sample_rate parameter type", 10, 0, 50, 5, 100.1, audio.NormType.NONE,
audio.MelType.HTK, TypeError, "sample_rate")
test_invalid_input("invalid sample_rate parameter Value", 20, 0, 50, 5, 999999999999, audio.NormType.NONE,
audio.MelType.HTK, ValueError, "sample_rate")
test_invalid_input("invalid norm parameter type", 10, 0, 50, 5, 100, None, audio.MelType.HTK,
TypeError, "norm")
test_invalid_input("invalid norm parameter type", 10, 0, 50, 5, 100, audio.NormType.SLANEY, None,
TypeError, "mel_type")
if __name__ == "__main__":
test_melscale_fbanks_normal()
test_melscale_fbanks_none_slaney()
test_melscale_fbanks_with_slaney_htk()
test_melscale_fbanks_with_slaney_slaney()
test_melscale_fbanks_invalid_input()