[feat] [assistant] [I3J6U8] Add a new audio operator DetectPitchFrequency

This commit is contained in:
Zhang Lian 2021-09-29 15:12:58 +08:00
parent 6d1a73b4c1
commit b7e6520e62
17 changed files with 892 additions and 19 deletions

View File

@ -28,6 +28,7 @@
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h" #include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" #include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/fade_ir.h" #include "minddata/dataset/audio/ir/kernels/fade_ir.h"
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
@ -218,6 +219,30 @@ std::shared_ptr<TensorOperation> DeemphBiquad::Parse() {
return std::make_shared<DeemphBiquadOperation>(data_->sample_rate_); return std::make_shared<DeemphBiquadOperation>(data_->sample_rate_);
} }
// DetectPitchFrequency Transform Operation.
struct DetectPitchFrequency::Data {
Data(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high)
: sample_rate_(sample_rate),
frame_time_(frame_time),
win_length_(win_length),
freq_low_(freq_low),
freq_high_(freq_high) {}
int32_t sample_rate_;
float frame_time_;
int32_t win_length_;
int32_t freq_low_;
int32_t freq_high_;
};
DetectPitchFrequency::DetectPitchFrequency(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low,
int32_t freq_high)
: data_(std::make_shared<Data>(sample_rate, frame_time, win_length, freq_low, freq_high)) {}
std::shared_ptr<TensorOperation> DetectPitchFrequency::Parse() {
return std::make_shared<DetectPitchFrequencyOperation>(data_->sample_rate_, data_->frame_time_, data_->win_length_,
data_->freq_low_, data_->freq_high_);
}
// EqualizerBiquad Transform Operation. // EqualizerBiquad Transform Operation.
struct EqualizerBiquad::Data { struct EqualizerBiquad::Data {
Data(int32_t sample_rate, float center_freq, float gain, float Q) Data(int32_t sample_rate, float center_freq, float gain, float Q)

View File

@ -32,6 +32,7 @@
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h" #include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h" #include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/fade_ir.h" #include "minddata/dataset/audio/ir/kernels/fade_ir.h"
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
@ -186,6 +187,19 @@ PYBIND_REGISTER(
})); }));
})); }));
PYBIND_REGISTER(DetectPitchFrequencyOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::DetectPitchFrequencyOperation, TensorOperation,
std::shared_ptr<audio::DetectPitchFrequencyOperation>>(
*m, "DetectPitchFrequencyOperation")
.def(py::init([](int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low,
int32_t freq_high) {
auto detect_pitch_frequency = std::make_shared<audio::DetectPitchFrequencyOperation>(
sample_rate, frame_time, win_length, freq_low, freq_high);
THROW_IF_ERROR(detect_pitch_frequency->ValidateParams());
return detect_pitch_frequency;
}));
}));
PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) { PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::EqualizerBiquadOperation, TensorOperation, (void)py::class_<audio::EqualizerBiquadOperation, TensorOperation,
std::shared_ptr<audio::EqualizerBiquadOperation>>(*m, "EqualizerBiquadOperation") std::shared_ptr<audio::EqualizerBiquadOperation>>(*m, "EqualizerBiquadOperation")

View File

@ -14,6 +14,7 @@ add_library(audio-ir-kernels OBJECT
contrast_ir.cc contrast_ir.cc
dc_shift_ir.cc dc_shift_ir.cc
deemph_biquad_ir.cc deemph_biquad_ir.cc
detect_pitch_frequency_ir.cc
equalizer_biquad_ir.cc equalizer_biquad_ir.cc
fade_ir.cc fade_ir.cc
frequency_masking_ir.cc frequency_masking_ir.cc

View File

@ -0,0 +1,61 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
#include "minddata/dataset/audio/ir/validators.h"
#include "minddata/dataset/audio/kernels/detect_pitch_frequency_op.h"
namespace mindspore {
namespace dataset {
namespace audio {
// DetectPitchFrequencyOperation
DetectPitchFrequencyOperation::DetectPitchFrequencyOperation(int32_t sample_rate, float frame_time, int32_t win_length,
int32_t freq_low, int32_t freq_high)
: sample_rate_(sample_rate),
frame_time_(frame_time),
win_length_(win_length),
freq_low_(freq_low),
freq_high_(freq_high) {}
Status DetectPitchFrequencyOperation::ValidateParams() {
RETURN_IF_NOT_OK(ValidateScalarNotZero("DetectPitchFrequency", "sample_rate", sample_rate_));
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("DetectPitchFrequency", "frame_time", frame_time_));
RETURN_IF_NOT_OK(ValidateIntScalarPositive("DetectPitchFrequency", "win_length", win_length_));
RETURN_IF_NOT_OK(ValidateIntScalarPositive("DetectPitchFrequency", "freq_low", freq_low_));
RETURN_IF_NOT_OK(ValidateIntScalarPositive("DetectPitchFrequency", "freq_high", freq_high_));
return Status::OK();
}
std::shared_ptr<TensorOp> DetectPitchFrequencyOperation::Build() {
std::shared_ptr<DetectPitchFrequencyOp> tensor_op =
std::make_shared<DetectPitchFrequencyOp>(sample_rate_, frame_time_, win_length_, freq_low_, freq_high_);
return tensor_op;
}
Status DetectPitchFrequencyOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["sample_rate"] = sample_rate_;
args["frame_time"] = frame_time_;
args["win_length"] = win_length_;
args["freq_low"] = freq_low_;
args["freq_high"] = freq_high_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,60 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DETECT_PITCH_FREQUENCY_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DETECT_PITCH_FREQUENCY_IR_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kDetectPitchFrequencyOperation[] = "DetectPitchFrequency";
class DetectPitchFrequencyOperation : public TensorOperation {
public:
DetectPitchFrequencyOperation(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low,
int32_t freq_high);
~DetectPitchFrequencyOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kDetectPitchFrequencyOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
int32_t sample_rate_;
float frame_time_;
int32_t win_length_;
int32_t freq_low_;
int32_t freq_high_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DETECT_PITCH_FREQUENCY_IR_H_

View File

@ -15,6 +15,7 @@ add_library(audio-kernels OBJECT
contrast_op.cc contrast_op.cc
dc_shift_op.cc dc_shift_op.cc
deemph_biquad_op.cc deemph_biquad_op.cc
detect_pitch_frequency_op.cc
equalizer_biquad_op.cc equalizer_biquad_op.cc
fade_op.cc fade_op.cc
frequency_masking_op.cc frequency_masking_op.cc

View File

@ -650,5 +650,79 @@ Status Magphase(const TensorRow &input, TensorRow *output, float power) {
return Status::OK(); return Status::OK();
} }
Status MedianSmoothing(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length) {
auto channel = input->shape()[0];
auto num_of_frames = input->shape()[1];
// Centered windowed
int32_t pad_length = (win_length - 1) / 2;
int32_t out_length = num_of_frames + pad_length - win_length + 1;
TensorShape out_shape({channel, out_length});
std::vector<int> signal;
std::vector<int> out;
std::vector<int> indices(channel * (num_of_frames + pad_length), 0);
// "replicate" padding in any dimension
for (auto itr = input->begin<int>(); itr != input->end<int>(); ++itr) {
signal.push_back(*itr);
}
for (int i = 0; i < channel; ++i) {
for (int j = 0; j < pad_length; ++j) {
indices[i * (num_of_frames + pad_length) + j] = signal[i * num_of_frames];
}
}
for (int i = 0; i < channel; ++i) {
for (int j = 0; j < num_of_frames; ++j) {
indices[i * (num_of_frames + pad_length) + j + pad_length] = signal[i * num_of_frames + j];
}
}
for (int i = 0; i < channel; ++i) {
int32_t index = i * (num_of_frames + pad_length);
for (int j = 0; j < out_length; ++j) {
std::vector<int> tem(indices.begin() + index, indices.begin() + win_length + index);
std::sort(tem.begin(), tem.end());
out.push_back(tem[pad_length]);
++index;
}
}
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out, out_shape, output));
return Status::OK();
}
Status DetectPitchFrequency(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high) {
std::shared_ptr<Tensor> nccf;
std::shared_ptr<Tensor> indices;
std::shared_ptr<Tensor> smooth_indices;
// pack batch
TensorShape input_shape = input->shape();
TensorShape to_shape({input->Size() / input_shape[-1], input_shape[-1]});
RETURN_IF_NOT_OK(input->Reshape(to_shape));
if (input->type() == DataType(DataType::DE_FLOAT32)) {
RETURN_IF_NOT_OK(ComputeNccf<float>(input, &nccf, sample_rate, frame_time, freq_low));
RETURN_IF_NOT_OK(FindMaxPerFrame<float>(nccf, &indices, sample_rate, freq_high));
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
RETURN_IF_NOT_OK(ComputeNccf<double>(input, &nccf, sample_rate, frame_time, freq_low));
RETURN_IF_NOT_OK(FindMaxPerFrame<double>(nccf, &indices, sample_rate, freq_high));
} else {
RETURN_IF_NOT_OK(ComputeNccf<float16>(input, &nccf, sample_rate, frame_time, freq_low));
RETURN_IF_NOT_OK(FindMaxPerFrame<float16>(nccf, &indices, sample_rate, freq_high));
}
RETURN_IF_NOT_OK(MedianSmoothing(indices, &smooth_indices, win_length));
// Convert indices to frequency
constexpr double EPSILON = 1e-9;
TensorShape freq_shape = smooth_indices->shape();
std::vector<float> out;
for (auto itr_fre = smooth_indices->begin<int>(); itr_fre != smooth_indices->end<int>(); ++itr_fre) {
out.push_back(sample_rate / (EPSILON + *itr_fre));
}
// unpack batch
auto shape_vec = input_shape.AsVector();
shape_vec[shape_vec.size() - 1] = freq_shape[-1];
TensorShape out_shape(shape_vec);
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out, out_shape, output));
return Status::OK();
}
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore

View File

@ -372,6 +372,191 @@ Status Vol(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
/// \param power: Power of the norm. /// \param power: Power of the norm.
Status Magphase(const TensorRow &input, TensorRow *output, float power); Status Magphase(const TensorRow &input, TensorRow *output, float power);
/// \brief Compute Normalized Cross-Correlation Function (NCCF).
/// \param input: Tensor of shape <channel,waveform_length>.
/// \param output: Tensor of shape <channel, num_of_frames, lags>.
/// \param sample_rate: The sample rate of the waveform (Hz).
/// \param frame_time: Duration of a frame.
/// \param freq_low: Lowest frequency that can be detected (Hz).
/// \return Status code.
template <typename T>
Status ComputeNccf(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
float frame_time, int32_t freq_low) {
auto channel = input->shape()[0];
auto waveform_length = input->shape()[1];
size_t idx = 0;
size_t channel_idx = 1;
int32_t lags = static_cast<int32_t>(ceil(static_cast<float>(sample_rate) / freq_low));
int32_t frame_size = static_cast<int32_t>(ceil(sample_rate * frame_time));
int32_t num_of_frames = static_cast<int32_t>(ceil(static_cast<float>(waveform_length) / frame_size));
int32_t p = lags + num_of_frames * frame_size - waveform_length;
TensorShape output_shape({channel, num_of_frames, lags});
DataType intput_type = input->type();
RETURN_IF_NOT_OK(Tensor::CreateEmpty(output_shape, intput_type, output));
// pad p 0 in -1 dimension
std::vector<T> signal;
// Tensor -> vector
for (auto itr = input->begin<T>(); itr != input->end<T>();) {
while (idx < waveform_length * channel_idx) {
signal.push_back(*itr);
++itr;
++idx;
}
// Each channel is processed with the sliding window
// waveform[channel, time] --> waveform[channel, time+p]
for (size_t i = 0; i < p; ++i) {
signal.push_back(static_cast<T>(0.0));
}
if (idx % waveform_length == 0) {
++channel_idx;
}
}
// compute ncc
for (dsize_t lag = 1; lag <= lags; ++lag) {
// compute one ncc
// one ncc out
std::vector<T> out;
channel_idx = 1;
idx = 0;
size_t win_idx = 0;
size_t waveform_length_p = waveform_length + p;
// Traversal signal
for (auto itr = signal.begin(); itr != signal.end();) {
// Each channel is processed with the sliding window
size_t s1 = idx;
size_t s2 = idx + lag;
size_t frame_count = 0;
T s1_norm = static_cast<T>(0);
T s2_norm = static_cast<T>(0);
T ncc_umerator = static_cast<T>(0);
T ncc = static_cast<T>(0);
while (idx < waveform_length_p * channel_idx) {
// Sliding window
if (frame_count == num_of_frames) {
++itr;
++idx;
continue;
}
if (win_idx < frame_size) {
ncc_umerator += signal[s1] * signal[s2];
s1_norm += signal[s1] * signal[s1];
s2_norm += signal[s2] * signal[s2];
++win_idx;
++s1;
++s2;
}
if (win_idx == frame_size) {
if (s1_norm != static_cast<T>(0.0) && s2_norm != static_cast<T>(0.0)) {
ncc = ncc_umerator / s1_norm / s2_norm;
} else {
ncc = static_cast<T>(0.0);
}
out.push_back(ncc);
ncc_umerator = static_cast<T>(0.0);
s1_norm = static_cast<T>(0.0);
s2_norm = static_cast<T>(0.0);
++frame_count;
win_idx = 0;
}
++itr;
++idx;
}
if (idx % waveform_length_p == 0) {
++channel_idx;
}
} // compute one ncc
// cat tensor
auto itr_out = out.begin();
for (dsize_t row_idx = 0; row_idx < channel; ++row_idx) {
for (dsize_t frame_idx = 0; frame_idx < num_of_frames; ++frame_idx) {
RETURN_IF_NOT_OK((*output)->SetItemAt({row_idx, frame_idx, lag - 1}, *itr_out));
++itr_out;
}
}
} // compute ncc
return Status::OK();
}
/// \brief For each frame, take the highest value of NCCF.
/// \param input: Tensor of shape <channel, num_of_frames, lags>.
/// \param output: Tensor of shape <channel, num_of_frames>.
/// \param sample_rate: The sample rate of the waveform (Hz).
/// \param freq_high: Highest frequency that can be detected (Hz).
/// \return Status code.
template <typename T>
Status FindMaxPerFrame(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
int32_t freq_high) {
std::vector<T> signal;
std::vector<int> out;
auto channel = input->shape()[0];
auto num_of_frames = input->shape()[1];
auto lags = input->shape()[2];
int32_t lag_min = static_cast<int32_t>(ceil(static_cast<float>(sample_rate) / freq_high));
TensorShape out_shape({channel, num_of_frames});
// pack batch
for (auto itr = input->begin<T>(); itr != input->end<T>(); ++itr) {
signal.push_back(*itr);
}
// find the best nccf
T best_max_value = static_cast<T>(0.0);
T half_max_value = static_cast<T>(0.0);
int32_t best_max_indices = 0;
int32_t half_max_indices = 0;
auto thresh = static_cast<T>(0.99);
auto lags_half = lags / 2;
for (dsize_t channel_idx = 0; channel_idx < channel; ++channel_idx) {
for (dsize_t frame_idx = 0; frame_idx < num_of_frames; ++frame_idx) {
auto index_01 = channel_idx * num_of_frames * lags + frame_idx * lags + lag_min;
best_max_value = signal[index_01];
half_max_value = signal[index_01];
best_max_indices = lag_min;
half_max_indices = lag_min;
for (dsize_t lag_idx = 0; lag_idx < lags; ++lag_idx) {
if (lag_idx > lag_min) {
auto index_02 = channel_idx * num_of_frames * lags + frame_idx * lags + lag_idx;
if (signal[index_02] > best_max_value) {
best_max_value = signal[index_02];
best_max_indices = lag_idx;
if (lag_idx < lags_half) {
half_max_value = signal[index_02];
half_max_indices = lag_idx;
}
}
}
}
// Add back minimal lag
// Add 1 empirical calibration offset
if (half_max_value > best_max_value * thresh) {
out.push_back(half_max_indices + 1);
} else {
out.push_back(best_max_indices + 1);
}
}
}
// unpack batch
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out, out_shape, output));
return Status::OK();
}
/// \brief Apply median smoothing to the 1D tensor over the given window.
/// \param input: Tensor of shape<channel, num_of_frames>.
/// \param output: Tensor of shape <channel, num_of_window>.
/// \param win_length: The window length for median smoothing (in number of frames).
/// \return Status code.
Status MedianSmoothing(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length);
/// \brief Detect pitch frequency.
/// \param input: Tensor of shape <channel,waveform_length>.
/// \param output: Tensor of shape <channel, num_of_frames, lags>.
/// \param sample_rate: The sample rate of the waveform (Hz).
/// \param frame_time: Duration of a frame.
/// \param win_length: The window length for median smoothing (in number of frames).
/// \param freq_low: Lowest frequency that can be detected (Hz).
/// \param freq_high: Highest frequency that can be detected (Hz).
/// \return Status code.
Status DetectPitchFrequency(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high);
} // namespace dataset } // namespace dataset
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_

View File

@ -0,0 +1,37 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/detect_pitch_frequency_op.h"
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status DetectPitchFrequencyOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
TensorShape input_shape = input->shape();
// check input tensor dimension, it should be greater than 0.
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0,
"DetectPitchFrequency: input tensor is not in shape of <..., time>.");
// check input type, it should be DE_FLOAT16, DE_FLOAT32 or DE_FLOAT64
CHECK_FAIL_RETURN_UNEXPECTED(
input->type() == DataType(DataType::DE_FLOAT32) || input->type() == DataType(DataType::DE_FLOAT64) ||
input->type() == DataType(DataType::DE_FLOAT16),
"DetectPitchFrequency: input tensor type should be float or double, but got: " + input->type().ToString());
return DetectPitchFrequency(input, output, sample_rate_, frame_time_, win_length_, freq_low_, freq_high_);
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DETECT_PITCH_FREQUENCY_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DETECT_PITCH_FREQUENCY_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class DetectPitchFrequencyOp : public TensorOp {
public:
DetectPitchFrequencyOp(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high)
: sample_rate_(sample_rate),
frame_time_(frame_time),
win_length_(win_length),
freq_low_(freq_low),
freq_high_(freq_high) {}
~DetectPitchFrequencyOp() override = default;
void Print(std::ostream &out) const override {
out << Name() << ": sample_rate: " << sample_rate_ << ", frame_time: " << frame_time_
<< ", win_length: " << win_length_ << ", freq_low: " << freq_low_ << ", freq_high: " << freq_high_ << std::endl;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kDetectPitchFrequencyOp; }
private:
int32_t sample_rate_;
float frame_time_;
int32_t win_length_;
int32_t freq_low_;
int32_t freq_high_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DETECT_PITCH_FREQUENCY_OP_H_

View File

@ -300,6 +300,33 @@ class DeemphBiquad final : public TensorTransform {
std::shared_ptr<Data> data_; std::shared_ptr<Data> data_;
}; };
/// \brief Detect pitch frequency.
class DetectPitchFrequency final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
/// \param[in] frame_time Duration of a frame, the value must be greater than zero (default=0.02).
/// \param[in] win_length The window length for median smoothing (in number of frames), the value must
/// be greater than zero (default=30).
/// \param[in] freq_low Lowest frequency that can be detected (Hz), the value must be greater than zero (default=85).
/// \param[in] freq_high Highest frequency that can be detected (Hz), the value must be greater than
/// zero (default=3400).
explicit DetectPitchFrequency(int32_t sample_rate, float frame_time = 0.01, int32_t win_length = 30,
int32_t freq_low = 85, int32_t freq_high = 3400);
/// \brief Destructor.
~DetectPitchFrequency() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio. /// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio.
class EqualizerBiquad final : public TensorTransform { class EqualizerBiquad final : public TensorTransform {
public: public:

View File

@ -155,6 +155,7 @@ constexpr char kComplexNormOp[] = "ComplexNormOp";
constexpr char kContrastOp[] = "ContrastOp"; constexpr char kContrastOp[] = "ContrastOp";
constexpr char kDCShiftOp[] = "DCShiftOp"; constexpr char kDCShiftOp[] = "DCShiftOp";
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp"; constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
constexpr char kDetectPitchFrequencyOp[] = "DetectPitchFrequencyOp";
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp"; constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
constexpr char kFadeOp[] = "FadeOp"; constexpr char kFadeOp[] = "FadeOp";
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp"; constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";

View File

@ -26,8 +26,8 @@ from ..transforms.c_transforms import TensorOperation
from .utils import FadeShape, GainType, ScaleType from .utils import FadeShape, GainType, ScaleType
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \ check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \ check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_highpass_biquad, \
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_riaa_biquad, \ check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_riaa_biquad, \
check_time_stretch, check_treble_biquad, check_vol check_time_stretch, check_treble_biquad, check_vol
@ -379,6 +379,45 @@ class DeemphBiquad(AudioTensorOperation):
return cde.DeemphBiquadOperation(self.sample_rate) return cde.DeemphBiquadOperation(self.sample_rate)
class DetectPitchFrequency(AudioTensorOperation):
"""
Detect pitch frequency.
It is implemented using normalized cross-correlation function and median smoothing.
Args:
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
frame_time (float, optional): Duration of a frame, the value must be greater than zero (default=0.01).
win_length (int, optional): The window length for median smoothing (in number of frames), the value must be
greater than zero (default=30).
freq_low (int, optional): Lowest frequency that can be detected (Hz), the value must be greater than zero
(default=85).
freq_high (int, optional): Highest frequency that can be detected (Hz), the value must be greater than zero
(default=3400).
Examples:
>>> import numpy as np
>>>
>>> waveform = np.array([[0.716064e-03, 5.347656e-03, 6.246826e-03, 2.089477e-02, 7.138305e-02],
... [4.156616e-02, 1.394653e-02, 3.550292e-02, 0.614379e-02, 3.840209e-02]])
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
>>> transforms = [audio.DetectPitchFrequency(30, 0.1, 3, 5, 25)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_detect_pitch_frequency
def __init__(self, sample_rate, frame_time=0.01, win_length=30, freq_low=85, freq_high=3400):
self.sample_rate = sample_rate
self.frame_time = frame_time
self.win_length = win_length
self.freq_low = freq_low
self.freq_high = freq_high
def parse(self):
return cde.DetectPitchFrequencyOperation(self.sample_rate, self.frame_time,
self.win_length, self.freq_low, self.freq_high)
class EqualizerBiquad(AudioTensorOperation): class EqualizerBiquad(AudioTensorOperation):
""" """
Design biquad equalizer filter and perform filtering. Similar to SoX implementation. Design biquad equalizer filter and perform filtering. Similar to SoX implementation.

View File

@ -453,3 +453,25 @@ def check_vol(method):
return method(self, *args, **kwargs) return method(self, *args, **kwargs)
return new_method return new_method
def check_detect_pitch_frequency(method):
"""Wrapper method to check the parameters of DetectPitchFrequency."""
@wraps(method)
def new_method(self, *args, **kwargs):
[sample_rate, frame_time, win_length, freq_low, freq_high], _ = parse_user_args(
method, *args, **kwargs)
type_check(sample_rate, (int,), "sample_rate")
check_int32_not_zero(sample_rate, "sample_rate")
type_check(frame_time, (float, int), "frame_time")
check_pos_float32(frame_time, "frame_time")
type_check(win_length, (int,), "win_length")
check_pos_int32(win_length, "win_length")
type_check(freq_low, (int, float), "freq_low")
check_pos_float32(freq_low, "freq_low")
type_check(freq_high, (int, float), "freq_high")
check_pos_float32(freq_high, "freq_high")
return method(self, *args, **kwargs)
return new_method

View File

@ -1375,7 +1375,7 @@ TEST_F(MindDataTestPipeline, TestMagphaseWrongArgs) {
std::shared_ptr<TensorTransform> magphase(new audio::Magphase(power_wrong)); std::shared_ptr<TensorTransform> magphase(new audio::Magphase(power_wrong));
std::unordered_map<std::string, mindspore::MSTensor> row; std::unordered_map<std::string, mindspore::MSTensor> row;
//Magphase: power must be greater than or equal to 0. // Magphase: power must be greater than or equal to 0.
std::shared_ptr<SchemaObj> schema = Schema(); std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {2, 2})); ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
std::shared_ptr<Dataset> ds = RandomData(8, schema); std::shared_ptr<Dataset> ds = RandomData(8, schema);
@ -1385,3 +1385,100 @@ TEST_F(MindDataTestPipeline, TestMagphaseWrongArgs) {
std::shared_ptr<Iterator> iter = ds->CreateIterator(); std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_EQ(iter, nullptr); EXPECT_EQ(iter, nullptr);
} }
TEST_F(MindDataTestPipeline, TestDetectPitchFrequencyBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDetectPitchFrequencyBasic.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 10000}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto DetectPitchFrequencyOp = audio::DetectPitchFrequency(44100);
ds = ds->Map({DetectPitchFrequencyOp});
EXPECT_NE(ds, nullptr);
// Detect pitch frequency
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, 8};
int i = 0;
while (row.size() != 0) {
auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestDetectPitchFrequencyParamCheck) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDetectPitchFrequencyParamCheck.";
std::shared_ptr<SchemaObj> schema = Schema();
// Original waveform
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
std::shared_ptr<Dataset> ds01;
std::shared_ptr<Dataset> ds02;
std::shared_ptr<Dataset> ds03;
std::shared_ptr<Dataset> ds04;
std::shared_ptr<Dataset> ds05;
EXPECT_NE(ds, nullptr);
// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
auto detect_pitch_frequency_op_01 = audio::DetectPitchFrequency(0);
ds01 = ds->Map({detect_pitch_frequency_op_01});
EXPECT_NE(ds01, nullptr);
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
EXPECT_EQ(iter01, nullptr);
// Check frame_time
MS_LOG(INFO) << "frame_time is zero.";
auto detect_pitch_frequency_op_02 = audio::DetectPitchFrequency(30, 0);
ds02 = ds->Map({detect_pitch_frequency_op_02});
EXPECT_NE(ds02, nullptr);
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
EXPECT_EQ(iter02, nullptr);
// Check win_length
MS_LOG(INFO) << "win_length is zero.";
auto detect_pitch_frequency_op_03 = audio::DetectPitchFrequency(30, 0.1, 0);
ds03 = ds->Map({detect_pitch_frequency_op_03});
EXPECT_NE(ds03, nullptr);
std::shared_ptr<Iterator> iter03 = ds03->CreateIterator();
EXPECT_EQ(iter03, nullptr);
// Check freq_low
MS_LOG(INFO) << "freq_low is zero.";
auto detect_pitch_frequency_op_04 = audio::DetectPitchFrequency(30, 0.1, 3, 0);
ds04 = ds->Map({detect_pitch_frequency_op_04});
EXPECT_NE(ds04, nullptr);
std::shared_ptr<Iterator> iter04 = ds04->CreateIterator();
EXPECT_EQ(iter04, nullptr);
// Check freq_high
MS_LOG(INFO) << "freq_high is zero.";
auto detect_pitch_frequency_op_05 = audio::DetectPitchFrequency(30, 0.1, 3, 5, 0);
ds05 = ds->Map({detect_pitch_frequency_op_05});
EXPECT_NE(ds05, nullptr);
std::shared_ptr<Iterator> iter05 = ds05->CreateIterator();
EXPECT_EQ(iter05, nullptr);
}

View File

@ -897,12 +897,8 @@ TEST_F(MindDataTestExecute, TestRiaaBiquadWithEager) {
TEST_F(MindDataTestExecute, TestRiaaBiquadWithWrongArg) { TEST_F(MindDataTestExecute, TestRiaaBiquadWithWrongArg) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRiaaBiquadWithWrongArg."; MS_LOG(INFO) << "Doing MindDataTestExecute-TestRiaaBiquadWithWrongArg.";
std::vector<float> labels = { std::vector<float> labels = {3.156, 5.690, 1.362, 1.093, 5.782, 6.381, 5.982, 3.098, 1.222, 6.027,
3.156, 5.690, 1.362, 1.093, 3.909, 7.993, 4.324, 1.092, 5.093, 0.991, 1.099, 4.092, 8.111, 6.666};
5.782, 6.381, 5.982, 3.098,
1.222, 6.027, 3.909, 7.993,
4.324, 1.092, 5.093, 0.991,
1.099, 4.092, 8.111, 6.666};
std::shared_ptr<Tensor> input; std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input)); ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input));
auto input01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input)); auto input01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
@ -917,12 +913,8 @@ TEST_F(MindDataTestExecute, TestRiaaBiquadWithWrongArg) {
TEST_F(MindDataTestExecute, TestTrebleBiquadWithEager) { TEST_F(MindDataTestExecute, TestTrebleBiquadWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTrebleBiquadWithEager."; MS_LOG(INFO) << "Doing MindDataTestExecute-TestTrebleBiquadWithEager.";
// Original waveform // Original waveform
std::vector<float> labels = { std::vector<float> labels = {3.156, 5.690, 1.362, 1.093, 5.782, 6.381, 5.982, 3.098, 1.222, 6.027,
3.156, 5.690, 1.362, 1.093, 3.909, 7.993, 4.324, 1.092, 5.093, 0.991, 1.099, 4.092, 8.111, 6.666};
5.782, 6.381, 5.982, 3.098,
1.222, 6.027, 3.909, 7.993,
4.324, 1.092, 5.093, 0.991,
1.099, 4.092, 8.111, 6.666};
std::shared_ptr<Tensor> input; std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input)); ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input)); auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
@ -949,9 +941,10 @@ TEST_F(MindDataTestExecute, TestTrebleBiquadWithWrongArg) {
std::shared_ptr<TensorTransform> treble_biquad_op01 = std::make_shared<audio::TrebleBiquad>(0.0, 200.0); std::shared_ptr<TensorTransform> treble_biquad_op01 = std::make_shared<audio::TrebleBiquad>(0.0, 200.0);
mindspore::dataset::Execute Transform01({treble_biquad_op01}); mindspore::dataset::Execute Transform01({treble_biquad_op01});
EXPECT_ERROR(Transform01(input01, &input01)); EXPECT_ERROR(Transform01(input01, &input01));
//Check Q // Check Q
MS_LOG(INFO) << "Q is zero."; MS_LOG(INFO) << "Q is zero.";
std::shared_ptr<TensorTransform> treble_biquad_op02 = std::make_shared<audio::TrebleBiquad>(44100, 200.0, 3000.0, 0.0); std::shared_ptr<TensorTransform> treble_biquad_op02 =
std::make_shared<audio::TrebleBiquad>(44100, 200.0, 3000.0, 0.0);
mindspore::dataset::Execute Transform02({treble_biquad_op02}); mindspore::dataset::Execute Transform02({treble_biquad_op02});
EXPECT_ERROR(Transform02(input02, &input02)); EXPECT_ERROR(Transform02(input02, &input02));
} }
@ -1169,8 +1162,7 @@ TEST_F(MindDataTestExecute, TestMagphaseEager) {
float power = 1.0; float power = 1.0;
std::vector<mindspore::MSTensor> output_tensor; std::vector<mindspore::MSTensor> output_tensor;
std::shared_ptr<Tensor> test; std::shared_ptr<Tensor> test;
std::vector<float> test_vector = {3, 4, -3, 4, 3, -4, -3, -4, std::vector<float> test_vector = {3, 4, -3, 4, 3, -4, -3, -4, 5, 12, -5, 12, 5, -12, -5, -12};
5, 12, -5, 12, 5, -12, -5, -12};
Tensor::CreateFromVector(test_vector, TensorShape({2, 4, 2}), &test); Tensor::CreateFromVector(test_vector, TensorShape({2, 4, 2}), &test);
auto input_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test)); auto input_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
std::shared_ptr<TensorTransform> magphase(new audio::Magphase({power})); std::shared_ptr<TensorTransform> magphase(new audio::Magphase({power}));
@ -1234,3 +1226,68 @@ TEST_F(MindDataTestExecute, TestRandomAdjustSharpnessEager) {
Status rc = transform(image, &image); Status rc = transform(image, &image);
EXPECT_EQ(rc, Status::OK()); EXPECT_EQ(rc, Status::OK());
} }
TEST_F(MindDataTestExecute, TestDetectPitchFrequencyWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDetectPitchFrequencyWithEager.";
// Original waveform
std::vector<double> labels = {
3.716064453125000000e-03, 2.347656250000000000e-03, 9.246826171875000000e-03, 4.089477539062500000e-02,
3.138305664062500000e-02, 1.156616210937500000e-02, 0.394653320312500000e-02, 1.550292968750000000e-02,
1.614379882812500000e-02, 0.840209960937500000e-02, 1.718139648437500000e-02, 2.599121093750000000e-02,
5.647949218750000000e-02, 1.510620117187500000e-02, 2.385498046875000000e-02, 1.345825195312500000e-02,
1.419067382812500000e-02, 3.284790039062500000e-02, 9.052856445312500000e-02, 2.368896484375000000e-03};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> detect_pitch_frequency_01 =
std::make_shared<audio::DetectPitchFrequency>(30, 0.1, 3, 5, 25);
mindspore::dataset::Execute Transform01({detect_pitch_frequency_01});
// Detect pitch frequence
Status s01 = Transform01(input_02, &input_02);
EXPECT_TRUE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestDetectPitchFrequencyWithWrongArg) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDetectPitchFrequencyWithWrongArg.";
std::vector<float> labels = {
0.716064e-03, 5.347656e-03, 6.246826e-03, 2.089477e-02, 7.138305e-02,
4.156616e-02, 1.394653e-02, 3.550292e-02, 0.614379e-02, 3.840209e-02,
};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 5}), &input));
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
// Check frame_time
MS_LOG(INFO) << "frame_time is zero.";
std::shared_ptr<TensorTransform> detect_pitch_frequency_01 =
std::make_shared<audio::DetectPitchFrequency>(40, 0, 3, 3, 20);
mindspore::dataset::Execute Transform01({detect_pitch_frequency_01});
Status s01 = Transform01(input_02, &input_02);
EXPECT_FALSE(s01.IsOk());
// Check win_length
MS_LOG(INFO) << "win_length is zero.";
std::shared_ptr<TensorTransform> detect_pitch_frequency_02 =
std::make_shared<audio::DetectPitchFrequency>(40, 0.1, 0, 3, 20);
mindspore::dataset::Execute Transform02({detect_pitch_frequency_02});
Status s02 = Transform02(input_02, &input_02);
EXPECT_FALSE(s02.IsOk());
// Check freq_low
MS_LOG(INFO) << "freq_low is zero.";
std::shared_ptr<TensorTransform> detect_pitch_frequency_03 =
std::make_shared<audio::DetectPitchFrequency>(40, 0.1, 3, 0, 20);
mindspore::dataset::Execute Transform03({detect_pitch_frequency_03});
Status s03 = Transform03(input_02, &input_02);
EXPECT_FALSE(s03.IsOk());
// Check freq_high
MS_LOG(INFO) << "freq_high is zero.";
std::shared_ptr<TensorTransform> detect_pitch_frequency_04 =
std::make_shared<audio::DetectPitchFrequency>(40, 0.1, 3, 3, 0);
mindspore::dataset::Execute Transform04({detect_pitch_frequency_04});
Status s04 = Transform04(input_02, &input_02);
EXPECT_FALSE(s04.IsOk());
// Check sample_rate
MS_LOG(INFO) << "sample_rate is zero.";
std::shared_ptr<TensorTransform> detect_pitch_frequency_05 = std::make_shared<audio::DetectPitchFrequency>(0);
mindspore::dataset::Execute Transform05({detect_pitch_frequency_05});
Status s05 = Transform05(input_02, &input_02);
EXPECT_FALSE(s05.IsOk());
}

View File

@ -0,0 +1,114 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def test_detect_pitch_frequency_eager():
""" mindspore eager mode normal testcase:detect_pitch_frequency op"""
# Original waveform
waveform = np.array([[2.716064453125e-03, 6.34765625e-03, 9.246826171875e-03, 1.0894775390625e-02,
1.1383056640625e-02, 1.1566162109375e-02, 1.3946533203125e-02, 1.55029296875e-02,
1.6143798828125e-02, 1.8402099609375e-02],
[1.7181396484375e-02, 1.59912109375e-02, 1.64794921875e-02, 1.5106201171875e-02,
1.385498046875e-02, 1.3458251953125e-02, 1.4190673828125e-02, 1.2847900390625e-02,
1.0528564453125e-02, 9.368896484375e-03]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array(
[[10., 10., 10.], [5., 5., 10.]], dtype=np.float64)
detect_pitch_frequency_op = audio.DetectPitchFrequency(30, 0.1, 3, 5, 25)
# Detect pitch frequence
output = detect_pitch_frequency_op(waveform)
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_detect_pitch_frequency_pipeline():
""" mindspore pipeline mode normal testcase:detect_pitch_frequency op"""
# Original waveform
waveform = np.array([[0.716064453125e-03, 5.34765625e-03, 6.246826171875e-03, 2.0894775390625e-02,
7.1383056640625e-02], [4.1566162109375e-02, 1.3946533203125e-02, 3.55029296875e-02,
0.6143798828125e-02, 3.8402099609375e-02]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[10.0000], [7.5000]], dtype=np.float64)
dataset = ds.NumpySlicesDataset(waveform, ["audio"], shuffle=False)
detect_pitch_frequency_op = audio.DetectPitchFrequency(30, 0.1, 3, 5, 25)
# Detect pitch frequence
dataset = dataset.map(input_columns=["audio"],
operations=detect_pitch_frequency_op, num_parallel_workers=8)
i = 0
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
count_unequal_element(expect_waveform[i, :],
item['audio'], 0.0001, 0.0001)
i += 1
def test_detect_pitch_frequency_invalid_input():
def test_invalid_input(test_name, sample_rate, frame_time, win_length, freq_low, freq_high, error, error_msg):
logger.info(
"Test DetectPitchFrequency with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
audio.DetectPitchFrequency(
sample_rate, frame_time, win_length, freq_low, freq_high)
assert error_msg in str(error_info.value)
test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 0.01, 30, 85, 3400, TypeError,
"Argument sample_rate with value 44100.5 is not of type [<class 'int'>],"
" but got <class 'float'>.")
test_invalid_input("invalid sample_rate parameter type as a String", "44100", 0.01, 30, 85, 3400, TypeError,
"Argument sample_rate with value 44100 is not of type [<class 'int'>], but got <class 'str'>.")
test_invalid_input("invalid frame_time parameter type as a String", 44100, "0.01", 30, 85, 3400, TypeError,
"Argument frame_time with value 0.01 is not of type [<class 'float'>, <class 'int'>],"
" but got <class 'str'>.")
test_invalid_input("invalid win_length parameter type as a float", 44100, 0.01, 30.1, 85, 3400, TypeError,
"Argument win_length with value 30.1 is not of type [<class 'int'>], but got <class 'float'>.")
test_invalid_input("invalid win_length parameter type as a String", 44100, 0.01, "30", 85, 3400, TypeError,
"Argument win_length with value 30 is not of type [<class 'int'>], but got <class 'str'>.")
test_invalid_input("invalid freq_low parameter type as a String", 44100, 0.01, 30, "85", 3400, TypeError,
"Argument freq_low with value 85 is not of type [<class 'int'>, <class 'float'>],"
" but got <class 'str'>.")
test_invalid_input("invalid freq_high parameter type as a String", 44100, 0.01, 30, 85, "3400", TypeError,
"Argument freq_high with value 3400 is not of type [<class 'int'>, <class 'float'>],"
" but got <class 'str'>.")
test_invalid_input("invalid sample_rate parameter value", 0, 0.01, 30, 85, 3400, ValueError,
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
test_invalid_input("invalid frame_time parameter value", 44100, 0, 30, 85, 3400, ValueError,
"Input frame_time is not within the required interval of (0, 16777216].")
test_invalid_input("invalid win_length parameter value", 44100, 0.01, 0, 85, 3400, ValueError,
"Input win_length is not within the required interval of [1, 2147483647].")
test_invalid_input("invalid freq_low parameter value", 44100, 0.01, 30, 0, 3400, ValueError,
"Input freq_low is not within the required interval of (0, 16777216].")
test_invalid_input("invalid freq_high parameter value", 44100, 0.01, 30, 85, 0, ValueError,
"Input freq_high is not within the required interval of (0, 16777216].")
if __name__ == "__main__":
test_detect_pitch_frequency_eager()
test_detect_pitch_frequency_pipeline()
test_detect_pitch_frequency_invalid_input()