diff --git a/mindspore/ccsrc/minddata/dataset/api/audio.cc b/mindspore/ccsrc/minddata/dataset/api/audio.cc index a02610af6c6..9838e3cd9dc 100644 --- a/mindspore/ccsrc/minddata/dataset/api/audio.cc +++ b/mindspore/ccsrc/minddata/dataset/api/audio.cc @@ -26,6 +26,7 @@ #include "minddata/dataset/audio/ir/kernels/biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" #include "minddata/dataset/audio/ir/kernels/contrast_ir.h" +#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h" #include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" #include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h" @@ -197,6 +198,19 @@ std::shared_ptr Contrast::Parse() { return std::make_shared(data_->enhancement_amount_); } +// DBToAmplitude Transform Operation. +struct DBToAmplitude::Data { + explicit Data(float ref, float power) : ref_(ref), power_(power) {} + float ref_; + float power_; +}; + +DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared(power, power)) {} + +std::shared_ptr DBToAmplitude::Parse() { + return std::make_shared(data_->ref_, data_->power_); +} + // DCShift Transform Operation. struct DCShift::Data { Data(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {} diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc index d9886a579b9..fdb9ce3abb6 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc @@ -30,6 +30,7 @@ #include "minddata/dataset/audio/ir/kernels/biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h" #include "minddata/dataset/audio/ir/kernels/contrast_ir.h" +#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h" #include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" #include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h" @@ -170,6 +171,17 @@ PYBIND_REGISTER(ContrastOperation, 1, ([](const py::module *m) { })); })); +PYBIND_REGISTER( + DBToAmplitudeOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "DBToAmplitudeOperation") + .def(py::init([](float ref, float power) { + auto db_to_amplitude = std::make_shared(ref, power); + THROW_IF_ERROR(db_to_amplitude->ValidateParams()); + return db_to_amplitude; + })); + })); + PYBIND_REGISTER(DCShiftOperation, 1, ([](const py::module *m) { (void)py::class_>( *m, "DCShiftOperation") diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt index 17b4057003e..806daef0d33 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt @@ -12,6 +12,7 @@ add_library(audio-ir-kernels OBJECT biquad_ir.cc complex_norm_ir.cc contrast_ir.cc + db_to_amplitude_ir.cc dc_shift_ir.cc deemph_biquad_ir.cc detect_pitch_frequency_ir.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.cc new file mode 100644 index 00000000000..a5af6db30ed --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h" + +#include "minddata/dataset/audio/kernels/db_to_amplitude_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { +// DBToAmplitudeOperation +DBToAmplitudeOperation::DBToAmplitudeOperation(float ref, float power) : ref_(ref), power_(power) {} + +Status DBToAmplitudeOperation::ValidateParams() { return Status::OK(); } + +std::shared_ptr DBToAmplitudeOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(ref_, power_); + return tensor_op; +} + +Status DBToAmplitudeOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["ref"] = ref_; + args["power"] = power_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h new file mode 100644 index 00000000000..ff17bdab1ac --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DB_TO_AMPLITUDE_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DB_TO_AMPLITUDE_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { +constexpr char kDBToAmplitudeOperation[] = "DBToAmplitude"; + +class DBToAmplitudeOperation : public TensorOperation { + public: + DBToAmplitudeOperation(float ref, float power); + + ~DBToAmplitudeOperation() = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kDBToAmplitudeOperation; } + + Status to_json(nlohmann::json *out_json) override; + + private: + float ref_; + float power_; +}; +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DB_TO_AMPLITUDE_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt index e946b705e2b..75671138d15 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt @@ -13,6 +13,7 @@ add_library(audio-kernels OBJECT biquad_op.cc complex_norm_op.cc contrast_op.cc + db_to_amplitude_op.cc dc_shift_op.cc deemph_biquad_op.cc detect_pitch_frequency_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h index 570b83cc4aa..53e65803362 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h @@ -152,6 +152,24 @@ Status Contrast(const std::shared_ptr &input, std::shared_ptr *o return Status::OK(); } +/// \brief Apply DBToAmplitude effect. +/// \param input/output: Tensor of shape <...,time> +/// \param ref: Reference which the output will be scaled by. +/// \param power: If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude. +/// \return Status code +template +Status DBToAmplitude(const std::shared_ptr &input, std::shared_ptr *output, T ref, T power) { + std::shared_ptr out; + RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), &out)); + auto itr_out = out->begin(); + for (auto itr_in = input->begin(); itr_in != input->end(); itr_in++) { + *itr_out = ref * pow(pow(10, (*itr_in) * 0.1), power); + itr_out++; + } + *output = out; + return Status::OK(); +} + /// \brief Apply a DC shift to the audio. /// \param input/output: Tensor of shape <...,time>. /// \param shift: the amount to shift the audio. diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/db_to_amplitude_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/db_to_amplitude_op.cc new file mode 100644 index 00000000000..3e35079110e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/db_to_amplitude_op.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/kernels/db_to_amplitude_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +Status DBToAmplitudeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + TensorShape input_shape = input->shape(); + // check input tensor dimension, it should be greater than 0. + CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "DBToAmplitude: input tensor is not in shape of <..., time>."); + CHECK_FAIL_RETURN_UNEXPECTED( + input->type().IsNumeric(), + "DBToAmplitude: input tensor type should be int, float or double, but got " + input->type().ToString()); + + std::shared_ptr input_tensor; + if (input->type() != DataType(DataType::DE_FLOAT64)) { + RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32))); + return DBToAmplitude(input_tensor, output, ref_, power_); + } else { + input_tensor = input; + return DBToAmplitude(input_tensor, output, ref_, power_); + } +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/db_to_amplitude_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/db_to_amplitude_op.h new file mode 100644 index 00000000000..22784dd2aa7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/db_to_amplitude_op.h @@ -0,0 +1,50 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DB_TO_AMPLITUDE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DB_TO_AMPLITUDE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { +class DBToAmplitudeOp : public TensorOp { + public: + DBToAmplitudeOp(float ref, float power) : ref_(ref), power_(power) {} + + ~DBToAmplitudeOp() override = default; + + void Print(std::ostream &out) const override { + out << Name() << " ref: " << ref_ << ", power: " << power_ << std::endl; + } + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kDBToAmplitudeOp; } + + private: + float ref_; + float power_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DB_TO_AMPLITUDE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h index 85e7e49adcb..7f2240611f2 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h @@ -256,6 +256,27 @@ class Contrast final : public TensorTransform { std::shared_ptr data_; }; +/// \brief Turn a waveform from the decibel scale to the power/amplitude scale. +class DBToAmplitude final : public TensorTransform { + public: + /// \brief Constructor + /// \param[in] ref Reference which the output will be scaled by. + /// \param[in] power If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude. + explicit DBToAmplitude(float ref, float power); + + /// \brief Destructor. + ~DBToAmplitude() = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + /// \brief Apply a DC shift to the audio. class DCShift : public TensorTransform { public: diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 6709c57de90..e803ec6df7b 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -154,6 +154,7 @@ constexpr char kBassBiquadOp[] = "BassBiquadOp"; constexpr char kBiquadOp[] = "BiquadOp"; constexpr char kComplexNormOp[] = "ComplexNormOp"; constexpr char kContrastOp[] = "ContrastOp"; +constexpr char kDBToAmplitudeOp[] = " DBToAmplitudeOp"; constexpr char kDCShiftOp[] = "DCShiftOp"; constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp"; constexpr char kDetectPitchFrequencyOp[] = "DetectPitchFrequencyOp"; diff --git a/mindspore/dataset/audio/transforms.py b/mindspore/dataset/audio/transforms.py index 3337689174a..7a62ba0841b 100644 --- a/mindspore/dataset/audio/transforms.py +++ b/mindspore/dataset/audio/transforms.py @@ -25,10 +25,11 @@ import mindspore._c_dataengine as cde from ..transforms.c_transforms import TensorOperation from .utils import FadeShape, GainType, Interpolation, Modulation, ScaleType from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \ - check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \ - check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_flanger, \ - check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, \ - check_overdrive, check_phaser, check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol + check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, \ + check_db_to_amplitude, check_dc_shift, check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, \ + check_fade, check_flanger, check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, \ + check_masking, check_mu_law_coding, check_overdrive, check_phaser, check_riaa_biquad, check_time_stretch, \ + check_treble_biquad, check_vol class AudioTensorOperation(TensorOperation): @@ -330,6 +331,32 @@ class Contrast(AudioTensorOperation): return cde.ContrastOperation(self.enhancement_amount) +class DBToAmplitude(AudioTensorOperation): + """ + Turn a waveform from the decibel scale to the power/amplitude scale. + + Args: + ref (float): Reference which the output will be scaled by. + power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude. + + Examples: + >>> import numpy as np + >>> + >>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]]) + >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"]) + >>> transforms = [audio.DBToAmplitude(0.5, 0.5)] + >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) + """ + + @check_db_to_amplitude + def __init__(self, ref, power): + self.ref = ref + self.power = power + + def parse(self): + return cde.DBToAmplitudeOperation(self.ref, self.power) + + class DCShift(AudioTensorOperation): """ Apply a DC shift to the audio. diff --git a/mindspore/dataset/audio/validators.py b/mindspore/dataset/audio/validators.py index f5f04bb5d94..eb108f09d05 100644 --- a/mindspore/dataset/audio/validators.py +++ b/mindspore/dataset/audio/validators.py @@ -195,6 +195,21 @@ def check_contrast(method): return new_method +def check_db_to_amplitude(method): + """Wrapper method to check the parameters of db_to_amplitude.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [ref, power], _ = parse_user_args(method, *args, **kwargs) + type_check(ref, (float, int), "ref") + check_float32(ref, "ref") + type_check(power, (float, int), "power") + check_float32(power, "power") + return method(self, *args, **kwargs) + + return new_method + + def check_dc_shift(method): """Wrapper method to check the parameters of DCShift.""" diff --git a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc index df85e9de8d9..795d106353b 100644 --- a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc +++ b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc @@ -1916,3 +1916,44 @@ TEST_F(MindDataTestPipeline, TestCreateDctWrongArg) { Status s04 = audio::CreateDct(&output, 200, -400, NormMode::kOrtho); EXPECT_FALSE(s04.IsOk()); } + +/// Feature: DBToAmplitude +/// Description: test DBToAmplitude in pipeline mode +/// Expectation: the data is processed successfully +TEST_F(MindDataTestPipeline, TestDBToAmplitudePipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBToAmplitudePipeline."; + // Original waveform + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200})); + std::shared_ptr ds = RandomData(50, schema); + EXPECT_NE(ds, nullptr); + + ds = ds->SetNumWorkers(4); + EXPECT_NE(ds, nullptr); + + auto db_to_amplitude_op = audio::DBToAmplitude(2, 2); + + ds = ds->Map({db_to_amplitude_op}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(ds, nullptr); + + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + std::vector expected = {2, 200}; + + int i = 0; + while (row.size() != 0) { + auto col = row["waveform"]; + ASSERT_EQ(col.Shape(), expected); + ASSERT_EQ(col.Shape().size(), 2); + ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32); + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + EXPECT_EQ(i, 50); + + iter->Stop(); +} diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc index 772e0fa194f..c7368614450 100644 --- a/tests/ut/cpp/dataset/execute_test.cc +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -1480,3 +1480,25 @@ TEST_F(MindDataTestExecute, TestFlangerWithWrongArg) { Status s01 = Transform01(input_02, &input_02); EXPECT_FALSE(s01.IsOk()); } + +// Feature: DBToAmplitude +// Description: test DBToAmplitude in eager mode +// Expectation: the data is processed successfully +TEST_F(MindDataTestExecute, TestDBToAmplitudeWithEager) { + MS_LOG(INFO) << "Doing MindDataTestExecute-TestDBToAmplitudeWithEager."; + // Original waveform + std::vector labels = { + 2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02, + 1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02, + 1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02, + 1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02, + 1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03}; + std::shared_ptr input; + ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input)); + auto input_02 = mindspore::MSTensor(std::make_shared(input)); + std::shared_ptr DBToAmplitude_01 = std::make_shared(2, 2); + mindspore::dataset::Execute Transform01({DBToAmplitude_01}); + // Filtered waveform by DBToAmplitude + Status s01 = Transform01(input_02, &input_02); + EXPECT_TRUE(s01.IsOk()); +} diff --git a/tests/ut/python/dataset/test_db_to_amplitude.py b/tests/ut/python/dataset/test_db_to_amplitude.py new file mode 100644 index 00000000000..9fa57dcb2d3 --- /dev/null +++ b/tests/ut/python/dataset/test_db_to_amplitude.py @@ -0,0 +1,105 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.audio.transforms as audio +from mindspore import log as logger + + +def count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def test_db_to_amplitude_eager(): + """ + Feature: DBToAmplitude + Description: test DBToAmplitude in eager mode + Expectation: the data is processed successfully + """ + logger.info("mindspore eager mode normal testcase:DBToAmplitude op") + + # Original waveform + waveform = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64) + # Expect waveform + expect_waveform = np.array([3.1698, 5.0238, 7.9621, 12.6191, 20.0000, 31.6979], dtype=np.float64) + DBToAmplitude_op = audio.DBToAmplitude(2, 2) + # Filtered waveform by DBToAmplitude + output = DBToAmplitude_op(waveform) + count_unequal_element(expect_waveform, output, 0.0001, 0.0001) + + +def test_db_to_amplitude_pipeline(): + """ + Feature: DBToAmplitude + Description: test DBToAmplitude in pipeline mode + Expectation: the data is processed successfully + """ + logger.info("mindspore pipeline mode normal testcase:DBToAmplitude op") + + # Original waveform + waveform = np.array([[2, 2, 3], [0.1, 0.2, 0.3]], dtype=np.float64) + # Expect waveform + expect_waveform = np.array([[2.5119, 2.5119, 3.9811], + [1.0471, 1.0965, 1.1482]], dtype=np.float64) + + dataset = ds.NumpySlicesDataset(waveform, ["audio"], shuffle=False) + DBToAmplitude_op = audio.DBToAmplitude(1, 2) + # Filtered waveform by DBToAmplitude + dataset = dataset.map(input_columns=["audio"], operations=DBToAmplitude_op, num_parallel_workers=8) + i = 0 + for item in dataset.create_dict_iterator(output_numpy=True): + count_unequal_element(expect_waveform[i, :], item['audio'], 0.0001, 0.0001) + i += 1 + + +def test_db_to_amplitude_invalid_input(): + """ + Feature: DBToAmplitude + Description: test param check of DBToAmplitude + Expectation: throw correct error and message + """ + logger.info("mindspore eager mode invalid input testcase:filter_wikipedia_xml op") + + def test_invalid_input(test_name, ref, power, error, error_msg): + logger.info("Test DBToAmplitude with bad input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + audio.DBToAmplitude(ref, power) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid ref parameter type as a String", "1.0", 1.0, TypeError, + "Argument ref with value 1.0 is not of type [, ]," + " but got .") + test_invalid_input("invalid ref parameter value", 122323242445423534543, 1.0, ValueError, + "Input ref is not within the required interval of [-16777216, 16777216].") + test_invalid_input("invalid power parameter type as a String", 1.0, "1.0", TypeError, + "Argument power with value 1.0 is not of type [, ]," + " but got .") + test_invalid_input("invalid power parameter value", 1.0, 1343454254325445, ValueError, + "Input power is not within the required interval of [-16777216, 16777216].") + + +if __name__ == "__main__": + test_db_to_amplitude_eager() + test_db_to_amplitude_pipeline() + test_db_to_amplitude_invalid_input()