forked from mindspore-Ecosystem/mindspore
[feat] [assistant] [I3CKEH] add new audio operator SlidingWindowCmn
This commit is contained in:
parent
8bf903ba19
commit
706f9e2cbb
|
@ -43,6 +43,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/phaser_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/sliding_window_cmn_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
|
||||
|
@ -496,6 +497,24 @@ std::shared_ptr<TensorOperation> RiaaBiquad::Parse() {
|
|||
return std::make_shared<RiaaBiquadOperation>(data_->sample_rate_);
|
||||
}
|
||||
|
||||
// SlidingWindowCmn Transform Operation.
|
||||
struct SlidingWindowCmn::Data {
|
||||
Data(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars)
|
||||
: cmn_window_(cmn_window), min_cmn_window_(min_cmn_window), center_(center), norm_vars_(norm_vars) {}
|
||||
int32_t cmn_window_;
|
||||
int32_t min_cmn_window_;
|
||||
bool center_;
|
||||
bool norm_vars_;
|
||||
};
|
||||
|
||||
SlidingWindowCmn::SlidingWindowCmn(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars)
|
||||
: data_(std::make_shared<Data>(cmn_window, min_cmn_window, center, norm_vars)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> SlidingWindowCmn::Parse() {
|
||||
return std::make_shared<SlidingWindowCmnOperation>(data_->cmn_window_, data_->min_cmn_window_, data_->center_,
|
||||
data_->norm_vars_);
|
||||
}
|
||||
|
||||
// TimeMasking Transform Operation.
|
||||
struct TimeMasking::Data {
|
||||
Data(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value)
|
||||
|
|
|
@ -47,6 +47,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/phaser_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/sliding_window_cmn_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
|
||||
|
@ -385,6 +386,17 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(SlidingWindowCmnOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::SlidingWindowCmnOperation, TensorOperation,
|
||||
std::shared_ptr<audio::SlidingWindowCmnOperation>>(*m, "SlidingWindowCmnOperation")
|
||||
.def(py::init([](int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars) {
|
||||
auto sliding_window_cmn = std::make_shared<audio::SlidingWindowCmnOperation>(
|
||||
cmn_window, min_cmn_window, center, norm_vars);
|
||||
THROW_IF_ERROR(sliding_window_cmn->ValidateParams());
|
||||
return sliding_window_cmn;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
TimeMaskingOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::TimeMaskingOperation, TensorOperation, std::shared_ptr<audio::TimeMaskingOperation>>(
|
||||
|
|
|
@ -29,6 +29,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
overdrive_ir.cc
|
||||
phaser_ir.cc
|
||||
riaa_biquad_ir.cc
|
||||
sliding_window_cmn_ir.cc
|
||||
time_masking_ir.cc
|
||||
time_stretch_ir.cc
|
||||
treble_biquad_ir.cc
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/ir/kernels/sliding_window_cmn_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/sliding_window_cmn_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
SlidingWindowCmnOperation::SlidingWindowCmnOperation(int32_t cmn_window, int32_t min_cmn_window, bool center,
|
||||
bool norm_vars)
|
||||
: cmn_window_(cmn_window), min_cmn_window_(min_cmn_window), center_(center), norm_vars_(norm_vars) {}
|
||||
|
||||
SlidingWindowCmnOperation::~SlidingWindowCmnOperation() = default;
|
||||
|
||||
Status SlidingWindowCmnOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("SlidingWindowCmn", "cmn_window", cmn_window_));
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("SlidingWindowCmn", "min_cmn_window", min_cmn_window_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SlidingWindowCmnOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["cmn_window"] = cmn_window_;
|
||||
args["min_cmn_window"] = min_cmn_window_;
|
||||
args["center"] = center_;
|
||||
args["norm_vars"] = norm_vars_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> SlidingWindowCmnOperation::Build() {
|
||||
std::shared_ptr<SlidingWindowCmnOp> tensor_op =
|
||||
std::make_shared<SlidingWindowCmnOp>(cmn_window_, min_cmn_window_, center_, norm_vars_);
|
||||
return tensor_op;
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_SLIDING_WINDOW_CMN_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_SLIDING_WINDOW_CMN_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kSlidingWindowCmnOperation[] = "SlidingWindowCmn";
|
||||
|
||||
class SlidingWindowCmnOperation : public TensorOperation {
|
||||
public:
|
||||
SlidingWindowCmnOperation(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars);
|
||||
|
||||
~SlidingWindowCmnOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kSlidingWindowCmnOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t cmn_window_;
|
||||
int32_t min_cmn_window_;
|
||||
bool center_;
|
||||
bool norm_vars_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_SLIDING_WINDOW_CMN_IR_H_
|
|
@ -30,6 +30,7 @@ add_library(audio-kernels OBJECT
|
|||
overdrive_op.cc
|
||||
phaser_op.cc
|
||||
riaa_biquad_op.cc
|
||||
sliding_window_cmn_op.cc
|
||||
time_masking_op.cc
|
||||
time_stretch_op.cc
|
||||
treble_biquad_op.cc
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
||||
#include <Eigen/Dense>
|
||||
#include <fstream>
|
||||
|
||||
#include "mindspore/core/base/float16.h"
|
||||
|
@ -888,5 +889,179 @@ Status ReadWaveFile(const std::string &wav_file_dir, std::vector<float> *wavefor
|
|||
delete header;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeCmnStartAndEnd(int32_t cmn_window, int32_t min_cmn_window, bool center, int32_t idx, int32_t num_frames,
|
||||
int32_t *cmn_window_start_p, int32_t *cmn_window_end_p) {
|
||||
RETURN_UNEXPECTED_IF_NULL(cmn_window_start_p);
|
||||
RETURN_UNEXPECTED_IF_NULL(cmn_window_end_p);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
cmn_window >= 0, "SlidingWindowCmn: cmn_window must be non negative, but got: " + std::to_string(cmn_window));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(min_cmn_window >= 0, "SlidingWindowCmn: min_cmn_window must be non negative, but got: " +
|
||||
std::to_string(min_cmn_window));
|
||||
int32_t cmn_window_start = 0, cmn_window_end = 0;
|
||||
constexpr int window_center = 2;
|
||||
if (center) {
|
||||
cmn_window_start = idx - cmn_window / window_center;
|
||||
cmn_window_end = cmn_window_start + cmn_window;
|
||||
} else {
|
||||
cmn_window_start = idx - cmn_window;
|
||||
cmn_window_end = idx + 1;
|
||||
}
|
||||
if (cmn_window_start < 0) {
|
||||
cmn_window_end -= cmn_window_start;
|
||||
cmn_window_start = 0;
|
||||
}
|
||||
if (!center) {
|
||||
if (cmn_window_end > idx) {
|
||||
cmn_window_end = std::max(idx + 1, min_cmn_window);
|
||||
}
|
||||
}
|
||||
if (cmn_window_end > num_frames) {
|
||||
cmn_window_start -= (cmn_window_end - num_frames);
|
||||
cmn_window_end = num_frames;
|
||||
if (cmn_window_start < 0) {
|
||||
cmn_window_start = 0;
|
||||
}
|
||||
}
|
||||
|
||||
*cmn_window_start_p = cmn_window_start;
|
||||
*cmn_window_end_p = cmn_window_end;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ComputeCmnWaveform(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *cmn_waveform_p,
|
||||
int32_t num_channels, int32_t num_frames, int32_t num_feats, int32_t cmn_window,
|
||||
int32_t min_cmn_window, bool center, bool norm_vars) {
|
||||
using ArrayXT = Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
constexpr int square_num = 2;
|
||||
int32_t last_window_start = -1, last_window_end = -1;
|
||||
ArrayXT cur_sum = ArrayXT(num_channels, num_feats);
|
||||
ArrayXT cur_sum_sq;
|
||||
if (norm_vars) {
|
||||
cur_sum_sq = ArrayXT(num_channels, num_feats);
|
||||
}
|
||||
for (int i = 0; i < num_frames; ++i) {
|
||||
int32_t cmn_window_start = 0, cmn_window_end = 0;
|
||||
RETURN_IF_NOT_OK(
|
||||
ComputeCmnStartAndEnd(cmn_window, min_cmn_window, center, i, num_frames, &cmn_window_start, &cmn_window_end));
|
||||
int32_t row = cmn_window_end - cmn_window_start * 2;
|
||||
int32_t cmn_window_frames = cmn_window_end - cmn_window_start;
|
||||
for (int32_t m = 0; m < num_channels; ++m) {
|
||||
if (last_window_start == -1) {
|
||||
auto it = reinterpret_cast<T *>(const_cast<uchar *>(input->GetBuffer()));
|
||||
it += (m * num_frames * num_feats + cmn_window_start * num_feats);
|
||||
auto tmp_map = Eigen::Map<ArrayXT>(it, row, num_feats);
|
||||
if (i > 0) {
|
||||
cur_sum.row(m) += tmp_map.colwise().sum();
|
||||
if (norm_vars) {
|
||||
cur_sum_sq.row(m) += tmp_map.pow(square_num).colwise().sum();
|
||||
}
|
||||
} else {
|
||||
cur_sum.row(m) = tmp_map.colwise().sum();
|
||||
if (norm_vars) {
|
||||
cur_sum_sq.row(m) = tmp_map.pow(square_num).colwise().sum();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (cmn_window_start > last_window_start) {
|
||||
auto it = reinterpret_cast<T *>(const_cast<uchar *>(input->GetBuffer()));
|
||||
it += (m * num_frames * num_feats + last_window_start * num_feats);
|
||||
auto tmp_map = Eigen::Map<ArrayXT>(it, 1, num_feats);
|
||||
cur_sum.row(m) -= tmp_map;
|
||||
if (norm_vars) {
|
||||
cur_sum_sq.row(m) -= tmp_map.pow(square_num);
|
||||
}
|
||||
}
|
||||
if (cmn_window_end > last_window_end) {
|
||||
auto it = reinterpret_cast<T *>(const_cast<uchar *>(input->GetBuffer()));
|
||||
it += (m * num_frames * num_feats + last_window_end * num_feats);
|
||||
auto tmp_map = Eigen::Map<ArrayXT>(it, 1, num_feats);
|
||||
cur_sum.row(m) += tmp_map;
|
||||
if (norm_vars) {
|
||||
cur_sum_sq.row(m) += tmp_map.pow(square_num);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto it = reinterpret_cast<T *>(const_cast<uchar *>(input->GetBuffer()));
|
||||
auto cmn_it = reinterpret_cast<T *>(const_cast<uchar *>((*cmn_waveform_p)->GetBuffer()));
|
||||
it += (m * num_frames * num_feats + i * num_feats);
|
||||
cmn_it += (m * num_frames * num_feats + i * num_feats);
|
||||
Eigen::Map<ArrayXT>(cmn_it, 1, num_feats) =
|
||||
Eigen::Map<ArrayXT>(it, 1, num_feats) - cur_sum.row(m) / cmn_window_frames;
|
||||
if (norm_vars) {
|
||||
if (cmn_window_frames == 1) {
|
||||
auto cmn_it_1 = reinterpret_cast<T *>(const_cast<uchar *>((*cmn_waveform_p)->GetBuffer()));
|
||||
cmn_it_1 += (m * num_frames * num_feats + i * num_feats);
|
||||
Eigen::Map<ArrayXT>(cmn_it_1, 1, num_feats).setZero();
|
||||
} else {
|
||||
auto variance = (Eigen::Map<ArrayXT>(cur_sum_sq.data(), num_channels, num_feats) / cmn_window_frames) -
|
||||
(cur_sum.pow(2) / std::pow(cmn_window_frames, 2));
|
||||
auto cmn_it_2 = reinterpret_cast<T *>(const_cast<uchar *>((*cmn_waveform_p)->GetBuffer()));
|
||||
cmn_it_2 += (m * num_frames * num_feats + i * num_feats);
|
||||
Eigen::Map<ArrayXT>(cmn_it_2, 1, num_feats) =
|
||||
Eigen::Map<ArrayXT>(cmn_it_2, 1, num_feats) * (1 / variance.sqrt()).row(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
last_window_start = cmn_window_start;
|
||||
last_window_end = cmn_window_end;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SlidingWindowCmnHelper(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t cmn_window,
|
||||
int32_t min_cmn_window, bool center, bool norm_vars) {
|
||||
int32_t num_frames = input->shape()[Tensor::HandleNeg(-2, input->shape().Size())];
|
||||
int32_t num_feats = input->shape()[Tensor::HandleNeg(-1, input->shape().Size())];
|
||||
|
||||
int32_t first_index = 1;
|
||||
std::vector<dsize_t> input_shape = input->shape().AsVector();
|
||||
std::for_each(input_shape.begin(), input_shape.end(), [&first_index](const dsize_t &item) { first_index *= item; });
|
||||
RETURN_IF_NOT_OK(
|
||||
input->Reshape(TensorShape({static_cast<int>(first_index / (num_frames * num_feats)), num_frames, num_feats})));
|
||||
|
||||
int32_t num_channels = static_cast<int32_t>(input->shape()[0]);
|
||||
TensorPtr cmn_waveform;
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateEmpty(TensorShape({num_channels, num_frames, num_feats}), input->type(), &cmn_waveform));
|
||||
RETURN_IF_NOT_OK(ComputeCmnWaveform<T>(input, &cmn_waveform, num_channels, num_frames, num_feats, cmn_window,
|
||||
min_cmn_window, center, norm_vars));
|
||||
|
||||
std::vector<dsize_t> re_shape = input_shape;
|
||||
auto r_it = re_shape.rbegin();
|
||||
*r_it++ = num_feats;
|
||||
*r_it = num_frames;
|
||||
RETURN_IF_NOT_OK(cmn_waveform->Reshape(TensorShape(re_shape)));
|
||||
|
||||
constexpr int specify_input_shape = 2;
|
||||
constexpr int specify_first_shape = 1;
|
||||
if (input_shape.size() == specify_input_shape && cmn_waveform->shape()[0] == specify_first_shape) {
|
||||
cmn_waveform->Squeeze();
|
||||
}
|
||||
*output = cmn_waveform;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SlidingWindowCmn(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t cmn_window,
|
||||
int32_t min_cmn_window, bool center, bool norm_vars) {
|
||||
TensorShape input_shape = input->shape();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() >= kMinAudioRank,
|
||||
"SlidingWindowCmn: input tensor is not in shape of <..., freq, time>.");
|
||||
|
||||
if (input->type().IsNumeric() && input->type().value() != DataType::DE_FLOAT64) {
|
||||
std::shared_ptr<Tensor> temp;
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &temp, DataType(DataType::DE_FLOAT32)));
|
||||
RETURN_IF_NOT_OK(SlidingWindowCmnHelper<float>(temp, output, cmn_window, min_cmn_window, center, norm_vars));
|
||||
} else if (input->type().value() == DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK(SlidingWindowCmnHelper<double>(input, output, cmn_window, min_cmn_window, center, norm_vars));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("SlidingWindowCmn: input tensor type should be int, float or double, but got: " +
|
||||
input->type().ToString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
constexpr double PI = 3.141592653589793;
|
||||
constexpr int kMinAudioRank = 2;
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -143,6 +144,7 @@ Status Contrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
auto itr_out = out->begin<T>();
|
||||
for (auto itr_in = input->begin<T>(); itr_in != input->end<T>(); itr_in++) {
|
||||
T temp1, temp2 = 0;
|
||||
// PI / 2 is half of the constant PI
|
||||
temp1 = static_cast<T>(*itr_in) * (PI / 2);
|
||||
temp2 = enhancement_amount_value * std::sin(temp1 * 4);
|
||||
*itr_out = std::sin(temp1 + temp2);
|
||||
|
@ -261,10 +263,10 @@ Status LFilter(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *ou
|
|||
m_py[m_num_order] -= a_coeffs[j] * m_py[m_num_order - j];
|
||||
}
|
||||
if (clamp) {
|
||||
if (m_py[m_num_order] > static_cast<T>(1.))
|
||||
out_vect[i] = static_cast<T>(1.);
|
||||
else if (m_py[m_num_order] < static_cast<T>(-1.))
|
||||
out_vect[i] = static_cast<T>(-1.);
|
||||
if (m_py[m_num_order] > static_cast<T>(1))
|
||||
out_vect[i] = static_cast<T>(1);
|
||||
else if (m_py[m_num_order] < static_cast<T>(-1))
|
||||
out_vect[i] = static_cast<T>(-1);
|
||||
else
|
||||
out_vect[i] = m_py[m_num_order];
|
||||
} else {
|
||||
|
@ -386,8 +388,10 @@ Status Overdrive(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
|
|||
T temp_fp2 = temp_fp * gain_ex + color;
|
||||
// 0.5 + 2/3 * 0.75 = 1, zoom and shift the sound.
|
||||
if (temp_fp2 < -1) {
|
||||
// -2.0 / 3.0 is -2/3 in the formula.
|
||||
temp.push_back(-2.0 / 3.0);
|
||||
} else if (temp_fp2 > 1) {
|
||||
// 2.0 / 3.0 is 2/3 in the formula.
|
||||
temp.push_back(2.0 / 3.0);
|
||||
} else {
|
||||
temp.push_back(temp_fp2 - temp_fp2 * temp_fp2 * temp_fp2 / 3.0);
|
||||
|
@ -824,6 +828,7 @@ std::vector<std::vector<T>> FlangerInterpolation(const std::shared_ptr<Tensor> &
|
|||
for (int k = 0; k < n_channels; k++) {
|
||||
delayed_value_c[j][k] = delayed_value_c[j][k] - delayed_value_a[j][k];
|
||||
delayed_value_b[j][k] = delayed_value_b[j][k] - delayed_value_a[j][k];
|
||||
// delayed_value_c[j][k] * 0.5 is half of the delayed_value_c[j][k]
|
||||
frac_delay_coefficient[j][k] = delayed_value_c[j][k] * 0.5 - delayed_value_b[j][k];
|
||||
frac_delay_value[j][k] = delayed_value_b[j][k] * 2 - delayed_value_c[j][k] * 0.5;
|
||||
// the next delay is obtained by delaying the data in the buffer
|
||||
|
@ -1014,6 +1019,17 @@ struct WavHeader {
|
|||
/// \param sample_rate: sample rate.
|
||||
/// \return Status code.
|
||||
Status ReadWaveFile(const std::string &wav_file_dir, std::vector<float> *waveform_vec, int32_t *sample_rate);
|
||||
|
||||
/// \brief Apply sliding-window cepstral mean and variance (optional) normalization per utterance.
|
||||
/// \param input: Tensor of shape <..., freq, time>.
|
||||
/// \param output: Tensor of shape <..., frame>.
|
||||
/// \param cmn_window: Window in frames for running average CMN computation.
|
||||
/// \param min_cmn_window: Minimum CMN window used at start of decoding.
|
||||
/// \param center: If true, use a window centered on the current frame. If false, window is to the left.
|
||||
/// \param norm_vars: If true, normalize variance to one.
|
||||
/// \return Status code.
|
||||
Status SlidingWindowCmn(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t cmn_window,
|
||||
int32_t min_cmn_window, bool center, bool norm_vars);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/sliding_window_cmn_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status SlidingWindowCmnOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
RETURN_IF_NOT_OK(SlidingWindowCmn(input, output, cmn_window_, min_cmn_window_, center_, norm_vars_));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_SLIDING_WINDOW_CMN_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_SLIDING_WINDOW_CMN_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class SlidingWindowCmnOp : public TensorOp {
|
||||
public:
|
||||
/// \brief Constructor of SlidingWindowCmnOp.
|
||||
/// \param[in] cmn_window - The window in frames for running average CMN computation.
|
||||
/// \param[in] min_cmn_window - The minimum CMN window. Only applicable if center is false, ignored if center==true.
|
||||
/// \param[in] center - If true, use a window centered on the current frame. If false, window is to the left.
|
||||
/// \param[in] norm_vars - If true, normalize variance to one.
|
||||
SlidingWindowCmnOp(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars)
|
||||
: cmn_window_(cmn_window), min_cmn_window_(min_cmn_window), center_(center), norm_vars_(norm_vars) {}
|
||||
|
||||
/// \brief Destructor of SlidingWindowCmnOp.
|
||||
~SlidingWindowCmnOp() override = default;
|
||||
|
||||
/// \brief Perform sliding window CMN to tensor.
|
||||
/// \param[in] input - Input tensor of Op.
|
||||
/// \param[out] output - Output tensor of Op.
|
||||
/// \return Status code.
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
/// \brief Print name of op.
|
||||
std::string Name() const override { return kSlidingWindowCmnOp; }
|
||||
|
||||
private:
|
||||
int32_t cmn_window_; // The window in frames for running average CMN computation.
|
||||
int32_t min_cmn_window_; // The minimum CMN window. Only applicable if center == false, ignored if center==true.
|
||||
bool center_; // If true, use a window centered on the current frame. If false, window is to the left.
|
||||
bool norm_vars_; // If true, normalize variance to one.
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_SLIDING_WINDOW_CMN_OP_H_
|
|
@ -664,6 +664,32 @@ class RiaaBiquad final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
|
||||
class SlidingWindowCmn final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor of SlidingWindowCmnOp.
|
||||
/// \param[in] cmn_window The window in frames for running average CMN computation (Default: 600).
|
||||
/// \param[in] min_cmn_window The minimum CMN window. Only applicable if center is false, ignored if center
|
||||
/// is true (Default: 100).
|
||||
/// \param[in] center If true, use a window centered on the current frame. If false, window is to the left
|
||||
/// (Default: false).
|
||||
/// \param[in] norm_vars If true, normalize variance to one (Default: false).
|
||||
explicit SlidingWindowCmn(int32_t cmn_window = 600, int32_t min_cmn_window = 100, bool center = false,
|
||||
bool norm_vars = false);
|
||||
|
||||
/// \brief Destructor.
|
||||
~SlidingWindowCmn() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief TimeMasking TensorTransform.
|
||||
/// \notes Apply masking to a spectrogram in the time domain.
|
||||
class TimeMasking final : public TensorTransform {
|
||||
|
|
|
@ -171,6 +171,7 @@ constexpr char kMuLawEncodingOp[] = "MuLawEncodingOp";
|
|||
constexpr char kOverdriveOp[] = "OverdriveOp";
|
||||
constexpr char kPhaserOp[] = "PhaserOp";
|
||||
constexpr char kRiaaBiquadOp[] = "RiaaBiquadOp";
|
||||
constexpr char kSlidingWindowCmnOp[] = "SlidingWindowCmnOp";
|
||||
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
|
||||
constexpr char kTimeStretchOp[] = "TimeStretchOp";
|
||||
constexpr char kTrebleBiquadOp[] = "TrebleBiquadOp";
|
||||
|
|
|
@ -28,8 +28,8 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_
|
|||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, \
|
||||
check_db_to_amplitude, check_dc_shift, check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, \
|
||||
check_fade, check_flanger, check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, \
|
||||
check_masking, check_mu_law_coding, check_overdrive, check_phaser, check_riaa_biquad, check_time_stretch, \
|
||||
check_treble_biquad, check_vol
|
||||
check_masking, check_mu_law_coding, check_overdrive, check_phaser, check_riaa_biquad, check_sliding_window_cmn, \
|
||||
check_time_stretch, check_treble_biquad, check_vol
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -870,6 +870,38 @@ class RiaaBiquad(AudioTensorOperation):
|
|||
return cde.RiaaBiquadOperation(self.sample_rate)
|
||||
|
||||
|
||||
class SlidingWindowCmn(AudioTensorOperation):
|
||||
"""
|
||||
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
|
||||
|
||||
Args:
|
||||
cmn_window (int, optional): Window in frames for running average CMN computation (default=600).
|
||||
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
|
||||
Only applicable if center is False, ignored if center is True (default=100).
|
||||
center (bool, optional): If True, use a window centered on the current frame. If False, window is
|
||||
to the left. (default=False).
|
||||
norm_vars (bool, optional): If True, normalize variance to one. (default=False).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[[1, 2, 3], [4, 5, 6]]], dtype=np.float64)
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.SlidingWindowCmn()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_sliding_window_cmn
|
||||
def __init__(self, cmn_window=600, min_cmn_window=100, center=False, norm_vars=False):
|
||||
self.cmn_window = cmn_window
|
||||
self.min_cmn_window = min_cmn_window
|
||||
self.center = center
|
||||
self.norm_vars = norm_vars
|
||||
|
||||
def parse(self):
|
||||
return cde.SlidingWindowCmnOperation(self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
|
||||
|
||||
|
||||
class TimeMasking(AudioTensorOperation):
|
||||
"""
|
||||
Apply masking to a spectrogram in the time domain.
|
||||
|
|
|
@ -556,3 +556,23 @@ def check_flanger(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_sliding_window_cmn(method):
|
||||
"""Wrapper method to check the parameters of SlidingWidowCmn."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[cmn_window, min_cmn_window, center, norm_vars], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(cmn_window, (int,), "cmn_window")
|
||||
check_non_negative_int32(cmn_window, "cmn_window")
|
||||
|
||||
type_check(min_cmn_window, (int,), "min_cmn_window")
|
||||
check_non_negative_int32(min_cmn_window, "min_cmn_window")
|
||||
|
||||
type_check(center, (bool,), "center")
|
||||
type_check(norm_vars, (bool,), "norm_vars")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -203,6 +203,58 @@ TEST_F(MindDataTestPipeline, TestRiaaBiquadWrongArg) {
|
|||
EXPECT_EQ(iter01, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: SlidingWindowCmn
|
||||
/// Description: test basic function of SlidingWindowCmn
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestPipeline, TestSlidingWindowCmn) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSlidingWindowCmn.";
|
||||
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {1, 2, 400}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
auto sliding_window_cmn = audio::SlidingWindowCmn(600, 100, false, false);
|
||||
auto ds1 = ds->Map({sliding_window_cmn});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
EXPECT_EQ(i, 8);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: SlidingWindowCmn
|
||||
/// Description: test wrong input args of SlidingWindowCmn
|
||||
/// Expectation: get nullptr of iterator
|
||||
TEST_F(MindDataTestPipeline, TestSlidingWindowCmnWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestSlidingWindowCmnWrongArgs.";
|
||||
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {1, 2, 400}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// SlidingWindowCmn: cmn_window must be greater than or equal to 0.
|
||||
auto sliding_window_cmn_1 = audio::SlidingWindowCmn(-1, 100, false, false);
|
||||
auto ds_1 = ds->Map({sliding_window_cmn_1});
|
||||
EXPECT_NE(ds_1, nullptr);
|
||||
std::shared_ptr<Iterator> iter_1 = ds_1->CreateIterator();
|
||||
EXPECT_EQ(iter_1, nullptr);
|
||||
|
||||
// SlidingWindowCmn: min_cmn_window must be greater than or equal to 0.
|
||||
auto sliding_window_cmn_2 = audio::SlidingWindowCmn(600, -1, false, false);
|
||||
auto ds2 = ds->Map({sliding_window_cmn_2});
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
std::shared_ptr<Iterator> iter_2 = ds2->CreateIterator();
|
||||
EXPECT_EQ(iter_2, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestTimeMaskingPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestTimeMaskingPipeline.";
|
||||
// Original waveform
|
||||
|
|
|
@ -1502,3 +1502,65 @@ TEST_F(MindDataTestExecute, TestDBToAmplitudeWithEager) {
|
|||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: SlidingWindowCmn
|
||||
/// Description: test basic function of SlidingWindowCmn
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestExecute, TestSlidingWindowCmn) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestSlidingWindowCmn.";
|
||||
|
||||
std::shared_ptr<Tensor> input_tensor_;
|
||||
int32_t cmn_window = 500;
|
||||
int32_t min_cmn_window = 50;
|
||||
bool center = false;
|
||||
bool norm_vars = false;
|
||||
|
||||
// create tensor shape
|
||||
TensorShape s = TensorShape({2, 2, 500});
|
||||
// init input vector
|
||||
std::vector<float> input_vec(s.NumOfElements());
|
||||
for (int idx = 0; idx < input_vec.size(); ++idx) {
|
||||
input_vec[idx] = std::rand() % (1000) / (1000.0f);
|
||||
}
|
||||
ASSERT_OK(Tensor::CreateFromVector(input_vec, s, &input_tensor_));
|
||||
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||
std::shared_ptr<TensorTransform> sliding_window_cmn_op =
|
||||
std::make_shared<audio::SlidingWindowCmn>(cmn_window, min_cmn_window, center, norm_vars);
|
||||
|
||||
// apply sliding_window_cmn
|
||||
mindspore::dataset::Execute Transform({sliding_window_cmn_op});
|
||||
Status status = Transform(input_ms, &input_ms);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: SlidingWindowCmn
|
||||
/// Description: test wrong input args of SlidingWindowCmn
|
||||
/// Expectation: get nullptr of iterator
|
||||
TEST_F(MindDataTestExecute, TestSlidingWindowCmnWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestSlidingWindowCmnWrongArgs.";
|
||||
|
||||
std::shared_ptr<Tensor> input_tensor_;
|
||||
// create tensor shape
|
||||
TensorShape s = TensorShape({2, 2, 500});
|
||||
// init input vector
|
||||
std::vector<float> input_vec(s.NumOfElements());
|
||||
for (int idx = 0; idx < input_vec.size(); ++idx) {
|
||||
input_vec[idx] = std::rand() % (1000) / (1000.0f);
|
||||
}
|
||||
ASSERT_OK(Tensor::CreateFromVector(input_vec, s, &input_tensor_));
|
||||
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||
|
||||
// SlidingWindowCmn: cmn_window must be greater than or equal to 0.
|
||||
std::shared_ptr<TensorTransform> sliding_window_cmn_op_1 =
|
||||
std::make_shared<audio::SlidingWindowCmn>(-1, 100, false, false);
|
||||
mindspore::dataset::Execute Transform_1({sliding_window_cmn_op_1});
|
||||
Status status_1 = Transform_1(input_ms, &input_ms);
|
||||
EXPECT_FALSE(status_1.IsOk());
|
||||
|
||||
// SlidingWindowCmn: min_cmn_window must be greater than or equal to 0.
|
||||
std::shared_ptr<TensorTransform> sliding_window_cmn_op_2 =
|
||||
std::make_shared<audio::SlidingWindowCmn>(500, -1, false, false);
|
||||
mindspore::dataset::Execute Transform_2({sliding_window_cmn_op_2});
|
||||
Status status_2 = Transform_2(input_ms, &input_ms);
|
||||
EXPECT_FALSE(status_2.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_sliding_window_cmn_eager():
|
||||
"""
|
||||
Feature: test the basic function in eager mode.
|
||||
Description: mindspore eager mode normal testcase:sliding_window_cmn op.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
# Original waveform
|
||||
waveform_1 = np.array([[[0.0000, 0.1000, 0.2000],
|
||||
[0.3000, 0.4000, 0.5000]],
|
||||
[[0.6000, 0.7000, 0.8000],
|
||||
[0.9000, 1.0000, 1.1000]]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform_1 = np.array([[[-0.1500, -0.1500, -0.1500],
|
||||
[0.1500, 0.1500, 0.1500]],
|
||||
[[-0.1500, -0.1500, -0.1500],
|
||||
[0.1500, 0.1500, 0.1500]]], dtype=np.float64)
|
||||
sliding_window_cmn_op_1 = audio.SlidingWindowCmn(500, 200, False, False)
|
||||
# Filtered waveform by sliding_window_cmn
|
||||
output_1 = sliding_window_cmn_op_1(waveform_1)
|
||||
count_unequal_element(expect_waveform_1, output_1, 0.0001, 0.0001)
|
||||
|
||||
# Original waveform
|
||||
waveform_2 = np.array([[0.0050, 0.0306, 0.6146, 0.7620, 0.6369],
|
||||
[0.9525, 0.0362, 0.6721, 0.6867, 0.8466]], dtype=np.float32)
|
||||
# Expect waveform
|
||||
expect_waveform_2 = np.array([[-1.0000, -1.0000, -1.0000, 1.0000, -1.0000],
|
||||
[1.0000, 1.0000, 1.0000, -1.0000, 1.0000]], dtype=np.float32)
|
||||
sliding_window_cmn_op_2 = audio.SlidingWindowCmn(600, 100, False, True)
|
||||
# Filtered waveform by sliding_window_cmn
|
||||
output_2 = sliding_window_cmn_op_2(waveform_2)
|
||||
count_unequal_element(expect_waveform_2, output_2, 0.0001, 0.0001)
|
||||
|
||||
# Original waveform
|
||||
waveform_3 = np.array([[[0.3764, 0.4168, 0.0635, 0.7082, 0.4596, 0.3457, 0.8438, 0.8860, 0.9151, 0.5746,
|
||||
0.6630, 0.0260, 0.2631, 0.7410, 0.5627, 0.6749, 0.7099, 0.1120, 0.4794, 0.2778],
|
||||
[0.4157, 0.2246, 0.2488, 0.2686, 0.0562, 0.4422, 0.9407, 0.0756, 0.5737, 0.7501,
|
||||
0.3122, 0.7982, 0.3034, 0.1880, 0.2298, 0.0961, 0.7439, 0.9947, 0.8156, 0.2907]]],
|
||||
dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform_3 = np.array([[[-1.0000, 1.0000, -1.0000, 1.0000, 1.0000, -1.0000, -1.0000, 1.0000,
|
||||
1.0000, -1.0000, 1.0000, -1.0000, -1.0000, 1.0000, 1.0000, 1.0000,
|
||||
-1.0000, -1.0000, -1.0000, -1.0000],
|
||||
[1.0000, -1.0000, 1.0000, -1.0000, -1.0000, 1.0000, 1.0000, -1.0000,
|
||||
-1.0000, 1.0000, -1.0000, 1.0000, 1.0000, -1.0000, -1.0000, -1.0000,
|
||||
1.0000, 1.0000, 1.0000, 1.0000]]], dtype=np.float64)
|
||||
sliding_window_cmn_op_3 = audio.SlidingWindowCmn(3, 0, True, True)
|
||||
# Filtered waveform by sliding_window_cmn
|
||||
output_3 = sliding_window_cmn_op_3(waveform_3)
|
||||
count_unequal_element(expect_waveform_3, output_3, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_sliding_window_cmn_pipeline():
|
||||
"""
|
||||
Feature: test the basic function in pipeline mode.
|
||||
Description: mindspore pipeline mode normal testcase:sliding_window_cmn op.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
# Original waveform
|
||||
waveform = np.array([[[3.2, 2.1, 1.3], [6.2, 5.3, 6]]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[[-1.0000, -1.0000, -1.0000],
|
||||
[1.0000, 1.0000, 1.0000]]], dtype=np.float64)
|
||||
dataset = ds.NumpySlicesDataset(waveform, ["audio"], shuffle=False)
|
||||
sliding_window_cmn_op = audio.SlidingWindowCmn(600, 100, False, True)
|
||||
# Filtered waveform by sliding_window_cmn
|
||||
dataset = dataset.map(input_columns=["audio"], operations=sliding_window_cmn_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['audio'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
||||
def test_sliding_window_cmn_invalid_input():
|
||||
"""
|
||||
Feature: test the validate function with invalid parameters.
|
||||
Description: mindspore invalid parameters testcase:sliding_window_cmn op.
|
||||
Expectation: compile done without error.
|
||||
"""
|
||||
def test_invalid_input(test_name, cmn_window, min_cmn_window, center, norm_vars, error, error_msg):
|
||||
logger.info("Test SlidingWindowCmn with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.SlidingWindowCmn(cmn_window, min_cmn_window, center, norm_vars)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid cmn_window parameter type as a String", "600", 100, False, False, TypeError,
|
||||
"Argument cmn_window with value 600 is not of type [<class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid cmn_window parameter value", 441324343243242342345300, 100, False, False, ValueError,
|
||||
"Input cmn_window is not within the required interval of [0, 2147483647].")
|
||||
test_invalid_input("invalid min_cmn_window parameter type as a String", 600, "100", False, False, TypeError,
|
||||
"Argument min_cmn_window with value 100 is not of type [<class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid min_cmn_window parameter value", 600, 441324343243242342345300, False, False,
|
||||
ValueError, "Input min_cmn_window is not within the required interval of [0, 2147483647].")
|
||||
test_invalid_input("invalid center parameter type as a String", 600, 100, "False", False, TypeError,
|
||||
"Argument center with value False is not of type [<class 'bool'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid norm_vars parameter type as a String", 600, 100, False, "False", TypeError,
|
||||
"Argument norm_vars with value False is not of type [<class 'bool'>],"
|
||||
" but got <class 'str'>.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sliding_window_cmn_eager()
|
||||
test_sliding_window_cmn_pipeline()
|
||||
test_sliding_window_cmn_invalid_input()
|
Loading…
Reference in New Issue