[feat][assistant][I3J6UY] add new audio operator RiaaBiquad

This commit is contained in:
zx 2021-08-05 19:42:27 +08:00
parent edad96be95
commit 1d95e6f480
16 changed files with 644 additions and 1 deletions

View File

@ -36,6 +36,7 @@
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
#include "minddata/dataset/audio/ir/kernels/vol_ir.h"
@ -338,6 +339,18 @@ std::shared_ptr<TensorOperation> MuLawDecoding::Parse() {
return std::make_shared<MuLawDecodingOperation>(data_->quantization_channels_);
}
// RiaaBiquad Transform Operation.
struct RiaaBiquad::Data {
explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {}
int32_t sample_rate_;
};
RiaaBiquad::RiaaBiquad(int32_t sample_rate) : data_(std::make_shared<Data>(sample_rate)) {}
std::shared_ptr<TensorOperation> RiaaBiquad::Parse() {
return std::make_shared<RiaaBiquadOperation>(data_->sample_rate_);
}
// TimeMasking Transform Operation.
struct TimeMasking::Data {
Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value)

View File

@ -40,6 +40,7 @@
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
#include "minddata/dataset/audio/ir/kernels/vol_ir.h"
@ -282,6 +283,17 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
RiaaBiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::RiaaBiquadOperation, TensorOperation, std::shared_ptr<audio::RiaaBiquadOperation>>(
*m, "RiaaBiquadOperation")
.def(py::init([](int32_t sample_rate) {
auto riaa_biquad = std::make_shared<audio::RiaaBiquadOperation>(sample_rate);
THROW_IF_ERROR(riaa_biquad->ValidateParams());
return riaa_biquad;
}));
}));
PYBIND_REGISTER(
TimeMaskingOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::TimeMaskingOperation, TensorOperation, std::shared_ptr<audio::TimeMaskingOperation>>(

View File

@ -22,6 +22,7 @@ add_library(audio-ir-kernels OBJECT
lowpass_biquad_ir.cc
magphase_ir.cc
mu_law_decoding_ir.cc
riaa_biquad_ir.cc
time_masking_ir.cc
time_stretch_ir.cc
vol_ir.cc

View File

@ -0,0 +1,45 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
#include "minddata/dataset/audio/kernels/riaa_biquad_op.h"
#include "minddata/dataset/audio/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace audio {
RiaaBiquadOperation::RiaaBiquadOperation(int32_t sample_rate) : sample_rate_(sample_rate) {}
Status RiaaBiquadOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateScalarValue("RiaaBiquad", "sample_rate", sample_rate_, {44100, 48000, 88200, 96000}));
return Status::OK();
}
std::shared_ptr<TensorOp> RiaaBiquadOperation::Build() {
std::shared_ptr<RiaaBiquadOp> tensor_op = std::make_shared<RiaaBiquadOp>(sample_rate_);
return tensor_op;
}
Status RiaaBiquadOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sample_rate"] = sample_rate_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_RIAA_BIQUAD_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_RIAA_BIQUAD_IR_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kRiaaBiquadOperation[] = "RiaaBiquad";
class RiaaBiquadOperation : public TensorOperation {
public:
explicit RiaaBiquadOperation(int32_t sample_rate);
~RiaaBiquadOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kRiaaBiquadOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
int32_t sample_rate_;
}; // class RiaaBiquadOperation
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_RIAA_BIQUAD_IR_H_

View File

@ -34,6 +34,27 @@ Status ValidateIntScalarNonNegative(const std::string &op_name, const std::strin
// Helper function to non-nan float scalar
Status ValidateFloatScalarNotNan(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to validate scalar value
template <typename T>
Status ValidateScalarValue(const std::string &op_name, const std::string &scalar_name, T scalar,
const std::vector<T> &values) {
if (std::find(values.begin(), values.end(), scalar) == values.end()) {
std::string init;
std::string mode = std::accumulate(values.begin(), values.end(), init, [](const std::string &str, T val) {
if (str.empty()) {
return std::to_string(val);
} else {
return str + ", " + std::to_string(val);
}
});
std::string err_msg =
op_name + ": " + scalar_name + " must be one of [" + mode + "], but got: " + std::to_string(scalar);
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
return Status::OK();
}
// Helper function to check scalar is not equal to zero
template <typename T>
Status ValidateScalarNotZero(const std::string &op_name, const std::string &scalar_name, const T scalar) {

View File

@ -23,6 +23,7 @@ add_library(audio-kernels OBJECT
lowpass_biquad_op.cc
magphase_op.cc
mu_law_decoding_op.cc
riaa_biquad_op.cc
time_masking_op.cc
time_stretch_op.cc
vol_op.cc

View File

@ -0,0 +1,87 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/riaa_biquad_op.h"
#include <map>
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
RiaaBiquadOp::RiaaBiquadOp(int32_t sample_rate) : sample_rate_(sample_rate) {}
Status RiaaBiquadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
TensorShape input_shape = input->shape();
// check input tensor dimension, it should be greater than 0.
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "RiaaBiquad: input tensor is not in shape of <..., time>.");
// check input type, it should be DE_FLOAT32 or DE_FLOAT16 or DE_FLOAT64.
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType(DataType::DE_FLOAT32) ||
input->type() == DataType(DataType::DE_FLOAT16) ||
input->type() == DataType(DataType::DE_FLOAT64),
"RiaaBiquad: input tensor type should be float, but got: " + input->type().ToString());
// indicate array zeros and poles.
const std::map<int32_t, std::vector<float>> kZeros = {
{44100, {-0.2014898, 0.9233820}},
{48000, {-0.1766069, 0.9321590}},
{88200, {-0.1168735, 0.9648312}},
{96000, {-0.1141486, 0.9676817}},
};
const std::map<int32_t, std::vector<float>> kPoles = {
{44100, {0.7083149, 0.9924091}},
{48000, {0.7396325, 0.9931330}},
{88200, {0.8590646, 0.9964002}},
{96000, {0.8699137, 0.9966946}},
};
const std::vector<float> &zeros = kZeros.at(sample_rate_);
const std::vector<float> &poles = kPoles.at(sample_rate_);
// computer a0, a1, a2, b0, b1, b2.
// polynomial coefficients with roots zeros[0] and zeros[1].
float b0 = 1.0;
float b1 = -(zeros[0] + zeros[1]);
float b2 = zeros[0] * zeros[1];
// polynomial coefficients with roots poles[0] and poles[1].
float a0 = 1.0;
float a1 = -(poles[0] + poles[1]);
float a2 = poles[0] * poles[1];
// normalize to 0dB at 1kHz.
float w0 = 2 * PI * 1000 / sample_rate_;
// re refers to the real part of the complex number.
float b_re = b0 + b1 * cos(-w0) + b2 * cos(-2 * w0);
float a_re = a0 + a1 * cos(-w0) + a2 * cos(-2 * w0);
// im refers to the imaginary part of the complex number.
float b_im = b1 * sin(-w0) + b2 * sin(-2 * w0);
float a_im = a1 * sin(-w0) + a2 * sin(-2 * w0);
// temp is the intermediate variable used to solve for b0, b1, b2.
float temp = 1 / sqrt((b_re * b_re + b_im * b_im) / (a_re * a_re + a_im * a_im));
b0 *= temp;
b1 *= temp;
b2 *= temp;
// use Biquad function.
if (input->type() == DataType(DataType::DE_FLOAT32)) {
return Biquad(input, output, static_cast<float>(b0), static_cast<float>(b1), static_cast<float>(b2),
static_cast<float>(a0), static_cast<float>(a1), static_cast<float>(a2));
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
return Biquad(input, output, static_cast<double>(b0), static_cast<double>(b1), static_cast<double>(b2),
static_cast<double>(a0), static_cast<double>(a1), static_cast<double>(a2));
} else {
return Biquad(input, output, static_cast<float16>(b0), static_cast<float16>(b1), static_cast<float16>(b2),
static_cast<float16>(a0), static_cast<float16>(a1), static_cast<float16>(a2));
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_RIAA_BIQUAD_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_RIAA_BIQUAD_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class RiaaBiquadOp : public TensorOp {
public:
explicit RiaaBiquadOp(int32_t sample_rate);
~RiaaBiquadOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kRiaaBiquadOp; }
private:
int32_t sample_rate_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_RIAA_BIQUAD_OP_H_

View File

@ -484,6 +484,27 @@ class MuLawDecoding final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Apply RIAA vinyl playback equalization.
class RiaaBiquad final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz),
/// can only be one of 44100, 48000, 88200, 96000.
explicit RiaaBiquad(int32_t sample_rate);
/// \brief Destructor.
~RiaaBiquad() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief TimeMasking TensorTransform.
/// \notes Apply masking to a spectrogram in the time domain.
class TimeMasking final : public TensorTransform {

View File

@ -159,6 +159,7 @@ constexpr char kLFilterOp[] = "LFilterOp";
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
constexpr char kMagphaseOp[] = "MagphaseOp";
constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
constexpr char kRiaaBiquadOp[] = "RiaaBiquadOp";
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
constexpr char kTimeStretchOp[] = "TimeStretchOp";
constexpr char kVolOp[] = "VolOp";

View File

@ -27,7 +27,8 @@ from .utils import FadeShape, GainType, ScaleType
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_time_stretch, check_vol
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_riaa_biquad, \
check_time_stretch, check_vol
class AudioTensorOperation(TensorOperation):
@ -626,6 +627,31 @@ class MuLawDecoding(AudioTensorOperation):
return cde.MuLawDecodingOperation(self.quantization_channels)
class RiaaBiquad(AudioTensorOperation):
"""
Apply RIAA vinyl playback equalization. Similar to SoX implementation.
Args:
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz),
can only be one of 44100, 48000, 88200, 96000.
Examples:
>>> import numpy as np
>>>
>>> waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
>>> transforms = [audio.RiaaBiquad(44100)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_riaa_biquad
def __init__(self, sample_rate):
self.sample_rate = sample_rate
def parse(self):
return cde.RiaaBiquadOperation(self.sample_rate)
class TimeMasking(AudioTensorOperation):
"""
Apply masking to a spectrogram in the time domain.

View File

@ -292,6 +292,21 @@ def check_mu_law_decoding(method):
return new_method
def check_riaa_biquad(method):
"""Wrapper method to check the parameters of RiaaBiquad."""
@wraps(method)
def new_method(self, *args, **kwargs):
[sample_rate], _ = parse_user_args(method, *args, **kwargs)
type_check(sample_rate, (int,), "sample_rate")
if sample_rate not in (44100, 48000, 88200, 96000):
raise ValueError("sample_rate should be one of [44100, 48000, 88200, 96000], but got {0}.".format(
sample_rate))
return method(self, *args, **kwargs)
return new_method
def check_time_stretch(method):
"""Wrapper method to check the parameters of TimeStretch."""

View File

@ -28,6 +28,181 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
protected:
};
TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate44100) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRiaaBiquadBasicSampleRate44100.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto RiaaBiquadOp = audio::RiaaBiquad(44100);
ds = ds->Map({RiaaBiquadOp});
EXPECT_NE(ds, nullptr);
// Filtered waveform by riaabiquad
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, 200};
int i = 0;
while (row.size() != 0) {
auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate48000) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRiaaBiquadBasicSampleRate48000.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {30, 40}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto RiaaBiquadOp = audio::RiaaBiquad(48000);
ds = ds->Map({RiaaBiquadOp});
EXPECT_NE(ds, nullptr);
// Filtered waveform by riaabiquad
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {30, 40};
int i = 0;
while (row.size() != 0) {
auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate88200) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRiaaBiquadBasicSampleRate88200.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {5, 4}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto RiaaBiquadOp = audio::RiaaBiquad(88200);
ds = ds->Map({RiaaBiquadOp});
EXPECT_NE(ds, nullptr);
// Filtered waveform by riaabiquad
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {5, 4};
int i = 0;
while (row.size() != 0) {
auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRiaaBiquadBasicSampleRate96000) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRiaaBiquadBasicSampleRate96000.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 3}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto RiaaBiquadOp = audio::RiaaBiquad(96000);
ds = ds->Map({RiaaBiquadOp});
EXPECT_NE(ds, nullptr);
// Filtered waveform by riaabiquad
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, 3};
int i = 0;
while (row.size() != 0) {
auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestRiaaBiquadWrongArg) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRiaaBiquadWrongArg.";
std::shared_ptr<SchemaObj> schema = Schema();
// Original waveform
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
std::shared_ptr<Dataset> ds01;
EXPECT_NE(ds, nullptr);
// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
auto riaa_biquad_op_01 = audio::RiaaBiquad(0);
ds01 = ds->Map({riaa_biquad_op_01});
EXPECT_NE(ds01, nullptr);
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
EXPECT_EQ(iter01, nullptr);
}
TEST_F(MindDataTestPipeline, TestTimeMaskingPipeline) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTimeMaskingPipeline.";
// Original waveform

View File

@ -876,6 +876,44 @@ TEST_F(MindDataTestExecute, TestMuLawDecodingEager) {
EXPECT_TRUE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestRiaaBiquadWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRiaaBiquadWithEager.";
// Original waveform
std::vector<float> labels = {
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> riaa_biquad_01 = std::make_shared<audio::RiaaBiquad>(44100);
mindspore::dataset::Execute Transform01({riaa_biquad_01});
// Filtered waveform by riaabiquad
Status s01 = Transform01(input_02, &input_02);
EXPECT_TRUE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestRiaaBiquadWithWrongArg) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRiaaBiquadWithWrongArg.";
std::vector<float> labels = {
3.156, 5.690, 1.362, 1.093,
5.782, 6.381, 5.982, 3.098,
1.222, 6.027, 3.909, 7.993,
4.324, 1.092, 5.093, 0.991,
1.099, 4.092, 8.111, 6.666};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input));
auto input01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
std::shared_ptr<TensorTransform> riaa_biquad_op01 = std::make_shared<audio::RiaaBiquad>(0);
mindspore::dataset::Execute Transform01({riaa_biquad_op01});
Status s01 = Transform01(input01, &input01);
EXPECT_FALSE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestLFilterWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestLFilterWithEager.";
// Original waveform

View File

@ -0,0 +1,87 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing RiaaBiquad op in DE.
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
data_expected[greater], data_me[greater], error[greater])
def test_riaa_biquad_eager():
""" mindspore eager mode normal testcase:riaa_biquad op"""
# Original waveform
waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[0.23806122, 0.70914434, 1.],
[0.95224489, 1., 1.]], dtype=np.float64)
riaa_biquad_op = audio.RiaaBiquad(44100)
# Filtered waveform by riaabiquad
output = riaa_biquad_op(waveform)
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_riaa_biquad_pipeline():
""" mindspore pipeline mode normal testcase:riaa_biquad op"""
# Original waveform
waveform = np.array([[1.47, 4.722, 5.863], [0.492, 0.235, 0.56]], dtype=np.float32)
# Expect waveform
expect_waveform = np.array([[0.18626465, 0.7859906, 1.],
[0.06234163, 0.09258664, 0.15710703]], dtype=np.float64)
dataset = ds.NumpySlicesDataset(waveform, ["waveform"], shuffle=False)
riaa_biquad_op = audio.RiaaBiquad(88200)
# Filtered waveform by riaabiquad
dataset = dataset.map(input_columns=["waveform"], operations=riaa_biquad_op)
i = 0
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
count_unequal_element(expect_waveform[i, :], item['waveform'], 0.0001, 0.0001)
i += 1
def test_riaa_biquad_invalid_parameter():
def test_invalid_input(test_name, sample_rate, error, error_msg):
logger.info("Test RiaaBiquad with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
audio.RiaaBiquad(sample_rate)
assert error_msg in str(error_info.value)
test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, TypeError,
"Argument sample_rate with value 44100.5 is not of type [<class 'int'>],"
" but got <class 'float'>.")
test_invalid_input("invalid sample_rate parameter type as a String", "44100", TypeError,
"Argument sample_rate with value 44100 is not of type [<class 'int'>],"
+ " but got <class 'str'>.")
test_invalid_input("invalid sample_rate parameter value", 45670, ValueError,
"sample_rate should be one of [44100, 48000, 88200, 96000], but got 45670.")
if __name__ == "__main__":
test_riaa_biquad_eager()
test_riaa_biquad_pipeline()
test_riaa_biquad_invalid_parameter()