forked from mindspore-Ecosystem/mindspore
!17885 [assistant][ops]Add new operator EqualizerBiquad
Merge pull request !17885 from YJfuel123/EqualizerBiquad
This commit is contained in:
commit
a61c2c5cfe
|
@ -26,6 +26,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
|
@ -176,6 +177,23 @@ std::shared_ptr<TensorOperation> DeemphBiquad::Parse() {
|
|||
return std::make_shared<DeemphBiquadOperation>(data_->sample_rate_);
|
||||
}
|
||||
|
||||
// EqualizerBiquad Transform Operation.
|
||||
struct EqualizerBiquad::Data {
|
||||
Data(int32_t sample_rate, float center_freq, float gain, float Q)
|
||||
: sample_rate_(sample_rate), center_freq_(center_freq), gain_(gain), Q_(Q) {}
|
||||
int32_t sample_rate_;
|
||||
float center_freq_;
|
||||
float gain_;
|
||||
float Q_;
|
||||
};
|
||||
|
||||
EqualizerBiquad::EqualizerBiquad(int32_t sample_rate, float center_freq, float gain, float Q)
|
||||
: data_(std::make_shared<Data>(sample_rate, center_freq, gain, Q)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> EqualizerBiquad::Parse() {
|
||||
return std::make_shared<EqualizerBiquadOperation>(data_->sample_rate_, data_->center_freq_, data_->gain_, data_->Q_);
|
||||
}
|
||||
|
||||
// FrequencyMasking Transform Operation.
|
||||
struct FrequencyMasking::Data {
|
||||
Data(bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value)
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
|
@ -157,6 +158,17 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::EqualizerBiquadOperation, TensorOperation,
|
||||
std::shared_ptr<audio::EqualizerBiquadOperation>>(*m, "EqualizerBiquadOperation")
|
||||
.def(py::init([](int sample_rate, float center_freq, float gain, float Q) {
|
||||
auto equalizer_biquad =
|
||||
std::make_shared<audio::EqualizerBiquadOperation>(sample_rate, center_freq, gain, Q);
|
||||
THROW_IF_ERROR(equalizer_biquad->ValidateParams());
|
||||
return equalizer_biquad;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
FrequencyMaskingOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
|
|
|
@ -12,6 +12,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
complex_norm_ir.cc
|
||||
contrast_ir.cc
|
||||
deemph_biquad_ir.cc
|
||||
equalizer_biquad_ir.cc
|
||||
frequency_masking_ir.cc
|
||||
highpass_biquad_ir.cc
|
||||
lowpass_biquad_ir.cc
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/audio/kernels/equalizer_biquad_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace audio {
|
||||
EqualizerBiquadOperation::EqualizerBiquadOperation(int32_t sample_rate, float center_freq, float gain, float Q)
|
||||
: sample_rate_(sample_rate), center_freq_(center_freq), gain_(gain), Q_(Q) {}
|
||||
|
||||
Status EqualizerBiquadOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateScalarNotZero("EqualizerBiquad", "sample_rate", sample_rate_));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("EqualizerBiquad", "Q", Q_, {0, 1.0}, true, false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> EqualizerBiquadOperation::Build() {
|
||||
std::shared_ptr<EqualizerBiquadOp> tensor_op =
|
||||
std::make_shared<EqualizerBiquadOp>(sample_rate_, center_freq_, gain_, Q_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status EqualizerBiquadOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sample_rate"] = sample_rate_;
|
||||
args["center_freq"] = center_freq_;
|
||||
args["gain"] = gain_;
|
||||
args["Q"] = Q_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_EQUALIZER_BIQUAD_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_EQUALIZER_BIQUAD_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
|
||||
constexpr char kEqualizerBiquadOperation[] = "EqualizerBiquad";
|
||||
|
||||
class EqualizerBiquadOperation : public TensorOperation {
|
||||
public:
|
||||
EqualizerBiquadOperation(int32_t sample_rate, float center_freq, float gain, float Q);
|
||||
|
||||
~EqualizerBiquadOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kEqualizerBiquadOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t sample_rate_;
|
||||
float center_freq_;
|
||||
float gain_;
|
||||
float Q_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_EQUALIZER_BIQUAD_IR_H_
|
|
@ -13,6 +13,7 @@ add_library(audio-kernels OBJECT
|
|||
complex_norm_op.cc
|
||||
contrast_op.cc
|
||||
deemph_biquad_op.cc
|
||||
equalizer_biquad_op.cc
|
||||
frequency_masking_op.cc
|
||||
highpass_biquad_op.cc
|
||||
lowpass_biquad_op.cc
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/kernels/equalizer_biquad_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
constexpr float_t EqualizerBiquadOp::kQ = 0.707;
|
||||
|
||||
Status EqualizerBiquadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
// check input tensor dimension, it should be greater than 0.
|
||||
TensorShape input_shape = input->shape();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "EqualizerBiquad: input tensor in not in shape of <..., time>.");
|
||||
// check input tensor type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type() == DataType(DataType::DE_FLOAT32) || input->type() == DataType(DataType::DE_FLOAT16) ||
|
||||
input->type() == DataType(DataType::DE_FLOAT64),
|
||||
"EqualizerBiquad: input tensor type should be float, but got: " + input->type().ToString());
|
||||
|
||||
double w0 = 2.0 * PI * center_freq_ / sample_rate_;
|
||||
double alpha = sin(w0) / 2.0 / Q_;
|
||||
double A = exp(gain_ / 40.0 * log(10));
|
||||
|
||||
double b0 = 1.0 + alpha * A;
|
||||
double b1 = -2.0 * cos(w0);
|
||||
double b2 = 1.0 - alpha * A;
|
||||
double a0 = 1.0 + alpha / A;
|
||||
double a1 = -2.0 * cos(w0);
|
||||
double a2 = 1.0 - alpha / A;
|
||||
if (input->type() == DataType(DataType::DE_FLOAT32)) {
|
||||
return Biquad(input, output, static_cast<float>(b0), static_cast<float>(b1), static_cast<float>(b2),
|
||||
static_cast<float>(a0), static_cast<float>(a1), static_cast<float>(a2));
|
||||
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
return Biquad(input, output, static_cast<double>(b0), static_cast<double>(b1), static_cast<double>(b2),
|
||||
static_cast<double>(a0), static_cast<double>(a1), static_cast<double>(a2));
|
||||
} else {
|
||||
return Biquad(input, output, static_cast<float16>(b0), static_cast<float16>(b1), static_cast<float16>(b2),
|
||||
static_cast<float16>(a0), static_cast<float16>(a1), static_cast<float16>(a2));
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_EQUALIZER_BIQUAD_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_EQUALIZER_BIQUAD_OP_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class EqualizerBiquadOp : public TensorOp {
|
||||
public:
|
||||
static const float kQ;
|
||||
|
||||
EqualizerBiquadOp(int32_t sample_rate, float center_freq, float gain, float Q)
|
||||
: sample_rate_(sample_rate), center_freq_(center_freq), gain_(gain), Q_(Q) {}
|
||||
|
||||
~EqualizerBiquadOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kEqualizerBiquadOp; }
|
||||
|
||||
protected:
|
||||
int32_t sample_rate_;
|
||||
float center_freq_;
|
||||
float gain_;
|
||||
float Q_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_EQUALIZER_BIQUAD_OP_H_
|
|
@ -249,6 +249,29 @@ class DeemphBiquad final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio.
|
||||
class EqualizerBiquad final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
/// \param[in] center_freq Filter's central frequency (in Hz).
|
||||
/// \param[in] gain Desired gain at the boost (or attenuation) in dB.
|
||||
/// \param[in] Q Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (Default: 0.707).
|
||||
EqualizerBiquad(int32_t sample_rate, float center_freq, float gain, float Q = 0.707);
|
||||
|
||||
/// \brief Destructor.
|
||||
~EqualizerBiquad() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief FrequencyMasking TensorTransform.
|
||||
/// \notes Apply masking to a spectrogram in the frequency domain.
|
||||
class FrequencyMasking final : public TensorTransform {
|
||||
|
|
|
@ -149,6 +149,7 @@ constexpr char kBassBiquadOp[] = "BassBiquadOp";
|
|||
constexpr char kComplexNormOp[] = "ComplexNormOp";
|
||||
constexpr char kContrastOp[] = "ContrastOp";
|
||||
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
|
||||
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
|
||||
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
|
||||
constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp";
|
||||
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
|
||||
|
|
|
@ -26,7 +26,8 @@ from ..transforms.c_transforms import TensorOperation
|
|||
from .utils import ScaleType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_deemph_biquad, \
|
||||
check_highpass_biquad, check_lowpass_biquad, check_masking, check_mu_law_decoding, check_time_stretch
|
||||
check_equalizer_biquad, check_highpass_biquad, check_lowpass_biquad, check_masking, check_mu_law_decoding,\
|
||||
check_time_stretch
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -318,6 +319,36 @@ class DeemphBiquad(AudioTensorOperation):
|
|||
return cde.DeemphBiquadOperation(self.sample_rate)
|
||||
|
||||
|
||||
class EqualizerBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design biquad equalizer filter and perform filtering. Similar to SoX implementation.
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
center_freq (float): Central frequency (in Hz).
|
||||
gain (float): Desired gain at the boost (or attenuation) in dB.
|
||||
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.EqualizerBiquad(44100, 1500, 5.5, 0.7)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_equalizer_biquad
|
||||
def __init__(self, sample_rate, center_freq, gain, Q=0.707):
|
||||
self.sample_rate = sample_rate
|
||||
self.center_freq = center_freq
|
||||
self.gain = gain
|
||||
self.Q = Q
|
||||
|
||||
def parse(self):
|
||||
return cde.EqualizerBiquadOperation(self.sample_rate, self.center_freq, self.gain, self.Q)
|
||||
|
||||
|
||||
class FrequencyMasking(AudioTensorOperation):
|
||||
"""
|
||||
Apply masking to a spectrogram in the frequency domain.
|
||||
|
|
|
@ -215,6 +215,21 @@ def check_deemph_biquad(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_equalizer_biquad(method):
|
||||
"""Wrapper method to check the parameters of EqualizerBiquad."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[sample_rate, center_freq, gain, Q], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_biquad_sample_rate(sample_rate)
|
||||
check_biquad_central_freq(center_freq)
|
||||
check_biquad_gain(gain)
|
||||
check_biquad_Q(Q)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_lowpass_biquad(method):
|
||||
"""Wrapper method to check the parameters of LowpassBiquad."""
|
||||
|
||||
|
|
|
@ -490,6 +490,61 @@ TEST_F(MindDataTestPipeline, TestAnglePipelineError) {
|
|||
EXPECT_ERROR(iter->GetNextRow(&row));
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestEqualizerBiquadSuccess) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEqualizerBiquadSuccess.";
|
||||
|
||||
// Create an input tensor
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {1, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create a filter object
|
||||
std::shared_ptr<TensorTransform> equalizer_biquad(new audio::EqualizerBiquad(44100, 3.5, 5.5, 0.707));
|
||||
auto ds1 = ds->Map({equalizer_biquad}, {"col1"}, {"audio"});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
EXPECT_EQ(i, 8);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestEqualizerBiquadWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestEqualizerBiquadWrongArgs.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
// Original waveform
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 10}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
std::shared_ptr<Dataset> ds01;
|
||||
std::shared_ptr<Dataset> ds02;
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Check sample_rate
|
||||
MS_LOG(INFO) << "sample_rate is zero.";
|
||||
auto equalizer_biquad_op_01 = audio::EqualizerBiquad(0, 200.0, 5.5, 0.7);
|
||||
ds01 = ds->Map({equalizer_biquad_op_01});
|
||||
EXPECT_NE(ds01, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
|
||||
EXPECT_EQ(iter01, nullptr);
|
||||
|
||||
// Check Q
|
||||
MS_LOG(INFO) << "Q is zero.";
|
||||
auto equalizer_biquad_op_02 = audio::EqualizerBiquad(44100, 2000.0, 5.5, 0);
|
||||
ds02 = ds->Map({equalizer_biquad_op_02});
|
||||
EXPECT_NE(ds02, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
|
||||
EXPECT_EQ(iter02, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestLowpassBiquadSuccess) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLowpassBiquadSuccess.";
|
||||
|
||||
|
|
|
@ -642,6 +642,54 @@ TEST_F(MindDataTestExecute, TestRGB2BGREager) {
|
|||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestEqualizerBiquadEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestEqualizerBiquadEager.";
|
||||
int sample_rate = 44100;
|
||||
float center_freq = 3.5;
|
||||
float gain =5.5;
|
||||
float Q = 0.707;
|
||||
std::vector<mindspore::MSTensor> output;
|
||||
std::shared_ptr<Tensor> test;
|
||||
std::vector<double> test_vector = {0.8236, 0.2049, 0.3335, 0.5933, 0.9911, 0.2482,
|
||||
0.3007, 0.9054, 0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288};
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({5,3}), &test);
|
||||
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
|
||||
std::shared_ptr<TensorTransform> equalizer_biquad(new audio::EqualizerBiquad({sample_rate, center_freq, gain, Q}));
|
||||
auto transform = Execute({equalizer_biquad});
|
||||
Status rc = transform({input}, &output);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestEqualizerBiquadParamCheckQ) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestEqualizerBiquadParamCheckQ.";
|
||||
std::vector<mindspore::MSTensor> output;
|
||||
std::shared_ptr<Tensor> test;
|
||||
std::vector<double> test_vector = {0.1129, 0.3899, 0.7762, 0.2437, 0.9911, 0.8764,
|
||||
0.4524, 0.9034, 0.3277, 0.8904, 0.1852, 0.6721, 0.1325, 0.2345, 0.5538};
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({3,5}), &test);
|
||||
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
|
||||
// Check Q
|
||||
std::shared_ptr<TensorTransform> equalizer_biquad_op = std::make_shared<audio::EqualizerBiquad>(44100, 3.5, 5.5, 0);
|
||||
mindspore::dataset::Execute transform({equalizer_biquad_op});
|
||||
Status rc = transform({input}, &output);
|
||||
ASSERT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestEqualizerBiquadParamCheckSampleRate) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestEqualizerBiquadParamCheckSampleRate.";
|
||||
std::vector<mindspore::MSTensor> output;
|
||||
std::shared_ptr<Tensor> test;
|
||||
std::vector<double> test_vector = {0.5236, 0.7049, 0.4335, 0.4533, 0.0911, 0.3482,
|
||||
0.3407, 0.9054, 0.7598, 0.5394, 0.2842, 0.5634, 0.6363, 0.2226, 0.2288,0.6743};
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({4,4}), &test);
|
||||
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
|
||||
// Check sample_rate
|
||||
std::shared_ptr<TensorTransform> equalizer_biquad_op = std::make_shared<audio::EqualizerBiquad>(0, 3.5, 5.5, 0.7);
|
||||
mindspore::dataset::Execute transform({equalizer_biquad_op});
|
||||
Status rc = transform({input}, &output);
|
||||
ASSERT_FALSE(rc.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestLowpassBiquadEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestLowpassBiquadEager.";
|
||||
int sample_rate = 44100;
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing EqualizerBiquad op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
|
||||
data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_equalizer_biquad_eager():
|
||||
""" mindspore eager mode normal testcase:highpass_biquad op"""
|
||||
# Original waveform
|
||||
waveform = np.array([[0.8236, 0.2049, 0.3335], [0.5933, 0.9911, 0.2482],
|
||||
[0.3007, 0.9054, 0.7598], [0.5394, 0.2842, 0.5634], [0.6363, 0.2226, 0.2288]])
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[1.0000, 0.2532, 0.1273], [0.7333, 1.0000, 0.1015],
|
||||
[0.3717, 1.0000, 0.8351], [0.6667, 0.3513, 0.5098], [0.7864, 0.2751, 0.0627]])
|
||||
equalizer_biquad_op = audio.EqualizerBiquad(4000, 1000.0, 5.5, 1)
|
||||
# Filtered waveform by highpass_biquad
|
||||
output = equalizer_biquad_op(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_equalizer_biquad_pipeline():
|
||||
""" mindspore pipeline mode normal testcase:highpass_biquad op"""
|
||||
# Original waveform
|
||||
waveform = np.array([[0.4063, 0.7729, 0.2325], [0.2687, 0.1426, 0.8987],
|
||||
[0.6914, 0.6681, 0.1783], [0.2704, 0.2680, 0.7975], [0.5880, 0.1776, 0.6323]])
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.5022, 0.9553, 0.1468], [0.3321, 0.1762, 1.0000],
|
||||
[0.8545, 0.8257, -0.0188], [0.3342, 0.3312, 0.8921], [0.7267, 0.2195, 0.5781]])
|
||||
dataset = ds.NumpySlicesDataset(waveform, ["col1"], shuffle=False)
|
||||
equalizer_biquad_op = audio.EqualizerBiquad(4000, 1000.0, 5.5, 1)
|
||||
# Filtered waveform by equalizer_biquad
|
||||
dataset = dataset.map(input_columns=["col1"], operations=equalizer_biquad_op, num_parallel_workers=4)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item["col1"], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
||||
def test_equalizer_biquad_invalid_input():
|
||||
"""
|
||||
Test invalid input of HighpassBiquad
|
||||
"""
|
||||
def test_invalid_input(test_name, sample_rate, center_freq, gain, Q, error, error_msg):
|
||||
logger.info("Test EqualizerBiquad with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.EqualizerBiquad(sample_rate, center_freq, gain, Q)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 1000, 5.5, 0.707, TypeError,
|
||||
"Argument sample_rate with value 44100.5 is not of type [<class 'int'>],"
|
||||
" but got <class 'float'>.")
|
||||
test_invalid_input("invalid sample_rate parameter type as a String", "44100", 1000, 5.5, 0.707, TypeError,
|
||||
"Argument sample_rate with value 44100 is not of type [<class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid central_freq parameter type as a String", 44100, "1000", 5.5, 0.707, TypeError,
|
||||
"Argument central_freq with value 1000 is not of type [<class 'float'>, <class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid gain parameter type as a String", 44100, 1000, "5.5", 0.707, TypeError,
|
||||
"Argument gain with value 5.5 is not of type [<class 'float'>, <class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid Q parameter type as a String", 44100, 1000, 5.5, "0.707", TypeError,
|
||||
"Argument Q with value 0.707 is not of type [<class 'float'>, <class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
|
||||
test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 1000, 5.5, 0.707, ValueError,
|
||||
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
|
||||
test_invalid_input("invalid central_freq parameter value", 44100, 3243432434, 5.5, 0.707, ValueError,
|
||||
"Input central_freq is not within the required interval of [-16777216, 16777216].")
|
||||
test_invalid_input("invalid sample_rate parameter value", 0, 1000, 5.5, 0.707, ValueError,
|
||||
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
|
||||
test_invalid_input("invalid Q parameter value", 44100, 1000, 5.5, 0, ValueError,
|
||||
"Input Q is not within the required interval of (0, 1].")
|
||||
|
||||
test_invalid_input("invalid sample_rate parameter value", None, 1000, 5.5, 0.707, TypeError,
|
||||
"Argument sample_rate with value None is not of type [<class 'int'>], "
|
||||
"but got <class 'NoneType'>.")
|
||||
test_invalid_input("invalid central_freq parameter value", 44100, None, 5.5, 0.707, TypeError,
|
||||
"Argument central_freq with value None is not of type [<class 'float'>, <class 'int'>],"
|
||||
" but got <class 'NoneType'>.")
|
||||
test_invalid_input("invalid gain parameter value", 44100, 200, None, 0.707, TypeError,
|
||||
"Argument gain with value None is not of type [<class 'float'>, <class 'int'>], "
|
||||
"but got <class 'NoneType'>.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_equalizer_biquad_eager()
|
||||
test_equalizer_biquad_pipeline()
|
||||
test_equalizer_biquad_invalid_input()
|
Loading…
Reference in New Issue