diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index 6107952a89a..454d33ebb6d 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -93,6 +93,7 @@ add_dependencies(engine-gnn core) add_dependencies(engine core) add_dependencies(callback core) add_dependencies(audio-kernels core) +add_dependencies(audio-ir core) add_dependencies(audio-ir-kernels core) add_dependencies(text core) add_dependencies(text-kernels core) @@ -156,6 +157,7 @@ set(submodules $ $ $ + $ $ $ $ diff --git a/mindspore/ccsrc/minddata/dataset/api/audio.cc b/mindspore/ccsrc/minddata/dataset/api/audio.cc index 64a43bfaa2b..8cdf1d24869 100644 --- a/mindspore/ccsrc/minddata/dataset/api/audio.cc +++ b/mindspore/ccsrc/minddata/dataset/api/audio.cc @@ -17,6 +17,7 @@ #include "minddata/dataset/include/dataset/audio.h" #include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h" #include "minddata/dataset/audio/ir/kernels/angle_ir.h" #include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" @@ -43,11 +44,27 @@ std::shared_ptr AllpassBiquad::Parse() { return std::make_shared(data_->sample_rate_, data_->central_freq_, data_->Q_); } +// AmplitudeToDB Operation. +struct AmplitudeToDB::Data { + Data(ScaleType stype, float ref_value, float amin, float top_db) + : stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {} + ScaleType stype_; + float ref_value_; + float amin_; + float top_db_; +}; + +AmplitudeToDB::AmplitudeToDB(ScaleType stype, float ref_value, float amin, float top_db) + : data_(std::make_shared(stype, ref_value, amin, top_db)) {} + +std::shared_ptr AmplitudeToDB::Parse() { + return std::make_shared(data_->stype_, data_->ref_value_, data_->amin_, data_->top_db_); +} + // Angle Transform Operation. Angle::Angle() {} std::shared_ptr Angle::Parse() { return std::make_shared(); } - // BandBiquad Transform Operation. struct BandBiquad::Data { Data(int32_t sample_rate, float central_freq, float Q, bool noise) diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc index 6ad774d8cf9..07a2a16ae67 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc @@ -18,6 +18,7 @@ #include "minddata/dataset/api/python/pybind_conversion.h" #include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h" #include "minddata/dataset/audio/ir/kernels/angle_ir.h" #include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" @@ -39,6 +40,24 @@ PYBIND_REGISTER( })); })); +PYBIND_REGISTER( + AmplitudeToDBOperation, 1, ([](const py::module *m) { + (void)py::class_>( + *m, "AmplitudeToDBOperation") + .def(py::init([](ScaleType stype, float ref_value, float amin, float top_db) { + auto amplitude_to_db = std::make_shared(stype, ref_value, amin, top_db); + THROW_IF_ERROR(amplitude_to_db->ValidateParams()); + return amplitude_to_db; + })); + })); + +PYBIND_REGISTER(ScaleType, 0, ([](const py::module *m) { + (void)py::enum_(*m, "ScaleType", py::arithmetic()) + .value("DE_SCALETYPE_MAGNITUDE", ScaleType::kMagnitude) + .value("DE_SCALETYPE_POWER", ScaleType::kPower) + .export_values(); + })); + PYBIND_REGISTER(AngleOperation, 1, ([](const py::module *m) { (void)py::class_>( *m, "AngleOperation") diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/ir/CMakeLists.txt index ceebec399c9..f6f6040e52a 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/CMakeLists.txt @@ -2,3 +2,5 @@ add_subdirectory(kernels) file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + +add_library(audio-ir OBJECT validators.cc) diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt index 6ebf43fe92b..ebc87851bc2 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt @@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE add_library(audio-ir-kernels OBJECT allpass_biquad_ir.cc + amplitude_to_db_ir.cc angle_ir.cc band_biquad_ir.cc bandpass_biquad_ir.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.cc new file mode 100644 index 00000000000..80412b1c437 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.cc @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h" +#include "minddata/dataset/audio/kernels/amplitude_to_db_op.h" + +#include "minddata/dataset/audio/ir/validators.h" + +namespace mindspore { +namespace dataset { +namespace audio { + +// AmplitudeToDB +AmplitudeToDBOperation::AmplitudeToDBOperation(ScaleType stype, float ref_value, float amin, float top_db) + : stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {} + +AmplitudeToDBOperation::~AmplitudeToDBOperation() = default; + +std::string AmplitudeToDBOperation::Name() const { return kAmplitudeToDBOperation; } + +Status AmplitudeToDBOperation::ValidateParams() { + RETURN_IF_NOT_OK(CheckFloatScalarNonNegative("AmplitudeToDB", "top_db", top_db_)); + RETURN_IF_NOT_OK(CheckFloatScalarPositive("AmplitudeToDB", "amin", amin_)); + RETURN_IF_NOT_OK(CheckFloatScalarPositive("AmplitudeToDB", "ref_value", ref_value_)); + + return Status::OK(); +} + +std::shared_ptr AmplitudeToDBOperation::Build() { + std::shared_ptr tensor_op = std::make_shared(stype_, ref_value_, amin_, top_db_); + return tensor_op; +} + +Status AmplitudeToDBOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["stype"] = stype_; + args["ref_value"] = ref_value_; + args["amin"] = amin_; + args["top_db"] = top_db_; + *out_json = args; + return Status::OK(); +} +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h new file mode 100644 index 00000000000..18e080e76b7 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_AMPLITUDE_TO_DB_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_AMPLITUDE_TO_DB_IR_H_ + +#include +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { +namespace audio { + +constexpr char kAmplitudeToDBOperation[] = "AmplitudeToDB"; + +class AmplitudeToDBOperation : public TensorOperation { + public: + AmplitudeToDBOperation(ScaleType stype, float ref_value, float amin, float top_db); + + ~AmplitudeToDBOperation(); + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override; + + Status to_json(nlohmann::json *out_json) override; + + private: + ScaleType stype_; + float ref_value_; + float amin_; + float top_db_; +}; + +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_AMPLITUDE_TO_DB_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/validators.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/validators.cc new file mode 100644 index 00000000000..f1cdea91dc3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/validators.cc @@ -0,0 +1,82 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/ir/validators.h" + +namespace mindspore { +namespace dataset { +/* ####################################### Validator Functions ############################################ */ +Status CheckFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) { + RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, true)); + return Status::OK(); +} + +Status CheckFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) { + RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, false)); + return Status::OK(); +} + +Status CheckStringScalarInList(const std::string &op_name, const std::string &scalar_name, const std::string &scalar, + const std::vector &str_vec) { + auto ret = std::find(str_vec.begin(), str_vec.end(), scalar); + if (ret == str_vec.end()) { + std::string interval_description = "["; + for (int m = 0; m < str_vec.size(); m++) { + std::string word = str_vec[m]; + interval_description = interval_description + word; + if (m != str_vec.size() - 1) interval_description = interval_description + ", "; + } + interval_description = interval_description + "]"; + + std::string err_msg = op_name + ": " + scalar_name + " must be one of " + interval_description + ", got: " + scalar; + MS_LOG(ERROR) << err_msg; + return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg); + } + return Status::OK(); +} + +template +Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const T scalar, + const std::vector &range, bool left_open_interval, bool right_open_interval) { + if (range.empty() || range.size() > 2) { + std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size()); + MS_LOG(ERROR) << err_msg; + return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg); + } + if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) { + std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to "; + std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) + + ", got: " + std::to_string(scalar); + MS_LOG(ERROR) << err_msg; + return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg); + } + if (range.size() == 2) { + if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) { + std::string left_bracket = left_open_interval ? "(" : "["; + std::string right_bracket = right_open_interval ? ")" : "]"; + std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket + + std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket + + ", got: " + std::to_string(scalar); + MS_LOG(ERROR) << err_msg; + return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg); + } + } + return Status::OK(); +} +template Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const float scalar, + const std::vector &range, bool left_open_interval, bool right_open_interval); + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/validators.h b/mindspore/ccsrc/minddata/dataset/audio/ir/validators.h index 837c3f0a0f4..d966bcee7fb 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/validators.h +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/validators.h @@ -18,8 +18,13 @@ #define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_VALIDATORS_H_ #include +#include #include "minddata/dataset/kernels/ir/validators.h" +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" +#include "minddata/dataset/util/status.h" + namespace mindspore { namespace dataset { @@ -34,6 +39,20 @@ Status CheckScalarNotZero(const std::string &op_name, const std::string &scalar_ return Status::OK(); } +// Helper function to positive float scalar +Status CheckFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar); + +// Helper function to non-negative float scalar +Status CheckFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar); + +// Helper function to check string scalar +Status CheckStringScalarInList(const std::string &op_name, const std::string &scalar_name, const std::string &scalar, + const std::vector &str_vec); + +// Helper function to validate scalar +template +Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const T scalar, + const std::vector &range, bool left_open_interval = false, bool right_open_interval = false); } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ADUIO_IR_VALIDATORS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt index dc2b60ef11c..a14896118ac 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt @@ -3,7 +3,9 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE add_library(audio-kernels OBJECT allpass_biquad_op.cc + amplitude_to_db_op.cc angle_op.cc + audio_utils.cc band_biquad_op.cc bandpass_biquad_op.cc bandreject_biquad_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/amplitude_to_db_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/amplitude_to_db_op.cc new file mode 100644 index 00000000000..dbebec42d39 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/amplitude_to_db_op.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include "minddata/dataset/audio/kernels/amplitude_to_db_op.h" +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/kernels/data/data_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +Status AmplitudeToDBOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (input->shape().Rank() < 2) { + std::string err_msg = "AmplitudeToDB: input tensor shape should be <..., freq, time>"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + + std::shared_ptr input_tensor; + + float top_db = top_db_; + float multiplier = stype_ == ScaleType::kPower ? 10.0 : 20.0; + float amin = 1e-10; + float db_multiplier = std::log10(std::max(amin_, ref_value_)); + + // typecast + CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING, + "AmplitudeToDB: input type should be float, but got string."); + if (input->type() != DataType::DE_FLOAT64) { + CHECK_FAIL_RETURN_UNEXPECTED(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)), + "AmplitudeToDB: input type should be float, but got " + input->type().ToString()); + return AmplitudeToDB(input_tensor, output, multiplier, amin, db_multiplier, top_db); + + } else { + input_tensor = input; + return AmplitudeToDB(input_tensor, output, multiplier, amin, db_multiplier, top_db); + } +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/amplitude_to_db_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/amplitude_to_db_op.h new file mode 100644 index 00000000000..bd84e888f9e --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/amplitude_to_db_op.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AMPLITUDE_TO_DB_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AMPLITUDE_TO_DB_OP_H_ + +#include +#include +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +class AmplitudeToDBOp : public TensorOp { + public: + AmplitudeToDBOp(ScaleType stype, float ref_value, float amin, float top_db) + : stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {} + + ~AmplitudeToDBOp() override = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + + std::string Name() const override { return kAmplitudeToDBOp; } + + private: + ScaleType stype_; + float ref_value_; + float amin_; + float top_db_; +}; +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AMPLITUDE_TO_DB_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc new file mode 100644 index 00000000000..6f23b475192 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/kernels/audio_utils.h" + +namespace mindspore { +namespace dataset { + +template +Status AmplitudeToDB(const std::shared_ptr &input, std::shared_ptr *output, T multiplier, T amin, + T db_multiplier, T top_db) { + TensorShape input_shape = input->shape(); + TensorShape to_shape = input_shape.Rank() == 2 + ? TensorShape({1, 1, input_shape[-2], input_shape[-1]}) + : TensorShape({input->Size() / (input_shape[-3] * input_shape[-2] * input_shape[-1]), + input_shape[-3], input_shape[-2], input_shape[-1]}); + RETURN_IF_NOT_OK(input->Reshape(to_shape)); + + std::vector max_val; + int step = to_shape[-3] * input_shape[-2] * input_shape[-1]; + int cnt = 0; + T temp_max = std::numeric_limits::lowest(); + for (auto itr = input->begin(); itr != input->end(); itr++) { + // do clamp + *itr = *itr < amin ? log10(amin) * multiplier : log10(*itr) * multiplier; + *itr -= multiplier * db_multiplier; + // calculate max by axis + cnt++; + if ((*itr) > temp_max) temp_max = *itr; + if (cnt % step == 0) { + max_val.push_back(temp_max); + temp_max = std::numeric_limits::lowest(); + } + } + + if (!std::isnan(top_db)) { + int ind = 0; + for (auto itr = input->begin(); itr != input->end(); itr++, ind++) { + float lower_bound = max_val[ind / step] - top_db; + *itr = std::max((*itr), static_cast(lower_bound)); + } + } + RETURN_IF_NOT_OK(input->Reshape(input_shape)); + *output = input; + return Status::OK(); +} +template Status AmplitudeToDB(const std::shared_ptr &input, std::shared_ptr *output, + float multiplier, float amin, float db_multiplier, float top_db); +template Status AmplitudeToDB(const std::shared_ptr &input, std::shared_ptr *output, + double multiplier, double amin, double db_multiplier, double top_db); +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h index d2dcd9ad938..755650d0a0c 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h @@ -17,8 +17,11 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ +#include #include +#include #include +#include #include #include "minddata/dataset/core/tensor.h" @@ -28,6 +31,17 @@ constexpr double PI = 3.141592653589793; namespace mindspore { namespace dataset { +/// \brief Turn a tensor from the power/amplitude scale to the decibel scale. +/// \param input/output: Tensor of shape <...,freq,time> +/// \param multiplier: power - 10, amplitude - 20 +/// \param amin: lower bound +/// \param db_multiplier: multiplier for decibels +/// \param top_db: the lower bound for decibels cut-off +/// \return Status code +template +Status AmplitudeToDB(const std::shared_ptr &input, std::shared_ptr *output, T multiplier, T amin, + T db_multiplier, T top_db); + /// \brief Calculate the angles of the complex numbers /// \param input/output: Tensor of shape <...,time> template @@ -162,7 +176,6 @@ Status LFilter(const std::shared_ptr &input, std::shared_ptr *ou delete m_py; return Status::OK(); } - } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h index adceff4cff0..ded1501d851 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ +#include #include #include #include @@ -92,6 +93,31 @@ class AllpassBiquad final : public TensorTransform { std::shared_ptr data_; }; +/// \brief AmplitudeToDB TensorTransform. +/// \notes Turn a tensor from the power/amplitude scale to the decibel scale. +class AmplitudeToDB final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] stype ['kPower', 'kMagnitude'] + /// \param[in] ref_value Calculate db_multiplier + /// \param[in] amin Clamp the input waveform + /// \param[in] top_db Decibels cut-off value + explicit AmplitudeToDB(ScaleType stype = ScaleType::kPower, float ref_value = 1.0, float amin = 1e-10, + float top_db = 80.0); + + /// \brief Destructor. + ~AmplitudeToDB() = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + /// \brief Design two-pole band-pass filter. class BandpassBiquad final : public TensorTransform { public: diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h index 851ca5637e2..7af6fb81267 100644 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/constants.h @@ -49,6 +49,12 @@ enum class ShuffleMode { kInfile = 3 ///< Shuffle data within each file. }; +/// \brief Possible scale for input audio. +enum class ScaleType { + kMagnitude = 0, ///< Audio scale is magnitude. + kPower = 1, ///< Audio scale is power. +}; + /// \brief The method of padding. enum class BorderType { kConstant = 0, ///< Fill the border with constant values. diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index b345994b0d0..dff6016880d 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -139,6 +139,7 @@ constexpr char kSentencepieceTokenizerOp[] = "SentencepieceTokenizerOp"; // audio constexpr char kAllpassBiquadOp[] = "AllpassBiquadOp"; +constexpr char kAmplitudeToDBOp[] = "AmplitudeToDBOp"; constexpr char kAngleOp[] = "AngleOp"; constexpr char kBandBiquadOp[] = "BandBiquadOp"; constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp"; diff --git a/mindspore/dataset/audio/transforms.py b/mindspore/dataset/audio/transforms.py index 88b9e760645..2589efe17bb 100644 --- a/mindspore/dataset/audio/transforms.py +++ b/mindspore/dataset/audio/transforms.py @@ -20,8 +20,9 @@ to improve their training models. import mindspore._c_dataengine as cde import numpy as np from ..transforms.c_transforms import TensorOperation -from .validators import check_allpass_biquad, check_band_biquad, check_bandpass_biquad, check_bandreject_biquad, \ - check_bass_biquad +from .utils import ScaleType +from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \ + check_bandreject_biquad, check_bass_biquad class AudioTensorOperation(TensorOperation): @@ -74,6 +75,42 @@ class AllpassBiquad(AudioTensorOperation): return cde.AllpassBiquadOperation(self.sample_rate, self.central_freq, self.Q) +DE_C_SCALETYPE_TYPE = {ScaleType.MAGNITUDE: cde.ScaleType.DE_SCALETYPE_MAGNITUDE, + ScaleType.POWER: cde.ScaleType.DE_SCALETYPE_POWER} + + +class AmplitudeToDB(AudioTensorOperation): + """ + Converts the input tensor from amplitude/power scale to decibel scale. + + Args: + stype (ScaleType, optional): Scale of the input tensor. (Default="ScaleType.POWER"). + It can be any of [ScaleType.MAGNITUDE, ScaleType.POWER]. + ref_value (float, optional): Param for generate db_multiplier. + amin (float, optional): Lower bound to clamp the input waveform. + top_db (float, optional): Minimum cut-off decibels. The range of values is non-negative. Commonly set at 80. + (Default=80.0) + Examples: + >>> channel = 1 + >>> n_fft = 400 + >>> n_frame = 30 + >>> specrogram = np.random.random([channel, n_fft//2+1, n_frame]) + >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=specrogram, column_names=["audio"]) + >>> transforms = [audio.AmplitudeToDB(stype=ScaleType.POWER)] + >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) + """ + + @ check_amplitude_to_db + def __init__(self, stype=ScaleType.POWER, ref_value=1.0, amin=1e-10, top_db=80.0): + self.stype = stype + self.ref_value = ref_value + self.amin = amin + self.top_db = top_db + + def parse(self): + return cde.AmplitudeToDBOperation(DE_C_SCALETYPE_TYPE[self.stype], self.ref_value, self.amin, self.top_db) + + class Angle(AudioTensorOperation): """ Calculate the angle of the complex number sequence of shape (..., 2). diff --git a/mindspore/dataset/audio/utils.py b/mindspore/dataset/audio/utils.py new file mode 100644 index 00000000000..1bf00f2da0d --- /dev/null +++ b/mindspore/dataset/audio/utils.py @@ -0,0 +1,23 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +enum for audio ops +""" +from enum import Enum + + +class ScaleType(str, Enum): + """Scale Type""" + POWER: str = "power" + MAGNITUDE: str = "magnitude" diff --git a/mindspore/dataset/audio/validators.py b/mindspore/dataset/audio/validators.py index f70a465dc14..42dc5d88511 100644 --- a/mindspore/dataset/audio/validators.py +++ b/mindspore/dataset/audio/validators.py @@ -16,8 +16,41 @@ Validators for TensorOps. """ from functools import wraps +from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, \ + check_value_normalize_std, check_value_ratio, FLOAT_MAX_INTEGER, parse_user_args, type_check +from .utils import ScaleType -from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, check_value_normalize_std, parse_user_args, type_check + +def check_amplitude_to_db(method): + """Wrapper method to check the parameters of amplitude_to_db.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [stype, ref_value, amin, top_db], _ = parse_user_args(method, *args, **kwargs) + + # type check stype + type_check(stype, (ScaleType,), "stype") + + # type check ref_value + type_check(ref_value, (int, float), "ref_value") + # value check ref_value + if not ref_value is None: + check_value_ratio(ref_value, (0, FLOAT_MAX_INTEGER), "ref_value") + + # type check amin + type_check(amin, (int, float), "amin") + # value check amin + if not amin is None: + check_value_ratio(amin, (0, FLOAT_MAX_INTEGER), "amin") + + # type check top_db + type_check(top_db, (int, float), "top_db") + # value check top_db + if not top_db is None: + check_value_ratio(top_db, (0, FLOAT_MAX_INTEGER), "top_db") + + return method(self, *args, **kwargs) + return new_method def check_biquad_sample_rate(sample_rate): diff --git a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc index 2a265455da6..089029ffd13 100644 --- a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc +++ b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc @@ -30,6 +30,65 @@ class MindDataTestPipeline : public UT::DatasetOpTesting { protected: }; +TEST_F(MindDataTestPipeline, TestAmplitudeToDBPipeline) { + MS_LOG(INFO) << "Basic Function Test"; + // Original waveform + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200})); + std::shared_ptr ds = RandomData(50, schema); + EXPECT_NE(ds, nullptr); + + ds = ds->SetNumWorkers(4); + EXPECT_NE(ds, nullptr); + + auto amplitude_to_db_op = audio::AmplitudeToDB(); + + ds = ds->Map({amplitude_to_db_op}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_NE(ds, nullptr); + + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + + std::vector expected = {2, 200}; + + int i = 0; + while (row.size() != 0) { + auto col = row["inputData"]; + ASSERT_EQ(col.Shape(), expected); + ASSERT_EQ(col.Shape().size(), 2); + ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32); + ASSERT_OK(iter->GetNextRow(&row)); + i++; + } + EXPECT_EQ(i, 50); + + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestAmplitudeToDBWrongArgs) { + MS_LOG(INFO) << "Basic Function Test"; + // Original waveform + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200})); + std::shared_ptr ds = RandomData(50, schema); + EXPECT_NE(ds, nullptr); + + ds = ds->SetNumWorkers(4); + EXPECT_NE(ds, nullptr); + + auto amplitude_to_db_op = audio::AmplitudeToDB(ScaleType::kPower, 1.0, -1e-10, 80.0); + + ds = ds->Map({amplitude_to_db_op}); + EXPECT_NE(ds, nullptr); + + std::shared_ptr iter = ds->CreateIterator(); + // Expect failure + EXPECT_EQ(iter, nullptr); +} + TEST_F(MindDataTestPipeline, Level0_TestBandBiquad001) { MS_LOG(INFO) << "Basic Function Test"; // Original waveform @@ -313,7 +372,7 @@ TEST_F(MindDataTestPipeline, Level0_TestBassBiquad001) { ds = ds->SetNumWorkers(4); EXPECT_NE(ds, nullptr); - auto BassBiquadOp = audio::BassBiquad(44100,50,200.0); + auto BassBiquadOp = audio::BassBiquad(44100, 50, 200.0); ds = ds->Map({BassBiquadOp}); EXPECT_NE(ds, nullptr); @@ -353,7 +412,7 @@ TEST_F(MindDataTestPipeline, Level0_TestBassBiquad002) { // Check sample_rate MS_LOG(INFO) << "sample_rate is zero."; - auto bass_biquad_op_01 = audio::BassBiquad(0,50,200.0); + auto bass_biquad_op_01 = audio::BassBiquad(0, 50, 200.0); ds01 = ds->Map({bass_biquad_op_01}); EXPECT_NE(ds01, nullptr); @@ -362,7 +421,7 @@ TEST_F(MindDataTestPipeline, Level0_TestBassBiquad002) { // Check Q_ MS_LOG(INFO) << "Q_ is zero."; - auto bass_biquad_op_02 = audio::BassBiquad(44100,50,200.0,0); + auto bass_biquad_op_02 = audio::BassBiquad(44100, 50, 200.0, 0); ds02 = ds->Map({bass_biquad_op_02}); EXPECT_NE(ds02, nullptr); diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc index df63d30d05d..e98c8cd989f 100644 --- a/tests/ut/cpp/dataset/execute_test.cc +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -20,6 +20,7 @@ #include "minddata/dataset/include/dataset/execute.h" #include "minddata/dataset/include/dataset/transforms.h" #include "minddata/dataset/include/dataset/vision.h" +#include "minddata/dataset/include/dataset/audio.h" #include "minddata/dataset/include/dataset/text.h" #include "utils/log_adapter.h" @@ -99,6 +100,65 @@ TEST_F(MindDataTestExecute, TestAdjustGammaEager2) { EXPECT_EQ(rc, Status::OK()); } +TEST_F(MindDataTestExecute, TestAmplitudeToDB) { + MS_LOG(INFO) << "Basic Function Test With Eager."; + // Original waveform + std::vector labels = { + 2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02, + 1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02, + 1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02, + 1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02, + 1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03, + 1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03}; + std::shared_ptr input; + ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 2, 2, 3}), &input)); + auto input_ms = mindspore::MSTensor(std::make_shared(input)); + std::shared_ptr amplitude_to_db_op = std::make_shared(); + // apply amplitude_to_db + mindspore::dataset::Execute trans({amplitude_to_db_op}); + Status status = trans(input_ms, &input_ms); + EXPECT_TRUE(status.IsOk()); +} + +TEST_F(MindDataTestExecute, TestAmplitudeToDBWrongArgs) { + MS_LOG(INFO) << "Wrong Arg."; + // Original waveform + std::vector labels = { + 2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02, + 1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02, + 1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02, + 1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02, + 1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03}; + std::shared_ptr input; + ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input)); + auto input_ms = mindspore::MSTensor(std::make_shared(input)); + std::shared_ptr amplitude_to_db_op = + std::make_shared(ScaleType::kPower, 1.0, -1e-10, 80.0); + // apply amplitude_to_db + mindspore::dataset::Execute trans({amplitude_to_db_op}); + Status status = trans(input_ms, &input_ms); + EXPECT_FALSE(status.IsOk()); +} + +TEST_F(MindDataTestExecute, TestAmplitudeToDBWrongInput) { + MS_LOG(INFO) << "Wrong Input."; + // Original waveform + std::vector labels = { + 2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02, + 1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02, + 1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02, + 1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02, + 1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03}; + std::shared_ptr input; + ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({20}), &input)); + auto input_ms = mindspore::MSTensor(std::make_shared(input)); + std::shared_ptr amplitude_to_db_op = std::make_shared(); + // apply amplitude_to_db + mindspore::dataset::Execute trans({amplitude_to_db_op}); + Status status = trans(input_ms, &input_ms); + EXPECT_FALSE(status.IsOk()); +} + TEST_F(MindDataTestExecute, TestComposeTransforms) { MS_LOG(INFO) << "Doing TestComposeTransforms."; diff --git a/tests/ut/python/dataset/test_amplitude_to_db.py b/tests/ut/python/dataset/test_amplitude_to_db.py new file mode 100644 index 00000000000..448b8b09ef4 --- /dev/null +++ b/tests/ut/python/dataset/test_amplitude_to_db.py @@ -0,0 +1,137 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing AmplitudeToDB op in DE +""" +import numpy as np +import pytest + +import mindspore.dataset as ds +import mindspore.dataset.audio.transforms as c_audio +from mindspore import log as logger +from mindspore.dataset.audio.utils import ScaleType + + +CHANNEL = 1 +FREQ = 20 +TIME = 15 + + +def gen(shape): + np.random.seed(0) + data = np.random.random(shape) + yield(np.array(data, dtype=np.float32),) + + +def _count_unequal_element(data_expected, data_me, rtol, atol): + """ Precision calculation func """ + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_expected) * rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count / total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + """ Precision calculation formula """ + if np.any(np.isnan(data_expected)): + assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan) + elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan): + _count_unequal_element(data_expected, data_me, rtol, atol) + else: + assert True + + +def test_func_amplitude_to_db_eager(): + """ mindspore eager mode normal testcase:amplitude_to_db op""" + + logger.info("check amplitude_to_db op output") + ndarr_in = np.array([[[[-0.2197528, 0.3821656]]], + [[[0.57418776, 0.46741104]]], + [[[-0.20381108, -0.9303914]]], + [[[0.3693608, -0.2017813]]], + [[[-1.727381, -1.3708513]]], + [[[1.259975, 0.4981323]]], + [[[0.76986176, -0.5793846]]]]).astype(np.float32) + # cal from benchmark + out_expect = np.array([[[[-84.17748, -4.177484]]], + [[[-2.4094608, -3.3030105]]], + [[[-100., -100.]]], + [[[-4.325492, -84.32549]]], + [[[-100., -100.]]], + [[[1.0036192, -3.0265532]]], + [[[-1.1358725, -81.13587]]]]).astype(np.float32) + + amplitude_to_db_op = c_audio.AmplitudeToDB() + out_mindspore = amplitude_to_db_op(ndarr_in) + + allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001) + + +def test_func_amplitude_to_db_pipeline(): + """ mindspore pipeline mode normal testcase:amplitude_to_db op""" + + logger.info("test AmplitudeToDB op with default value") + generator = gen([CHANNEL, FREQ, TIME]) + + data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"]) + + transforms = [ + c_audio.AmplitudeToDB() + ] + data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"]) + + for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True): + out_put = item["multi_dimensional_data"] + assert out_put.shape == (CHANNEL, FREQ, TIME) + + +def test_amplitude_to_db_invalid_input(): + + def test_invalid_input(test_name, stype, ref_value, amin, top_db, error, error_msg): + logger.info("Test AmplitudeToDB with bad input: {0}".format(test_name)) + with pytest.raises(error) as error_info: + c_audio.AmplitudeToDB(stype=stype, ref_value=ref_value, amin=amin, top_db=top_db) + assert error_msg in str(error_info.value) + + test_invalid_input("invalid stype parameter value", "test", 1.0, 1e-10, 80.0, TypeError, + "Argument stype with value test is not of type [], but got .") + test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, -1.0, 1e-10, 80.0, ValueError, + "Input ref_value is not within the required interval of (0, 16777216]") + test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, -1e-10, 80.0, ValueError, + "Input amin is not within the required interval of (0, 16777216]") + test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, -80.0, ValueError, + "Input top_db is not within the required interval of (0, 16777216]") + + test_invalid_input("invalid stype parameter value", True, 1.0, 1e-10, 80.0, TypeError, + "Argument stype with value True is not of type [], but got .") + test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, "value", 1e-10, 80.0, TypeError, + "Argument ref_value with value value is not of type [, ], " + + "but got ") + test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, "value", -80.0, TypeError, + "Argument amin with value value is not of type [, ], " + + "but got ") + test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, "value", TypeError, + "Argument top_db with value value is not of type [, ], " + + "but got ") + + +if __name__ == "__main__": + test_func_amplitude_to_db_eager() + test_func_amplitude_to_db_pipeline() + test_amplitude_to_db_invalid_input()