commit
d7ac70d50a
|
@ -58,6 +58,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/vad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/vol_ir.h"
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
@ -863,6 +864,65 @@ std::shared_ptr<TensorOperation> TrebleBiquad::Parse() {
|
|||
return std::make_shared<TrebleBiquadOperation>(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_);
|
||||
}
|
||||
|
||||
// Vad Transform Operation.
|
||||
struct Vad::Data {
|
||||
Data(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap,
|
||||
float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time,
|
||||
float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time,
|
||||
float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq)
|
||||
: sample_rate_(sample_rate),
|
||||
trigger_level_(trigger_level),
|
||||
trigger_time_(trigger_time),
|
||||
search_time_(search_time),
|
||||
allowed_gap_(allowed_gap),
|
||||
pre_trigger_time_(pre_trigger_time),
|
||||
boot_time_(boot_time),
|
||||
noise_up_time_(noise_up_time),
|
||||
noise_down_time_(noise_down_time),
|
||||
noise_reduction_amount_(noise_reduction_amount),
|
||||
measure_freq_(measure_freq),
|
||||
measure_duration_(measure_duration),
|
||||
measure_smooth_time_(measure_smooth_time),
|
||||
hp_filter_freq_(hp_filter_freq),
|
||||
lp_filter_freq_(lp_filter_freq),
|
||||
hp_lifter_freq_(hp_lifter_freq),
|
||||
lp_lifter_freq_(lp_lifter_freq) {}
|
||||
int32_t sample_rate_;
|
||||
float trigger_level_;
|
||||
float trigger_time_;
|
||||
float search_time_;
|
||||
float allowed_gap_;
|
||||
float pre_trigger_time_;
|
||||
float boot_time_;
|
||||
float noise_up_time_;
|
||||
float noise_down_time_;
|
||||
float noise_reduction_amount_;
|
||||
float measure_freq_;
|
||||
float measure_duration_;
|
||||
float measure_smooth_time_;
|
||||
float hp_filter_freq_;
|
||||
float lp_filter_freq_;
|
||||
float hp_lifter_freq_;
|
||||
float lp_lifter_freq_;
|
||||
};
|
||||
|
||||
Vad::Vad(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap,
|
||||
float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time,
|
||||
float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time,
|
||||
float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq)
|
||||
: data_(std::make_shared<Data>(sample_rate, trigger_level, trigger_time, search_time, allowed_gap, pre_trigger_time,
|
||||
boot_time, noise_up_time, noise_down_time, noise_reduction_amount, measure_freq,
|
||||
measure_duration, measure_smooth_time, hp_filter_freq, lp_filter_freq,
|
||||
hp_lifter_freq, lp_lifter_freq)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Vad::Parse() {
|
||||
return std::make_shared<VadOperation>(
|
||||
data_->sample_rate_, data_->trigger_level_, data_->trigger_time_, data_->search_time_, data_->allowed_gap_,
|
||||
data_->pre_trigger_time_, data_->boot_time_, data_->noise_up_time_, data_->noise_down_time_,
|
||||
data_->noise_reduction_amount_, data_->measure_freq_, data_->measure_duration_, data_->measure_smooth_time_,
|
||||
data_->hp_filter_freq_, data_->lp_filter_freq_, data_->hp_lifter_freq_, data_->lp_lifter_freq_);
|
||||
}
|
||||
|
||||
// Vol Transform Operation.
|
||||
struct Vol::Data {
|
||||
Data(float gain, GainType gain_type) : gain_(gain), gain_type_(gain_type) {}
|
||||
|
|
|
@ -62,6 +62,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/vad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/vol_ir.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
@ -590,6 +591,23 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(VadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::VadOperation, TensorOperation, std::shared_ptr<audio::VadOperation>>(
|
||||
*m, "VadOperation")
|
||||
.def(py::init([](int32_t sample_rate, float trigger_level, float trigger_time, float search_time,
|
||||
float allowed_gap, float pre_trigger_time, float boot_time, float noise_up_time,
|
||||
float noise_down_time, float noise_reduction_amount, float measure_freq,
|
||||
float measure_duration, float measure_smooth_time, float hp_filter_freq,
|
||||
float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq) {
|
||||
auto vad = std::make_shared<audio::VadOperation>(
|
||||
sample_rate, trigger_level, trigger_time, search_time, allowed_gap, pre_trigger_time, boot_time,
|
||||
noise_up_time, noise_down_time, noise_reduction_amount, measure_freq, measure_duration,
|
||||
measure_smooth_time, hp_filter_freq, lp_filter_freq, hp_lifter_freq, lp_lifter_freq);
|
||||
THROW_IF_ERROR(vad->ValidateParams());
|
||||
return vad;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(VolOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::VolOperation, TensorOperation, std::shared_ptr<audio::VolOperation>>(
|
||||
*m, "VolOperation")
|
||||
|
|
|
@ -44,6 +44,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
time_masking_ir.cc
|
||||
time_stretch_ir.cc
|
||||
treble_biquad_ir.cc
|
||||
vad_ir.cc
|
||||
vol_ir.cc
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/vad_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/vad_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
// Vad
|
||||
VadOperation::VadOperation(int32_t sample_rate, float trigger_level, float trigger_time, float search_time,
|
||||
float allowed_gap, float pre_trigger_time, float boot_time, float noise_up_time,
|
||||
float noise_down_time, float noise_reduction_amount, float measure_freq,
|
||||
float measure_duration, float measure_smooth_time, float hp_filter_freq,
|
||||
float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq)
|
||||
: sample_rate_(sample_rate),
|
||||
trigger_level_(trigger_level),
|
||||
trigger_time_(trigger_time),
|
||||
search_time_(search_time),
|
||||
allowed_gap_(allowed_gap),
|
||||
pre_trigger_time_(pre_trigger_time),
|
||||
boot_time_(boot_time),
|
||||
noise_up_time_(noise_up_time),
|
||||
noise_down_time_(noise_down_time),
|
||||
noise_reduction_amount_(noise_reduction_amount),
|
||||
measure_freq_(measure_freq),
|
||||
measure_duration_(measure_duration),
|
||||
measure_smooth_time_(measure_smooth_time),
|
||||
hp_filter_freq_(hp_filter_freq),
|
||||
lp_filter_freq_(lp_filter_freq),
|
||||
hp_lifter_freq_(hp_lifter_freq),
|
||||
lp_lifter_freq_(lp_lifter_freq) {}
|
||||
|
||||
VadOperation::~VadOperation() = default;
|
||||
|
||||
std::string VadOperation::Name() const { return kVadOperation; }
|
||||
|
||||
Status VadOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("Vad", "sample_rate", sample_rate_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "trigger_time", trigger_time_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "search_time", search_time_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "allowed_gap", allowed_gap_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "pre_trigger_time", pre_trigger_time_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "boot_time", boot_time_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "noise_up_time", noise_up_time_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "noise_down_time", noise_down_time_));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(noise_down_time_ <= noise_up_time_,
|
||||
"Vad: noise_up_time must be greater than or equal to noise_down_time.");
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "noise_reduction_amount", noise_reduction_amount_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("Vad", "measure_freq", measure_freq_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "measure_duration", measure_duration_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Vad", "measure_smooth_time", measure_smooth_time_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("Vad", "hp_filter_freq", hp_filter_freq_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("Vad", "lp_filter_freq", lp_filter_freq_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("Vad", "hp_lifter_freq", hp_lifter_freq_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("Vad", "lp_lifter_freq", lp_lifter_freq_));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> VadOperation::Build() {
|
||||
float measure_duration = measure_duration_ == 0 ? 2.0 / measure_freq_ : measure_duration_;
|
||||
std::shared_ptr<VadOp> tensor_op = std::make_shared<VadOp>(
|
||||
sample_rate_, trigger_level_, trigger_time_, search_time_, allowed_gap_, pre_trigger_time_, boot_time_,
|
||||
noise_up_time_, noise_down_time_, noise_reduction_amount_, measure_freq_, measure_duration, measure_smooth_time_,
|
||||
hp_filter_freq_, lp_filter_freq_, hp_lifter_freq_, lp_lifter_freq_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status VadOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sample_rate"] = sample_rate_;
|
||||
args["trigger_level"] = trigger_level_;
|
||||
args["trigger_time"] = trigger_time_;
|
||||
args["search_time"] = search_time_;
|
||||
args["allowed_gap"] = allowed_gap_;
|
||||
args["pre_trigger_time"] = pre_trigger_time_;
|
||||
args["boot_time"] = boot_time_;
|
||||
args["noise_up_time"] = noise_up_time_;
|
||||
args["noise_down_time"] = noise_down_time_;
|
||||
args["noise_reduction_amount"] = noise_reduction_amount_;
|
||||
args["measure_freq"] = measure_freq_;
|
||||
args["measure_duration"] = measure_duration_;
|
||||
args["measure_smooth_time"] = measure_smooth_time_;
|
||||
args["hp_filter_freq"] = hp_filter_freq_;
|
||||
args["lp_filter_freq"] = lp_filter_freq_;
|
||||
args["hp_lifter_freq"] = hp_lifter_freq_;
|
||||
args["lp_lifter_freq"] = lp_lifter_freq_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,71 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_VAD_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_VAD_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kVadOperation[] = "Vad";
|
||||
|
||||
class VadOperation : public TensorOperation {
|
||||
public:
|
||||
VadOperation(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap,
|
||||
float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time,
|
||||
float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time,
|
||||
float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq);
|
||||
|
||||
~VadOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t sample_rate_;
|
||||
float trigger_level_;
|
||||
float trigger_time_;
|
||||
float search_time_;
|
||||
float allowed_gap_;
|
||||
float pre_trigger_time_;
|
||||
float boot_time_;
|
||||
float noise_up_time_;
|
||||
float noise_down_time_;
|
||||
float noise_reduction_amount_;
|
||||
float measure_freq_;
|
||||
float measure_duration_;
|
||||
float measure_smooth_time_;
|
||||
float hp_filter_freq_;
|
||||
float lp_filter_freq_;
|
||||
float hp_lifter_freq_;
|
||||
float lp_lifter_freq_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_VAD_IR_H_
|
|
@ -45,6 +45,7 @@ add_library(audio-kernels OBJECT
|
|||
time_masking_op.cc
|
||||
time_stretch_op.cc
|
||||
treble_biquad_op.cc
|
||||
vad_op.cc
|
||||
vol_op.cc
|
||||
)
|
||||
|
||||
|
|
|
@ -1725,6 +1725,306 @@ Status GriffinLim(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
|
|||
int32_t win_length, int32_t hop_length, WindowType window_type, float power, float momentum,
|
||||
int32_t length, bool rand_init, std::mt19937 rnd);
|
||||
|
||||
/// \brief Calculate measure for VAD.
|
||||
template <typename T>
|
||||
Status Measure(int32_t measure_len_ws, int32_t index, const Eigen::MatrixXd &samples, Eigen::MatrixXd *spectrum,
|
||||
Eigen::MatrixXd *noise_spectrum, const std::vector<T> &spectrum_window, int32_t spectrum_start,
|
||||
int32_t spectrum_end, const std::vector<T> &cepstrum_window, int32_t cepstrum_start,
|
||||
int32_t cepstrum_end, T noise_reduction_amount, T measure_smooth_time_mult, T noise_up_time_mult,
|
||||
T noise_down_time_mult, int32_t index_ns, int32_t boot_count, float *meas) {
|
||||
RETURN_UNEXPECTED_IF_NULL(spectrum);
|
||||
RETURN_UNEXPECTED_IF_NULL(noise_spectrum);
|
||||
RETURN_UNEXPECTED_IF_NULL(meas);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(spectrum->cols() == noise_spectrum->cols(),
|
||||
"Measure: the number of columns of spectrum must be equal to noise_spectrum.");
|
||||
int samples_len_ns = samples.cols();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_len_ns != 0, "Measure: the number of columns of samples cannot be zero.");
|
||||
int dft_len_ws = spectrum->cols();
|
||||
std::vector<float> dft_buf(dft_len_ws, 0);
|
||||
for (int ind = 0; ind < measure_len_ws; ind++) {
|
||||
int index_new = (ind == 0) ? index_ns : ((index_ns + ind) % samples_len_ns);
|
||||
dft_buf[ind] = samples(index, index_new) * spectrum_window[ind];
|
||||
}
|
||||
|
||||
// use fft from eigen to calculate rfft
|
||||
Eigen::FFT<float> fft;
|
||||
std::vector<std::complex<float>> rfft_res;
|
||||
fft.fwd(rfft_res, dft_buf);
|
||||
// truncate redundant information in fft
|
||||
int rfft_len = rfft_res.size() - (rfft_res.size() - 1) / 2;
|
||||
rfft_res.resize(rfft_len);
|
||||
|
||||
float mult = (boot_count >= 0) ? (static_cast<float>(boot_count) / (1.0 + boot_count))
|
||||
: static_cast<float>(measure_smooth_time_mult);
|
||||
std::vector<float> dft_buf_abs(spectrum_end - spectrum_start);
|
||||
for (int i = 0; i < spectrum_end - spectrum_start; i++) {
|
||||
auto dba_i = std::abs(rfft_res[i + spectrum_start]);
|
||||
// inplace revise spectrum
|
||||
(*spectrum)(index, spectrum_start + i) *= mult;
|
||||
(*spectrum)(index, spectrum_start + i) += dba_i * (1 - mult);
|
||||
|
||||
dba_i = std::pow((*spectrum)(index, spectrum_start + i), TWO);
|
||||
|
||||
float mult2 = 0;
|
||||
// new mult
|
||||
if (boot_count >= 0) {
|
||||
mult2 = 0;
|
||||
} else {
|
||||
if (dba_i > (*noise_spectrum)(index, spectrum_start + i)) {
|
||||
mult2 = noise_up_time_mult;
|
||||
} else {
|
||||
mult2 = noise_down_time_mult;
|
||||
}
|
||||
}
|
||||
// inplace revise noise spectrum
|
||||
(*noise_spectrum)(index, spectrum_start + i) *= mult2;
|
||||
(*noise_spectrum)(index, spectrum_start + i) += dba_i * (1 - mult2);
|
||||
|
||||
dba_i = dba_i - noise_reduction_amount * (*noise_spectrum)(index, spectrum_start + i);
|
||||
dba_i = dba_i <= 0 ? 0 : std::sqrt(dba_i);
|
||||
|
||||
dft_buf_abs.push_back(dba_i);
|
||||
}
|
||||
|
||||
// cepstrum_buf
|
||||
std::vector<float> cepstrum_buf(dft_len_ws >> 1, 0);
|
||||
for (int i = 0; i < spectrum_end - spectrum_start; i++) {
|
||||
cepstrum_buf[spectrum_start + i] = dft_buf_abs[i] * cepstrum_window[i];
|
||||
}
|
||||
std::vector<std::complex<float>> rfft_res2;
|
||||
fft.fwd(rfft_res2, cepstrum_buf);
|
||||
rfft_len = rfft_res2.size() - (rfft_res.size() - 1) / TWO;
|
||||
rfft_res2.resize(rfft_len);
|
||||
|
||||
float result = 0;
|
||||
for (int i = cepstrum_start; i < cepstrum_end; i++) {
|
||||
result += std::pow(std::abs(rfft_res2[i]), TWO);
|
||||
}
|
||||
|
||||
result = result > 0 ? std::log(result / (cepstrum_end - cepstrum_start)) : INT_MIN;
|
||||
int base = 21;
|
||||
*meas = static_cast<float>(std::max(0.0f, base + result));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Update parameters in VAD calculation.
|
||||
inline Status UpdateVadParams(int32_t *samples_index_ns, int32_t *pos, const int32_t samples_len_ns,
|
||||
int32_t *measure_timer_ns, const int32_t measure_period_ns, int32_t *measures_index,
|
||||
const int32_t measures_len, int32_t *boot_count, const int32_t boot_count_max,
|
||||
bool has_triggered, const int32_t num_measures_to_flush, int32_t *flushed_len_ns) {
|
||||
RETURN_UNEXPECTED_IF_NULL(samples_index_ns);
|
||||
RETURN_UNEXPECTED_IF_NULL(pos);
|
||||
RETURN_UNEXPECTED_IF_NULL(measure_timer_ns);
|
||||
RETURN_UNEXPECTED_IF_NULL(measures_index);
|
||||
RETURN_UNEXPECTED_IF_NULL(boot_count);
|
||||
RETURN_UNEXPECTED_IF_NULL(flushed_len_ns);
|
||||
|
||||
*samples_index_ns = *samples_index_ns + 1;
|
||||
*pos += 1;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(measures_len != 0, "UpdateVadParams: the length of measures cannot be zero.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(samples_len_ns != 0,
|
||||
"UpdateVadParams: the number of columns of samples cannot be zero.");
|
||||
*samples_index_ns = *samples_index_ns == samples_len_ns ? 0 : *samples_index_ns;
|
||||
if (*measure_timer_ns == 0) {
|
||||
*measure_timer_ns = measure_period_ns;
|
||||
*measures_index += 1;
|
||||
*measures_index = *measures_index % measures_len;
|
||||
*boot_count = (*boot_count >= 0) && (*boot_count == boot_count_max) ? -1 : *boot_count + 1;
|
||||
}
|
||||
if (has_triggered) {
|
||||
*flushed_len_ns = (measures_len - num_measures_to_flush) * measure_period_ns;
|
||||
*samples_index_ns = (*samples_index_ns + *flushed_len_ns) % samples_len_ns;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Init spectrum window and cepstrum window for VAD.
|
||||
inline Status FlushMeasures(int measures_len, int measures_index, const int row_ind, const Eigen::MatrixXd &measures,
|
||||
float trigger_level, int gap_len, int *num_measures_to_flush) {
|
||||
RETURN_UNEXPECTED_IF_NULL(num_measures_to_flush);
|
||||
|
||||
int n = measures_len, k = measures_index;
|
||||
int j_trigger = n, j_zero = n, j = 0;
|
||||
for (int j = 0; j < n; j++) {
|
||||
if ((measures(row_ind, k) >= trigger_level) && (j <= (j_trigger + gap_len))) {
|
||||
j_trigger = j;
|
||||
j_zero = j_trigger;
|
||||
} else if ((measures(row_ind, k) == 0) && (j_trigger >= j_zero)) {
|
||||
j_zero = j;
|
||||
}
|
||||
k = (k + n - 1) % n;
|
||||
}
|
||||
j = std::min(j, j_zero);
|
||||
*num_measures_to_flush = std::min(std::max((*num_measures_to_flush), j), n);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Init spectrum window and cepstrum window for Vad.
|
||||
template <typename T>
|
||||
Status InitWindows(std::vector<T> *spectrum_window, std::vector<T> *cepstrum_window) {
|
||||
RETURN_UNEXPECTED_IF_NULL(spectrum_window);
|
||||
RETURN_UNEXPECTED_IF_NULL(cepstrum_window);
|
||||
|
||||
float half = 0.5;
|
||||
for (int i = 0; i < spectrum_window->size(); i++) {
|
||||
auto hann = half - half * std::cos(static_cast<T>(i) / spectrum_window->size() * PI * TWO);
|
||||
(*spectrum_window)[i] *= hann;
|
||||
}
|
||||
|
||||
for (int i = 0; i < cepstrum_window->size(); i++) {
|
||||
auto hann = half - half * std::cos(static_cast<T>(i) / cepstrum_window->size() * PI * TWO);
|
||||
(*cepstrum_window)[i] *= hann;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Voice activity detector.
|
||||
/// \param input/output Tensor of shape <..., time>.
|
||||
/// \param sample_rate Sample rate of audio signal.
|
||||
/// \param trigger_level The measurement level used to trigger activity detection.
|
||||
/// \param trigger_time The time constant (in seconds) used to help ignore short sounds.
|
||||
/// \param search_time The amount of audio (in seconds) to search for quieter/shorter sounds to include prior to
|
||||
/// the detected trigger point.
|
||||
/// \param allowed_gap The allowed gap (in seconds) between quiteter/shorter sounds to include prior to the
|
||||
/// detected trigger point.
|
||||
/// \param pre_trigger_time The amount of audio (in seconds) to preserve before the trigger point and any found
|
||||
/// quieter/shorter bursts.
|
||||
/// \param boot_time The time for the initial noise estimate.
|
||||
/// \param noise_up_time Time constant used by the adaptive noise estimator, when the noise level is increasing.
|
||||
/// \param noise_down_time Time constant used by the adaptive noise estimator, when the noise level is decreasing.
|
||||
/// \param noise_reduction_amount The amount of noise reduction used in the detection algorithm.
|
||||
/// \param measure_freq The frequency of the algorithm’s processing.
|
||||
/// \param measure_duration The duration of measurement.
|
||||
/// \param measure_smooth_time The time constant used to smooth spectral measurements.
|
||||
/// \param hp_filter_freq The "Brick-wall" frequency of high-pass filter applied at the input to the detector
|
||||
/// algorithm.
|
||||
/// \param lp_filter_freq The "Brick-wall" frequency of low-pass filter applied at the input to the detector
|
||||
/// algorithm.
|
||||
/// \param hp_lifter_freq The "Brick-wall" frequency of high-pass lifter applied at the input to the detector
|
||||
/// algorithm.
|
||||
/// \param lp_lifter_freq The "Brick-wall" frequency of low-pass lifter applied at the input to the detector
|
||||
/// algorithm.
|
||||
/// \return Status code.
|
||||
template <typename T>
|
||||
Status Vad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int sample_rate, T trigger_level,
|
||||
T trigger_time, T search_time, T allowed_gap, T pre_trigger_time, T boot_time, T noise_up_time,
|
||||
T noise_down_time, T noise_reduction_amount, T measure_freq, T measure_duration, T measure_smooth_time,
|
||||
T hp_filter_freq, T lp_filter_freq, T hp_lifter_freq, T lp_lifter_freq) {
|
||||
const int measure_len_ws = static_cast<int>(sample_rate * measure_duration + 0.5);
|
||||
const int measure_len_ns = measure_len_ws;
|
||||
|
||||
int dft_len_ws = 16;
|
||||
for (; dft_len_ws < measure_len_ws;) {
|
||||
dft_len_ws *= TWO;
|
||||
}
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(measure_freq != 0, "Vad: measure_freq cannot be zero.");
|
||||
const int measure_period_ns = static_cast<int>(sample_rate / measure_freq + 0.5);
|
||||
const int measures_len = std::ceil(search_time * measure_freq);
|
||||
const int search_pre_trigger_len_ns = static_cast<int>(measures_len * measure_period_ns);
|
||||
const int gap_len = static_cast<int>(allowed_gap * measure_freq + 0.5);
|
||||
const int fixed_pre_trigger_len_ns = static_cast<int>(pre_trigger_time * sample_rate + 0.5);
|
||||
const int samples_len_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns;
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(sample_rate != 0, "Vad: sample_rate cannot be zero.");
|
||||
auto spectrum_start = static_cast<int>(hp_filter_freq / sample_rate * dft_len_ws + 0.5);
|
||||
spectrum_start = std::max(spectrum_start, 1);
|
||||
auto spectrum_end = static_cast<int>(lp_filter_freq / sample_rate * dft_len_ws + 0.5);
|
||||
spectrum_end = std::min(spectrum_end, dft_len_ws / TWO);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(spectrum_end > spectrum_start,
|
||||
"Vad: the end of spectrum must be greater than the start. Check if `hp_filter_freq` is "
|
||||
"too large or `lp_filter_freq` is too small.");
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(lp_lifter_freq != 0, "Vad: lp_lifter_freq cannot be zero.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(hp_lifter_freq != 0, "Vad: hp_lifter_freq cannot be zero.");
|
||||
int cepstrum_start = std::ceil(sample_rate * 0.5 / lp_lifter_freq);
|
||||
int cepstrum_end = std::floor(sample_rate * 0.5 / hp_lifter_freq);
|
||||
cepstrum_end = std::min(cepstrum_end, dft_len_ws / (TWO * TWO));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(cepstrum_end > cepstrum_start,
|
||||
"Vad: the end of cepstrum must be greater than the start. Check if `hp_lifter_freq` is "
|
||||
"too large or `lp_lifter_freq` is too small.");
|
||||
// init spectrum & cepstrum window
|
||||
std::vector<T> spectrum_window(measure_len_ws, static_cast<T>(TWO / std::sqrt(static_cast<T>(measure_len_ws))));
|
||||
std::vector<T> cepstrum_window(spectrum_end - spectrum_start,
|
||||
static_cast<T>(TWO / std::sqrt(static_cast<T>(spectrum_end) - spectrum_start)));
|
||||
RETURN_IF_NOT_OK(InitWindows<T>(&spectrum_window, &cepstrum_window));
|
||||
|
||||
T noise_up_time_mult = std::exp(-1.0 / (noise_up_time * measure_freq));
|
||||
T noise_down_time_mult = std::exp(-1.0 / (noise_down_time * measure_freq));
|
||||
T measure_smooth_time_mult = std::exp(-1.0 / (measure_smooth_time * measure_freq));
|
||||
T trigger_meas_time_mult = std::exp(-1.0 / (trigger_time * measure_freq));
|
||||
|
||||
auto boot_count_max = static_cast<int>(boot_time * measure_freq - 0.5);
|
||||
auto measure_timer_ns = measure_len_ns;
|
||||
int boot_count = 0, measures_index = 0, flushed_len_ns = 0, samples_index_ns = 0;
|
||||
|
||||
// pack batch
|
||||
TensorShape input_shape = input->shape();
|
||||
TensorShape to_shape({input->Size() / input_shape[-1], input_shape[-1]});
|
||||
RETURN_IF_NOT_OK(input->Reshape(to_shape));
|
||||
int n_channels = to_shape[0], ilen = to_shape[1];
|
||||
std::vector<T> mean_meas(n_channels, 0);
|
||||
Eigen::MatrixXd samples = Eigen::MatrixXd::Zero(n_channels, samples_len_ns);
|
||||
Eigen::MatrixXd spectrum = Eigen::MatrixXd::Zero(n_channels, dft_len_ws);
|
||||
Eigen::MatrixXd noise_spectrum = Eigen::MatrixXd::Zero(n_channels, dft_len_ws);
|
||||
Eigen::MatrixXd measures = Eigen::MatrixXd::Zero(n_channels, measures_len);
|
||||
|
||||
bool has_triggered = false;
|
||||
int num_measures_to_flush = 0, pos = 0;
|
||||
|
||||
// convert input to eigen mat
|
||||
auto wave_ptr = &*input->begin<T>();
|
||||
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>> waveform_t(wave_ptr, ilen, n_channels);
|
||||
auto waveform = waveform_t.transpose();
|
||||
while ((pos < ilen) && (!has_triggered)) {
|
||||
measure_timer_ns -= 1;
|
||||
for (int i = 0; i < n_channels; i++) {
|
||||
samples(i, samples_index_ns) = waveform(i, pos);
|
||||
|
||||
if (measure_timer_ns == 0) {
|
||||
int index_ns = (samples_index_ns + samples_len_ns - measure_len_ns) % samples_len_ns;
|
||||
// measure
|
||||
float meas = 0;
|
||||
RETURN_IF_NOT_OK(Measure(measure_len_ws, i, samples, &spectrum, &noise_spectrum, spectrum_window,
|
||||
spectrum_start, spectrum_end, cepstrum_window, cepstrum_start, cepstrum_end,
|
||||
noise_reduction_amount, measure_smooth_time_mult, noise_up_time_mult,
|
||||
noise_down_time_mult, index_ns, boot_count, &meas));
|
||||
measures(i, measures_index) = meas;
|
||||
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult);
|
||||
|
||||
has_triggered = has_triggered || (mean_meas[i] >= trigger_level);
|
||||
if (has_triggered) {
|
||||
RETURN_IF_NOT_OK(
|
||||
FlushMeasures(measures_len, measures_index, i, measures, trigger_level, gap_len, &num_measures_to_flush));
|
||||
}
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(UpdateVadParams(&samples_index_ns, &pos, samples_len_ns, &measure_timer_ns, measure_period_ns,
|
||||
&measures_index, measures_len, &boot_count, boot_count_max, has_triggered,
|
||||
num_measures_to_flush, &flushed_len_ns));
|
||||
}
|
||||
|
||||
// results truncate
|
||||
int new_col_ind = pos - samples_len_ns + flushed_len_ns;
|
||||
if (new_col_ind < (-1 * waveform.cols())) {
|
||||
new_col_ind = 0;
|
||||
} else if (new_col_ind < 0) {
|
||||
new_col_ind += ilen;
|
||||
}
|
||||
int new_cols = ilen - new_col_ind;
|
||||
auto res = waveform.rightCols(new_cols);
|
||||
// unpack
|
||||
std::shared_ptr<Tensor> res_tensor;
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic> mat_res = res.transpose();
|
||||
std::vector<T> res_vec(mat_res.data(), mat_res.data() + mat_res.size());
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(res_vec, TensorShape({n_channels, new_cols}), &res_tensor));
|
||||
auto reshape_vec = input_shape.AsVector();
|
||||
reshape_vec[input_shape.Size() - 1] = new_cols;
|
||||
RETURN_IF_NOT_OK(res_tensor->Reshape(TensorShape(reshape_vec)));
|
||||
*output = res_tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/vad_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status VadOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
RETURN_IF_NOT_OK(ValidateLowRank("Vad", input, kMinAudioDim, "<..., time>"));
|
||||
RETURN_IF_NOT_OK(ValidateTensorNumeric("Vad", input));
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
if (input->type() != DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
|
||||
return Vad<float>(input_tensor, output, sample_rate_, trigger_level_, trigger_time_, search_time_, allowed_gap_,
|
||||
pre_trigger_time_, boot_time_, noise_up_time_, noise_down_time_, noise_reduction_amount_,
|
||||
measure_freq_, measure_duration_, measure_smooth_time_, hp_filter_freq_, lp_filter_freq_,
|
||||
hp_lifter_freq_, lp_lifter_freq_);
|
||||
} else {
|
||||
input_tensor = input;
|
||||
return Vad<double>(input_tensor, output, sample_rate_, trigger_level_, trigger_time_, search_time_, allowed_gap_,
|
||||
pre_trigger_time_, boot_time_, noise_up_time_, noise_down_time_, noise_reduction_amount_,
|
||||
measure_freq_, measure_duration_, measure_smooth_time_, hp_filter_freq_, lp_filter_freq_,
|
||||
hp_lifter_freq_, lp_lifter_freq_);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VadOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
RETURN_IF_NOT_OK(ValidateTensorType("Vad", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString()));
|
||||
if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_VAD_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_VAD_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class VadOp : public TensorOp {
|
||||
public:
|
||||
VadOp(int32_t sample_rate, float trigger_level, float trigger_time, float search_time, float allowed_gap,
|
||||
float pre_trigger_time, float boot_time, float noise_up_time, float noise_down_time,
|
||||
float noise_reduction_amount, float measure_freq, float measure_duration, float measure_smooth_time,
|
||||
float hp_filter_freq, float lp_filter_freq, float hp_lifter_freq, float lp_lifter_freq)
|
||||
: sample_rate_(sample_rate),
|
||||
trigger_level_(trigger_level),
|
||||
trigger_time_(trigger_time),
|
||||
search_time_(search_time),
|
||||
allowed_gap_(allowed_gap),
|
||||
pre_trigger_time_(pre_trigger_time),
|
||||
boot_time_(boot_time),
|
||||
noise_up_time_(noise_up_time),
|
||||
noise_down_time_(noise_down_time),
|
||||
noise_reduction_amount_(noise_reduction_amount),
|
||||
measure_freq_(measure_freq),
|
||||
measure_duration_(measure_duration),
|
||||
measure_smooth_time_(measure_smooth_time),
|
||||
hp_filter_freq_(hp_filter_freq),
|
||||
lp_filter_freq_(lp_filter_freq),
|
||||
hp_lifter_freq_(hp_lifter_freq),
|
||||
lp_lifter_freq_(lp_lifter_freq) {}
|
||||
|
||||
~VadOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kVadOp; }
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
private:
|
||||
int32_t sample_rate_;
|
||||
float trigger_level_;
|
||||
float trigger_time_;
|
||||
float search_time_;
|
||||
float allowed_gap_;
|
||||
float pre_trigger_time_;
|
||||
float boot_time_;
|
||||
float noise_up_time_;
|
||||
float noise_down_time_;
|
||||
float noise_reduction_amount_;
|
||||
float measure_freq_;
|
||||
float measure_duration_;
|
||||
float measure_smooth_time_;
|
||||
float hp_filter_freq_;
|
||||
float lp_filter_freq_;
|
||||
float hp_lifter_freq_;
|
||||
float lp_lifter_freq_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_VAD_OP_H_
|
|
@ -1099,6 +1099,57 @@ class MS_API TrebleBiquad final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Vad TensorTransform.
|
||||
/// \notes Attempt to trim silent background sounds from the end of the voice recording.
|
||||
class MS_API Vad final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] sample_rate Sample rate of audio signal.
|
||||
/// \param[in] trigger_level The measurement level used to trigger activity detection (Default: 7.0).
|
||||
/// \param[in] trigger_time The time constant (in seconds) used to help ignore short sounds (Default: 0.25).
|
||||
/// \param[in] search_time The amount of audio (in seconds) to search for quieter/shorter sounds to include prior to
|
||||
/// the detected trigger point (Default: 1.0).
|
||||
/// \param[in] allowed_gap The allowed gap (in seconds) between quiteter/shorter sounds to include prior to the
|
||||
/// detected trigger point (Default: 0.25).
|
||||
/// \param[in] pre_trigger_time The amount of audio (in seconds) to preserve before the trigger point and any found
|
||||
/// quieter/shorter bursts (Default: 0.0).
|
||||
/// \param[in] boot_time The time for the initial noise estimate (Default: 0.35).
|
||||
/// \param[in] noise_up_time Time constant used by the adaptive noise estimator, when the noise level is increasing
|
||||
/// (Default: 0.1).
|
||||
/// \param[in] noise_down_time Time constant used by the adaptive noise estimator, when the noise level is decreasing
|
||||
/// (Default: 0.01).
|
||||
/// \param[in] noise_reduction_amount The amount of noise reduction used in the detection algorithm (Default: 1.35).
|
||||
/// \param[in] measure_freq The frequency of the algorithm’s processing (Default: 20.0).
|
||||
/// \param[in] measure_duration The duration of measurement (Default: 0, use twice the measurement period).
|
||||
/// \param[in] measure_smooth_time The time constant used to smooth spectral measurements (Default: 0.4).
|
||||
/// \param[in] hp_filter_freq The "Brick-wall" frequency of high-pass filter applied at the input to the detector
|
||||
/// algorithm (Default: 50.0).
|
||||
/// \param[in] lp_filter_freq The "Brick-wall" frequency of low-pass filter applied at the input to the detector
|
||||
/// algorithm (Default: 6000.0).
|
||||
/// \param[in] hp_lifter_freq The "Brick-wall" frequency of high-pass lifter applied at the input to the detector
|
||||
/// algorithm (Default: 150.0).
|
||||
/// \param[in] lp_lifter_freq The "Brick-wall" frequency of low-pass lifter applied at the input to the detector
|
||||
/// algorithm (Default: 2000.0).
|
||||
explicit Vad(int32_t sample_rate, float trigger_level = 7.0, float trigger_time = 0.25, float search_time = 1.0,
|
||||
float allowed_gap = 0.25, float pre_trigger_time = 0.0, float boot_time = 0.35,
|
||||
float noise_up_time = 0.1, float noise_down_time = 0.01, float noise_reduction_amount = 1.35,
|
||||
float measure_freq = 20.0, float measure_duration = 0.0, float measure_smooth_time = 0.4,
|
||||
float hp_filter_freq = 50.0, float lp_filter_freq = 6000.0, float hp_lifter_freq = 150.0,
|
||||
float lp_lifter_freq = 2000.0);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Vad() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Vol TensorTransform.
|
||||
/// \notes Add a volume to an waveform.
|
||||
class MS_API Vol final : public TensorTransform {
|
||||
|
|
|
@ -194,6 +194,7 @@ constexpr char kSpectrogramOp[] = "SpectrogramOp";
|
|||
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
|
||||
constexpr char kTimeStretchOp[] = "TimeStretchOp";
|
||||
constexpr char kTrebleBiquadOp[] = "TrebleBiquadOp";
|
||||
constexpr char kVadOp[] = "VadOp";
|
||||
constexpr char kVolOp[] = "VolOp";
|
||||
|
||||
// data
|
||||
|
|
|
@ -31,7 +31,7 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_
|
|||
check_highpass_biquad, check_inverse_mel_scale, check_lfilter, check_lowpass_biquad, check_magphase, \
|
||||
check_mask_along_axis, check_mask_along_axis_iid, check_masking, check_mel_scale, check_mu_law_coding, \
|
||||
check_overdrive, check_phase_vocoder, check_phaser, check_riaa_biquad, check_sliding_window_cmn, \
|
||||
check_spectral_centroid, check_spectrogram, check_time_stretch, check_treble_biquad, check_vol
|
||||
check_spectral_centroid, check_spectrogram, check_time_stretch, check_treble_biquad, check_vad, check_vol
|
||||
from ..transforms.py_transforms_util import Implementation
|
||||
from ..transforms.transforms import TensorOperation
|
||||
|
||||
|
@ -1767,6 +1767,81 @@ class TrebleBiquad(AudioTensorOperation):
|
|||
return cde.TrebleBiquadOperation(self.sample_rate, self.gain, self.central_freq, self.quality_factor)
|
||||
|
||||
|
||||
class Vad(AudioTensorOperation):
|
||||
"""
|
||||
Attempt to trim silent background sounds from the end of the voice recording.
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sample rate of audio signal.
|
||||
trigger_level (float, optional): The measurement level used to trigger activity detection (default=7.0).
|
||||
trigger_time (float, optional): The time constant (in seconds) used to help ignore short sounds (default=0.25).
|
||||
search_time (float, optional): The amount of audio (in seconds) to search for quieter/shorter sounds to include
|
||||
prior to the detected trigger point (default=1.0).
|
||||
allowed_gap (float, optional): The allowed gap (in seconds) between quiteter/shorter sounds to include prior to
|
||||
the detected trigger point (default=0.25).
|
||||
pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve before the trigger point and
|
||||
any found quieter/shorter bursts (default=0.0).
|
||||
boot_time (float, optional): The time for the initial noise estimate (default=0.35).
|
||||
noise_up_time (float, optional): Time constant used by the adaptive noise estimator, when the noise level is
|
||||
increasing (default=0.1).
|
||||
noise_down_time (float, optional): Time constant used by the adaptive noise estimator, when the noise level is
|
||||
decreasing (default=0.01).
|
||||
noise_reduction_amount (float, optional): The amount of noise reduction used in the detection algorithm
|
||||
(default=1.35).
|
||||
measure_freq (float, optional): The frequency of the algorithm’s processing (default=20.0).
|
||||
measure_duration (float, optional): The duration of measurement (default=None, use twice the measurement
|
||||
period).
|
||||
measure_smooth_time (float, optional): The time constant used to smooth spectral measurements (default=0.4).
|
||||
hp_filter_freq (float, optional): The "Brick-wall" frequency of high-pass filter applied at the input to the
|
||||
detector algorithm (default=50.0).
|
||||
lp_filter_freq (float, optional): The "Brick-wall" frequency of low-pass filter applied at the input to the
|
||||
detector algorithm (default=6000.0).
|
||||
hp_lifter_freq (float, optional): The "Brick-wall" frequency of high-pass lifter applied at the input to the
|
||||
detector algorithm (default=150.0).
|
||||
lp_lifter_freq (float, optional): The "Brick-wall" frequency of low-pass lifter applied at the input to the
|
||||
detector algorithm (default=2000.0).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([2, 1000])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Vad(sample_rate=600)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_vad
|
||||
def __init__(self, sample_rate, trigger_level=7.0, trigger_time=0.25, search_time=1.0, allowed_gap=0.25,
|
||||
pre_trigger_time=0.0, boot_time=0.35, noise_up_time=0.1, noise_down_time=0.01,
|
||||
noise_reduction_amount=1.35, measure_freq=20.0, measure_duration=None, measure_smooth_time=0.4,
|
||||
hp_filter_freq=50.0, lp_filter_freq=6000.0, hp_lifter_freq=150.0, lp_lifter_freq=2000.0):
|
||||
super().__init__()
|
||||
self.sample_rate = sample_rate
|
||||
self.trigger_level = trigger_level
|
||||
self.trigger_time = trigger_time
|
||||
self.search_time = search_time
|
||||
self.allowed_gap = allowed_gap
|
||||
self.pre_trigger_time = pre_trigger_time
|
||||
self.boot_time = boot_time
|
||||
self.noise_up_time = noise_up_time
|
||||
self.noise_down_time = noise_down_time
|
||||
self.noise_reduction_amount = noise_reduction_amount
|
||||
self.measure_freq = measure_freq
|
||||
self.measure_duration = measure_duration if measure_duration else 2.0 / measure_freq
|
||||
self.measure_smooth_time = measure_smooth_time
|
||||
self.hp_filter_freq = hp_filter_freq
|
||||
self.lp_filter_freq = lp_filter_freq
|
||||
self.hp_lifter_freq = hp_lifter_freq
|
||||
self.lp_lifter_freq = lp_lifter_freq
|
||||
|
||||
def parse(self):
|
||||
return cde.VadOperation(self.sample_rate, self.trigger_level, self.trigger_time, self.search_time,
|
||||
self.allowed_gap, self.pre_trigger_time, self.boot_time, self.noise_up_time,
|
||||
self.noise_down_time, self.noise_reduction_amount, self.measure_freq,
|
||||
self.measure_duration, self.measure_smooth_time, self.hp_filter_freq,
|
||||
self.lp_filter_freq, self.hp_lifter_freq, self.lp_lifter_freq)
|
||||
|
||||
|
||||
DE_C_GAIN_TYPE = {GainType.AMPLITUDE: cde.GainType.DE_GAIN_TYPE_AMPLITUDE,
|
||||
GainType.POWER: cde.GainType.DE_GAIN_TYPE_POWER,
|
||||
GainType.DB: cde.GainType.DE_GAIN_TYPE_DB}
|
||||
|
|
|
@ -719,6 +719,60 @@ def check_griffin_lim(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_vad(method):
|
||||
"""Wrapper method to check the parameters of Vad."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[sample_rate, trigger_level, trigger_time, search_time, allowed_gap, pre_trigger_time,
|
||||
boot_time, noise_up_time, noise_down_time, noise_reduction_amount, measure_freq,
|
||||
measure_duration, measure_smooth_time, hp_filter_freq, lp_filter_freq,
|
||||
hp_lifter_freq, lp_lifter_freq], _ = parse_user_args(method, *args, **kwargs)
|
||||
|
||||
type_check(sample_rate, (int,), "sample_rate")
|
||||
check_pos_int32(sample_rate, "sample_rate")
|
||||
type_check(trigger_level, (int, float), "trigger_level")
|
||||
check_float32(trigger_level, "trigger_level")
|
||||
type_check(trigger_time, (int, float), "trigger_time")
|
||||
check_non_negative_float32(trigger_time, "trigger_time")
|
||||
type_check(search_time, (int, float), "search_time")
|
||||
check_non_negative_float32(search_time, "search_time")
|
||||
type_check(allowed_gap, (int, float), "allowed_gap")
|
||||
check_non_negative_float32(allowed_gap, "allowed_gap")
|
||||
type_check(pre_trigger_time, (int, float), "pre_trigger_time")
|
||||
check_non_negative_float32(pre_trigger_time, "pre_trigger_time")
|
||||
type_check(boot_time, (int, float), "boot_time")
|
||||
check_non_negative_float32(boot_time, "boot_time")
|
||||
type_check(noise_up_time, (int, float), "noise_up_time")
|
||||
check_non_negative_float32(noise_up_time, "noise_up_time")
|
||||
type_check(noise_down_time, (int, float), "noise_down_time")
|
||||
check_non_negative_float32(noise_down_time, "noise_down_time")
|
||||
if noise_up_time < noise_down_time:
|
||||
raise ValueError("Input noise_up_time should be greater than noise_down_time, but got noise_up_time: {0} "
|
||||
"and noise_down_time: {1}.".format(noise_up_time, noise_down_time))
|
||||
type_check(noise_reduction_amount, (int, float), "noise_reduction_amount")
|
||||
check_non_negative_float32(noise_reduction_amount, "noise_reduction_amount")
|
||||
type_check(measure_freq, (int, float), "measure_freq")
|
||||
check_pos_float32(measure_freq, "measure_freq")
|
||||
if measure_duration:
|
||||
type_check(measure_duration, (int, float), "measure_duration")
|
||||
check_non_negative_float32(measure_duration, "measure_duration")
|
||||
type_check(measure_smooth_time, (int, float), "measure_smooth_time")
|
||||
check_non_negative_float32(measure_smooth_time, "measure_smooth_time")
|
||||
type_check(hp_filter_freq, (int, float), "hp_filter_freq")
|
||||
check_pos_float32(hp_filter_freq, "hp_filter_freq")
|
||||
type_check(lp_filter_freq, (int, float), "lp_filter_freq")
|
||||
check_pos_float32(lp_filter_freq, "lp_filter_freq")
|
||||
type_check(hp_lifter_freq, (int, float), "hp_lifter_freq")
|
||||
check_pos_float32(hp_lifter_freq, "hp_lifter_freq")
|
||||
type_check(lp_lifter_freq, (int, float), "lp_lifter_freq")
|
||||
check_pos_float32(lp_lifter_freq, "lp_lifter_freq")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_vol(method):
|
||||
"""Wrapper method to check the parameters of Vol."""
|
||||
|
||||
|
|
|
@ -1144,6 +1144,209 @@ TEST_F(MindDataTestPipeline, TestGriffinLimWrongArgs) {
|
|||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Vad.
|
||||
/// Description: test pipeline.
|
||||
/// Expectation: success.
|
||||
TEST_F(MindDataTestPipeline, TestVadPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVadPipeline.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 1000}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
auto vad = audio::Vad(600);
|
||||
ds = ds->Map({vad});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
std::vector<int64_t> expected = {2, 660};
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
iter->Stop();
|
||||
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema2 = Schema();
|
||||
ASSERT_OK(schema2->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {1000}));
|
||||
std::shared_ptr<Dataset> ds2 = RandomData(50, schema2);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
ds2 = ds2->SetNumWorkers(4);
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
auto vad2 = audio::Vad(1600);
|
||||
ds2 = ds2->Map({vad2});
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
std::shared_ptr<Iterator> iter2 = ds2->CreateIterator();
|
||||
EXPECT_NE(ds2, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row2;
|
||||
ASSERT_OK(iter2->GetNextRow(&row2));
|
||||
std::vector<int64_t> expected2 = {760};
|
||||
i = 0;
|
||||
while (row2.size() != 0) {
|
||||
auto col = row2["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected2);
|
||||
ASSERT_EQ(col.Shape().size(), 1);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter2->GetNextRow(&row2));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Vad.
|
||||
/// Description: test some invalid parameters.
|
||||
/// Expectation: success.
|
||||
TEST_F(MindDataTestPipeline, TestVadWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVadWrongArgs.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 1000}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// invalid sample rate
|
||||
auto vad_op = audio::Vad(-10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid search_time
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, -1.0);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid allowed_gap
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid pre_trigger_time
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid boot_time
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid noise_up_time
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid noise_down_time
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid noise_up_time and noise_down_time
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.2);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid noise_reduction_amount
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid measure_freq
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, 1.35, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid measure_duration
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, 1.35, 20.0, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid measure_smooth_time
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, 1.35, 20.0, 0.0, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid hp_filter_freq
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, 1.35, 20.0, 0.0, 0.4, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid lp_filter_freq
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, 1.35, 20.0, 0.0, 0.4, 50, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid hp_lifter_freq
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, 1.35, 20.0, 0.0, 0.4, 50, 6000.0, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
|
||||
// invalid lp_lifter_freq
|
||||
vad_op = audio::Vad(100, 7.0, 0.25, 1.0, 0.25, 0.0, 0.35, 0.1, 0.01, 1.35, 20.0, 0.0, 0.4, 50, 6000.0, 150.0, -10);
|
||||
ds = ds->Map({vad_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
iter = ds->CreateIterator();
|
||||
// Expect failure
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestVolPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestVolPipeline.";
|
||||
// Original waveform
|
||||
|
|
|
@ -1456,7 +1456,6 @@ TEST_F(MindDataTestExecute, TestFadeWithBool) {
|
|||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
|
||||
/// Feature: GriffinLim
|
||||
/// Description: test basic usage of GriffinLim
|
||||
/// Expectation: success
|
||||
|
@ -1479,6 +1478,27 @@ TEST_F(MindDataTestExecute, TestGriffinLimDefaultValue) {
|
|||
EXPECT_TRUE(status.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: Vad
|
||||
/// Description: test basic usage of Vad
|
||||
/// Expectation: success
|
||||
TEST_F(MindDataTestExecute, TestVadDefaultValue) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestVadDefaultValue.";
|
||||
// Random waveform
|
||||
std::mt19937 gen;
|
||||
std::normal_distribution<float> distribution(1.0, 0.5);
|
||||
std::vector<float> vec;
|
||||
for (int i = 0; i < 1000; ++i){
|
||||
vec.push_back(distribution(gen));
|
||||
}
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(vec, TensorShape({1, 1000}), &input));
|
||||
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> vad_op = std::make_shared<audio::Vad>(1600);
|
||||
// apply vad
|
||||
mindspore::dataset::Execute trans({vad_op});
|
||||
Status status = trans(input_ms, &input_ms);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestVolDefalutValue) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestVolDefalutValue.";
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,374 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing Vad op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as c_audio
|
||||
from mindspore import log as logger
|
||||
|
||||
DATA_DIR = "../data/dataset/audiorecord/"
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
|
||||
if np.any(np.isnan(data_expected)):
|
||||
assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
|
||||
elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
|
||||
count_unequal_element(data_expected, data_me, rtol, atol)
|
||||
|
||||
|
||||
def test_vad_pipeline1():
|
||||
"""
|
||||
Feature: Vad
|
||||
Description: test Vad cpp op in pipeline
|
||||
Expectation: equal results from Mindspore and benchmark
|
||||
"""
|
||||
# <1000>
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")[np.newaxis, :],
|
||||
column_names=["multi_dimensional_data"],
|
||||
shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=600)], input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(DATA_DIR + "single_channel_res.npy"), 0.001, 0.001)
|
||||
|
||||
# <2, 1000>
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "double_channel.npy")
|
||||
[np.newaxis, :], column_names=["multi_dimensional_data"], shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=1600)], input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(DATA_DIR + "double_channel_res.npy"), 0.001, 0.001)
|
||||
|
||||
# <1, 1000>
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")
|
||||
[np.newaxis, np.newaxis, :], column_names=["multi_dimensional_data"], shuffle=False)
|
||||
transforms = [c_audio.Vad(sample_rate=600)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(DATA_DIR + "single_channel_res.npy"), 0.001, 0.001)
|
||||
|
||||
|
||||
def test_vad_pipeline2():
|
||||
"""
|
||||
Feature: Vad
|
||||
Description: test Vad cpp op in pipeline
|
||||
Expectation: equal results from Mindspore and benchmark
|
||||
"""
|
||||
# <1, 1000> trigger level and time
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")
|
||||
[np.newaxis, np.newaxis, :], column_names=["multi_dimensional_data"],
|
||||
shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=700, trigger_level=14.0,
|
||||
trigger_time=1.0)], input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(
|
||||
DATA_DIR + "single_channel_trigger_res.npy"), 0.001, 0.001)
|
||||
|
||||
# <1, 1000> search time
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")
|
||||
[np.newaxis, np.newaxis, :], column_names=["multi_dimensional_data"],
|
||||
shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=750, trigger_level=14.0, trigger_time=1.0,
|
||||
search_time=2.0)],
|
||||
input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(
|
||||
DATA_DIR + "single_channel_search_res.npy"), 0.001, 0.001)
|
||||
|
||||
# <1, 1000> allowed gap
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")
|
||||
[np.newaxis, np.newaxis, :],
|
||||
column_names=["multi_dimensional_data"], shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=750, trigger_level=14.0, trigger_time=1.0,
|
||||
search_time=2.0, allowed_gap=0.125)],
|
||||
input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(
|
||||
DATA_DIR + "single_channel_allowed_gap_res.npy"), 0.001, 0.001)
|
||||
|
||||
# <1, 1000> boot time
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")
|
||||
[np.newaxis, np.newaxis, :], column_names=["multi_dimensional_data"],
|
||||
shuffle=False)
|
||||
dataset = dataset.map(operations=[
|
||||
c_audio.Vad(sample_rate=750,
|
||||
trigger_level=14.0,
|
||||
trigger_time=1.0,
|
||||
search_time=2.0,
|
||||
allowed_gap=0.125,
|
||||
boot_time=0.7)
|
||||
], input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(
|
||||
DATA_DIR + "single_channel_boot_time_res.npy"), 0.001, 0.001)
|
||||
|
||||
|
||||
def test_vad_pipeline3():
|
||||
"""
|
||||
Feature: Vad
|
||||
Description: test Vad cpp op in pipeline
|
||||
Expectation: equal results from Mindspore and benchmark
|
||||
"""
|
||||
# <1, 1000> noise
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")
|
||||
[np.newaxis, np.newaxis, :], column_names=["multi_dimensional_data"],
|
||||
shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=750,
|
||||
trigger_level=14.0,
|
||||
trigger_time=1.0,
|
||||
search_time=2.0,
|
||||
allowed_gap=0.125,
|
||||
boot_time=0.7,
|
||||
noise_up_time=0.5,
|
||||
noise_down_time=0.1,
|
||||
noise_reduction_amount=2.7)],
|
||||
input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(
|
||||
DATA_DIR + "single_channel_noise_res.npy"), 0.001, 0.001)
|
||||
|
||||
# <1, 1000> measure
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")
|
||||
[np.newaxis, np.newaxis, :], column_names=["multi_dimensional_data"],
|
||||
shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=800,
|
||||
trigger_level=14.0,
|
||||
trigger_time=1.0,
|
||||
search_time=2.0,
|
||||
allowed_gap=0.125,
|
||||
boot_time=0.7,
|
||||
noise_up_time=0.5,
|
||||
noise_down_time=0.1,
|
||||
noise_reduction_amount=2.7,
|
||||
measure_freq=40,
|
||||
measure_duration=0.05,
|
||||
measure_smooth_time=1.0)], input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(
|
||||
DATA_DIR + "single_channel_measure_res.npy"), 0.001, 0.001)
|
||||
|
||||
# <1, 1000> filter freq
|
||||
dataset = ds.NumpySlicesDataset(np.load(DATA_DIR + "single_channel.npy")[np.newaxis, np.newaxis, :],
|
||||
column_names=["multi_dimensional_data"],
|
||||
shuffle=False)
|
||||
dataset = dataset.map(operations=[c_audio.Vad(sample_rate=800,
|
||||
trigger_level=14.0,
|
||||
trigger_time=1.0,
|
||||
search_time=2.0,
|
||||
allowed_gap=0.125,
|
||||
boot_time=0.7,
|
||||
measure_freq=40,
|
||||
measure_duration=0.05,
|
||||
measure_smooth_time=1.0,
|
||||
hp_filter_freq=20.0,
|
||||
lp_filter_freq=3000.0,
|
||||
hp_lifter_freq=75.0,
|
||||
lp_lifter_freq=1000.0,
|
||||
noise_up_time=0.5,
|
||||
noise_down_time=0.1,
|
||||
noise_reduction_amount=2.7)],
|
||||
input_columns=["multi_dimensional_data"])
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
allclose_nparray(item["multi_dimensional_data"], np.load(DATA_DIR + "single_channel_filter_res.npy"), 0.001,
|
||||
0.001)
|
||||
|
||||
|
||||
def test_vad_pipeline_invalid_param1():
|
||||
"""
|
||||
Feature: Vad
|
||||
Description: test Vad with invalid input parameters
|
||||
Expectation: throw ValueError or TypeError
|
||||
"""
|
||||
logger.info("test InverseMelScale op with default values")
|
||||
in_data = np.load(DATA_DIR + "single_channel.npy")[np.newaxis, :]
|
||||
data1 = ds.NumpySlicesDataset(in_data, column_names=["multi_dimensional_data"], shuffle=False)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input sample_rate is not within the required interval of \[1, 2147483647\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError, match=r"Input search_time is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, search_time=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError, match=r"Input allowed_gap is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, allowed_gap=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input pre_trigger_time is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, pre_trigger_time=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError, match=r"Input boot_time is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, boot_time=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
|
||||
def test_vad_pipeline_invalid_param2():
|
||||
"""
|
||||
Feature: Vad
|
||||
Description: test Vad with invalid input parameters
|
||||
Expectation: throw ValueError or TypeError
|
||||
"""
|
||||
logger.info("test InverseMelScale op with default values")
|
||||
in_data = np.load(DATA_DIR + "single_channel.npy")[np.newaxis, :]
|
||||
data1 = ds.NumpySlicesDataset(in_data, column_names=["multi_dimensional_data"], shuffle=False)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input noise_up_time is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, noise_up_time=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input noise_down_time is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, noise_down_time=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input noise_up_time should be greater than noise_down_time, but got noise_up_time: 1 and"
|
||||
+ r" noise_down_time: 3."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, noise_up_time=1, noise_down_time=3)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input noise_reduction_amount is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, noise_reduction_amount=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
|
||||
def test_vad_pipeline_invalid_param3():
|
||||
"""
|
||||
Feature: Vad
|
||||
Description: test Vad with invalid input parameters
|
||||
Expectation: throw ValueError or TypeError
|
||||
"""
|
||||
logger.info("test InverseMelScale op with default values")
|
||||
in_data = np.load(DATA_DIR + "single_channel.npy")[np.newaxis, :]
|
||||
data1 = ds.NumpySlicesDataset(in_data, column_names=["multi_dimensional_data"], shuffle=False)
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input measure_freq is not within the required interval of \(0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, measure_freq=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input measure_duration is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, measure_duration=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input measure_smooth_time is not within the required interval of \[0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, measure_smooth_time=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input hp_filter_freq is not within the required interval of \(0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, hp_filter_freq=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input lp_filter_freq is not within the required interval of \(0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, lp_filter_freq=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input hp_lifter_freq is not within the required interval of \(0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, hp_lifter_freq=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
with pytest.raises(ValueError,
|
||||
match=r"Input lp_lifter_freq is not within the required interval of \(0, 16777216\]."):
|
||||
transforms = [c_audio.Vad(sample_rate=1000, lp_lifter_freq=-10)]
|
||||
data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
_ = item["multi_dimensional_data"]
|
||||
|
||||
|
||||
def test_vad_eager():
|
||||
"""
|
||||
Feature: Vad
|
||||
Description: test Vad cpp op with eager mode
|
||||
Expectation: equal results from Mindspore and benchmark
|
||||
"""
|
||||
spectrogram = np.load(DATA_DIR + "single_channel.npy")
|
||||
out_ms = c_audio.Vad(sample_rate=600)(spectrogram)
|
||||
out_expect = np.load(DATA_DIR + "single_channel_res.npy")
|
||||
allclose_nparray(out_ms, out_expect, 0.001, 0.001)
|
||||
|
||||
spectrogram = np.load(DATA_DIR + "double_channel.npy")
|
||||
out_ms = c_audio.Vad(sample_rate=1600)(spectrogram)
|
||||
out_expect = np.load(DATA_DIR + "double_channel_res.npy")
|
||||
allclose_nparray(out_ms, out_expect, 0.001, 0.001)
|
||||
|
||||
# benchmark op trigger warning
|
||||
spectrogram = np.load(DATA_DIR + "three_channel.npy")
|
||||
out_ms = c_audio.Vad(sample_rate=1600)(spectrogram)
|
||||
out_expect = np.load(DATA_DIR + "three_channel_res.npy")
|
||||
allclose_nparray(out_ms, out_expect, 0.001, 0.001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_vad_pipeline1()
|
||||
test_vad_pipeline2()
|
||||
test_vad_pipeline3()
|
||||
test_vad_pipeline_invalid_param1()
|
||||
test_vad_pipeline_invalid_param2()
|
||||
test_vad_pipeline_invalid_param3()
|
||||
test_vad_eager()
|
Loading…
Reference in New Issue