From 9b7d7be65cba909e713b2cd95fa0d3a4f40ae3bd Mon Sep 17 00:00:00 2001 From: zx <2597798649@qq.com> Date: Sat, 31 Jul 2021 12:07:25 +0800 Subject: [PATCH] [feat][assistant][I3J6UU] add new audio operator Overdrive --- mindspore/ccsrc/minddata/dataset/api/audio.cc | 14 +++ .../dataset/audio/kernels/ir/bindings.cc | 12 +++ .../dataset/audio/ir/kernels/CMakeLists.txt | 1 + .../dataset/audio/ir/kernels/overdrive_ir.cc | 47 ++++++++ .../dataset/audio/ir/kernels/overdrive_ir.h | 57 ++++++++++ .../dataset/audio/kernels/CMakeLists.txt | 1 + .../dataset/audio/kernels/audio_utils.h | 74 +++++++++++++ .../dataset/audio/kernels/overdrive_op.cc | 58 ++++++++++ .../dataset/audio/kernels/overdrive_op.h | 48 +++++++++ .../minddata/dataset/include/dataset/audio.h | 21 ++++ .../minddata/dataset/kernels/tensor_op.h | 1 + mindspore/dataset/audio/transforms.py | 29 ++++- mindspore/dataset/audio/validators.py | 15 +++ .../ut/cpp/dataset/c_api_audio_a_to_q_test.cc | 71 +++++++++++++ tests/ut/cpp/dataset/execute_test.cc | 50 +++++++++ tests/ut/python/dataset/test_overdrive.py | 100 ++++++++++++++++++ 16 files changed, 598 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.cc create mode 100644 mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.h create mode 100644 mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.h create mode 100644 tests/ut/python/dataset/test_overdrive.py diff --git a/mindspore/ccsrc/minddata/dataset/api/audio.cc b/mindspore/ccsrc/minddata/dataset/api/audio.cc index ce4393ad6bb..ee6053f57fb 100644 --- a/mindspore/ccsrc/minddata/dataset/api/audio.cc +++ b/mindspore/ccsrc/minddata/dataset/api/audio.cc @@ -38,6 +38,7 @@ #include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/magphase_ir.h" #include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" +#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h" #include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" @@ -400,6 +401,19 @@ std::shared_ptr MuLawDecoding::Parse() { return std::make_shared(data_->quantization_channels_); } +// Overdrive Transform Operation. +struct Overdrive::Data { + Data(float gain, float color) : gain_(gain), color_(color) {} + float gain_; + float color_; +}; + +Overdrive::Overdrive(float gain, float color) : data_(std::make_shared(gain, color)) {} + +std::shared_ptr Overdrive::Parse() { + return std::make_shared(data_->gain_, data_->color_); +} + // RiaaBiquad Transform Operation. struct RiaaBiquad::Data { explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {} diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc index 3b36c63f88a..01a7ff4d1d1 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc @@ -42,6 +42,7 @@ #include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/magphase_ir.h" #include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" +#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h" #include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" @@ -325,6 +326,17 @@ PYBIND_REGISTER( })); })); +PYBIND_REGISTER(OverdriveOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "OverdriveOperation") + .def(py::init([](float gain, float color) { + auto overdrive = std::make_shared(gain, color); + THROW_IF_ERROR(overdrive->ValidateParams()); + return overdrive; + })); + })); + PYBIND_REGISTER( RiaaBiquadOperation, 1, ([](const py::module *m) { (void)py::class_>( diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt index a8ba9a6d4ad..18bf2821d28 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt @@ -24,6 +24,7 @@ add_library(audio-ir-kernels OBJECT lowpass_biquad_ir.cc magphase_ir.cc mu_law_decoding_ir.cc + overdrive_ir.cc riaa_biquad_ir.cc time_masking_ir.cc time_stretch_ir.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.cc new file mode 100644 index 00000000000..456bc0964f5 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.cc @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h" + +#include "minddata/dataset/audio/kernels/overdrive_op.h" +#include "minddata/dataset/kernels/ir/validators.h" + +namespace mindspore { +namespace dataset { +namespace audio { +OverdriveOperation::OverdriveOperation(float gain, float color) : gain_(gain), color_(color) {} + +Status OverdriveOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateScalar("Overdrive", "gain", gain_, {0.0f, 100.0f}, false, false)); + RETURN_IF_NOT_OK(ValidateScalar("Overdrive", "color", color_, {0.0f, 100.0f}, false, false)); + return Status::OK(); +} + +std::shared_ptr OverdriveOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(gain_, color_); + return tensor_op; +} + +Status OverdriveOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["gain"] = gain_; + args["color"] = color_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.h new file mode 100644 index 00000000000..65ef1bed1dc --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/overdrive_ir.h @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_OVERDRIVE_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_OVERDRIVE_IR_H_ + +#include +#include +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kOverdriveOperation[] = "Overdrive"; + +class OverdriveOperation : public TensorOperation { + public: + explicit OverdriveOperation(float gain, float color); + + ~OverdriveOperation() = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kOverdriveOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + float gain_; + float color_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_OVERDRIVE_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt index 9ab8ce91234..231afdd4570 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt @@ -25,6 +25,7 @@ add_library(audio-kernels OBJECT lowpass_biquad_op.cc magphase_op.cc mu_law_decoding_op.cc + overdrive_op.cc riaa_biquad_op.cc time_masking_op.cc time_stretch_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h index 6a514157046..43a9e584049 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h @@ -318,6 +318,80 @@ Status ComplexNorm(const std::shared_ptr &input, std::shared_ptr /// \return Status code. Status MuLawDecoding(const std::shared_ptr &input, std::shared_ptr *output, int quantization_channels); +/// \brief Apply a overdrive effect to the audio. +/// \param input Tensor of shape <..., time>. +/// \param output Tensor of shape <..., time>. +/// \param gain Coefficient of overload in dB. +/// \param color Coefficient of translation. +/// \return Status code. +template +Status Overdrive(const std::shared_ptr &input, std::shared_ptr *output, float gain, float color) { + TensorShape input_shape = input->shape(); + // input->2D. + auto rows = input->Size() / input_shape[-1]; + auto cols = input_shape[-1]; + TensorShape to_shape({rows, cols}); + RETURN_IF_NOT_OK(input->Reshape(to_shape)); + // apply dB2Linear on gain, 20dB is expect to gain. + float gain_ex = exp(gain * log(10) / 20.0); + color = color / 200; + // declare the array used to store the input. + std::vector input_vec; + // out_vec is used to save the result of applying overdrive. + std::vector out_vec; + // store intermediate results of input. + std::vector temp; + // scale and pan the input two-dimensional sound wave array to a certain extent. + for (auto itr = input->begin(); itr != input->end(); itr++) { + // store the value of traverse the input. + T temp_fp = *itr; + input_vec.push_back(temp_fp); + // use 0 to initialize out_vec. + out_vec.push_back(0); + T temp_fp2 = temp_fp * gain_ex + color; + // 0.5 + 2/3 * 0.75 = 1, zoom and shift the sound. + if (temp_fp2 < -1) { + temp.push_back(-2.0 / 3.0); + } else if (temp_fp2 > 1) { + temp.push_back(2.0 / 3.0); + } else { + temp.push_back(temp_fp2 - temp_fp2 * temp_fp2 * temp_fp2 / 3.0); + } + } + // last_in and last_out are the intermediate values for processing each moment. + std::vector last_in; + std::vector last_out; + for (size_t i = 0; i < cols; i++) { + last_in.push_back(0.0); + last_out.push_back(0.0); + } + // overdrive core loop. + for (size_t i = 0; i < cols; i++) { + size_t index = 0; + // calculate the value of each moment according to the rules of overdrive. + for (size_t j = i; j < rows * cols; j += cols, index++) { + // 0.995 is the preservation ratio of sound waves. + last_out[index] = temp[j] - last_in[index] + last_out[index] * 0.995; + last_in[index] = temp[j]; + // 0.5 + 2/3 * 0.75 = 1, zoom and shift the sound. + T temp_fp = input_vec[j] * 0.5 + last_out[index] * 0.75; + // clamp min=-1, max=1. + if (temp_fp < -1) { + out_vec[j] = -1.0; + } else if (temp_fp > 1) { + out_vec[j] = 1.0; + } else { + out_vec[j] = temp_fp; + } + } + } + // move data to output tensor. + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateFromVector(out_vec, input_shape, &out)); + *output = out; + return Status::OK(); +} + /// \brief Add a fade in and/or fade out to an input. /// \param[in] input: The input tensor. /// \param[out] output: Added fade in and/or fade out audio with the same shape. diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.cc new file mode 100644 index 00000000000..c2659dec697 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/kernels/overdrive_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/kernels/data/data_utils.h" + +namespace mindspore { +namespace dataset { +OverdriveOp::OverdriveOp(float gain, float color) : gain_(gain), color_(color) {} + +Status OverdriveOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + TensorShape input_shape = input->shape(); + // check input tensor dimension, it should be greater than 0. + CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "Overdrive: input tensor is not in shape of <..., time>."); + // check input type, it should be numeric. + CHECK_FAIL_RETURN_UNEXPECTED( + input->type().IsNumeric(), + "Overdrive: input tensor type should be int, float or double, but got: " + input->type().ToString()); + std::shared_ptr input_tensor; + if (input->type() != DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32))); + return Overdrive(input_tensor, output, gain_, color_); + } else { + input_tensor = input; + return Overdrive(input_tensor, output, gain_, color_); + } +} + +Status OverdriveOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + if (!inputs[0].IsNumeric()) { + RETURN_STATUS_UNEXPECTED("Overdrive: input tensor type should be int, float or double, but got: " + + inputs[0].ToString()); + } else if (inputs[0] == DataType(DataType::DE_FLOAT64)) { + outputs[0] = DataType(DataType::DE_FLOAT64); + } else { + outputs[0] = DataType(DataType::DE_FLOAT32); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.h new file mode 100644 index 00000000000..acb2a25db49 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/overdrive_op.h @@ -0,0 +1,48 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_OVERDRIVE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_OVERDRIVE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +class OverdriveOp : public TensorOp { + public: + explicit OverdriveOp(float gain, float color); + + ~OverdriveOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kOverdriveOp; } + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + private: + float gain_; + float color_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_OVERDRIVE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h index db021f24374..43413684df9 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h @@ -543,6 +543,27 @@ class MuLawDecoding final : public TensorTransform { std::shared_ptr data_; }; +/// \brief Overdrive TensorTransform. +class Overdrive final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] gain Coefficient of overload in dB, in range of [0, 100] (Default: 20.0). + /// \param[in] color Coefficient of translation, in range of [0, 100] (Default: 20.0). + explicit Overdrive(float gain = 20.0f, float color = 20.0f); + + /// \brief Destructor. + ~Overdrive() = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + /// \brief Apply RIAA vinyl playback equalization. class RiaaBiquad final : public TensorTransform { public: diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 106e2ba4bf8..aac5958b163 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -165,6 +165,7 @@ constexpr char kLFilterOp[] = "LFilterOp"; constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp"; constexpr char kMagphaseOp[] = "MagphaseOp"; constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp"; +constexpr char kOverdriveOp[] = "OverdriveOp"; constexpr char kRiaaBiquadOp[] = "RiaaBiquadOp"; constexpr char kTimeMaskingOp[] = "TimeMaskingOp"; constexpr char kTimeStretchOp[] = "TimeStretchOp"; diff --git a/mindspore/dataset/audio/transforms.py b/mindspore/dataset/audio/transforms.py index db96dc4cb1a..6539346f615 100644 --- a/mindspore/dataset/audio/transforms.py +++ b/mindspore/dataset/audio/transforms.py @@ -28,7 +28,7 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_ check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \ check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_flanger, \ check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, \ - check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol + check_overdrive, check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol class AudioTensorOperation(TensorOperation): @@ -719,6 +719,33 @@ class MuLawDecoding(AudioTensorOperation): return cde.MuLawDecodingOperation(self.quantization_channels) +class Overdrive(AudioTensorOperation): + """ + Apply overdrive on input audio. + + Args: + gain (float): Desired gain at the boost (or attenuation) in dB, in range of [0, 100] (default=20.0). + color (float): Controls the amount of even harmonic content in the over-driven output, + in range of [0, 100] (default=20.0). + + Examples: + >>> import numpy as np + >>> + >>> waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"]) + >>> transforms = [audio.Overdrive()] + >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) + """ + + @check_overdrive + def __init__(self, gain=20.0, color=20.0): + self.gain = gain + self.color = color + + def parse(self): + return cde.OverdriveOperation(self.gain, self.color) + + class RiaaBiquad(AudioTensorOperation): """ Apply RIAA vinyl playback equalization. Similar to SoX implementation. diff --git a/mindspore/dataset/audio/validators.py b/mindspore/dataset/audio/validators.py index 282e443cd0b..3cd3e091cdf 100644 --- a/mindspore/dataset/audio/validators.py +++ b/mindspore/dataset/audio/validators.py @@ -292,6 +292,21 @@ def check_mu_law_decoding(method): return new_method +def check_overdrive(method): + """Wrapper method to check the parameters of Overdrive.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [gain, color], _ = parse_user_args(method, *args, **kwargs) + type_check(gain, (float, int), "gain") + check_value(gain, [0, 100], "gain") + type_check(color, (float, int), "color") + check_value(color, [0, 100], "color") + return method(self, *args, **kwargs) + + return new_method + + def check_riaa_biquad(method): """Wrapper method to check the parameters of RiaaBiquad.""" diff --git a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc index 11ac3498a27..7156ec22954 100644 --- a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc +++ b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc @@ -955,6 +955,77 @@ TEST_F(MindDataTestPipeline, TestMuLawDecodingWrongArgs) { EXPECT_EQ(iter1, nullptr); } +/// Feature: Overdrive +/// Description: test basic usage of Overdrive +/// Expectation: get correct number of data +TEST_F(MindDataTestPipeline, TestOverdriveBasic) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOverdriveBasic."; + // Original waveform + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200})); + std::shared_ptr ds = RandomData(50, schema); + EXPECT_NE(ds, nullptr); + + ds = ds->SetNumWorkers(4); + EXPECT_NE(ds, nullptr); + + auto OverdriveOp = audio::Overdrive(); + + ds = ds->Map({OverdriveOp}); + EXPECT_NE(ds, nullptr); + + // Apply a phasing effect to the audio + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(ds, nullptr); + + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + std::vector expected = {2, 200}; + + int i = 0; + while (row.size() != 0) { + auto col = row["waveform"]; + ASSERT_EQ(col.Shape(), expected); + ASSERT_EQ(col.Shape().size(), 2); + ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32); + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + EXPECT_EQ(i, 50); + iter->Stop(); +} + +/// Feature: Overdrive +/// Description: test invalid parameter of Overdrive +/// Expectation: throw exception correctly +TEST_F(MindDataTestPipeline, TestOverdriveWrongArg) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOverdriveWrongArg."; + std::shared_ptr schema = Schema(); + // Original waveform + ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 2})); + std::shared_ptr ds = RandomData(50, schema); + std::shared_ptr ds01; + std::shared_ptr ds02; + EXPECT_NE(ds, nullptr); + + // Check gain out of range [0,100] + MS_LOG(INFO) << "gain is less than 0."; + auto overdrive_op_01 = audio::Overdrive(-0.2, 20.0); + ds01 = ds->Map({overdrive_op_01}); + EXPECT_NE(ds01, nullptr); + std::shared_ptr iter01 = ds01->CreateIterator(); + EXPECT_EQ(iter01, nullptr); + + // Check color out of range [0,100] + MS_LOG(INFO) << "color is greater than 100."; + auto overdrive_op_02 = audio::Overdrive(20.0, 102.3); + ds02 = ds->Map({overdrive_op_02}); + EXPECT_NE(ds02, nullptr); + std::shared_ptr iter02 = ds02->CreateIterator(); + EXPECT_EQ(iter02, nullptr); +} + TEST_F(MindDataTestPipeline, TestLfilterPipeline) { MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLfilterPipeline."; // Original waveform diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc index 9e5070e679f..e9713f0527e 100644 --- a/tests/ut/cpp/dataset/execute_test.cc +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -876,6 +876,56 @@ TEST_F(MindDataTestExecute, TestMuLawDecodingEager) { EXPECT_TRUE(s01.IsOk()); } +/// Feature: Overdrive +/// Description: test basic usage of Overdrive +/// Expectation: get correct number of data +TEST_F(MindDataTestExecute, TestOverdriveBasicWithEager) { + MS_LOG(INFO) << "Doing MindDataTestExecute-TestOverdriveBasicWithEager."; + // Original waveform + std::vector labels = { + 2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02, + 1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02, + 1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02, + 1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02, + 1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03}; + std::shared_ptr input; + ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input)); + auto input_02 = mindspore::MSTensor(std::make_shared(input)); + std::shared_ptr phaser_op_01 = std::make_shared(5.0, 3.0); + mindspore::dataset::Execute Transform01({phaser_op_01}); + Status s01 = Transform01(input_02, &input_02); + EXPECT_TRUE(s01.IsOk()); +} + +/// Feature: Overdrive +/// Description: test invalid parameter of Overdrive +/// Expectation: throw exception correctly +TEST_F(MindDataTestExecute, TestOverdriveWrongArgWithEager) { + MS_LOG(INFO) << "Doing MindDataTestExecute-TestOverdriveWrongArgWithEager"; + std::vector labels = { + 0.271, 1.634, 9.246, 0.108, + 1.138, 1.156, 3.394, 1.55, + 3.614, 1.8402, 0.718, 4.599, + 5.64, 2.510620117187500000e-02, 1.38, 5.825, + 4.1906, 5.28, 1.052, 9.36}; + std::shared_ptr input; + ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input)); + + // verify the gain range from 0 to 100 + auto input_01 = mindspore::MSTensor(std::make_shared(input)); + std::shared_ptr overdrive_op1 = std::make_shared(100.1); + mindspore::dataset::Execute Transform01({overdrive_op1}); + Status s01 = Transform01(input_01, &input_01); + EXPECT_FALSE(s01.IsOk()); + + // verify the color range from 0 to 100 + auto input_02 = mindspore::MSTensor(std::make_shared(input)); + std::shared_ptr overdrive_op2 = std::make_shared(5.0, 100.1); + mindspore::dataset::Execute Transform02({overdrive_op2}); + Status s02 = Transform02(input_02, &input_02); + EXPECT_FALSE(s02.IsOk()); +} + TEST_F(MindDataTestExecute, TestRiaaBiquadWithEager) { MS_LOG(INFO) << "Doing MindDataTestExecute-TestRiaaBiquadWithEager."; // Original waveform diff --git a/tests/ut/python/dataset/test_overdrive.py b/tests/ut/python/dataset/test_overdrive.py new file mode 100644 index 00000000000..b1cfa61801f --- /dev/null +++ b/tests/ut/python/dataset/test_overdrive.py @@ -0,0 +1,100 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.audio.transforms as audio +from mindspore import log as logger + + +def count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format( + data_expected[greater], data_me[greater], error[greater]) + + +def test_overdrive_eager(): + """ + Feature: Overdrive + Description: test Overdrive in eager mode + Expectation: the results are as expected + """ + # Original waveform + waveform = np.array([[1.47, 4.722, 5.863], [0.492, 0.235, 0.56]], dtype=np.float32) + # Expect waveform + expect_waveform = np.array([[1., 1., 1.], + [0.74600005, 0.615, 0.77501255]], dtype=np.float32) + overdrive_op = audio.Overdrive() + # Filtered waveform by overdrive + output = overdrive_op(waveform) + count_unequal_element(expect_waveform, output, 0.0001, 0.0001) + + +def test_overdrive_pipeline(): + """ + Feature: Overdrive + Description: test Overdrive in pipeline mode + Expectation: the results are as expected + """ + # Original waveform + waveform = np.array([[0.1, 0.2], [0.4, 2.6]], dtype=np.float32) + # Expect waveform + expect_waveform = np.array([[0.29598799, 0.52081579], + [0.7, 1.]], dtype=np.float32) + dataset = ds.NumpySlicesDataset(waveform, ["waveform"], shuffle=False) + overdrive_op = audio.Overdrive(10.0, 5.0) + # Filtered waveform by overdrive + dataset = dataset.map( + input_columns=["waveform"], operations=overdrive_op) + i = 0 + for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True): + count_unequal_element(expect_waveform[i, :], + item['waveform'], 0.0001, 0.0001) + i += 1 + + +def test_overdrive_invalid_input(): + """ + Feature: Overdrive + Description: test invalid parameter of Overdrive + Expectation: catch exceptions correctly + """ + def test_invalid_input(test_name, gain, color, error, error_msg): + logger.info("Test Overdrive with bad input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + audio.Overdrive(gain, color) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid gain parameter type as a str", "20", 20, TypeError, + "Argument gain with value 20 is not of type [, ]," + + " but got .") + test_invalid_input("invalid color parameter type as a str", 10, "5", TypeError, + "Argument color with value 5 is not of type [, ]," + + " but got .") + test_invalid_input("invalid gain out of range [0, 100]", 100.23, 5.0, ValueError, + "Input gain is not within the required interval of [0, 100].") + test_invalid_input("invalid color out of range [0, 100]", 30, -0.333, ValueError, + "Input color is not within the required interval of [0, 100].") + + +if __name__ == "__main__": + test_overdrive_eager() + test_overdrive_pipeline() + test_overdrive_invalid_input()