[feat][assistant][I3J6V0] add new audio operator TrebleBiquad

This commit is contained in:
zx 2021-08-05 15:37:17 +08:00
parent 5debb583e2
commit 3da0f0f487
15 changed files with 537 additions and 1 deletions

View File

@ -39,6 +39,7 @@
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/vol_ir.h"
namespace mindspore {
@ -385,6 +386,23 @@ std::shared_ptr<TensorOperation> TimeStretch::Parse() {
return std::make_shared<TimeStretchOperation>(data_->hop_length_, data_->n_freq_, data_->fixed_rate_);
}
// TrebleBiquad Transform Operation.
struct TrebleBiquad::Data {
Data(int32_t sample_rate, float gain, float central_freq, float Q)
: sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {}
int32_t sample_rate_;
float gain_;
float central_freq_;
float Q_;
};
TrebleBiquad::TrebleBiquad(int32_t sample_rate, float gain, float central_freq, float Q)
: data_(std::make_shared<Data>(sample_rate, gain, central_freq, Q)) {}
std::shared_ptr<TensorOperation> TrebleBiquad::Parse() {
return std::make_shared<TrebleBiquadOperation>(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_);
}
// Vol Transform Operation.
struct Vol::Data {
Data(float gain, GainType gain_type) : gain_(gain), gain_type_(gain_type) {}

View File

@ -43,6 +43,7 @@
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/vol_ir.h"
namespace mindspore {
@ -317,6 +318,17 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
TrebleBiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::TrebleBiquadOperation, TensorOperation, std::shared_ptr<audio::TrebleBiquadOperation>>(
*m, "TrebleBiquadOperation")
.def(py::init([](int32_t sample_rate, float gain, float central_freq, float Q) {
auto treble_biquad = std::make_shared<audio::TrebleBiquadOperation>(sample_rate, gain, central_freq, Q);
THROW_IF_ERROR(treble_biquad->ValidateParams());
return treble_biquad;
}));
}));
PYBIND_REGISTER(VolOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::VolOperation, TensorOperation, std::shared_ptr<audio::VolOperation>>(
*m, "VolOperation")

View File

@ -25,6 +25,7 @@ add_library(audio-ir-kernels OBJECT
riaa_biquad_ir.cc
time_masking_ir.cc
time_stretch_ir.cc
treble_biquad_ir.cc
vol_ir.cc
)

View File

@ -0,0 +1,49 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
#include "minddata/dataset/audio/ir/validators.h"
#include "minddata/dataset/audio/kernels/treble_biquad_op.h"
namespace mindspore {
namespace dataset {
namespace audio {
TrebleBiquadOperation::TrebleBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q)
: sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {}
Status TrebleBiquadOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateScalar("TrebleBiquad", "Q", Q_, {0, 1.0}, true, false));
RETURN_IF_NOT_OK(ValidateScalarNotZero("TrebleBiquad", "sample_rate", sample_rate_));
return Status::OK();
}
std::shared_ptr<TensorOp> TrebleBiquadOperation::Build() {
std::shared_ptr<TrebleBiquadOp> tensor_op = std::make_shared<TrebleBiquadOp>(sample_rate_, gain_, central_freq_, Q_);
return tensor_op;
}
Status TrebleBiquadOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sample_rate"] = sample_rate_;
args["gain"] = gain_;
args["central_freq"] = central_freq_;
args["Q"] = Q_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TREBLE_BIQUAD_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TREBLE_BIQUAD_IR_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kTrebleBiquadOperation[] = "TrebleBiquad";
class TrebleBiquadOperation : public TensorOperation {
public:
TrebleBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q);
~TrebleBiquadOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kTrebleBiquadOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
int32_t sample_rate_;
float gain_;
float central_freq_;
float Q_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TREBLE_BIQUAD_IR_H_

View File

@ -26,6 +26,7 @@ add_library(audio-kernels OBJECT
riaa_biquad_op.cc
time_masking_op.cc
time_stretch_op.cc
treble_biquad_op.cc
vol_op.cc
)

View File

@ -0,0 +1,68 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/treble_biquad_op.h"
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
TrebleBiquadOp::TrebleBiquadOp(int32_t sample_rate, float gain, float central_freq, float Q)
: sample_rate_(sample_rate), gain_(gain), central_freq_(central_freq), Q_(Q) {}
Status TrebleBiquadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
TensorShape input_shape = input->shape();
// check input tensor dimension, it should be greater than 0.
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "TrebleBiquad: input tensor is not in shape of <..., time>.");
// check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType(DataType::DE_FLOAT32) ||
input->type() == DataType(DataType::DE_FLOAT16) ||
input->type() == DataType(DataType::DE_FLOAT64),
"TrebleBiquad: input tensor type should be float, but got: " + input->type().ToString());
// computer a0, a1, a2, b0, b1, b2
float w0 = 2 * PI * central_freq_ / sample_rate_;
float alpha = sin(w0) / 2 / Q_;
// for peaking and shelving EQ filters only
float attenuation = exp(gain_ / 40 * log(10));
// temp1, temp2, temp3 are the intermediate variable used to solve for a and b.
float temp1 = 2 * sqrt(attenuation) * alpha;
float temp2 = (attenuation - 1) * cos(w0);
float temp3 = (attenuation + 1) * cos(w0);
float b0 = attenuation * ((attenuation + 1) + temp2 + temp1);
float b1 = -2 * attenuation * ((attenuation - 1) + temp3);
float b2 = attenuation * ((attenuation + 1) + temp2 - temp1);
float a0 = (attenuation + 1) - temp2 + temp1;
float a1 = 2 * ((attenuation - 1) - temp3);
float a2 = (attenuation + 1) - temp2 - temp1;
// use Biquad function
if (input->type() == DataType(DataType::DE_FLOAT32)) {
return Biquad(input, output, static_cast<float>(b0), static_cast<float>(b1), static_cast<float>(b2),
static_cast<float>(a0), static_cast<float>(a1), static_cast<float>(a2));
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
return Biquad(input, output, static_cast<double>(b0), static_cast<double>(b1), static_cast<double>(b2),
static_cast<double>(a0), static_cast<double>(a1), static_cast<double>(a2));
} else {
return Biquad(input, output, static_cast<float16>(b0), static_cast<float16>(b1), static_cast<float16>(b2),
static_cast<float16>(a0), static_cast<float16>(a1), static_cast<float16>(a2));
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TREBLE_BIQUAD_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TREBLE_BIQUAD_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class TrebleBiquadOp : public TensorOp {
public:
TrebleBiquadOp(int32_t sample_rate, float gain, float central_freq, float Q);
~TrebleBiquadOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kTrebleBiquadOp; }
private:
int32_t sample_rate_;
float gain_;
float central_freq_;
float Q_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TREBLE_BIQUAD_OP_H_

View File

@ -556,6 +556,29 @@ class TimeStretch final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Design a treble tone-control effect.
class TrebleBiquad final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
/// \param[in] gain Desired gain at the boost (or attenuation) in dB.
/// \param[in] central_freq Central frequency (in Hz) (Default: 3000).
/// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (Default: 0.707).
TrebleBiquad(int32_t sample_rate, float gain, float central_freq = 3000, float Q = 0.707);
/// \brief Destructor.
~TrebleBiquad() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Vol TensorTransform.
/// \notes Add a volume to an waveform.
class Vol final : public TensorTransform {

View File

@ -165,6 +165,7 @@ constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
constexpr char kRiaaBiquadOp[] = "RiaaBiquadOp";
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
constexpr char kTimeStretchOp[] = "TimeStretchOp";
constexpr char kTrebleBiquadOp[] = "TrebleBiquadOp";
constexpr char kVolOp[] = "VolOp";
// data

View File

@ -28,7 +28,7 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_riaa_biquad, \
check_time_stretch, check_vol
check_time_stretch, check_treble_biquad, check_vol
class AudioTensorOperation(TensorOperation):
@ -714,6 +714,36 @@ class TimeStretch(AudioTensorOperation):
return cde.TimeStretchOperation(self.hop_length, self.n_freq, self.fixed_rate)
class TrebleBiquad(AudioTensorOperation):
"""
Design a treble tone-control effect. Similar to SoX implementation.
Args:
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
gain (float): Desired gain at the boost (or attenuation) in dB.
central_freq (float, optional): Central frequency (in Hz) (default=3000).
Q(float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
Examples:
>>> import numpy as np
>>>
>>> waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
>>> transforms = [audio.TrebleBiquad(44100, 200.0)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_treble_biquad
def __init__(self, sample_rate, gain, central_freq=3000, Q=0.707):
self.sample_rate = sample_rate
self.gain = gain
self.central_freq = central_freq
self.Q = Q
def parse(self):
return cde.TrebleBiquadOperation(self.sample_rate, self.gain, self.central_freq, self.Q)
DE_C_GAINTYPE_TYPE = {GainType.AMPLITUDE: cde.GainType.DE_GAINTYPE_AMPLITUDE,
GainType.POWER: cde.GainType.DE_GAINTYPE_POWER,
GainType.DB: cde.GainType.DE_GAINTYPE_DB}

View File

@ -329,6 +329,22 @@ def check_time_stretch(method):
return new_method
def check_treble_biquad(method):
"""Wrapper method to check the parameters of TrebleBiquad."""
@wraps(method)
def new_method(self, *args, **kwargs):
[sample_rate, gain, central_freq, Q], _ = parse_user_args(
method, *args, **kwargs)
check_biquad_sample_rate(sample_rate)
check_biquad_gain(gain)
check_biquad_central_freq(central_freq)
check_biquad_Q(Q)
return method(self, *args, **kwargs)
return new_method
def check_masking(method):
"""Wrapper method to check the parameters of time_masking and FrequencyMasking"""

View File

@ -331,6 +331,74 @@ TEST_F(MindDataTestPipeline, TestTimeStretchPipelineWrongArgs) {
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestTrebleBiquadBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTrebleBiquadBasic.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto TrebleBiquadOp = audio::TrebleBiquad(44100, 200.0, 2000, 0.604);
ds = ds->Map({TrebleBiquadOp});
EXPECT_NE(ds, nullptr);
// Filtered waveform by treblebiquad
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, 200};
int i = 0;
while (row.size() != 0) {
auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTrebleBiquadWrongArg) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTrebleBiquadWrongArg.";
std::shared_ptr<SchemaObj> schema = Schema();
// Original waveform
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
std::shared_ptr<Dataset> ds01;
std::shared_ptr<Dataset> ds02;
EXPECT_NE(ds, nullptr);
// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
auto treble_biquad_op_01 = audio::TrebleBiquad(0, 200);
ds01 = ds->Map({treble_biquad_op_01});
EXPECT_NE(ds01, nullptr);
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
EXPECT_EQ(iter01, nullptr);
// Check Q_
MS_LOG(INFO) << "Q_ is zero.";
auto treble_biquad_op_02 = audio::TrebleBiquad(44100, 200.0, 3000.0, 0);
ds02 = ds->Map({treble_biquad_op_02});
EXPECT_NE(ds02, nullptr);
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
EXPECT_EQ(iter02, nullptr);
}
TEST_F(MindDataTestPipeline, TestVolPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVolPipeline.";
// Original waveform

View File

@ -914,6 +914,48 @@ TEST_F(MindDataTestExecute, TestRiaaBiquadWithWrongArg) {
EXPECT_FALSE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestTrebleBiquadWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTrebleBiquadWithEager.";
// Original waveform
std::vector<float> labels = {
3.156, 5.690, 1.362, 1.093,
5.782, 6.381, 5.982, 3.098,
1.222, 6.027, 3.909, 7.993,
4.324, 1.092, 5.093, 0.991,
1.099, 4.092, 8.111, 6.666};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> treble_biquad_01 = std::make_shared<audio::TrebleBiquad>(44100, 200);
mindspore::dataset::Execute Transform01({treble_biquad_01});
// Filtered waveform by treblebiquad
EXPECT_OK(Transform01(input_01, &input_01));
}
TEST_F(MindDataTestExecute, TestTrebleBiquadWithWrongArg) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTrebleBiquadWithWrongArg.";
std::vector<double> labels = {
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
auto input01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
auto input02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
std::shared_ptr<TensorTransform> treble_biquad_op01 = std::make_shared<audio::TrebleBiquad>(0.0, 200.0);
mindspore::dataset::Execute Transform01({treble_biquad_op01});
EXPECT_ERROR(Transform01(input01, &input01));
//Check Q
MS_LOG(INFO) << "Q is zero.";
std::shared_ptr<TensorTransform> treble_biquad_op02 = std::make_shared<audio::TrebleBiquad>(44100, 200.0, 3000.0, 0.0);
mindspore::dataset::Execute Transform02({treble_biquad_op02});
EXPECT_ERROR(Transform02(input02, &input02));
}
TEST_F(MindDataTestExecute, TestLFilterWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestLFilterWithEager.";
// Original waveform

View File

@ -0,0 +1,100 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
data_expected[greater], data_me[greater], error[greater])
def test_treble_biquad_eager():
""" mindspore eager mode normal testcase:treble_biquad op"""
# Original waveform
waveform = np.array([[0.234, 1.873, 0.786], [-2.673, 0.886, 1.666]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[1., 1., -1.], [-1., 1., -1.]], dtype=np.float64)
treble_biquad_op = audio.TrebleBiquad(44100, 200.0)
# Filtered waveform by treblebiquad
output = treble_biquad_op(waveform)
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_treble_biquad_pipeline():
""" mindspore pipeline mode normal testcase:treble_biquad op"""
# Original waveform
waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[1., -1., 1.], [1., -1., 1.]], dtype=np.float64)
dataset = ds.NumpySlicesDataset(waveform, ["waveform"], shuffle=False)
treble_biquad_op = audio.TrebleBiquad(44100, 200.0)
# Filtered waveform by treblebiquad
dataset = dataset.map(input_columns=["waveform"], operations=treble_biquad_op)
i = 0
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
count_unequal_element(expect_waveform[i, :], item['waveform'], 0.0001, 0.0001)
i += 1
def test_treble_biquad_invalid_input():
def test_invalid_input(test_name, sample_rate, gain, central_freq, Q, error, error_msg):
logger.info("Test TrebleBiquad with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
audio.TrebleBiquad(sample_rate, gain, central_freq, Q)
assert error_msg in str(error_info.value)
test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 0.2, 3000, 0.707, TypeError,
"Argument sample_rate with value 44100.5 is not of type [<class 'int'>],"
" but got <class 'float'>.")
test_invalid_input("invalid sample_rate parameter type as a String", "44100", 0.2, 3000, 0.707, TypeError,
"Argument sample_rate with value 44100 is not of type [<class 'int'>], "
"but got <class 'str'>.")
test_invalid_input("invalid gain parameter type as a String", 4410, "0", 3000, 0.707, TypeError,
"Argument gain with value 0 is not of type [<class 'float'>, <class 'int'>],"
+ " but got <class 'str'>.")
test_invalid_input("invalid central_rate parameter value", 4410, 0.2, None, 0.707, TypeError,
"Argument central_freq with value None is not of type [<class 'float'>, <class 'int'>]," +
" but got <class 'NoneType'>.")
test_invalid_input("invalid Q parameter type as a String", 4410, 0.2, 3000, "0", TypeError,
"Argument Q with value 0 is not of type [<class 'float'>, <class 'int'>]," +
" but got <class 'str'>.")
test_invalid_input("invalid sample_rate parameter value", 0, 0.2, 3000, 0.707, ValueError,
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 0.2, 3000, 0.707, ValueError,
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
test_invalid_input("invalid gain parameter value", 44100, 32434324324234321, 3000, 0.707, ValueError,
"Input gain is not within the required interval of [-16777216, 16777216].")
test_invalid_input("invalid central_freq parameter value", 44100, 0.2, 32434324324234321, 0.707, ValueError,
"Input central_freq is not within the required interval of [-16777216, 16777216].")
test_invalid_input("invalid Q parameter value", 44100, 0.2, 3000, 1.707, ValueError,
"Input Q is not within the required interval of (0, 1].")
test_invalid_input("invalid Q parameter value", 44100, 0.2, 3000, 0, ValueError,
"Input Q is not within the required interval of (0, 1].")
if __name__ == "__main__":
test_treble_biquad_eager()
test_treble_biquad_pipeline()
test_treble_biquad_invalid_input()