!19093 [assistant][ops] Add new audio operator Biquad

Merge pull request !19093 from 杨旭华/Biquad
This commit is contained in:
i-robot 2021-09-14 01:52:01 +00:00 committed by Gitee
commit 32c2b77595
16 changed files with 554 additions and 13 deletions

View File

@ -23,6 +23,7 @@
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
@ -145,6 +146,25 @@ std::shared_ptr<TensorOperation> BassBiquad::Parse() {
return std::make_shared<BassBiquadOperation>(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_);
}
// Biquad Transform Operation.
struct Biquad::Data {
Data(float b0, float b1, float b2, float a0, float a1, float a2)
: b0_(b0), b1_(b1), b2_(b2), a0_(a0), a1_(a1), a2_(a2) {}
float b0_;
float b1_;
float b2_;
float a0_;
float a1_;
float a2_;
};
Biquad::Biquad(float b0, float b1, float b2, float a0, float a1, float a2)
: data_(std::make_shared<Data>(b0, b1, b2, a0, a1, a2)) {}
std::shared_ptr<TensorOperation> Biquad::Parse() {
return std::make_shared<BiquadOperation>(data_->b0_, data_->b1_, data_->b2_, data_->a0_, data_->a1_, data_->a1_);
}
// ComplexNorm Transform Operation.
struct ComplexNorm::Data {
explicit Data(float power) : power_(power) {}

View File

@ -27,6 +27,7 @@
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
@ -127,6 +128,16 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(BiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::BiquadOperation, TensorOperation, std::shared_ptr<audio::BiquadOperation>>(
*m, "BiquadOperation")
.def(py::init([](float b0, float b1, float b2, float a0, float a1, float a2) {
auto biquad = std::make_shared<audio::BiquadOperation>(b0, b1, b2, a0, a1, a2);
THROW_IF_ERROR(biquad->ValidateParams());
return biquad;
}));
}));
PYBIND_REGISTER(
ComplexNormOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::ComplexNormOperation, TensorOperation, std::shared_ptr<audio::ComplexNormOperation>>(
@ -259,5 +270,6 @@ PYBIND_REGISTER(
return timestretch;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -9,6 +9,7 @@ add_library(audio-ir-kernels OBJECT
bandpass_biquad_ir.cc
bandreject_biquad_ir.cc
bass_biquad_ir.cc
biquad_ir.cc
complex_norm_ir.cc
contrast_ir.cc
dc_shift_ir.cc

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/biquad_ir.h"
#include "minddata/dataset/audio/ir/validators.h"
#include "minddata/dataset/audio/kernels/biquad_op.h"
namespace mindspore {
namespace dataset {
namespace audio {
// BiquadOperation
BiquadOperation::BiquadOperation(float b0, float b1, float b2, float a0, float a1, float a2)
: b0_(b0), b1_(b1), b2_(b2), a0_(a0), a1_(a1), a2_(a2) {}
Status BiquadOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateScalarNotZero("Biquad", "a0", a0_));
return Status::OK();
}
std::shared_ptr<TensorOp> BiquadOperation::Build() {
std::shared_ptr<BiquadOp> tensor_op = std::make_shared<BiquadOp>(b0_, b1_, b2_, a0_, a1_, a2_);
return tensor_op;
}
Status BiquadOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["b0"] = b0_;
args["b1"] = b1_;
args["b2"] = b2_;
args["a0"] = a0_;
args["a1"] = a1_;
args["a2"] = a2_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BIQUAD_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BIQUAD_IR_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kBiquadOperation[] = "Biquad";
class BiquadOperation : public TensorOperation {
public:
BiquadOperation(float b0, float b1, float b2, float a0, float a1, float a2);
~BiquadOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kBiquadOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
float b0_;
float b1_;
float b2_;
float a0_;
float a1_;
float a2_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_BIQUAD_IR_H_

View File

@ -10,6 +10,7 @@ add_library(audio-kernels OBJECT
bandpass_biquad_op.cc
bandreject_biquad_op.cc
bass_biquad_op.cc
biquad_op.cc
complex_norm_op.cc
contrast_op.cc
dc_shift_op.cc

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/biquad_op.h"
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status BiquadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
TensorShape input_shape = input->shape();
// check input tensor dimension, it should be greater than 0.
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "Biquad: input tensor is not in shape of <..., time>.");
// check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64
CHECK_FAIL_RETURN_UNEXPECTED(
input->type() == DataType(DataType::DE_FLOAT32) || input->type() == DataType(DataType::DE_FLOAT16) ||
input->type() == DataType(DataType::DE_FLOAT64),
"Biquad: input tensor type should be float or double, but got: " + input->type().ToString());
if (input->type() == DataType(DataType::DE_FLOAT32)) {
return Biquad(input, output, static_cast<float>(b0_), static_cast<float>(b1_), static_cast<float>(b2_),
static_cast<float>(a0_), static_cast<float>(a1_), static_cast<float>(a2_));
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
return Biquad(input, output, static_cast<double>(b0_), static_cast<double>(b1_), static_cast<double>(b2_),
static_cast<double>(a0_), static_cast<double>(a1_), static_cast<double>(a2_));
} else {
return Biquad(input, output, static_cast<float16>(b0_), static_cast<float16>(b1_), static_cast<float16>(b2_),
static_cast<float16>(a0_), static_cast<float16>(a1_), static_cast<float16>(a2_));
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class BiquadOp : public TensorOp {
public:
BiquadOp(float b0, float b1, float b2, float a0, float a1, float a2)
: b0_(b0), b1_(b1), b2_(b2), a0_(a0), a1_(a1), a2_(a2) {}
~BiquadOp() override = default;
void Print(std::ostream &out) const override {
out << Name() << ": b0: " << b0_ << ", b1: " << b1_ << ", b2: " << b2_ << ", a0: " << a0_ << ", a1: " << a1_
<< ", a2: " << a2_ << std::endl;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kBiquadOp; }
private:
float b0_;
float b1_;
float b2_;
float a0_;
float a1_;
float a2_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_BIQUAD_OP_H_

View File

@ -189,6 +189,31 @@ class BassBiquad final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Perform a biquad filter of input tensor.
class Biquad final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] b0 Numerator coefficient of current input, x[n].
/// \param[in] b1 Numerator coefficient of input one time step ago x[n-1].
/// \param[in] b2 Numerator coefficient of input two time steps ago x[n-2].
/// \param[in] a0 Denominator coefficient of current output y[n], the value can't be zero, typically 1.
/// \param[in] a1 Denominator coefficient of current output y[n-1].
/// \param[in] a2 Denominator coefficient of current output y[n-2].
explicit Biquad(float b0, float b1, float b2, float a0, float a1, float a2);
/// \brief Destructor.
~Biquad() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief ComplexNorm TensorTransform.
/// \notes Compute the norm of complex tensor input.
class ComplexNorm final : public TensorTransform {
@ -465,6 +490,7 @@ class TimeStretch final : public TensorTransform {
struct Data;
std::shared_ptr<Data> data_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -146,6 +146,7 @@ constexpr char kBandBiquadOp[] = "BandBiquadOp";
constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp";
constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp";
constexpr char kBassBiquadOp[] = "BassBiquadOp";
constexpr char kBiquadOp[] = "BiquadOp";
constexpr char kComplexNormOp[] = "ComplexNormOp";
constexpr char kContrastOp[] = "ContrastOp";
constexpr char kDCShiftOp[] = "DCShiftOp";

View File

@ -25,7 +25,7 @@ import mindspore._c_dataengine as cde
from ..transforms.c_transforms import TensorOperation
from .utils import ScaleType
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_dc_shift, \
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
check_deemph_biquad, check_equalizer_biquad, check_highpass_biquad, check_lfilter, check_lowpass_biquad, \
check_masking, check_mu_law_decoding, check_time_stretch
@ -246,6 +246,38 @@ class BassBiquad(AudioTensorOperation):
return cde.BassBiquadOperation(self.sample_rate, self.gain, self.central_freq, self.Q)
class Biquad(TensorOperation):
"""
Perform a biquad filter of input tensor.
Args:
b0 (float): Numerator coefficient of current input, x[n].
b1 (float): Numerator coefficient of input one time step ago x[n-1].
b2 (float): Numerator coefficient of input two time steps ago x[n-2].
a0 (float): Denominator coefficient of current output y[n], the value can't be zero, typically 1.
a1 (float): Denominator coefficient of current output y[n-1].
a2 (float): Denominator coefficient of current output y[n-2].
Examples:
>>> import numpy as np
>>>
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
>>> biquad_op = audio.Biquad(0.01, 0.02, 0.13, 1, 0.12, 0.3)
>>> waveform_filtered = biquad_op(waveform)
"""
@check_biquad
def __init__(self, b0, b1, b2, a0, a1, a2):
self.b0 = b0
self.b1 = b1
self.b2 = b2
self.a0 = a0
self.a1 = a1
self.a2 = a2
def parse(self):
return cde.BiquadOperation(self.b0, self.b1, self.b2, self.a0, self.a1, self.a2)
class ComplexNorm(AudioTensorOperation):
"""
Compute the norm of complex tensor input.

View File

@ -18,8 +18,9 @@ Validators for TensorOps.
from functools import wraps
from mindspore.dataset.core.validator_helpers import check_float32, check_int32_not_zero, check_list_same_size, \
check_non_negative_float32, check_pos_float32, check_pos_int32, check_value, parse_user_args, type_check
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, \
check_int32_not_zero, check_list_same_size, check_non_negative_float32, check_pos_float32, \
check_pos_int32, check_value, parse_user_args, type_check
from .utils import ScaleType
@ -342,3 +343,28 @@ def check_complex_norm(method):
return method(self, *args, **kwargs)
return new_method
def check_biquad_coeff(coeff, arg_name):
"""Wrapper method to check the parameters of coeff."""
type_check(coeff, (float, int), arg_name)
check_float32(coeff, arg_name)
def check_biquad(method):
"""Wrapper method to check the parameters of Biquad."""
@wraps(method)
def new_method(self, *args, **kwargs):
[b0, b1, b2, a0, a1, a2], _ = parse_user_args(
method, *args, **kwargs)
check_biquad_coeff(b0, "b0")
check_biquad_coeff(b1, "b1")
check_biquad_coeff(b2, "b2")
type_check(a0, (float, int), "a0")
check_float32_not_zero(a0, "a0")
check_biquad_coeff(a1, "a1")
check_biquad_coeff(a2, "a2")
return method(self, *args, **kwargs)
return new_method

View File

@ -368,6 +368,14 @@ def check_non_negative_float64(value, arg_name=""):
check_value(value, [UINT32_MIN, DOUBLE_MAX_INTEGER], arg_name)
def check_float32_not_zero(value, arg_name=""):
arg_name = pad_arg_name(arg_name)
type_check(value, (int,), arg_name)
if value < FLOAT_MIN_INTEGER or value > FLOAT_MAX_INTEGER or value == 0:
raise ValueError(
"Input {0}is not within the required interval of [-16777216, 0) and (0, 16777216].".format(arg_name))
def check_valid_detype(type_):
"""
Validates if a type is a DE Type.

View File

@ -1067,3 +1067,61 @@ TEST_F(MindDataTestPipeline, TestDCShiftPipelineError) {
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, TestBiquadBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestBiquadBasic.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto BiquadOp = audio::Biquad(0.01, 0.02, 0.13, 1, 0.12, 0.3);
ds = ds->Map({BiquadOp});
EXPECT_NE(ds, nullptr);
// Filtered waveform by biquad
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, 200};
int i = 0;
while (row.size() != 0) {
auto col = row["inputData"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestBiquadParamCheck) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestBiquadParamCheck.";
std::shared_ptr<SchemaObj> schema = Schema();
// Original waveform
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
std::shared_ptr<Dataset> ds01;
EXPECT_NE(ds, nullptr);
// Check a0
MS_LOG(INFO) << "a0 is zero.";
auto biquad_op_01 = audio::Biquad(0.01, 0.02, 0.13, 0, 0.12, 0.3);
ds01 = ds->Map({biquad_op_01});
EXPECT_NE(ds01, nullptr);
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
EXPECT_EQ(iter01, nullptr);
}

View File

@ -646,13 +646,13 @@ TEST_F(MindDataTestExecute, TestEqualizerBiquadEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestEqualizerBiquadEager.";
int sample_rate = 44100;
float center_freq = 3.5;
float gain =5.5;
float gain = 5.5;
float Q = 0.707;
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482,
0.3007, 0.9054, 0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
Tensor::CreateFromVector(test_vector, TensorShape({5,3}), &test);
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482, 0.3007, 0.9054,
0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
Tensor::CreateFromVector(test_vector, TensorShape({5, 3}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
std::shared_ptr<TensorTransform> equalizer_biquad(new audio::EqualizerBiquad({sample_rate, center_freq, gain, Q}));
auto transform = Execute({equalizer_biquad});
@ -664,9 +664,9 @@ TEST_F(MindDataTestExecute, TestEqualizerBiquadParamCheckQ) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestEqualizerBiquadParamCheckQ.";
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.1129, 0.3899, 0.7762, 0.2437, 0.9911, 0.8764,
0.4524, 0.9034, 0.3277, 0.8904, 0.1852, 0.6721, 0.1325, 0.2345, 0.5538};
Tensor::CreateFromVector(test_vector, TensorShape({3,5}), &test);
std::vector<double> test_vector = {0.1129, 0.3899, 0.7762, 0.2437, 0.9911, 0.8764, 0.4524, 0.9034,
0.3277, 0.8904, 0.1852, 0.6721, 0.1325, 0.2345, 0.5538};
Tensor::CreateFromVector(test_vector, TensorShape({3, 5}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check Q
std::shared_ptr<TensorTransform> equalizer_biquad_op = std::make_shared<audio::EqualizerBiquad>(44100, 3.5, 5.5, 0);
@ -679,9 +679,9 @@ TEST_F(MindDataTestExecute, TestEqualizerBiquadParamCheckSampleRate) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestEqualizerBiquadParamCheckSampleRate.";
std::vector<mindspore::MSTensor> output;
std::shared_ptr<Tensor> test;
std::vector<double> test_vector = {0.5236, 0.7049, 0.4335, 0.4533, 0.0911, 0.3482,
0.3407, 0.9054, 0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288,0.6743};
Tensor::CreateFromVector(test_vector, TensorShape({4,4}), &test);
std::vector<double> test_vector = {0.5236, 0.7049, 0.4335, 0.4533, 0.0911, 0.3482, 0.3407, 0.9054,
0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288, 0.6743};
Tensor::CreateFromVector(test_vector, TensorShape({4, 4}), &test);
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
// Check sample_rate
std::shared_ptr<TensorTransform> equalizer_biquad_op = std::make_shared<audio::EqualizerBiquad>(0, 3.5, 5.5, 0.7);
@ -927,3 +927,37 @@ TEST_F(MindDataTestExecute, TestDCShiftEager) {
Status s = Transform(input, &input);
ASSERT_TRUE(s.IsOk());
}
TEST_F(MindDataTestExecute, TestBiquadWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestBiquadWithEager.";
// Original waveform
std::vector<float> labels = {3.716064453125, 12.34765625, 5.246826171875, 1.0894775390625,
1.1383056640625, 2.1566162109375, 1.3946533203125, 3.55029296875};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 4}), &input));
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> biquad_01 = std::make_shared<audio::Biquad>(1, 0.02, 0.13, 1, 0.12, 0.3);
mindspore::dataset::Execute Transform01({biquad_01});
// Filtered waveform by biquad
Status s01 = Transform01(input_01, &input_01);
EXPECT_TRUE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestBiquadWithWrongArg) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestBiquadWithWrongArg.";
std::vector<double> labels = {
2.716064453125000000e-03,
6.347656250000000000e-03,
9.246826171875000000e-03,
1.089477539062500000e-02,
};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({1, 4}), &input));
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
// Check a0
MS_LOG(INFO) << "a0 is zero.";
std::shared_ptr<TensorTransform> biquad_op = std::make_shared<audio::Biquad>(1, 0.02, 0.13, 0, 0.12, 0.3);
mindspore::dataset::Execute Transform01({biquad_op});
Status s01 = Transform01(input_02, &input_02);
EXPECT_FALSE(s01.IsOk());
}

View File

@ -0,0 +1,109 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def test_func_biquad_eager():
""" mindspore eager mode normal testcase:biquad op"""
# Original waveform
waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[0.0100, 0.0388, 0.1923],
[0.0400, 0.1252, 0.6530]], dtype=np.float64)
biquad_op = audio.Biquad(0.01, 0.02, 0.13, 1, 0.12, 0.3)
# Filtered waveform by biquad
output = biquad_op(waveform)
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_func_biquad_pipeline():
""" mindspore pipeline mode normal testcase:biquad op"""
# Original waveform
waveform = np.array([[3.2, 2.1, 1.3], [6.2, 5.3, 6]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[1.0000, 1.0000, 0.5844],
[1.0000, 1.0000, 1.0000]], dtype=np.float64)
dataset = ds.NumpySlicesDataset(waveform, ["audio"], shuffle=False)
biquad_op = audio.Biquad(1, 0.02, 0.13, 1, 0.12, 0.3)
# Filtered waveform by biquad
dataset = dataset.map(input_columns=["audio"], operations=biquad_op, num_parallel_workers=8)
i = 0
for item in dataset.create_dict_iterator(output_numpy=True):
count_unequal_element(expect_waveform[i, :],
item['audio'], 0.0001, 0.0001)
i += 1
def test_biquad_invalid_input():
def test_invalid_input(test_name, b0, b1, b2, a0, a1, a2, error, error_msg):
logger.info("Test Biquad with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
audio.Biquad(b0, b1, b2, a0, a1, a2)
assert error_msg in str(error_info.value)
test_invalid_input("invalid b0 parameter type as a String", "0.01", 0.02, 0.13, 1, 0.12, 0.3, TypeError,
"Argument b0 with value 0.01 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid b0 parameter value", 441324343243242342345300, 0.02, 0.13, 1, 0.12, 0.3, ValueError,
"Input b0 is not within the required interval of [-16777216, 16777216].")
test_invalid_input("invalid b1 parameter type as a String", 0.01, "0.02", 0.13, 0, 0.12, 0.3, TypeError,
"Argument b1 with value 0.02 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid b1 parameter value", 0.01, 441324343243242342345300, 0.13, 1, 0.12, 0.3, ValueError,
"Input b1 is not within the required interval of [-16777216, 16777216].")
test_invalid_input("invalid b2 parameter type as a String", 0.01, 0.02, "0.13", 0, 0.12, 0.3, TypeError,
"Argument b2 with value 0.13 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid b2 parameter value", 0.01, 0.02, 441324343243242342345300, 1, 0.12, 0.3, ValueError,
"Input b2 is not within the required interval of [-16777216, 16777216].")
test_invalid_input("invalid a0 parameter type as a String", 0.01, 0.02, 0.13, '1', 0.12, 0.3, TypeError,
"Argument a0 with value 1 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid a0 parameter value", 0.01, 0.02, 0.13, 0, 0.12, 0.3, ValueError,
"Input a0 is not within the required interval of [-16777216, 0) and (0, 16777216].")
test_invalid_input("invalid a0 parameter value", 0.01, 0.02, 0.13, 441324343243242342345300, 0.12, 0.3, ValueError,
"Input a0 is not within the required interval of [-16777216, 0) and (0, 16777216].")
test_invalid_input("invalid a1 parameter type as a String", 0.01, 0.02, 0.13, 1, '0.12', 0.3, TypeError,
"Argument a1 with value 0.12 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid a1 parameter value", 0.01, 0.02, 0.13, 1, 441324343243242342345300, 0.3, ValueError,
"Input a1 is not within the required interval of [-16777216, 16777216].")
test_invalid_input("invalid a2 parameter type as a String", 0.01, 0.02, 0.13, 1, 0.12, '0.3', TypeError,
"Argument a2 with value 0.3 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid a1 parameter value", 0.01, 0.02, 0.13, 1, 0.12, 441324343243242342345300, ValueError,
"Input a2 is not within the required interval of [-16777216, 16777216].")
if __name__ == '__main__':
test_func_biquad_eager()
test_func_biquad_pipeline()
test_biquad_invalid_input()