forked from mindspore-Ecosystem/mindspore
[feat] [assistant] [I3CKEE] add new audio operator ComputeDeltas
This commit is contained in:
parent
f593f6f95a
commit
4c3931f5c9
|
@ -25,6 +25,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/compute_deltas_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
|
@ -187,6 +188,20 @@ ComplexNorm::ComplexNorm(float power) : data_(std::make_shared<Data>(power)) {}
|
|||
|
||||
std::shared_ptr<TensorOperation> ComplexNorm::Parse() { return std::make_shared<ComplexNormOperation>(data_->power_); }
|
||||
|
||||
// ComputeDeltas Transform Operation.
|
||||
struct ComputeDeltas::Data {
|
||||
Data(int32_t win_length, BorderType pad_mode) : win_length_(win_length), pad_mode_(pad_mode) {}
|
||||
int32_t win_length_;
|
||||
BorderType pad_mode_;
|
||||
};
|
||||
|
||||
ComputeDeltas::ComputeDeltas(int32_t win_length, BorderType pad_mode)
|
||||
: data_(std::make_shared<Data>(win_length, pad_mode)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> ComputeDeltas::Parse() {
|
||||
return std::make_shared<ComputeDeltasOperation>(data_->win_length_, data_->pad_mode_);
|
||||
}
|
||||
|
||||
// Contrast Transform Operation.
|
||||
struct Contrast::Data {
|
||||
explicit Data(float enhancement_amount) : enhancement_amount_(enhancement_amount) {}
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/compute_deltas_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
|
@ -161,6 +162,17 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
ComputeDeltasOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::ComputeDeltasOperation, TensorOperation, std::shared_ptr<audio::ComputeDeltasOperation>>(
|
||||
*m, "ComputeDeltasOperation")
|
||||
.def(py::init([](int32_t win_length, BorderType pad_mode) {
|
||||
auto compute_deltas = std::make_shared<audio::ComputeDeltasOperation>(win_length, pad_mode);
|
||||
THROW_IF_ERROR(compute_deltas->ValidateParams());
|
||||
return compute_deltas;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(ContrastOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<audio::ContrastOperation, TensorOperation, std::shared_ptr<audio::ContrastOperation>>(
|
||||
|
|
|
@ -11,6 +11,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
bass_biquad_ir.cc
|
||||
biquad_ir.cc
|
||||
complex_norm_ir.cc
|
||||
compute_deltas_ir.cc
|
||||
contrast_ir.cc
|
||||
db_to_amplitude_ir.cc
|
||||
dc_shift_ir.cc
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/compute_deltas_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/compute_deltas_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
ComputeDeltasOperation::ComputeDeltasOperation(int32_t win_length, BorderType pad_mode)
|
||||
: win_length_(win_length), pad_mode_(pad_mode) {}
|
||||
|
||||
std::shared_ptr<TensorOp> ComputeDeltasOperation::Build() {
|
||||
return std::make_shared<ComputeDeltasOp>(win_length_, pad_mode_);
|
||||
}
|
||||
|
||||
Status ComputeDeltasOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["win_length"] = win_length_;
|
||||
args["pad_mode"] = pad_mode_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeDeltasOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateScalar("ComputeDeltas", "win_length", win_length_, {3}, false));
|
||||
if (pad_mode_ != BorderType::kConstant && pad_mode_ != BorderType::kEdge && pad_mode_ != BorderType::kReflect &&
|
||||
pad_mode_ != BorderType::kSymmetric) {
|
||||
std::string err_msg = "ComputeDeltas: invalid BorderType, please check input value of enum.";
|
||||
LOG_AND_RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPUTE_DELTAS_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPUTE_DELTAS_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kComputeDeltasOperation[] = "ComputeDeltas";
|
||||
|
||||
class ComputeDeltasOperation : public TensorOperation {
|
||||
public:
|
||||
ComputeDeltasOperation(int32_t win_length, BorderType pad_mode);
|
||||
|
||||
~ComputeDeltasOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kComputeDeltasOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t win_length_;
|
||||
BorderType pad_mode_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_COMPUTE_DELTAS_IR_H_
|
|
@ -12,6 +12,7 @@ add_library(audio-kernels OBJECT
|
|||
bass_biquad_op.cc
|
||||
biquad_op.cc
|
||||
complex_norm_op.cc
|
||||
compute_deltas_op.cc
|
||||
contrast_op.cc
|
||||
db_to_amplitude_op.cc
|
||||
dc_shift_op.cc
|
||||
|
|
|
@ -1063,5 +1063,147 @@ Status SlidingWindowCmn(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t pad_left, int32_t pad_right,
|
||||
BorderType padding_mode, T value = 0) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "Pad: input tensor is not in shape of <..., time>.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type().IsNumeric(),
|
||||
"Pad: input tensor type should be int, float or double, but got: " + input->type().ToString());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(pad_left >= 0 && pad_right >= 0,
|
||||
"Pad: left and right padding values must be non negative, but got pad_left: " +
|
||||
std::to_string(pad_left) + " and pad_right: " + std::to_string(pad_right));
|
||||
TensorShape input_shape = input->shape();
|
||||
int32_t wave_length = input_shape[-1];
|
||||
int32_t num_wavs = static_cast<int32_t>(input->Size() / wave_length);
|
||||
TensorShape to_shape = TensorShape({num_wavs, wave_length});
|
||||
RETURN_IF_NOT_OK(input->Reshape(to_shape));
|
||||
int32_t pad_length = wave_length + pad_left + pad_right;
|
||||
TensorShape new_shape = TensorShape({num_wavs, pad_length});
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(new_shape, input->type(), output));
|
||||
using MatrixXT = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using Eigen::Map;
|
||||
constexpr int pad_mul = 2;
|
||||
T *input_data = reinterpret_cast<T *>(const_cast<uchar *>(input->GetBuffer()));
|
||||
T *output_data = reinterpret_cast<T *>(const_cast<uchar *>((*output)->GetBuffer()));
|
||||
auto input_map = Map<MatrixXT>(input_data, num_wavs, wave_length);
|
||||
auto output_map = Map<MatrixXT>(output_data, num_wavs, pad_length);
|
||||
output_map.block(0, pad_left, num_wavs, wave_length) = input_map;
|
||||
if (padding_mode == BorderType::kConstant) {
|
||||
output_map.block(0, 0, num_wavs, pad_left).setConstant(value);
|
||||
output_map.block(0, pad_left + wave_length, num_wavs, pad_right).setConstant(value);
|
||||
} else if (padding_mode == BorderType::kEdge) {
|
||||
output_map.block(0, 0, num_wavs, pad_left).colwise() = input_map.col(0);
|
||||
output_map.block(0, pad_left + wave_length, num_wavs, pad_right).colwise() = input_map.col(wave_length - 1);
|
||||
} else if (padding_mode == BorderType::kReflect) {
|
||||
// First, deal with the pad operation on the right.
|
||||
int32_t current_pad = wave_length - 1;
|
||||
while (pad_right >= current_pad) {
|
||||
// current_pad: the length of pad required for current loop.
|
||||
// pad_right: the length of the remaining pad on the right.
|
||||
output_map.block(0, pad_left + current_pad + 1, num_wavs, current_pad) =
|
||||
output_map.block(0, pad_left, num_wavs, current_pad).rowwise().reverse();
|
||||
pad_right -= current_pad;
|
||||
current_pad += current_pad;
|
||||
}
|
||||
output_map.block(0, pad_length - pad_right, num_wavs, pad_right) =
|
||||
output_map.block(0, pad_length - pad_right * pad_mul - 1, num_wavs, pad_right).rowwise().reverse();
|
||||
// Next, deal with the pad operation on the left.
|
||||
current_pad = wave_length - 1;
|
||||
while (pad_left >= current_pad) {
|
||||
// current_pad: the length of pad required for current loop.
|
||||
// pad_left: the length of the remaining pad on the left.
|
||||
output_map.block(0, pad_left - current_pad, num_wavs, current_pad) =
|
||||
output_map.block(0, pad_left + 1, num_wavs, current_pad).rowwise().reverse();
|
||||
pad_left -= current_pad;
|
||||
current_pad += current_pad;
|
||||
}
|
||||
output_map.block(0, 0, num_wavs, pad_left) =
|
||||
output_map.block(0, pad_left + 1, num_wavs, pad_left).rowwise().reverse();
|
||||
} else if (padding_mode == BorderType::kSymmetric) {
|
||||
// First, deal with the pad operation on the right.
|
||||
int32_t current_pad = wave_length;
|
||||
while (pad_right >= current_pad) {
|
||||
// current_pad: the length of pad required for current loop.
|
||||
// pad_right: the length of the remaining pad on the right.
|
||||
output_map.block(0, pad_left + current_pad, num_wavs, current_pad) =
|
||||
output_map.block(0, pad_left, num_wavs, current_pad).rowwise().reverse();
|
||||
pad_right -= current_pad;
|
||||
current_pad += current_pad;
|
||||
}
|
||||
output_map.block(0, pad_length - pad_right, num_wavs, pad_right) =
|
||||
output_map.block(0, pad_length - pad_right * pad_mul, num_wavs, pad_right).rowwise().reverse();
|
||||
// Next, deal with the pad operation on the left.
|
||||
current_pad = wave_length;
|
||||
while (pad_left >= current_pad) {
|
||||
// current_pad: the length of pad required for current loop.
|
||||
// pad_left: the length of the remaining pad on the left.
|
||||
output_map.block(0, pad_left - current_pad, num_wavs, current_pad) =
|
||||
output_map.block(0, pad_left, num_wavs, current_pad).rowwise().reverse();
|
||||
pad_left -= current_pad;
|
||||
current_pad += current_pad;
|
||||
}
|
||||
output_map.block(0, 0, num_wavs, pad_left) = output_map.block(0, pad_left, num_wavs, pad_left).rowwise().reverse();
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Pad: unsupported border type.");
|
||||
}
|
||||
std::vector<dsize_t> shape_vec = input_shape.AsVector();
|
||||
shape_vec[shape_vec.size() - 1] = static_cast<dsize_t>(pad_length);
|
||||
TensorShape output_shape(shape_vec);
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(output_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeDeltasImpl(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int all_freqs,
|
||||
int n_frame, int n) {
|
||||
using VectorXT = Eigen::Matrix<T, Eigen::Dynamic, 1>;
|
||||
using MatrixXT = Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using Eigen::Map;
|
||||
int32_t denom = n * (n + 1) * (n * 2 + 1) / 3;
|
||||
// twice sum of integer squared
|
||||
VectorXT kernel = VectorXT::LinSpaced(2 * n + 1, -n, n); // 2n+1
|
||||
T *input_data = reinterpret_cast<T *>(const_cast<uchar *>(input->GetBuffer())); // [all_freq,n_fram+2n]
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape{all_freqs, n_frame}, input->type(), output));
|
||||
T *output_data = reinterpret_cast<T *>(const_cast<uchar *>((*output)->GetBuffer()));
|
||||
for (int freq = 0; freq < all_freqs; ++freq) { // conv with im2col
|
||||
auto input_map = Map<MatrixXT, 0, Eigen::OuterStride<1>>(input_data + freq * (n_frame + 2 * n), n_frame,
|
||||
2 * n + 1); // n_frmae,2n+1
|
||||
Map<VectorXT>(output_data + freq * n_frame, n_frame) = (input_map * kernel).array() / T(denom);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeDeltas(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length,
|
||||
const BorderType &mode) {
|
||||
constexpr int min_shape_dim = 2;
|
||||
auto raw_shape = input->shape();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(raw_shape.Size() >= min_shape_dim,
|
||||
"ComputeDeltas: input tensor is not in shape of <..., freq, time>.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type().IsNumeric(),
|
||||
"ComputeDeltas: input tensor type should be int, float or double, but got: " + input->type().ToString());
|
||||
|
||||
// reshape Tensor from <..., freq, time> to <-1, time>
|
||||
int32_t n_frames = raw_shape[-1];
|
||||
int32_t all_freqs = raw_shape.NumOfElements() / n_frames;
|
||||
RETURN_IF_NOT_OK(input->Reshape(TensorShape{all_freqs, n_frames}));
|
||||
|
||||
int32_t n = (win_length - 1) / 2;
|
||||
|
||||
std::shared_ptr<Tensor> specgram_local_pad;
|
||||
if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
RETURN_IF_NOT_OK(Pad<double>(input, &specgram_local_pad, n, n, mode));
|
||||
RETURN_IF_NOT_OK(ComputeDeltasImpl<double>(specgram_local_pad, output, all_freqs, n_frames, n));
|
||||
} else {
|
||||
std::shared_ptr<Tensor> float_tensor;
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &float_tensor, DataType(DataType::DE_FLOAT32)));
|
||||
RETURN_IF_NOT_OK(Pad<float>(float_tensor, &specgram_local_pad, n, n, mode));
|
||||
RETURN_IF_NOT_OK(ComputeDeltasImpl<float>(specgram_local_pad, output, all_freqs, n_frames, n));
|
||||
}
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(raw_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1030,6 +1030,14 @@ Status ReadWaveFile(const std::string &wav_file_dir, std::vector<float> *wavefor
|
|||
/// \return Status code.
|
||||
Status SlidingWindowCmn(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t cmn_window,
|
||||
int32_t min_cmn_window, bool center, bool norm_vars);
|
||||
/// \brief Compute delta coefficients of a tensor, usually a spectrogram.
|
||||
/// \param input: Tensor of shape <...,freq,time>.
|
||||
/// \param output: Tensor of shape <...,freq,time>.
|
||||
/// \param win_length: The window length used for computing delta.
|
||||
/// \param mode: Padding mode.
|
||||
/// \return Status code.
|
||||
Status ComputeDeltas(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length,
|
||||
const BorderType &mode);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/kernels/compute_deltas_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
ComputeDeltasOp::ComputeDeltasOp(int32_t win_length, BorderType mode) : win_length_(win_length), mode_(mode) {}
|
||||
|
||||
Status ComputeDeltasOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
return ComputeDeltas(input, output, win_length_, mode_);
|
||||
}
|
||||
|
||||
Status ComputeDeltasOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
if (!inputs[0].IsNumeric()) {
|
||||
RETURN_STATUS_UNEXPECTED("ComputeDeltas: input tensor type should be int, float or double, but got: " +
|
||||
inputs[0].ToString());
|
||||
} else if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_COMPUTE_DELTAS_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_COMPUTE_DELTAS_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class ComputeDeltasOp : public TensorOp {
|
||||
public:
|
||||
explicit ComputeDeltasOp(int32_t win_length = 5, BorderType mode = BorderType::kEdge);
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kComputeDeltasOp; };
|
||||
|
||||
~ComputeDeltasOp() = default;
|
||||
|
||||
private:
|
||||
int32_t win_length_;
|
||||
BorderType mode_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_COMPUTE_DELTAS_OP_H_
|
|
@ -236,6 +236,28 @@ class ComplexNorm final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief ComputeDeltas Transform.
|
||||
/// \note Compute delta coefficients of a spectrogram.
|
||||
class ComputeDeltas final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Construct a new Compute Deltas object.
|
||||
/// \param[in] win_length The window length used for computing delta, must be no less than 3 (Default: 5).
|
||||
/// \param[in] pad_mode Mode parameter passed to padding (Default: BorderType::kEdge).
|
||||
explicit ComputeDeltas(int32_t win_length = 5, BorderType pad_mode = BorderType::kEdge);
|
||||
|
||||
/// \brief Destructor.
|
||||
~ComputeDeltas() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply contrast effect.
|
||||
class Contrast final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -153,6 +153,7 @@ constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp";
|
|||
constexpr char kBassBiquadOp[] = "BassBiquadOp";
|
||||
constexpr char kBiquadOp[] = "BiquadOp";
|
||||
constexpr char kComplexNormOp[] = "ComplexNormOp";
|
||||
constexpr char kComputeDeltasOp[] = "ComputeDeltasOp";
|
||||
constexpr char kContrastOp[] = "ContrastOp";
|
||||
constexpr char kDBToAmplitudeOp[] = " DBToAmplitudeOp";
|
||||
constexpr char kDCShiftOp[] = "DCShiftOp";
|
||||
|
|
|
@ -23,13 +23,13 @@ import numpy as np
|
|||
|
||||
import mindspore._c_dataengine as cde
|
||||
from ..transforms.c_transforms import TensorOperation
|
||||
from .utils import FadeShape, GainType, Interpolation, Modulation, ScaleType
|
||||
from .utils import BorderType, FadeShape, GainType, Interpolation, Modulation, ScaleType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, \
|
||||
check_db_to_amplitude, check_dc_shift, check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, \
|
||||
check_fade, check_flanger, check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, \
|
||||
check_masking, check_mu_law_coding, check_overdrive, check_phaser, check_riaa_biquad, check_sliding_window_cmn, \
|
||||
check_time_stretch, check_treble_biquad, check_vol
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_compute_deltas, \
|
||||
check_contrast, check_db_to_amplitude, check_dc_shift, check_deemph_biquad, check_detect_pitch_frequency, \
|
||||
check_equalizer_biquad, check_fade, check_flanger, check_highpass_biquad, check_lfilter, check_lowpass_biquad, \
|
||||
check_magphase, check_masking, check_mu_law_coding, check_overdrive, check_phaser, check_riaa_biquad, \
|
||||
check_sliding_window_cmn, check_time_stretch, check_treble_biquad, check_vol
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -305,6 +305,51 @@ class ComplexNorm(AudioTensorOperation):
|
|||
return cde.ComplexNormOperation(self.power)
|
||||
|
||||
|
||||
DE_C_BORDER_TYPE = {
|
||||
BorderType.CONSTANT: cde.BorderType.DE_BORDER_CONSTANT,
|
||||
BorderType.EDGE: cde.BorderType.DE_BORDER_EDGE,
|
||||
BorderType.REFLECT: cde.BorderType.DE_BORDER_REFLECT,
|
||||
BorderType.SYMMETRIC: cde.BorderType.DE_BORDER_SYMMETRIC,
|
||||
}
|
||||
|
||||
|
||||
class ComputeDeltas(AudioTensorOperation):
|
||||
"""
|
||||
Compute delta coefficients of a spectrogram.
|
||||
|
||||
Args:
|
||||
win_length (int): The window length used for computing delta, must be no less than 3 (default=5).
|
||||
mode (BorderType): Mode parameter passed to padding (default=BorderType.EDGE).It can be any of
|
||||
[BorderType.CONSTANT, BorderType.EDGE, BorderType.REFLECT, BordBorderTypeer.SYMMETRIC].
|
||||
|
||||
- BorderType.CONSTANT, means it fills the border with constant values.
|
||||
|
||||
- BorderType.EDGE, means it pads with the last value on the edge.
|
||||
|
||||
- BorderType.REFLECT, means it reflects the values on the edge omitting the last
|
||||
value of edge.
|
||||
|
||||
- BorderType.SYMMETRIC, means it reflects the values on the edge repeating the last
|
||||
value of edge.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 400//2+1, 30])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.ComputeDeltas(win_length=7, pad_mode = BorderType.EDGE)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_compute_deltas
|
||||
def __init__(self, win_length=5, pad_mode=BorderType.EDGE):
|
||||
self.win_len = win_length
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def parse(self):
|
||||
return cde.ComputeDeltasOperation(self.win_len, DE_C_BORDER_TYPE[self.pad_mode])
|
||||
|
||||
|
||||
class Contrast(AudioTensorOperation):
|
||||
"""
|
||||
Apply contrast effect. Similar to SoX implementation.
|
||||
|
|
|
@ -141,3 +141,23 @@ def CreateDct(n_mfcc, n_mels, norm=NormMode.NONE):
|
|||
if n_mels <= 0:
|
||||
raise ValueError("n_mels must be greater than 0, but got {0}.".format(n_mels))
|
||||
return cde.CreateDct(n_mfcc, n_mels, DE_C_NORMMODE_TYPE[norm]).as_array()
|
||||
|
||||
|
||||
class BorderType(str, Enum):
|
||||
"""
|
||||
Padding Mode, BorderType Type.
|
||||
|
||||
Possible enumeration values are: BorderType.CONSTANT, BorderType.EDGE, BorderType.REFLECT, BorderType.SYMMETRIC.
|
||||
|
||||
- BorderType.CONSTANT: means it fills the border with constant values.
|
||||
- BorderType.EDGE: means it pads with the last value on the edge.
|
||||
- BorderType.REFLECT: means it reflects the values on the edge omitting the last value of edge.
|
||||
- BorderType.SYMMETRIC: means it reflects the values on the edge repeating the last value of edge.
|
||||
|
||||
Note: This class derived from class str to support json serializable.
|
||||
"""
|
||||
CONSTANT: str = "constant"
|
||||
EDGE: str = "edge"
|
||||
REFLECT: str = "reflect"
|
||||
SYMMETRIC: str = "symmetric"
|
||||
|
|
@ -20,8 +20,8 @@ from functools import wraps
|
|||
|
||||
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32,\
|
||||
check_int32_not_zero, check_list_same_size, check_non_negative_float32, check_non_negative_int32, \
|
||||
check_pos_float32, check_pos_int32, check_value, parse_user_args, type_check
|
||||
from .utils import FadeShape, GainType, Interpolation, Modulation, ScaleType
|
||||
check_pos_float32, check_pos_int32, check_value, INT32_MAX, parse_user_args, type_check
|
||||
from .utils import BorderType, FadeShape, GainType, Interpolation, Modulation, ScaleType
|
||||
|
||||
|
||||
def check_amplitude_to_db(method):
|
||||
|
@ -576,3 +576,16 @@ def check_sliding_window_cmn(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_compute_deltas(method):
|
||||
"""Wrapper method to check the parameter of ComputeDeltas."""
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[win_length, pad_mode], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(pad_mode, (BorderType,), "pad_mode")
|
||||
type_check(win_length, (int,), "win_length")
|
||||
check_value(win_length, (3, INT32_MAX), "win_length")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -1957,3 +1957,66 @@ TEST_F(MindDataTestPipeline, TestDBToAmplitudePipeline) {
|
|||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: ComputeDeltas
|
||||
/// Description: test basic function of ComputeDeltas
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestPipeline, TestComputeDeltas) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestComputeDeltas.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto compute_deltas_op = audio::ComputeDeltas();
|
||||
|
||||
ds = ds->Map({compute_deltas_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: ComputeDeltas
|
||||
/// Description: test wrong input args of ComputeDeltas
|
||||
/// Expectation: get nullptr of iterator
|
||||
TEST_F(MindDataTestPipeline, TestComputeDeltasWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestComputeDeltasWrongArgs.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
auto compute_deltas_op = audio::ComputeDeltas(2, mindspore::dataset::BorderType::kEdge);
|
||||
ds = ds->Map({compute_deltas_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
|
@ -177,6 +177,57 @@ TEST_F(MindDataTestExecute, TestComposeTransforms) {
|
|||
EXPECT_EQ(30, image.Shape()[1]);
|
||||
}
|
||||
|
||||
/// Feature: ComputeDeltas
|
||||
/// Description: test basic function of ComputeDeltas
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestExecute, TestComputeDeltas) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestComputeDeltas.";
|
||||
std::shared_ptr<Tensor> input_tensor_;
|
||||
|
||||
int win_length = 5;
|
||||
|
||||
// create tensor
|
||||
TensorShape s = TensorShape({2, 15, 7});
|
||||
// init input vec
|
||||
std::vector<float> input_vec(s.NumOfElements());
|
||||
for (int ind = 0; ind < input_vec.size(); ind++) {
|
||||
input_vec[ind] = std::rand() % (1000) / (1000.0f);
|
||||
}
|
||||
ASSERT_OK(Tensor::CreateFromVector(input_vec, s, &input_tensor_));
|
||||
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||
std::shared_ptr<TensorTransform> compute_deltas_op = std::make_shared<audio::ComputeDeltas>(win_length);
|
||||
|
||||
// apply compute_deltas
|
||||
mindspore::dataset::Execute Transform({compute_deltas_op});
|
||||
Status status = Transform(input_ms, &input_ms);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: ComputeDeltas
|
||||
/// Description: test wrong input args of ComputeDeltas
|
||||
/// Expectation: get nullptr of iterator
|
||||
TEST_F(MindDataTestExecute, TestComputeDeltasWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestComputeDeltasWrongArgs.";
|
||||
std::shared_ptr<Tensor> input_tensor_;
|
||||
// win_length is less than minimum of 3
|
||||
int win_length = 2;
|
||||
|
||||
// create tensor
|
||||
TensorShape s = TensorShape({2, 15, 7});
|
||||
// init input vec
|
||||
std::vector<float> input_vec(s.NumOfElements());
|
||||
for (int ind = 0; ind < input_vec.size(); ind++) {
|
||||
input_vec[ind] = std::rand() % (1000) / (1000.0f);
|
||||
}
|
||||
ASSERT_OK(Tensor::CreateFromVector(input_vec, s, &input_tensor_));
|
||||
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||
|
||||
std::shared_ptr<TensorTransform> compute_deltas_op = std::make_shared<audio::ComputeDeltas>(win_length);
|
||||
mindspore::dataset::Execute Transform({compute_deltas_op});
|
||||
Status status = Transform(input_ms, &input_ms);
|
||||
EXPECT_FALSE(status.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestCrop) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestCrop.";
|
||||
|
||||
|
@ -937,12 +988,9 @@ TEST_F(MindDataTestExecute, TestOverdriveBasicWithEager) {
|
|||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestExecute, TestOverdriveWrongArgWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestOverdriveWrongArgWithEager";
|
||||
std::vector<double> labels = {
|
||||
0.271, 1.634, 9.246, 0.108,
|
||||
1.138, 1.156, 3.394, 1.55,
|
||||
3.614, 1.8402, 0.718, 4.599,
|
||||
5.64, 2.510620117187500000e-02, 1.38, 5.825,
|
||||
4.1906, 5.28, 1.052, 9.36};
|
||||
std::vector<double> labels = {0.271, 1.634, 9.246, 0.108, 1.138, 1.156, 3.394,
|
||||
1.55, 3.614, 1.8402, 0.718, 4.599, 5.64, 2.510620117187500000e-02,
|
||||
1.38, 5.825, 4.1906, 5.28, 1.052, 9.36};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input));
|
||||
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing ComputeDeltas op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as c_audio
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.audio.utils import BorderType
|
||||
|
||||
CHANNEL = 1
|
||||
FREQ = 20
|
||||
TIME = 15
|
||||
|
||||
|
||||
def gen(shape):
|
||||
np.random.seed(0)
|
||||
data = np.random.random(shape)
|
||||
yield (np.array(data, dtype=np.float32),)
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
""" Precision calculation func """
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
|
||||
data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
|
||||
""" Precision calculation formula """
|
||||
if np.any(np.isnan(data_expected)):
|
||||
assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
|
||||
elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
|
||||
count_unequal_element(data_expected, data_me, rtol, atol)
|
||||
|
||||
|
||||
def test_compute_deltas_eager():
|
||||
"""
|
||||
Feature: test the basic function in eager mode.
|
||||
Description: mindspore eager mode normal testcase:compute_deltas op.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
|
||||
logger.info("check compute_deltas op output")
|
||||
ndarr_in = np.array([[[0.08746047, -0.33246294, 0.5240939, 0.6064913, -0.70366],
|
||||
[1.1420338, 0.50532603, 0.73435473, -0.83435977, -1.0607501],
|
||||
[-1.4298731, -0.86117035, -0.7773941, -0.60023546, 1.1807907],
|
||||
[0.4973711, 0.5299286, 0.818514, 0.7559297, -0.3418539],
|
||||
[-0.2824797, 0.30402678, 0.7848569, -0.4135576, 0.19522846],
|
||||
[-0.11636204, -0.4780833, 1.2691815, 0.9824286, 0.029275],
|
||||
[-1.2611166, -1.1957082, 0.26212585, 0.35354254, 0.3609486]]]).astype(np.float32)
|
||||
|
||||
out_expect = np.array([[[0.0453, 0.1475, -0.0643, -0.1970, -0.3766],
|
||||
[-0.1452, -0.4360, -0.5745, -0.4927, -0.3817],
|
||||
[0.1874, 0.2312, 0.5482, 0.6042, 0.5697],
|
||||
[0.0675, 0.0838, -0.1452, -0.2904, -0.3419],
|
||||
[0.2721, 0.0805, 0.0238, -0.0807, -0.0570],
|
||||
[0.2409, 0.3583, 0.1752, -0.0225, -0.3433],
|
||||
[0.3112, 0.4753, 0.4793, 0.3212, 0.0205]]]).astype(np.float32)
|
||||
|
||||
compute_deltas_op = c_audio.ComputeDeltas()
|
||||
out_mindspore = compute_deltas_op(ndarr_in)
|
||||
|
||||
allclose_nparray(out_mindspore, out_expect, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_compute_deltas_pipeline():
|
||||
"""
|
||||
Feature: test the basic function in pipeline mode.
|
||||
Description: mindspore pipeline mode normal testcase:compute_deltas op.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
|
||||
logger.info("test ComputeDeltas op with default value")
|
||||
generator = gen([CHANNEL, FREQ, TIME])
|
||||
|
||||
data1 = ds.GeneratorDataset(
|
||||
source=generator, column_names=["multi_dimensional_data"]
|
||||
)
|
||||
|
||||
transforms = [c_audio.ComputeDeltas()]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
out_put = item["multi_dimensional_data"]
|
||||
assert out_put.shape == (CHANNEL, FREQ, TIME)
|
||||
|
||||
|
||||
def test_compute_deltas_invalid_input():
|
||||
"""
|
||||
Feature: test the validate function with invalid parameters.
|
||||
Description: mindspore invalid parameters testcase:compute_deltas op.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
def test_invalid_input(test_name, win_length, pad_mode, error, error_msg):
|
||||
logger.info("Test ComputeDeltas with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
c_audio.ComputeDeltas(win_length=win_length, pad_mode=pad_mode)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input(
|
||||
"invalid win_length parameter value", "test", BorderType.EDGE, TypeError,
|
||||
"Argument win_length with value test is not of type [<class 'int'>], but got <class 'str'>.",
|
||||
)
|
||||
test_invalid_input(
|
||||
"invalid win_length parameter value", 2, BorderType.EDGE, ValueError,
|
||||
"Input win_length is not within the required interval of [3, 2147483647]",
|
||||
)
|
||||
test_invalid_input(
|
||||
"invalid pad_mode parameter value", 5, 2, TypeError,
|
||||
"Argument pad_mode with value 2 is not of type [<enum 'BorderType'>], but got <class 'int'>.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_compute_deltas_eager()
|
||||
test_compute_deltas_pipeline()
|
||||
test_compute_deltas_invalid_input()
|
Loading…
Reference in New Issue