!17630 [assistant][AmplitudeToDB]

Merge pull request !17630 from QingfengLi/AmplitudeToDB
This commit is contained in:
i-robot 2021-08-06 07:14:59 +00:00 committed by Gitee
commit e18c277304
23 changed files with 835 additions and 8 deletions

View File

@ -93,6 +93,7 @@ add_dependencies(engine-gnn core)
add_dependencies(engine core) add_dependencies(engine core)
add_dependencies(callback core) add_dependencies(callback core)
add_dependencies(audio-kernels core) add_dependencies(audio-kernels core)
add_dependencies(audio-ir core)
add_dependencies(audio-ir-kernels core) add_dependencies(audio-ir-kernels core)
add_dependencies(text core) add_dependencies(text core)
add_dependencies(text-kernels core) add_dependencies(text-kernels core)
@ -156,6 +157,7 @@ set(submodules
$<TARGET_OBJECTS:engine-cache-client> $<TARGET_OBJECTS:engine-cache-client>
$<TARGET_OBJECTS:engine> $<TARGET_OBJECTS:engine>
$<TARGET_OBJECTS:audio-kernels> $<TARGET_OBJECTS:audio-kernels>
$<TARGET_OBJECTS:audio-ir>
$<TARGET_OBJECTS:audio-ir-kernels> $<TARGET_OBJECTS:audio-ir-kernels>
$<TARGET_OBJECTS:text> $<TARGET_OBJECTS:text>
$<TARGET_OBJECTS:text-kernels> $<TARGET_OBJECTS:text-kernels>

View File

@ -17,6 +17,7 @@
#include "minddata/dataset/include/dataset/audio.h" #include "minddata/dataset/include/dataset/audio.h"
#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h"
#include "minddata/dataset/audio/ir/kernels/angle_ir.h" #include "minddata/dataset/audio/ir/kernels/angle_ir.h"
#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
@ -43,11 +44,27 @@ std::shared_ptr<TensorOperation> AllpassBiquad::Parse() {
return std::make_shared<AllpassBiquadOperation>(data_->sample_rate_, data_->central_freq_, data_->Q_); return std::make_shared<AllpassBiquadOperation>(data_->sample_rate_, data_->central_freq_, data_->Q_);
} }
// AmplitudeToDB Operation.
struct AmplitudeToDB::Data {
Data(ScaleType stype, float ref_value, float amin, float top_db)
: stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {}
ScaleType stype_;
float ref_value_;
float amin_;
float top_db_;
};
AmplitudeToDB::AmplitudeToDB(ScaleType stype, float ref_value, float amin, float top_db)
: data_(std::make_shared<Data>(stype, ref_value, amin, top_db)) {}
std::shared_ptr<TensorOperation> AmplitudeToDB::Parse() {
return std::make_shared<AmplitudeToDBOperation>(data_->stype_, data_->ref_value_, data_->amin_, data_->top_db_);
}
// Angle Transform Operation. // Angle Transform Operation.
Angle::Angle() {} Angle::Angle() {}
std::shared_ptr<TensorOperation> Angle::Parse() { return std::make_shared<AngleOperation>(); } std::shared_ptr<TensorOperation> Angle::Parse() { return std::make_shared<AngleOperation>(); }
// BandBiquad Transform Operation. // BandBiquad Transform Operation.
struct BandBiquad::Data { struct BandBiquad::Data {
Data(int32_t sample_rate, float central_freq, float Q, bool noise) Data(int32_t sample_rate, float central_freq, float Q, bool noise)

View File

@ -18,6 +18,7 @@
#include "minddata/dataset/api/python/pybind_conversion.h" #include "minddata/dataset/api/python/pybind_conversion.h"
#include "minddata/dataset/api/python/pybind_register.h" #include "minddata/dataset/api/python/pybind_register.h"
#include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/allpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h"
#include "minddata/dataset/audio/ir/kernels/angle_ir.h" #include "minddata/dataset/audio/ir/kernels/angle_ir.h"
#include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/band_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
@ -39,6 +40,24 @@ PYBIND_REGISTER(
})); }));
})); }));
PYBIND_REGISTER(
AmplitudeToDBOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::AmplitudeToDBOperation, TensorOperation, std::shared_ptr<audio::AmplitudeToDBOperation>>(
*m, "AmplitudeToDBOperation")
.def(py::init([](ScaleType stype, float ref_value, float amin, float top_db) {
auto amplitude_to_db = std::make_shared<audio::AmplitudeToDBOperation>(stype, ref_value, amin, top_db);
THROW_IF_ERROR(amplitude_to_db->ValidateParams());
return amplitude_to_db;
}));
}));
PYBIND_REGISTER(ScaleType, 0, ([](const py::module *m) {
(void)py::enum_<ScaleType>(*m, "ScaleType", py::arithmetic())
.value("DE_SCALETYPE_MAGNITUDE", ScaleType::kMagnitude)
.value("DE_SCALETYPE_POWER", ScaleType::kPower)
.export_values();
}));
PYBIND_REGISTER(AngleOperation, 1, ([](const py::module *m) { PYBIND_REGISTER(AngleOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::AngleOperation, TensorOperation, std::shared_ptr<audio::AngleOperation>>( (void)py::class_<audio::AngleOperation, TensorOperation, std::shared_ptr<audio::AngleOperation>>(
*m, "AngleOperation") *m, "AngleOperation")

View File

@ -2,3 +2,5 @@ add_subdirectory(kernels)
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(audio-ir OBJECT validators.cc)

View File

@ -3,6 +3,7 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library(audio-ir-kernels OBJECT add_library(audio-ir-kernels OBJECT
allpass_biquad_ir.cc allpass_biquad_ir.cc
amplitude_to_db_ir.cc
angle_ir.cc angle_ir.cc
band_biquad_ir.cc band_biquad_ir.cc
bandpass_biquad_ir.cc bandpass_biquad_ir.cc

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/amplitude_to_db_ir.h"
#include "minddata/dataset/audio/kernels/amplitude_to_db_op.h"
#include "minddata/dataset/audio/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace audio {
// AmplitudeToDB
AmplitudeToDBOperation::AmplitudeToDBOperation(ScaleType stype, float ref_value, float amin, float top_db)
: stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {}
AmplitudeToDBOperation::~AmplitudeToDBOperation() = default;
std::string AmplitudeToDBOperation::Name() const { return kAmplitudeToDBOperation; }
Status AmplitudeToDBOperation::ValidateParams() {
RETURN_IF_NOT_OK(CheckFloatScalarNonNegative("AmplitudeToDB", "top_db", top_db_));
RETURN_IF_NOT_OK(CheckFloatScalarPositive("AmplitudeToDB", "amin", amin_));
RETURN_IF_NOT_OK(CheckFloatScalarPositive("AmplitudeToDB", "ref_value", ref_value_));
return Status::OK();
}
std::shared_ptr<TensorOp> AmplitudeToDBOperation::Build() {
std::shared_ptr<AmplitudeToDBOp> tensor_op = std::make_shared<AmplitudeToDBOp>(stype_, ref_value_, amin_, top_db_);
return tensor_op;
}
Status AmplitudeToDBOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["stype"] = stype_;
args["ref_value"] = ref_value_;
args["amin"] = amin_;
args["top_db"] = top_db_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_AMPLITUDE_TO_DB_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_AMPLITUDE_TO_DB_IR_H_
#include <memory>
#include <string>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kAmplitudeToDBOperation[] = "AmplitudeToDB";
class AmplitudeToDBOperation : public TensorOperation {
public:
AmplitudeToDBOperation(ScaleType stype, float ref_value, float amin, float top_db);
~AmplitudeToDBOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
private:
ScaleType stype_;
float ref_value_;
float amin_;
float top_db_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_AMPLITUDE_TO_DB_IR_H_

View File

@ -0,0 +1,82 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/validators.h"
namespace mindspore {
namespace dataset {
/* ####################################### Validator Functions ############################################ */
Status CheckFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar) {
RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, true));
return Status::OK();
}
Status CheckFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, false));
return Status::OK();
}
Status CheckStringScalarInList(const std::string &op_name, const std::string &scalar_name, const std::string &scalar,
const std::vector<std::string> &str_vec) {
auto ret = std::find(str_vec.begin(), str_vec.end(), scalar);
if (ret == str_vec.end()) {
std::string interval_description = "[";
for (int m = 0; m < str_vec.size(); m++) {
std::string word = str_vec[m];
interval_description = interval_description + word;
if (m != str_vec.size() - 1) interval_description = interval_description + ", ";
}
interval_description = interval_description + "]";
std::string err_msg = op_name + ": " + scalar_name + " must be one of " + interval_description + ", got: " + scalar;
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
return Status::OK();
}
template <typename T>
Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
const std::vector<T> &range, bool left_open_interval, bool right_open_interval) {
if (range.empty() || range.size() > 2) {
std::string err_msg = "Range check expecting size 1 or 2, but got: " + std::to_string(range.size());
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
if ((left_open_interval && scalar <= range[0]) || (!left_open_interval && scalar < range[0])) {
std::string interval_description = left_open_interval ? " greater than " : " greater than or equal to ";
std::string err_msg = op_name + ":" + scalar_name + " must be" + interval_description + std::to_string(range[0]) +
", got: " + std::to_string(scalar);
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
if (range.size() == 2) {
if ((right_open_interval && scalar >= range[1]) || (!right_open_interval && scalar > range[1])) {
std::string left_bracket = left_open_interval ? "(" : "[";
std::string right_bracket = right_open_interval ? ")" : "]";
std::string err_msg = op_name + ":" + scalar_name + " is out of range " + left_bracket +
std::to_string(range[0]) + ", " + std::to_string(range[1]) + right_bracket +
", got: " + std::to_string(scalar);
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
}
return Status::OK();
}
template Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const float scalar,
const std::vector<float> &range, bool left_open_interval, bool right_open_interval);
} // namespace dataset
} // namespace mindspore

View File

@ -18,8 +18,13 @@
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_VALIDATORS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_VALIDATORS_H_
#include <string> #include <string>
#include <vector>
#include "minddata/dataset/kernels/ir/validators.h" #include "minddata/dataset/kernels/ir/validators.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
#include "minddata/dataset/util/status.h"
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
@ -34,6 +39,20 @@ Status CheckScalarNotZero(const std::string &op_name, const std::string &scalar_
return Status::OK(); return Status::OK();
} }
// Helper function to positive float scalar
Status CheckFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to non-negative float scalar
Status CheckFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to check string scalar
Status CheckStringScalarInList(const std::string &op_name, const std::string &scalar_name, const std::string &scalar,
const std::vector<std::string> &str_vec);
// Helper function to validate scalar
template <typename T>
Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const T scalar,
const std::vector<T> &range, bool left_open_interval = false, bool right_open_interval = false);
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ADUIO_IR_VALIDATORS_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ADUIO_IR_VALIDATORS_H_

View File

@ -3,7 +3,9 @@ set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE
add_library(audio-kernels OBJECT add_library(audio-kernels OBJECT
allpass_biquad_op.cc allpass_biquad_op.cc
amplitude_to_db_op.cc
angle_op.cc angle_op.cc
audio_utils.cc
band_biquad_op.cc band_biquad_op.cc
bandpass_biquad_op.cc bandpass_biquad_op.cc
bandreject_biquad_op.cc bandreject_biquad_op.cc

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <limits>
#include "minddata/dataset/audio/kernels/amplitude_to_db_op.h"
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status AmplitudeToDBOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
if (input->shape().Rank() < 2) {
std::string err_msg = "AmplitudeToDB: input tensor shape should be <..., freq, time>";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
std::shared_ptr<Tensor> input_tensor;
float top_db = top_db_;
float multiplier = stype_ == ScaleType::kPower ? 10.0 : 20.0;
float amin = 1e-10;
float db_multiplier = std::log10(std::max(amin_, ref_value_));
// typecast
CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING,
"AmplitudeToDB: input type should be float, but got string.");
if (input->type() != DataType::DE_FLOAT64) {
CHECK_FAIL_RETURN_UNEXPECTED(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)),
"AmplitudeToDB: input type should be float, but got " + input->type().ToString());
return AmplitudeToDB<float>(input_tensor, output, multiplier, amin, db_multiplier, top_db);
} else {
input_tensor = input;
return AmplitudeToDB<double>(input_tensor, output, multiplier, amin, db_multiplier, top_db);
}
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AMPLITUDE_TO_DB_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AMPLITUDE_TO_DB_OP_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
class AmplitudeToDBOp : public TensorOp {
public:
AmplitudeToDBOp(ScaleType stype, float ref_value, float amin, float top_db)
: stype_(stype), ref_value_(ref_value), amin_(amin), top_db_(top_db) {}
~AmplitudeToDBOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kAmplitudeToDBOp; }
private:
ScaleType stype_;
float ref_value_;
float amin_;
float top_db_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AMPLITUDE_TO_DB_OP_H_

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/audio_utils.h"
namespace mindspore {
namespace dataset {
template <typename T>
Status AmplitudeToDB(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, T multiplier, T amin,
T db_multiplier, T top_db) {
TensorShape input_shape = input->shape();
TensorShape to_shape = input_shape.Rank() == 2
? TensorShape({1, 1, input_shape[-2], input_shape[-1]})
: TensorShape({input->Size() / (input_shape[-3] * input_shape[-2] * input_shape[-1]),
input_shape[-3], input_shape[-2], input_shape[-1]});
RETURN_IF_NOT_OK(input->Reshape(to_shape));
std::vector<T> max_val;
int step = to_shape[-3] * input_shape[-2] * input_shape[-1];
int cnt = 0;
T temp_max = std::numeric_limits<T>::lowest();
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++) {
// do clamp
*itr = *itr < amin ? log10(amin) * multiplier : log10(*itr) * multiplier;
*itr -= multiplier * db_multiplier;
// calculate max by axis
cnt++;
if ((*itr) > temp_max) temp_max = *itr;
if (cnt % step == 0) {
max_val.push_back(temp_max);
temp_max = std::numeric_limits<T>::lowest();
}
}
if (!std::isnan(top_db)) {
int ind = 0;
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
float lower_bound = max_val[ind / step] - top_db;
*itr = std::max((*itr), static_cast<T>(lower_bound));
}
}
RETURN_IF_NOT_OK(input->Reshape(input_shape));
*output = input;
return Status::OK();
}
template Status AmplitudeToDB<float>(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
float multiplier, float amin, float db_multiplier, float top_db);
template Status AmplitudeToDB<double>(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
double multiplier, double amin, double db_multiplier, double top_db);
} // namespace dataset
} // namespace mindspore

View File

@ -17,8 +17,11 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
#include <algorithm>
#include <cmath> #include <cmath>
#include <limits>
#include <memory> #include <memory>
#include <string>
#include <vector> #include <vector>
#include "minddata/dataset/core/tensor.h" #include "minddata/dataset/core/tensor.h"
@ -28,6 +31,17 @@
constexpr double PI = 3.141592653589793; constexpr double PI = 3.141592653589793;
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
/// \brief Turn a tensor from the power/amplitude scale to the decibel scale.
/// \param input/output: Tensor of shape <...,freq,time>
/// \param multiplier: power - 10, amplitude - 20
/// \param amin: lower bound
/// \param db_multiplier: multiplier for decibels
/// \param top_db: the lower bound for decibels cut-off
/// \return Status code
template <typename T>
Status AmplitudeToDB(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, T multiplier, T amin,
T db_multiplier, T top_db);
/// \brief Calculate the angles of the complex numbers /// \brief Calculate the angles of the complex numbers
/// \param input/output: Tensor of shape <...,time> /// \param input/output: Tensor of shape <...,time>
template <typename T> template <typename T>
@ -162,7 +176,6 @@ Status LFilter(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *ou
delete m_py; delete m_py;
return Status::OK(); return Status::OK();
} }
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_

View File

@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_ #define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_DATASET_AUDIO_H_
#include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
@ -92,6 +93,31 @@ class AllpassBiquad final : public TensorTransform {
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };
/// \brief AmplitudeToDB TensorTransform.
/// \notes Turn a tensor from the power/amplitude scale to the decibel scale.
class AmplitudeToDB final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] stype ['kPower', 'kMagnitude']
/// \param[in] ref_value Calculate db_multiplier
/// \param[in] amin Clamp the input waveform
/// \param[in] top_db Decibels cut-off value
explicit AmplitudeToDB(ScaleType stype = ScaleType::kPower, float ref_value = 1.0, float amin = 1e-10,
float top_db = 80.0);
/// \brief Destructor.
~AmplitudeToDB() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief Design two-pole band-pass filter. /// \brief Design two-pole band-pass filter.
class BandpassBiquad final : public TensorTransform { class BandpassBiquad final : public TensorTransform {
public: public:

View File

@ -49,6 +49,12 @@ enum class ShuffleMode {
kInfile = 3 ///< Shuffle data within each file. kInfile = 3 ///< Shuffle data within each file.
}; };
/// \brief Possible scale for input audio.
enum class ScaleType {
kMagnitude = 0, ///< Audio scale is magnitude.
kPower = 1, ///< Audio scale is power.
};
/// \brief The method of padding. /// \brief The method of padding.
enum class BorderType { enum class BorderType {
kConstant = 0, ///< Fill the border with constant values. kConstant = 0, ///< Fill the border with constant values.

View File

@ -139,6 +139,7 @@ constexpr char kSentencepieceTokenizerOp[] = "SentencepieceTokenizerOp";
// audio // audio
constexpr char kAllpassBiquadOp[] = "AllpassBiquadOp"; constexpr char kAllpassBiquadOp[] = "AllpassBiquadOp";
constexpr char kAmplitudeToDBOp[] = "AmplitudeToDBOp";
constexpr char kAngleOp[] = "AngleOp"; constexpr char kAngleOp[] = "AngleOp";
constexpr char kBandBiquadOp[] = "BandBiquadOp"; constexpr char kBandBiquadOp[] = "BandBiquadOp";
constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp"; constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp";

View File

@ -20,8 +20,9 @@ to improve their training models.
import mindspore._c_dataengine as cde import mindspore._c_dataengine as cde
import numpy as np import numpy as np
from ..transforms.c_transforms import TensorOperation from ..transforms.c_transforms import TensorOperation
from .validators import check_allpass_biquad, check_band_biquad, check_bandpass_biquad, check_bandreject_biquad, \ from .utils import ScaleType
check_bass_biquad from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad
class AudioTensorOperation(TensorOperation): class AudioTensorOperation(TensorOperation):
@ -74,6 +75,42 @@ class AllpassBiquad(AudioTensorOperation):
return cde.AllpassBiquadOperation(self.sample_rate, self.central_freq, self.Q) return cde.AllpassBiquadOperation(self.sample_rate, self.central_freq, self.Q)
DE_C_SCALETYPE_TYPE = {ScaleType.MAGNITUDE: cde.ScaleType.DE_SCALETYPE_MAGNITUDE,
ScaleType.POWER: cde.ScaleType.DE_SCALETYPE_POWER}
class AmplitudeToDB(AudioTensorOperation):
"""
Converts the input tensor from amplitude/power scale to decibel scale.
Args:
stype (ScaleType, optional): Scale of the input tensor. (Default="ScaleType.POWER").
It can be any of [ScaleType.MAGNITUDE, ScaleType.POWER].
ref_value (float, optional): Param for generate db_multiplier.
amin (float, optional): Lower bound to clamp the input waveform.
top_db (float, optional): Minimum cut-off decibels. The range of values is non-negative. Commonly set at 80.
(Default=80.0)
Examples:
>>> channel = 1
>>> n_fft = 400
>>> n_frame = 30
>>> specrogram = np.random.random([channel, n_fft//2+1, n_frame])
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=specrogram, column_names=["audio"])
>>> transforms = [audio.AmplitudeToDB(stype=ScaleType.POWER)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@ check_amplitude_to_db
def __init__(self, stype=ScaleType.POWER, ref_value=1.0, amin=1e-10, top_db=80.0):
self.stype = stype
self.ref_value = ref_value
self.amin = amin
self.top_db = top_db
def parse(self):
return cde.AmplitudeToDBOperation(DE_C_SCALETYPE_TYPE[self.stype], self.ref_value, self.amin, self.top_db)
class Angle(AudioTensorOperation): class Angle(AudioTensorOperation):
""" """
Calculate the angle of the complex number sequence of shape (..., 2). Calculate the angle of the complex number sequence of shape (..., 2).

View File

@ -0,0 +1,23 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
enum for audio ops
"""
from enum import Enum
class ScaleType(str, Enum):
"""Scale Type"""
POWER: str = "power"
MAGNITUDE: str = "magnitude"

View File

@ -16,8 +16,41 @@
Validators for TensorOps. Validators for TensorOps.
""" """
from functools import wraps from functools import wraps
from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, \
check_value_normalize_std, check_value_ratio, FLOAT_MAX_INTEGER, parse_user_args, type_check
from .utils import ScaleType
from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, check_value_normalize_std, parse_user_args, type_check
def check_amplitude_to_db(method):
"""Wrapper method to check the parameters of amplitude_to_db."""
@wraps(method)
def new_method(self, *args, **kwargs):
[stype, ref_value, amin, top_db], _ = parse_user_args(method, *args, **kwargs)
# type check stype
type_check(stype, (ScaleType,), "stype")
# type check ref_value
type_check(ref_value, (int, float), "ref_value")
# value check ref_value
if not ref_value is None:
check_value_ratio(ref_value, (0, FLOAT_MAX_INTEGER), "ref_value")
# type check amin
type_check(amin, (int, float), "amin")
# value check amin
if not amin is None:
check_value_ratio(amin, (0, FLOAT_MAX_INTEGER), "amin")
# type check top_db
type_check(top_db, (int, float), "top_db")
# value check top_db
if not top_db is None:
check_value_ratio(top_db, (0, FLOAT_MAX_INTEGER), "top_db")
return method(self, *args, **kwargs)
return new_method
def check_biquad_sample_rate(sample_rate): def check_biquad_sample_rate(sample_rate):

View File

@ -30,6 +30,65 @@ class MindDataTestPipeline : public UT::DatasetOpTesting {
protected: protected:
}; };
TEST_F(MindDataTestPipeline, TestAmplitudeToDBPipeline) {
MS_LOG(INFO) << "Basic Function Test";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto amplitude_to_db_op = audio::AmplitudeToDB();
ds = ds->Map({amplitude_to_db_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, 200};
int i = 0;
while (row.size() != 0) {
auto col = row["inputData"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestAmplitudeToDBWrongArgs) {
MS_LOG(INFO) << "Basic Function Test";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto amplitude_to_db_op = audio::AmplitudeToDB(ScaleType::kPower, 1.0, -1e-10, 80.0);
ds = ds->Map({amplitude_to_db_op});
EXPECT_NE(ds, nullptr);
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure
EXPECT_EQ(iter, nullptr);
}
TEST_F(MindDataTestPipeline, Level0_TestBandBiquad001) { TEST_F(MindDataTestPipeline, Level0_TestBandBiquad001) {
MS_LOG(INFO) << "Basic Function Test"; MS_LOG(INFO) << "Basic Function Test";
// Original waveform // Original waveform
@ -313,7 +372,7 @@ TEST_F(MindDataTestPipeline, Level0_TestBassBiquad001) {
ds = ds->SetNumWorkers(4); ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
auto BassBiquadOp = audio::BassBiquad(44100,50,200.0); auto BassBiquadOp = audio::BassBiquad(44100, 50, 200.0);
ds = ds->Map({BassBiquadOp}); ds = ds->Map({BassBiquadOp});
EXPECT_NE(ds, nullptr); EXPECT_NE(ds, nullptr);
@ -353,7 +412,7 @@ TEST_F(MindDataTestPipeline, Level0_TestBassBiquad002) {
// Check sample_rate // Check sample_rate
MS_LOG(INFO) << "sample_rate is zero."; MS_LOG(INFO) << "sample_rate is zero.";
auto bass_biquad_op_01 = audio::BassBiquad(0,50,200.0); auto bass_biquad_op_01 = audio::BassBiquad(0, 50, 200.0);
ds01 = ds->Map({bass_biquad_op_01}); ds01 = ds->Map({bass_biquad_op_01});
EXPECT_NE(ds01, nullptr); EXPECT_NE(ds01, nullptr);
@ -362,7 +421,7 @@ TEST_F(MindDataTestPipeline, Level0_TestBassBiquad002) {
// Check Q_ // Check Q_
MS_LOG(INFO) << "Q_ is zero."; MS_LOG(INFO) << "Q_ is zero.";
auto bass_biquad_op_02 = audio::BassBiquad(44100,50,200.0,0); auto bass_biquad_op_02 = audio::BassBiquad(44100, 50, 200.0, 0);
ds02 = ds->Map({bass_biquad_op_02}); ds02 = ds->Map({bass_biquad_op_02});
EXPECT_NE(ds02, nullptr); EXPECT_NE(ds02, nullptr);

View File

@ -20,6 +20,7 @@
#include "minddata/dataset/include/dataset/execute.h" #include "minddata/dataset/include/dataset/execute.h"
#include "minddata/dataset/include/dataset/transforms.h" #include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/include/dataset/vision.h" #include "minddata/dataset/include/dataset/vision.h"
#include "minddata/dataset/include/dataset/audio.h"
#include "minddata/dataset/include/dataset/text.h" #include "minddata/dataset/include/dataset/text.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -99,6 +100,65 @@ TEST_F(MindDataTestExecute, TestAdjustGammaEager2) {
EXPECT_EQ(rc, Status::OK()); EXPECT_EQ(rc, Status::OK());
} }
TEST_F(MindDataTestExecute, TestAmplitudeToDB) {
MS_LOG(INFO) << "Basic Function Test With Eager.";
// Original waveform
std::vector<float> labels = {
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03,
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 2, 2, 3}), &input));
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> amplitude_to_db_op = std::make_shared<audio::AmplitudeToDB>();
// apply amplitude_to_db
mindspore::dataset::Execute trans({amplitude_to_db_op});
Status status = trans(input_ms, &input_ms);
EXPECT_TRUE(status.IsOk());
}
TEST_F(MindDataTestExecute, TestAmplitudeToDBWrongArgs) {
MS_LOG(INFO) << "Wrong Arg.";
// Original waveform
std::vector<float> labels = {
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> amplitude_to_db_op =
std::make_shared<audio::AmplitudeToDB>(ScaleType::kPower, 1.0, -1e-10, 80.0);
// apply amplitude_to_db
mindspore::dataset::Execute trans({amplitude_to_db_op});
Status status = trans(input_ms, &input_ms);
EXPECT_FALSE(status.IsOk());
}
TEST_F(MindDataTestExecute, TestAmplitudeToDBWrongInput) {
MS_LOG(INFO) << "Wrong Input.";
// Original waveform
std::vector<float> labels = {
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({20}), &input));
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> amplitude_to_db_op = std::make_shared<audio::AmplitudeToDB>();
// apply amplitude_to_db
mindspore::dataset::Execute trans({amplitude_to_db_op});
Status status = trans(input_ms, &input_ms);
EXPECT_FALSE(status.IsOk());
}
TEST_F(MindDataTestExecute, TestComposeTransforms) { TEST_F(MindDataTestExecute, TestComposeTransforms) {
MS_LOG(INFO) << "Doing TestComposeTransforms."; MS_LOG(INFO) << "Doing TestComposeTransforms.";

View File

@ -0,0 +1,137 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing AmplitudeToDB op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as c_audio
from mindspore import log as logger
from mindspore.dataset.audio.utils import ScaleType
CHANNEL = 1
FREQ = 20
TIME = 15
def gen(shape):
np.random.seed(0)
data = np.random.random(shape)
yield(np.array(data, dtype=np.float32),)
def _count_unequal_element(data_expected, data_me, rtol, atol):
""" Precision calculation func """
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
""" Precision calculation formula """
if np.any(np.isnan(data_expected)):
assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
_count_unequal_element(data_expected, data_me, rtol, atol)
else:
assert True
def test_func_amplitude_to_db_eager():
""" mindspore eager mode normal testcase:amplitude_to_db op"""
logger.info("check amplitude_to_db op output")
ndarr_in = np.array([[[[-0.2197528, 0.3821656]]],
[[[0.57418776, 0.46741104]]],
[[[-0.20381108, -0.9303914]]],
[[[0.3693608, -0.2017813]]],
[[[-1.727381, -1.3708513]]],
[[[1.259975, 0.4981323]]],
[[[0.76986176, -0.5793846]]]]).astype(np.float32)
# cal from benchmark
out_expect = np.array([[[[-84.17748, -4.177484]]],
[[[-2.4094608, -3.3030105]]],
[[[-100., -100.]]],
[[[-4.325492, -84.32549]]],
[[[-100., -100.]]],
[[[1.0036192, -3.0265532]]],
[[[-1.1358725, -81.13587]]]]).astype(np.float32)
amplitude_to_db_op = c_audio.AmplitudeToDB()
out_mindspore = amplitude_to_db_op(ndarr_in)
allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001)
def test_func_amplitude_to_db_pipeline():
""" mindspore pipeline mode normal testcase:amplitude_to_db op"""
logger.info("test AmplitudeToDB op with default value")
generator = gen([CHANNEL, FREQ, TIME])
data1 = ds.GeneratorDataset(source=generator, column_names=["multi_dimensional_data"])
transforms = [
c_audio.AmplitudeToDB()
]
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
out_put = item["multi_dimensional_data"]
assert out_put.shape == (CHANNEL, FREQ, TIME)
def test_amplitude_to_db_invalid_input():
def test_invalid_input(test_name, stype, ref_value, amin, top_db, error, error_msg):
logger.info("Test AmplitudeToDB with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
c_audio.AmplitudeToDB(stype=stype, ref_value=ref_value, amin=amin, top_db=top_db)
assert error_msg in str(error_info.value)
test_invalid_input("invalid stype parameter value", "test", 1.0, 1e-10, 80.0, TypeError,
"Argument stype with value test is not of type [<enum 'ScaleType'>], but got <class 'str'>.")
test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, -1.0, 1e-10, 80.0, ValueError,
"Input ref_value is not within the required interval of (0, 16777216]")
test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, -1e-10, 80.0, ValueError,
"Input amin is not within the required interval of (0, 16777216]")
test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, -80.0, ValueError,
"Input top_db is not within the required interval of (0, 16777216]")
test_invalid_input("invalid stype parameter value", True, 1.0, 1e-10, 80.0, TypeError,
"Argument stype with value True is not of type [<enum 'ScaleType'>], but got <class 'bool'>.")
test_invalid_input("invalid ref_value parameter value", ScaleType.POWER, "value", 1e-10, 80.0, TypeError,
"Argument ref_value with value value is not of type [<class 'int'>, <class 'float'>], " +
"but got <class 'str'>")
test_invalid_input("invalid amin parameter value", ScaleType.POWER, 1.0, "value", -80.0, TypeError,
"Argument amin with value value is not of type [<class 'int'>, <class 'float'>], " +
"but got <class 'str'>")
test_invalid_input("invalid top_db parameter value", ScaleType.POWER, 1.0, 1e-10, "value", TypeError,
"Argument top_db with value value is not of type [<class 'int'>, <class 'float'>], " +
"but got <class 'str'>")
if __name__ == "__main__":
test_func_amplitude_to_db_eager()
test_func_amplitude_to_db_pipeline()
test_amplitude_to_db_invalid_input()