From 9e62623b4a075fa7defb23757b9e3527c2aad95a Mon Sep 17 00:00:00 2001 From: chenx2ovo <417994009@qq.com> Date: Tue, 31 Aug 2021 23:24:29 +0800 Subject: [PATCH] [feat][assistant][I3CEGI] add op vol --- mindspore/ccsrc/minddata/dataset/api/audio.cc | 14 ++ .../dataset/audio/kernels/ir/bindings.cc | 18 ++ .../dataset/audio/ir/kernels/CMakeLists.txt | 1 + .../dataset/audio/ir/kernels/vol_ir.cc | 57 +++++++ .../dataset/audio/ir/kernels/vol_ir.h | 55 ++++++ .../dataset/audio/kernels/CMakeLists.txt | 1 + .../dataset/audio/kernels/audio_utils.h | 50 +++++- .../minddata/dataset/audio/kernels/vol_op.cc | 55 ++++++ .../minddata/dataset/audio/kernels/vol_op.h | 47 +++++ .../minddata/dataset/include/dataset/audio.h | 23 +++ .../dataset/include/dataset/constants.h | 7 + .../minddata/dataset/kernels/tensor_op.h | 1 + mindspore/dataset/audio/transforms.py | 54 ++++-- mindspore/dataset/audio/utils.py | 19 ++- mindspore/dataset/audio/validators.py | 23 ++- .../ut/cpp/dataset/c_api_audio_r_to_z_test.cc | 62 ++++++- tests/ut/cpp/dataset/execute_test.cc | 27 ++- tests/ut/python/dataset/test_vol.py | 160 ++++++++++++++++++ 18 files changed, 648 insertions(+), 26 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.cc create mode 100644 mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.h create mode 100644 mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.h create mode 100644 tests/ut/python/dataset/test_vol.py diff --git a/mindspore/ccsrc/minddata/dataset/api/audio.cc b/mindspore/ccsrc/minddata/dataset/api/audio.cc index d7c5606748c..ec1af151891 100755 --- a/mindspore/ccsrc/minddata/dataset/api/audio.cc +++ b/mindspore/ccsrc/minddata/dataset/api/audio.cc @@ -37,6 +37,7 @@ #include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" #include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" +#include "minddata/dataset/audio/ir/kernels/vol_ir.h" namespace mindspore { namespace dataset { @@ -360,6 +361,19 @@ TimeStretch::TimeStretch(float hop_length, int n_freq, float fixed_rate) std::shared_ptr TimeStretch::Parse() { return std::make_shared(data_->hop_length_, data_->n_freq_, data_->fixed_rate_); } + +// Vol Transform Operation. +struct Vol::Data { + Data(float gain, GainType gain_type) : gain_(gain), gain_type_(gain_type) {} + float gain_; + GainType gain_type_; +}; + +Vol::Vol(float gain, GainType gain_type) : data_(std::make_shared(gain, gain_type)) {} + +std::shared_ptr Vol::Parse() { + return std::make_shared(data_->gain_, data_->gain_type_); +} } // namespace audio } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc index 33ff7134758..0aa8149eb14 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc @@ -41,6 +41,7 @@ #include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" #include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" +#include "minddata/dataset/audio/ir/kernels/vol_ir.h" namespace mindspore { namespace dataset { @@ -292,5 +293,22 @@ PYBIND_REGISTER( })); })); +PYBIND_REGISTER(VolOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "VolOperation") + .def(py::init([](float gain, GainType gain_type) { + auto vol = std::make_shared(gain, gain_type); + THROW_IF_ERROR(vol->ValidateParams()); + return vol; + })); + })); + +PYBIND_REGISTER(GainType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "GainType", py::arithmetic()) + .value("DE_GAINTYPE_AMPLITUDE", GainType::kAmplitude) + .value("DE_GAINTYPE_POWER", GainType::kPower) + .value("DE_GAINTYPE_DB", GainType::kDb) + .export_values(); + })); } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt index a052b32b01c..cd54df25de4 100755 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt @@ -23,5 +23,6 @@ add_library(audio-ir-kernels OBJECT mu_law_decoding_ir.cc time_masking_ir.cc time_stretch_ir.cc + vol_ir.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.cc new file mode 100644 index 00000000000..6fe751ffd8d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.cc @@ -0,0 +1,57 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/vol_ir.h" + +#include "minddata/dataset/audio/kernels/vol_op.h" + +namespace mindspore { +namespace dataset { +namespace audio { + +// Vol +VolOperation::VolOperation(float gain, GainType gain_type) : gain_(gain), gain_type_(gain_type) {} + +VolOperation::~VolOperation() = default; + +std::string VolOperation::Name() const { return kVolOperation; } + +Status VolOperation::ValidateParams() { + CHECK_FAIL_RETURN_UNEXPECTED( + !(gain_type_ == GainType::kPower && gain_ <= 0), + "Vol: gain must be greater than 0 when gain_type is Power, but got: " + std::to_string(gain_)); + + CHECK_FAIL_RETURN_UNEXPECTED( + !(gain_type_ == GainType::kAmplitude && gain_ < 0), + "Vol: gain must be greater than or equal to 0 when gain_type is Amplitude, but got: " + std::to_string(gain_)); + return Status::OK(); +} + +std::shared_ptr VolOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(gain_, gain_type_); + return tensor_op; +} + +Status VolOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["gain"] = gain_; + args["gain_type"] = gain_type_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.h new file mode 100644 index 00000000000..63949404011 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/vol_ir.h @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_VOL_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_VOL_IR_H_ + +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { + +constexpr char kVolOperation[] = "Vol"; + +class VolOperation : public TensorOperation { + public: + VolOperation(float gain, GainType gain_type); + + ~VolOperation(); + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override; + + Status to_json(nlohmann::json *out_json) override; + + private: + float gain_; + GainType gain_type_; +}; + +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_VOL_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt index 7e6833f1bb3..3cd09cec7e5 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt @@ -24,5 +24,6 @@ add_library(audio-kernels OBJECT mu_law_decoding_op.cc time_masking_op.cc time_stretch_op.cc + vol_op.cc ) diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h index 9a54b86d2f9..db0c8f8f2a0 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h @@ -272,11 +272,11 @@ Status LFilter(const std::shared_ptr &input, std::shared_ptr *ou } /// \brief Stretch STFT in time at a given rate, without changing the pitch. -/// \param[in] input - Tensor of shape <..., freq, time>. -/// \param[in] rate - Stretch factor. -/// \param[in] phase_advance - Expected phase advance in each bin. -/// \param[out] output - Tensor after stretch in time domain. -/// \return Status return code. +/// \param input: Tensor of shape <..., freq, time>. +/// \param rate: Stretch factor. +/// \param phase_advance: Expected phase advance in each bin. +/// \param output: Tensor after stretch in time domain. +/// \return Status code. Status TimeStretch(std::shared_ptr input, std::shared_ptr *output, float rate, float hop_length, float n_freq); @@ -325,6 +325,46 @@ Status MuLawDecoding(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape); + +/// \brief Add a volume to an waveform. +/// \param input/output: Tensor of shape <..., time>. +/// \param gain: Gain value, varies according to the value of gain_type. +/// \param gain_type: Type of gain, should be one of [GainType::kAmplitude, GainType::kDb, GainType::kPower]. +/// \return Status code. +template +Status Vol(const std::shared_ptr &input, std::shared_ptr *output, T gain, GainType gain_type) { + const T lower_bound = -1; + const T upper_bound = 1; + + // DB is a unit which converts a numeric value into decibel scale and for conversion, we have to use log10 + // A(in dB) = 20log10(A in amplitude) + // When referring to measurements of power quantities, a ratio can be expressed as a level in decibels by evaluating + // ten times the base-10 logarithm of the ratio of the measured quantity to reference value + // A(in dB) = 10log10(A in power) + const int power_factor_div = 20; + const int power_factor_mul = 10; + const int base = 10; + + if (gain_type == GainType::kDb) { + if (gain != 0) { + gain = std::pow(base, (gain / power_factor_div)); + } + } else if (gain_type == GainType::kPower) { + gain = power_factor_mul * std::log10(gain); + gain = std::pow(base, (gain / power_factor_div)); + } + + for (auto itr = input->begin(); itr != input->end(); itr++) { + if (gain != 0 || gain_type == GainType::kAmplitude) { + *itr = (*itr) * gain; + } + *itr = std::min(std::max((*itr), lower_bound), upper_bound); + } + + *output = input; + + return Status::OK(); +} } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.cc new file mode 100644 index 00000000000..e3f5cfb4b6b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/vol_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status VolOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + std::shared_ptr input_tensor; + TensorShape input_shape = input->shape(); + CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "Vol: input tensor is not in shape of <..., time>."); + CHECK_FAIL_RETURN_UNEXPECTED( + input->type().IsNumeric(), + "Vol: input tensor type should be int, float or double, but got: " + input->type().ToString()); + if (input->type() != DataType::DE_FLOAT64) { + RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32))); + return Vol(input_tensor, output, gain_, gain_type_); + } else { + input_tensor = input; + return Vol(input_tensor, output, static_cast(gain_), gain_type_); + } +} + +Status VolOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + if (!inputs[0].IsNumeric()) { + RETURN_STATUS_UNEXPECTED("Vol: input tensor type should be int, float or double, but got: " + inputs[0].ToString()); + } else if (inputs[0] == DataType(DataType::DE_FLOAT64)) { + outputs[0] = DataType(DataType::DE_FLOAT64); + } else { + outputs[0] = DataType(DataType::DE_FLOAT32); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.h new file mode 100644 index 00000000000..74a023367b4 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/vol_op.h @@ -0,0 +1,47 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_VOL_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_VOL_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class VolOp : public TensorOp { + public: + explicit VolOp(float gain, GainType gain_type = GainType::kAmplitude) : gain_(gain), gain_type_(gain_type) {} + + ~VolOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kVolOp; } + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + private: + float gain_; + GainType gain_type_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_VOL_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h index c5cd597f1f8..1aadb74923b 100755 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h @@ -515,6 +515,29 @@ class TimeStretch final : public TensorTransform { std::shared_ptr data_; }; +/// \brief Vol TensorTransform. +/// \notes Add a volume to an waveform. +class Vol final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] gain Gain value, varies according to the value of gain_type. If gain_type is GainType::kAmplitude, + /// gain must be greater than or equal to zero. If gain_type is GainType::kPower, gain must be greater than zero. + /// If gain_type is GainType::kDb, there is no limit for gain. + /// \param[in] gain_type Type of gain, should be one of [GainType::kAmplitude, GainType::kDb, GainType::kPower]. + explicit Vol(float gain, GainType gain_type = GainType::kAmplitude); + + /// \brief Destructor. + ~Vol() = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; } // namespace audio } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h index 7140d5c3a61..f73d9d14bd0 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h @@ -79,6 +79,13 @@ enum class ScaleType { kPower = 1, ///< Audio scale is power. }; +/// \brief The scale for gain type. +enum class GainType { + kAmplitude = 0, ///< Audio gain type is amplitude. + kPower = 1, ///< Audio gain type is power. + kDb = 2, ///< Audio gain type is db. +}; + /// \brief The method of padding. enum class BorderType { kConstant = 0, ///< Fill the border with constant values. diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index d9a8b63dd88..e32a8cece33 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -160,6 +160,7 @@ constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp"; constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp"; constexpr char kTimeMaskingOp[] = "TimeMaskingOp"; constexpr char kTimeStretchOp[] = "TimeStretchOp"; +constexpr char kVolOp[] = "VolOp"; // data constexpr char kConcatenateOp[] = "ConcatenateOp"; diff --git a/mindspore/dataset/audio/transforms.py b/mindspore/dataset/audio/transforms.py index 30640d98cd4..a78dca34d64 100755 --- a/mindspore/dataset/audio/transforms.py +++ b/mindspore/dataset/audio/transforms.py @@ -23,11 +23,11 @@ import numpy as np import mindspore._c_dataengine as cde from ..transforms.c_transforms import TensorOperation -from .utils import FadeShape, ScaleType +from .utils import FadeShape, GainType, ScaleType from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \ check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \ check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \ - check_lowpass_biquad, check_masking, check_mu_law_decoding, check_time_stretch + check_lowpass_biquad, check_masking, check_mu_law_decoding, check_time_stretch, check_vol class AudioTensorOperation(TensorOperation): @@ -644,15 +644,12 @@ class TimeStretch(AudioTensorOperation): fixed_rate (float, optional): Rate to speed up or slow down the input in time (default=None). Examples: - >>> freq = 44100 - >>> num_frame = 30 - >>> def gen(): - ... np.random.seed(0) - ... data = np.random.random([freq, num_frame]) - ... yield (np.array(data, dtype=np.float32), ) - >>> data1 = ds.GeneratorDataset(source=gen, column_names=["multi_dimensional_data"]) - >>> transforms = [py_audio.TimeStretch()] - >>> data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) + >>> import numpy as np + >>> + >>> waveform = np.random.random([1, 30]) + >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"]) + >>> transforms = [audio.TimeStretch()] + >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) """ @check_time_stretch def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): @@ -665,3 +662,38 @@ class TimeStretch(AudioTensorOperation): def parse(self): return cde.TimeStretchOperation(self.hop_length, self.n_freq, self.fixed_rate) + + +DE_C_GAINTYPE_TYPE = {GainType.AMPLITUDE: cde.GainType.DE_GAINTYPE_AMPLITUDE, + GainType.POWER: cde.GainType.DE_GAINTYPE_POWER, + GainType.DB: cde.GainType.DE_GAINTYPE_DB} + + +class Vol(AudioTensorOperation): + """ + Apply amplification or attenuation to the whole waveform. + + Args: + gain (float): Value of gain adjustment. + If gain_type = amplitude, gain stands for nonnegative amplitude ratio. + If gain_type = power, gain stands for power. + If gain_type = db, gain stands for decibels. + gain_type (ScaleType, optional): Type of gain, contains the following three enumeration values + GainType.AMPLITUDE, GainType.POWER and GainType.DB (default=GainType.AMPLITUDE). + + Examples: + >>> import numpy as np + >>> + >>> waveform = np.random.random([20, 30]) + >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"]) + >>> transforms = [audio.Vol(gain=10, gain_type=GainType.DB)] + >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) + """ + + @check_vol + def __init__(self, gain, gain_type=GainType.AMPLITUDE): + self.gain = gain + self.gain_type = gain_type + + def parse(self): + return cde.VolOperation(self.gain, DE_C_GAINTYPE_TYPE[self.gain_type]) diff --git a/mindspore/dataset/audio/utils.py b/mindspore/dataset/audio/utils.py index 1773b134d45..946504bf110 100644 --- a/mindspore/dataset/audio/utils.py +++ b/mindspore/dataset/audio/utils.py @@ -19,12 +19,6 @@ enum for audio ops from enum import Enum -class ScaleType(str, Enum): - """Scale Type""" - POWER: str = "power" - MAGNITUDE: str = "magnitude" - - class FadeShape(str, Enum): """Fade Shape""" LINEAR: str = "linear" @@ -32,3 +26,16 @@ class FadeShape(str, Enum): LOGARITHMIC: str = "logarithmic" QUARTERSINE: str = "quarter_sine" HALFSINE: str = "half_sine" + + +class GainType(str, Enum): + """Gain Type""" + POWER: str = "power" + AMPLITUDE: str = "amplitude" + DB: str = "db" + + +class ScaleType(str, Enum): + """Scale Type""" + POWER: str = "power" + MAGNITUDE: str = "magnitude" diff --git a/mindspore/dataset/audio/validators.py b/mindspore/dataset/audio/validators.py index a0ec6fd9e8b..7cf70783252 100755 --- a/mindspore/dataset/audio/validators.py +++ b/mindspore/dataset/audio/validators.py @@ -21,7 +21,7 @@ from functools import wraps from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32_not_zero, \ check_list_same_size, check_non_negative_float32, check_non_negative_int32, check_pos_float32, check_pos_int32, \ check_value, parse_user_args, type_check -from .utils import FadeShape, ScaleType +from .utils import FadeShape, GainType, ScaleType def check_amplitude_to_db(method): @@ -384,3 +384,24 @@ def check_fade(method): return method(self, *args, **kwargs) return new_method + + +def check_vol(method): + """Wrapper method to check the parameters of Vol.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [gain, gain_type], _ = parse_user_args(method, *args, **kwargs) + # type check gain + type_check(gain, (int, float), "gain") + # type check gain_type and value check gain + type_check(gain_type, (GainType,), "gain_type") + if gain_type == GainType.AMPLITUDE: + check_non_negative_float32(gain, "gain") + elif gain_type == GainType.POWER: + check_pos_float32(gain, "gain") + else: + check_float32(gain, "gain") + return method(self, *args, **kwargs) + + return new_method diff --git a/tests/ut/cpp/dataset/c_api_audio_r_to_z_test.cc b/tests/ut/cpp/dataset/c_api_audio_r_to_z_test.cc index cc833a53654..dd810f8f3e9 100644 --- a/tests/ut/cpp/dataset/c_api_audio_r_to_z_test.cc +++ b/tests/ut/cpp/dataset/c_api_audio_r_to_z_test.cc @@ -44,7 +44,7 @@ TEST_F(MindDataTestPipeline, TestTimeMaskingPipeline) { ds = ds->Map({timemasking}); EXPECT_NE(ds, nullptr); - // Filtered waveform by bandbiquad + // mask waveform std::shared_ptr iter = ds->CreateIterator(); EXPECT_NE(ds, nullptr); @@ -83,7 +83,6 @@ TEST_F(MindDataTestPipeline, TestTimeMaskingWrongArgs) { ds = ds->Map({timemasking}); EXPECT_NE(ds, nullptr); - // Filtered waveform by bandbiquad std::shared_ptr iter = ds->CreateIterator(); // Expect failure EXPECT_EQ(iter, nullptr); @@ -156,3 +155,62 @@ TEST_F(MindDataTestPipeline, TestTimeStretchPipelineWrongArgs) { // Expect failure EXPECT_EQ(iter, nullptr); } + +TEST_F(MindDataTestPipeline, TestVolPipeline) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVolPipeline."; + // Original waveform + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200})); + std::shared_ptr ds = RandomData(50, schema); + EXPECT_NE(ds, nullptr); + + ds = ds->SetNumWorkers(4); + EXPECT_NE(ds, nullptr); + + auto vol = audio::Vol(0.3); + + ds = ds->Map({vol}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(ds, nullptr); + + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + std::vector expected = {2, 200}; + + int i = 0; + while (row.size() != 0) { + auto col = row["inputData"]; + ASSERT_EQ(col.Shape(), expected); + ASSERT_EQ(col.Shape().size(), 2); + ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32); + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + EXPECT_EQ(i, 50); + + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestVolWrongArgs) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVolWrongArgs."; + // Original waveform + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200})); + std::shared_ptr ds = RandomData(50, schema); + EXPECT_NE(ds, nullptr); + + ds = ds->SetNumWorkers(4); + EXPECT_NE(ds, nullptr); + + auto vol_op = audio::Vol(-1.5, GainType::kPower); + + ds = ds->Map({vol_op}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure + EXPECT_EQ(iter, nullptr); +} diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc index d17e63ded42..61a826443d9 100644 --- a/tests/ut/cpp/dataset/execute_test.cc +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -1056,4 +1056,29 @@ TEST_F(MindDataTestExecute, TestFadeWithInvalidArg) { mindspore::dataset::Execute Transform04({fade4}); Status s04 = Transform04(input_04, &input_04); EXPECT_FALSE(s04.IsOk()); -} \ No newline at end of file +} +TEST_F(MindDataTestExecute, TestVolDefalutValue) { + MS_LOG(INFO) << "Doing MindDataTestExecute-TestVolDefalutValue."; + std::shared_ptr input_tensor_; + TensorShape s = TensorShape({2, 6}); + ASSERT_OK(Tensor::CreateFromVector( + std::vector({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}), s, &input_tensor_)); + auto input_tensor = mindspore::MSTensor(std::make_shared(input_tensor_)); + std::shared_ptr vol_op = std::make_shared(0.333); + mindspore::dataset::Execute transform({vol_op}); + Status status = transform(input_tensor, &input_tensor); + EXPECT_TRUE(status.IsOk()); +} + +TEST_F(MindDataTestExecute, TestVolGainTypePower) { + MS_LOG(INFO) << "Doing MindDataTestExecute-TestVolGainTypePower."; + std::shared_ptr input_tensor_; + TensorShape s = TensorShape({4, 3}); + ASSERT_OK(Tensor::CreateFromVector( + std::vector({4.0f, 5.0f, 3.0f, 5.0f, 4.0f, 6.0f, 6.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f}), s, &input_tensor_)); + auto input_tensor = mindspore::MSTensor(std::make_shared(input_tensor_)); + std::shared_ptr vol_op = std::make_shared(0.2, GainType::kPower); + mindspore::dataset::Execute transform({vol_op}); + Status status = transform(input_tensor, &input_tensor); + EXPECT_TRUE(status.IsOk()); +} diff --git a/tests/ut/python/dataset/test_vol.py b/tests/ut/python/dataset/test_vol.py new file mode 100644 index 00000000000..ea81181343e --- /dev/null +++ b/tests/ut/python/dataset/test_vol.py @@ -0,0 +1,160 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing Vol op in DE +""" +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.audio.transforms as c_audio +from mindspore import log as logger +from mindspore.dataset.audio import utils + + +def count_unequal_element(data_expected, data_me, rtol, atol): + """ Precision calculation func """ + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + """ Precision calculation formula """ + if np.any(np.isnan(data_expected)): + assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): + count_unequal_element(data_expected, data_me, rtol, atol) + + +def test_func_vol_eager(): + """ mindspore eager mode normal testcase:vol op""" + + logger.info("check vol op output") + ndarr_in = np.array([[0.3667, 0.5295, 0.2949, 0.4508, 0.6457, 0.3625, 0.4377, 0.3568], + [0.6488, 0.6525, 0.5140, 0.6725, 0.9261, 0.0609, 0.3910, 0.4608], + [0.0454, 0.0487, 0.6990, 0.1637, 0.5763, 0.1086, 0.5343, 0.4699], + [0.9993, 0.0776, 0.3498, 0.0429, 0.1588, 0.3061, 0.1166, 0.3716], + [0.7625, 0.2410, 0.8888, 0.5027, 0.0913, 0.2520, 0.5625, 0.9873]]).astype(np.float32) + # cal from benchmark + out_expect = np.array([[0.0733, 0.1059, 0.0590, 0.0902, 0.1291, 0.0725, 0.0875, 0.0714], + [0.1298, 0.1305, 0.1028, 0.1345, 0.1852, 0.0122, 0.0782, 0.0922], + [0.0091, 0.0097, 0.1398, 0.0327, 0.1153, 0.0217, 0.1069, 0.0940], + [0.1999, 0.0155, 0.0700, 0.0086, 0.0318, 0.0612, 0.0233, 0.0743], + [0.1525, 0.0482, 0.1778, 0.1005, 0.0183, 0.0504, 0.1125, 0.1975]]) + op = c_audio.Vol(gain=0.2, gain_type=utils.GainType.AMPLITUDE) + out_mindspore = op(ndarr_in) + allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001) + + ndarr_in = np.array([[[-0.5794799327850342, 0.19526369869709015], + [-0.5935744047164917, 0.2948109209537506], + [-0.42077431082725525, 0.04923877865076065]], + [[0.5497273802757263, -0.22815021872520447], + [-0.05891447141766548, -0.16206198930740356], + [-1.4782767295837402, -1.3815662860870361]]]).astype(np.float32) + # cal from benchmark + out_expect = np.array([[[-0.5761537551879883, 0.1941428929567337], + [-0.5901673436164856, 0.2931187152862549], + [-0.41835910081863403, 0.04895615205168724]], + [[0.5465719699859619, -0.22684065997600555], + [-0.0585763081908226, -0.16113176941871643], + [-1.0, -1.0]]]) + op = c_audio.Vol(gain=-0.05, gain_type=utils.GainType.DB) + out_mindspore = op(ndarr_in) + allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001) + + ndarr_in = np.array([[[0.09491927176713943, 0.11639882624149323, -0.1725238710641861, -0.18556903302669525], + [-0.7140364646911621, 1.6223102807998657, 1.6710518598556519, 0.6019048094749451]], + [[-0.8635917901992798, -0.31538113951683044, -0.2209240198135376, 1.3067045211791992], + [-2.0922982692718506, 0.6822009682655334, 0.20066820085048676, 0.006392406765371561]]]) + # cal from benchmark + out_expect = np.array([[[0.042449187487363815, 0.05205513536930084, -0.07715501636266708, -0.08298899233341217], + [-0.31932681798934937, 0.7255191802978516, 0.7473170757293701, 0.2691799998283386]], + [[-0.38620999455451965, -0.14104272425174713, -0.09880022704601288, 0.5843760371208191], + [-0.935704231262207, 0.30508953332901, 0.0897415429353714, 0.0028587712440639734]]]) + op = c_audio.Vol(gain=0.2, gain_type=utils.GainType.POWER) + out_mindspore = op(ndarr_in) + allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001) + + +def test_func_vol_pipeline(): + """ mindspore pipeline mode normal testcase:vol op""" + + logger.info("test vol op with gain_type='power'") + data = np.array([[[0.7012, 0.2500, 0.0108], + [0.3617, 0.6367, 0.6096]]]).astype(np.float32) + out_expect = np.array([[1.0000, 0.7906, 0.0342], + [1.0000, 1.0000, 1.0000]]) + data1 = ds.NumpySlicesDataset(data, column_names=["multi_dimensional_data"]) + transforms = [c_audio.Vol(gain=10, gain_type=utils.GainType.POWER)] + data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + out_put = item["multi_dimensional_data"] + allclose_nparray(out_put, out_expect, 0.0001, 0.0001) + + logger.info("test vol op with gain_type='amplitude' and datatype='float64'") + data = np.array([[[0.9342139979247938, 0.613965955965896, 0.5356328030249583, 0.589909976354571], + [0.7301220295167696, 0.31194499547960186, 0.3982210622160919, 0.20984374897512215], + [0.18619300588033616, 0.9443723899839336, 0.7395507950492876, 0.4904588086175671]]]) + data = data.astype(np.float64) + out_expect = np.array([[0.18684279918670654, 0.12279318571090699, 0.10712655782699586, 0.1179819941520691], + [0.1460244059562683, 0.062388998270034794, 0.07964421510696412, 0.04196875095367432], + [0.03723860085010529, 0.1888744831085205, 0.14791015386581421, 0.09809176325798036]]) + data1 = ds.NumpySlicesDataset(data, column_names=["multi_dimensional_data"]) + transforms = [c_audio.Vol(gain=0.2, gain_type=utils.GainType.AMPLITUDE)] + data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + out_put = item["multi_dimensional_data"] + allclose_nparray(out_put, out_expect, 0.0001, 0.0001) + + logger.info("test vol op with gain_type='db'") + data = np.array([[[0.1302, 0.5908, 0.1225, 0.7044], + [0.6405, 0.6540, 0.9908, 0.8605], + [0.7023, 0.0115, 0.8790, 0.5806]]]).astype(np.float32) + out_expect = np.array([[0.1096, 0.4971, 0.1031, 0.5927], + [0.5389, 0.5503, 0.8336, 0.7240], + [0.5909, 0.0097, 0.7396, 0.4885]]) + data1 = ds.NumpySlicesDataset(data, column_names=["multi_dimensional_data"]) + transforms = [c_audio.Vol(gain=-1.5, gain_type=utils.GainType.DB)] + data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + out_put = item["multi_dimensional_data"] + allclose_nparray(out_put, out_expect, 0.0001, 0.0001) + + +def test_vol_invalid_input(): + def test_invalid_input(test_name, gain, gain_type, error, error_msg): + logger.info("Test Vol with invalid input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + c_audio.Vol(gain, gain_type) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid gain value when gain_type equals 'power'", -1.5, utils.GainType.POWER, ValueError, + "Input gain is not within the required interval of (0, 16777216].") + test_invalid_input("invalid gain value when gain_type equals 'amplitude'", -1.5, utils.GainType.AMPLITUDE, + ValueError, "Input gain is not within the required interval of [0, 16777216].") + test_invalid_input("invalid gain value when gain_type equals 'amplitude'", 1.5, "TEST", TypeError, + "Argument gain_type with value TEST is not of type [], but got .") + + +if __name__ == "__main__": + test_func_vol_eager() + test_func_vol_pipeline() + test_vol_invalid_input()