!17846 [assistant][ops]Add new operator HighpassBiquad

Merge pull request !17846 from YJfuel123/HighpassBiquad
This commit is contained in:
i-robot 2021-09-09 02:06:41 +00:00 committed by Gitee
commit 9ad56e86ab
15 changed files with 538 additions and 8 deletions

16
mindspore/ccsrc/minddata/dataset/api/audio.cc Normal file → Executable file
View File

@ -26,6 +26,7 @@
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
@ -182,6 +183,21 @@ std::shared_ptr<TensorOperation> FrequencyMasking::Parse() {
data_->mask_start_, data_->mask_value_);
}
// HighpassBiquad Transform Operation.
struct HighpassBiquad::Data {
Data(int32_t sample_rate, float cutoff_freq, float Q) : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {}
int32_t sample_rate_;
float cutoff_freq_;
float Q_;
};
HighpassBiquad::HighpassBiquad(int32_t sample_rate, float cutoff_freq, float Q)
: data_(std::make_shared<Data>(sample_rate, cutoff_freq, Q)) {}
std::shared_ptr<TensorOperation> HighpassBiquad::Parse() {
return std::make_shared<HighpassBiquadOperation>(data_->sample_rate_, data_->cutoff_freq_, data_->Q_);
}
// LowpassBiquad Transform Operation.
struct LowpassBiquad::Data {
Data(int32_t sample_rate, float cutoff_freq, float Q) : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {}

View File

@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "minddata/dataset/api/python/pybind_conversion.h"
@ -29,6 +30,7 @@
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
@ -155,6 +157,17 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
HighpassBiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::HighpassBiquadOperation, TensorOperation, std::shared_ptr<audio::HighpassBiquadOperation>>(
*m, "HighpassBiquadOperation")
.def(py::init([](float sample_rate, float cutoff_freq, float Q) {
auto highpass_biquad = std::make_shared<audio::HighpassBiquadOperation>(sample_rate, cutoff_freq, Q);
THROW_IF_ERROR(highpass_biquad->ValidateParams());
return highpass_biquad;
}));
}));
PYBIND_REGISTER(
LowpassBiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::LowpassBiquadOperation, TensorOperation, std::shared_ptr<audio::LowpassBiquadOperation>>(

View File

@ -12,7 +12,9 @@ add_library(audio-ir-kernels OBJECT
complex_norm_ir.cc
contrast_ir.cc
frequency_masking_ir.cc
highpass_biquad_ir.cc
lowpass_biquad_ir.cc
time_masking_ir.cc
time_stretch_ir.cc
)

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/validators.h"
#include "minddata/dataset/audio/kernels/highpass_biquad_op.h"
namespace mindspore {
namespace dataset {
namespace audio {
// HighpassBiquadOperation
HighpassBiquadOperation::HighpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q)
: sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {}
Status HighpassBiquadOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateScalarNotZero("HighpassBiquad", "sample_rate", sample_rate_));
RETURN_IF_NOT_OK(ValidateScalar("HighpassBiquad", "Q", Q_, {0, 1.0}, true, false));
return Status::OK();
}
std::shared_ptr<TensorOp> HighpassBiquadOperation::Build() {
std::shared_ptr<HighpassBiquadOp> tensor_op = std::make_shared<HighpassBiquadOp>(sample_rate_, cutoff_freq_, Q_);
return tensor_op;
}
Status HighpassBiquadOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sample_rate"] = sample_rate_;
args["cutoff_freq"] = cutoff_freq_;
args["Q"] = Q_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,59 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kHighpassBiquadOperation[] = "HighpassBiquad";
class HighpassBiquadOperation : public TensorOperation {
public:
HighpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q);
~HighpassBiquadOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kHighpassBiquadOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
int32_t sample_rate_;
float cutoff_freq_;
float Q_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_HIGHPASS_BIQUAD_IR_H_

View File

@ -13,6 +13,7 @@ add_library(audio-kernels OBJECT
complex_norm_op.cc
contrast_op.cc
frequency_masking_op.cc
highpass_biquad_op.cc
lowpass_biquad_op.cc
time_masking_op.cc
time_stretch_op.cc

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/highpass_biquad_op.h"
#include "minddata/dataset/audio/kernels/audio_utils.h"
namespace mindspore {
namespace dataset {
Status HighpassBiquadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
TensorShape input_shape = input->shape();
// check input tensor dimension, it should be greater than 0.
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "HighpassBiquad: input tensor in not in shape of <..., time>.");
// check input tensor type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64
CHECK_FAIL_RETURN_UNEXPECTED(
input->type() == DataType(DataType::DE_FLOAT32) || input->type() == DataType(DataType::DE_FLOAT16) ||
input->type() == DataType(DataType::DE_FLOAT64),
"HighpassBiquad: input tensor type should be float, but got: " + input->type().ToString());
double w0 = 2 * PI * cutoff_freq_ / sample_rate_;
double alpha = sin(w0) / 2 / Q_;
double b0 = (1 + cos(w0)) / 2;
double b1 = -1 - cos(w0);
double b2 = b0;
double a0 = 1 + alpha;
double a1 = -2 * cos(w0);
double a2 = 1 - alpha;
if (input->type() == DataType(DataType::DE_FLOAT32)) {
return Biquad(input, output, static_cast<float>(b0), static_cast<float>(b1), static_cast<float>(b2),
static_cast<float>(a0), static_cast<float>(a1), static_cast<float>(a2));
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
return Biquad(input, output, static_cast<double>(b0), static_cast<double>(b1), static_cast<double>(b2),
static_cast<double>(a0), static_cast<double>(a1), static_cast<double>(a2));
} else {
return Biquad(input, output, static_cast<float16>(b0), static_cast<float16>(b1), static_cast<float16>(b2),
static_cast<float16>(a0), static_cast<float16>(a1), static_cast<float16>(a2));
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,48 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_
#include <memory>
#include <string>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class HighpassBiquadOp : public TensorOp {
public:
HighpassBiquadOp(int32_t sample_rate, float cutoff_freq, float Q)
: sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {}
~HighpassBiquadOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kHighpassBiquadOp; };
protected:
int32_t sample_rate_;
float cutoff_freq_;
float Q_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_HIGHPASS_BIQUAD_OP_H_

View File

@ -257,6 +257,28 @@ class FrequencyMasking final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief HighpassBiquad TensorTransform. Apply highpass biquad filter on audio.
class HighpassBiquad final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
/// \param[in] cutoff_freq Filter cutoff frequency (in Hz).
/// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (Default: 0.707).
HighpassBiquad(int32_t sample_rate, float cutoff_freq, float Q = 0.707);
/// \brief Destructor.
~HighpassBiquad() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
class LowpassBiquad final : public TensorTransform {
public:

View File

@ -149,6 +149,7 @@ constexpr char kBassBiquadOp[] = "BassBiquadOp";
constexpr char kComplexNormOp[] = "ComplexNormOp";
constexpr char kContrastOp[] = "ContrastOp";
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp";
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
constexpr char kTimeStretchOp[] = "TimeStretchOp";

32
mindspore/dataset/audio/transforms.py Normal file → Executable file
View File

@ -25,8 +25,8 @@ import mindspore._c_dataengine as cde
from ..transforms.c_transforms import TensorOperation
from .utils import ScaleType
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_lowpass_biquad, \
check_masking, check_time_stretch
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, \
check_highpass_biquad, check_lowpass_biquad, check_masking, check_time_stretch
class AudioTensorOperation(TensorOperation):
@ -326,6 +326,34 @@ class FrequencyMasking(AudioTensorOperation):
self.mask_value)
class HighpassBiquad(AudioTensorOperation):
"""
Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args:
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
cutoff_freq (float): Filter cutoff frequency (in Hz).
Q (float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
Examples:
>>> import numpy as np
>>>
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
>>> transforms = [audio.HighpassBiquad(44100, 1500, 0.7)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_highpass_biquad
def __init__(self, sample_rate, cutoff_freq, Q=0.707):
self.sample_rate = sample_rate
self.cutoff_freq = cutoff_freq
self.Q = Q
def parse(self):
return cde.HighpassBiquadOperation(self.sample_rate, self.cutoff_freq, self.Q)
class LowpassBiquad(AudioTensorOperation):
"""
Design biquad lowpass filter and perform filtering. Similar to SoX implementation.

26
mindspore/dataset/audio/validators.py Normal file → Executable file
View File

@ -106,6 +106,26 @@ def check_band_biquad(method):
return new_method
def check_biquad_cutoff_freq(cutoff_freq):
"""Wrapper method to check the parameters of cutoff_freq."""
type_check(cutoff_freq, (float, int), "cutoff_freq")
check_float32(cutoff_freq, "cutoff_freq")
def check_highpass_biquad(method):
"""Wrapper method to check the parameters of HighpassBiquad."""
@wraps(method)
def new_method(self, *args, **kwargs):
[sample_rate, cutoff_freq, Q], _ = parse_user_args(method, *args, **kwargs)
check_biquad_sample_rate(sample_rate)
check_biquad_cutoff_freq(cutoff_freq)
check_biquad_Q(Q)
return method(self, *args, **kwargs)
return new_method
def check_allpass_biquad(method):
"""Wrapper method to check the parameters of AllpassBiquad."""
@ -168,12 +188,6 @@ def check_bass_biquad(method):
return new_method
def check_biquad_cutoff_freq(cutoff_freq):
"""Wrapper method to check the parameters of cutoff_freq."""
type_check(cutoff_freq, (float, int), "cutoff_freq")
check_float32(cutoff_freq, "cutoff_freq")
def check_contrast(method):
"""Wrapper method to check the parameters of Contrast."""

View File

@ -728,3 +728,58 @@ TEST_F(MindDataTestPipeline, TestContrastParamCheck) {
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
EXPECT_EQ(iter02, nullptr);
}
TEST_F(MindDataTestPipeline, TestHighpassBiquadSuccess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestHighpassBiquadSuccess.";
// Create an input tensor
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {1, 200}));
std::shared_ptr<Dataset> ds = RandomData(8, schema);
EXPECT_NE(ds, nullptr);
// Create a filter object
std::shared_ptr<TensorTransform> highpass_biquad(new audio::HighpassBiquad(44100, 3000.5, 0.707));
auto ds1 = ds->Map({highpass_biquad}, {"col1"}, {"audio"});
EXPECT_NE(ds1, nullptr);
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
EXPECT_NE(iter, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
uint64_t i = 0;
while (row.size() != 0) {
i++;
ASSERT_OK(iter->GetNextRow(&row));
}
EXPECT_EQ(i, 8);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestHighpassBiquadWrongArgs) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestHighpassBiquadWrongArgs.";
std::shared_ptr<SchemaObj> schema = Schema();
// Original waveform
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 10}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
std::shared_ptr<Dataset> ds01;
std::shared_ptr<Dataset> ds02;
EXPECT_NE(ds, nullptr);
// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
auto highpass_biquad_op_01 = audio::HighpassBiquad(0,200.0,0.7);
ds01 = ds->Map({highpass_biquad_op_01});
EXPECT_NE(ds01, nullptr);
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
EXPECT_EQ(iter01, nullptr);
// Check Q
MS_LOG(INFO) << "Q is zero.";
auto highpass_biquad_op_02 = audio::HighpassBiquad(44100,2000.0,0);
ds02 = ds->Map({highpass_biquad_op_02});
EXPECT_NE(ds02, nullptr);
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
EXPECT_EQ(iter02, nullptr);
}

View File

@ -733,3 +733,52 @@ TEST_F(MindDataTestExecute, TestContrastWithWrongArg) {
Status s01 = Transform01(input_02, &input_02);
EXPECT_FALSE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestHighpassBiquadEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestHighpassBiquadEager.";
int sample_rate = 44100;
float cutoff_freq = 3000.5;
float Q = 0.707;
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482,
0.3007, 0.9054, 0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
Tensor::CreateFromVector(test_vector, TensorShape({5,3}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
std::shared_ptr<TensorTransform> highpass_biquad(new audio::HighpassBiquad({sample_rate, cutoff_freq, Q}));
auto transform = Execute({highpass_biquad});
Status rc = transform({input}, &output);
ASSERT_TRUE(rc.IsOk());
}
TEST_F(MindDataTestExecute, TestHighpassBiquadParamCheckQ) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestHighpassBiquadParamCheckQ.";
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<float> test_vector = {0.6013, 0.8081, 0.6600, 0.4278, 0.4049, 0.0541, 0.8800, 0.7143, 0.0926,
0.3502, 0.6148, 0.8738, 0.1869, 0.9023, 0.4293, 0.2175, 0.5132, 0.2622,
0.6490, 0.0741, 0.7903, 0.3428, 0.1598, 0.4841, 0.8128, 0.7409, 0.7226,
0.4951, 0.5589, 0.9210};
Tensor::CreateFromVector(test_vector, TensorShape({5,3,2}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check Q
std::shared_ptr<TensorTransform> highpass_biquad_op = std::make_shared<audio::HighpassBiquad>(44100, 3000.5, 0);
mindspore::dataset::Execute transform({highpass_biquad_op});
Status rc = transform({input}, &output);
ASSERT_FALSE(rc.IsOk());
}
TEST_F(MindDataTestExecute, TestHighpassBiquadParamCheckSampleRate) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestHighpassBiquadParamCheckSampleRate.";
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.0237, 0.6026, 0.3801, 0.1978, 0.8672,
0.0095, 0.5166, 0.2641, 0.5485, 0.5144};
Tensor::CreateFromVector(test_vector, TensorShape({1,10}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check sample_rate
std::shared_ptr<TensorTransform> highpass_biquad_op = std::make_shared<audio::HighpassBiquad>(0, 3000.5, 0.7);
mindspore::dataset::Execute transform({highpass_biquad_op});
Status rc = transform({input}, &output);
ASSERT_FALSE(rc.IsOk());
}

View File

@ -0,0 +1,116 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing HighpassBiquad op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def test_highpass_biquad_eager():
""" mindspore eager mode normal testcase:highpass_biquad op"""
# Original waveform
waveform = np.array([[0.8236, 0.2049, 0.3335], [0.5933, 0.9911, 0.2482],
[0.3007, 0.9054, 0.7598], [0.5394, 0.2842, 0.5634], [0.6363, 0.2226, 0.2288]])
# Expect waveform
expect_waveform = np.array([[0.2745, -0.4808, 0.1576], [0.1978, -0.0652, -0.4462],
[0.1002, 0.1013, -0.2835], [0.1798, -0.2649, 0.1182], [0.2121, -0.3500, 0.0693]])
highpass_biquad_op = audio.HighpassBiquad(4000, 1000.0, 1)
# Filtered waveform by highpass_biquad
output = highpass_biquad_op(waveform)
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_highpass_biquad_pipeline():
""" mindspore pipeline mode normal testcase:highpass_biquad op"""
# Original waveform
waveform = np.array([[0.4063, 0.7729, 0.2325], [0.2687, 0.1426, 0.8987],
[0.6914, 0.6681, 0.1783], [0.2704, 0.2680, 0.7975], [0.5880, 0.1776, 0.6323]])
# Expect waveform
expect_waveform = np.array([[0.1354, -0.0133, -0.3474], [0.0896, -0.1316, 0.2642],
[0.2305, -0.2382, -0.2323], [0.0901, -0.0909, 0.1473], [0.1960, -0.3328, 0.2230]])
dataset = ds.NumpySlicesDataset(waveform, ["col1"], shuffle=False)
highpass_biquad_op = audio.HighpassBiquad(4000, 1000.0, 1)
# Filtered waveform by highpass_biquad
dataset = dataset.map(
input_columns=["col1"], operations=highpass_biquad_op, num_parallel_workers=4)
i = 0
for item in dataset.create_dict_iterator(output_numpy=True):
count_unequal_element(expect_waveform[i, :],
item["col1"], 0.0001, 0.0001)
i += 1
def test_highpass_biquad_invalid_input():
"""
Test invalid input of HighpassBiquad
"""
def test_invalid_input(test_name, sample_rate, cutoff_freq, Q, error, error_msg):
logger.info("Test HighpassBiquad with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
audio.HighpassBiquad(sample_rate, cutoff_freq, Q)
assert error_msg in str(error_info.value)
test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 1000, 0.707, TypeError,
"Argument sample_rate with value 44100.5 is not of type [<class 'int'>],"
" but got <class 'float'>.")
test_invalid_input("invalid sample_rate parameter type as a String", "44100", 1000, 0.707, TypeError,
"Argument sample_rate with value 44100 is not of type [<class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid cutoff_freq parameter type as a String", 44100, "1000", 0.707, TypeError,
"Argument cutoff_freq with value 1000 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid Q parameter type as a String", 44100, 1000, "0.707", TypeError,
"Argument Q with value 0.707 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 1000, 0.707, ValueError,
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
test_invalid_input("invalid cutoff_freq parameter value", 44100, 32434324324234321, 0.707, ValueError,
"Input cutoff_freq is not within the required interval of [-16777216, 16777216].")
test_invalid_input("invalid sample_rate parameter value", None, 1000, 0.707, TypeError,
"Argument sample_rate with value None is not of type [<class 'int'>], "
"but got <class 'NoneType'>.")
test_invalid_input("invalid cutoff_rate parameter value", 44100, None, 0.707, TypeError,
"Argument cutoff_freq with value None is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'NoneType'>.")
test_invalid_input("invalid sample_rate parameter value", 0, 1000, 0.707, ValueError,
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
test_invalid_input("invalid Q parameter value", 44100, 1000, 0, ValueError,
"Input Q is not within the required interval of (0, 1].")
if __name__ == "__main__":
test_highpass_biquad_eager()
test_highpass_biquad_pipeline()
test_highpass_biquad_invalid_input()