forked from mindspore-Ecosystem/mindspore
[feat] [assistant] [I3J6U8] Add a new audio operator DetectPitchFrequency
This commit is contained in:
parent
6d1a73b4c1
commit
b7e6520e62
|
@ -28,6 +28,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
|
@ -218,6 +219,30 @@ std::shared_ptr<TensorOperation> DeemphBiquad::Parse() {
|
|||
return std::make_shared<DeemphBiquadOperation>(data_->sample_rate_);
|
||||
}
|
||||
|
||||
// DetectPitchFrequency Transform Operation.
|
||||
struct DetectPitchFrequency::Data {
|
||||
Data(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high)
|
||||
: sample_rate_(sample_rate),
|
||||
frame_time_(frame_time),
|
||||
win_length_(win_length),
|
||||
freq_low_(freq_low),
|
||||
freq_high_(freq_high) {}
|
||||
int32_t sample_rate_;
|
||||
float frame_time_;
|
||||
int32_t win_length_;
|
||||
int32_t freq_low_;
|
||||
int32_t freq_high_;
|
||||
};
|
||||
|
||||
DetectPitchFrequency::DetectPitchFrequency(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low,
|
||||
int32_t freq_high)
|
||||
: data_(std::make_shared<Data>(sample_rate, frame_time, win_length, freq_low, freq_high)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> DetectPitchFrequency::Parse() {
|
||||
return std::make_shared<DetectPitchFrequencyOperation>(data_->sample_rate_, data_->frame_time_, data_->win_length_,
|
||||
data_->freq_low_, data_->freq_high_);
|
||||
}
|
||||
|
||||
// EqualizerBiquad Transform Operation.
|
||||
struct EqualizerBiquad::Data {
|
||||
Data(int32_t sample_rate, float center_freq, float gain, float Q)
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
|
@ -186,6 +187,19 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DetectPitchFrequencyOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::DetectPitchFrequencyOperation, TensorOperation,
|
||||
std::shared_ptr<audio::DetectPitchFrequencyOperation>>(
|
||||
*m, "DetectPitchFrequencyOperation")
|
||||
.def(py::init([](int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low,
|
||||
int32_t freq_high) {
|
||||
auto detect_pitch_frequency = std::make_shared<audio::DetectPitchFrequencyOperation>(
|
||||
sample_rate, frame_time, win_length, freq_low, freq_high);
|
||||
THROW_IF_ERROR(detect_pitch_frequency->ValidateParams());
|
||||
return detect_pitch_frequency;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::EqualizerBiquadOperation, TensorOperation,
|
||||
std::shared_ptr<audio::EqualizerBiquadOperation>>(*m, "EqualizerBiquadOperation")
|
||||
|
|
|
@ -14,6 +14,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
contrast_ir.cc
|
||||
dc_shift_ir.cc
|
||||
deemph_biquad_ir.cc
|
||||
detect_pitch_frequency_ir.cc
|
||||
equalizer_biquad_ir.cc
|
||||
fade_ir.cc
|
||||
frequency_masking_ir.cc
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/detect_pitch_frequency_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
// DetectPitchFrequencyOperation
|
||||
DetectPitchFrequencyOperation::DetectPitchFrequencyOperation(int32_t sample_rate, float frame_time, int32_t win_length,
|
||||
int32_t freq_low, int32_t freq_high)
|
||||
: sample_rate_(sample_rate),
|
||||
frame_time_(frame_time),
|
||||
win_length_(win_length),
|
||||
freq_low_(freq_low),
|
||||
freq_high_(freq_high) {}
|
||||
|
||||
Status DetectPitchFrequencyOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateScalarNotZero("DetectPitchFrequency", "sample_rate", sample_rate_));
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("DetectPitchFrequency", "frame_time", frame_time_));
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("DetectPitchFrequency", "win_length", win_length_));
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("DetectPitchFrequency", "freq_low", freq_low_));
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("DetectPitchFrequency", "freq_high", freq_high_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> DetectPitchFrequencyOperation::Build() {
|
||||
std::shared_ptr<DetectPitchFrequencyOp> tensor_op =
|
||||
std::make_shared<DetectPitchFrequencyOp>(sample_rate_, frame_time_, win_length_, freq_low_, freq_high_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status DetectPitchFrequencyOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sample_rate"] = sample_rate_;
|
||||
args["frame_time"] = frame_time_;
|
||||
args["win_length"] = win_length_;
|
||||
args["freq_low"] = freq_low_;
|
||||
args["freq_high"] = freq_high_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DETECT_PITCH_FREQUENCY_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DETECT_PITCH_FREQUENCY_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kDetectPitchFrequencyOperation[] = "DetectPitchFrequency";
|
||||
|
||||
class DetectPitchFrequencyOperation : public TensorOperation {
|
||||
public:
|
||||
DetectPitchFrequencyOperation(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low,
|
||||
int32_t freq_high);
|
||||
|
||||
~DetectPitchFrequencyOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kDetectPitchFrequencyOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t sample_rate_;
|
||||
float frame_time_;
|
||||
int32_t win_length_;
|
||||
int32_t freq_low_;
|
||||
int32_t freq_high_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DETECT_PITCH_FREQUENCY_IR_H_
|
|
@ -15,6 +15,7 @@ add_library(audio-kernels OBJECT
|
|||
contrast_op.cc
|
||||
dc_shift_op.cc
|
||||
deemph_biquad_op.cc
|
||||
detect_pitch_frequency_op.cc
|
||||
equalizer_biquad_op.cc
|
||||
fade_op.cc
|
||||
frequency_masking_op.cc
|
||||
|
|
|
@ -650,5 +650,79 @@ Status Magphase(const TensorRow &input, TensorRow *output, float power) {
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MedianSmoothing(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length) {
|
||||
auto channel = input->shape()[0];
|
||||
auto num_of_frames = input->shape()[1];
|
||||
// Centered windowed
|
||||
int32_t pad_length = (win_length - 1) / 2;
|
||||
int32_t out_length = num_of_frames + pad_length - win_length + 1;
|
||||
TensorShape out_shape({channel, out_length});
|
||||
std::vector<int> signal;
|
||||
std::vector<int> out;
|
||||
std::vector<int> indices(channel * (num_of_frames + pad_length), 0);
|
||||
// "replicate" padding in any dimension
|
||||
for (auto itr = input->begin<int>(); itr != input->end<int>(); ++itr) {
|
||||
signal.push_back(*itr);
|
||||
}
|
||||
for (int i = 0; i < channel; ++i) {
|
||||
for (int j = 0; j < pad_length; ++j) {
|
||||
indices[i * (num_of_frames + pad_length) + j] = signal[i * num_of_frames];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < channel; ++i) {
|
||||
for (int j = 0; j < num_of_frames; ++j) {
|
||||
indices[i * (num_of_frames + pad_length) + j + pad_length] = signal[i * num_of_frames + j];
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < channel; ++i) {
|
||||
int32_t index = i * (num_of_frames + pad_length);
|
||||
for (int j = 0; j < out_length; ++j) {
|
||||
std::vector<int> tem(indices.begin() + index, indices.begin() + win_length + index);
|
||||
std::sort(tem.begin(), tem.end());
|
||||
out.push_back(tem[pad_length]);
|
||||
++index;
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out, out_shape, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DetectPitchFrequency(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
|
||||
float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high) {
|
||||
std::shared_ptr<Tensor> nccf;
|
||||
std::shared_ptr<Tensor> indices;
|
||||
std::shared_ptr<Tensor> smooth_indices;
|
||||
// pack batch
|
||||
TensorShape input_shape = input->shape();
|
||||
TensorShape to_shape({input->Size() / input_shape[-1], input_shape[-1]});
|
||||
RETURN_IF_NOT_OK(input->Reshape(to_shape));
|
||||
if (input->type() == DataType(DataType::DE_FLOAT32)) {
|
||||
RETURN_IF_NOT_OK(ComputeNccf<float>(input, &nccf, sample_rate, frame_time, freq_low));
|
||||
RETURN_IF_NOT_OK(FindMaxPerFrame<float>(nccf, &indices, sample_rate, freq_high));
|
||||
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
RETURN_IF_NOT_OK(ComputeNccf<double>(input, &nccf, sample_rate, frame_time, freq_low));
|
||||
RETURN_IF_NOT_OK(FindMaxPerFrame<double>(nccf, &indices, sample_rate, freq_high));
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(ComputeNccf<float16>(input, &nccf, sample_rate, frame_time, freq_low));
|
||||
RETURN_IF_NOT_OK(FindMaxPerFrame<float16>(nccf, &indices, sample_rate, freq_high));
|
||||
}
|
||||
RETURN_IF_NOT_OK(MedianSmoothing(indices, &smooth_indices, win_length));
|
||||
|
||||
// Convert indices to frequency
|
||||
constexpr double EPSILON = 1e-9;
|
||||
TensorShape freq_shape = smooth_indices->shape();
|
||||
std::vector<float> out;
|
||||
for (auto itr_fre = smooth_indices->begin<int>(); itr_fre != smooth_indices->end<int>(); ++itr_fre) {
|
||||
out.push_back(sample_rate / (EPSILON + *itr_fre));
|
||||
}
|
||||
|
||||
// unpack batch
|
||||
auto shape_vec = input_shape.AsVector();
|
||||
shape_vec[shape_vec.size() - 1] = freq_shape[-1];
|
||||
TensorShape out_shape(shape_vec);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out, out_shape, output));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -372,6 +372,191 @@ Status Vol(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
|
|||
/// \param power: Power of the norm.
|
||||
Status Magphase(const TensorRow &input, TensorRow *output, float power);
|
||||
|
||||
/// \brief Compute Normalized Cross-Correlation Function (NCCF).
|
||||
/// \param input: Tensor of shape <channel,waveform_length>.
|
||||
/// \param output: Tensor of shape <channel, num_of_frames, lags>.
|
||||
/// \param sample_rate: The sample rate of the waveform (Hz).
|
||||
/// \param frame_time: Duration of a frame.
|
||||
/// \param freq_low: Lowest frequency that can be detected (Hz).
|
||||
/// \return Status code.
|
||||
template <typename T>
|
||||
Status ComputeNccf(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
|
||||
float frame_time, int32_t freq_low) {
|
||||
auto channel = input->shape()[0];
|
||||
auto waveform_length = input->shape()[1];
|
||||
size_t idx = 0;
|
||||
size_t channel_idx = 1;
|
||||
int32_t lags = static_cast<int32_t>(ceil(static_cast<float>(sample_rate) / freq_low));
|
||||
int32_t frame_size = static_cast<int32_t>(ceil(sample_rate * frame_time));
|
||||
int32_t num_of_frames = static_cast<int32_t>(ceil(static_cast<float>(waveform_length) / frame_size));
|
||||
int32_t p = lags + num_of_frames * frame_size - waveform_length;
|
||||
TensorShape output_shape({channel, num_of_frames, lags});
|
||||
DataType intput_type = input->type();
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(output_shape, intput_type, output));
|
||||
// pad p 0 in -1 dimension
|
||||
std::vector<T> signal;
|
||||
// Tensor -> vector
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>();) {
|
||||
while (idx < waveform_length * channel_idx) {
|
||||
signal.push_back(*itr);
|
||||
++itr;
|
||||
++idx;
|
||||
}
|
||||
// Each channel is processed with the sliding window
|
||||
// waveform:[channel, time] --> waveform:[channel, time+p]
|
||||
for (size_t i = 0; i < p; ++i) {
|
||||
signal.push_back(static_cast<T>(0.0));
|
||||
}
|
||||
if (idx % waveform_length == 0) {
|
||||
++channel_idx;
|
||||
}
|
||||
}
|
||||
// compute ncc
|
||||
for (dsize_t lag = 1; lag <= lags; ++lag) {
|
||||
// compute one ncc
|
||||
// one ncc out
|
||||
std::vector<T> out;
|
||||
channel_idx = 1;
|
||||
idx = 0;
|
||||
size_t win_idx = 0;
|
||||
size_t waveform_length_p = waveform_length + p;
|
||||
// Traversal signal
|
||||
for (auto itr = signal.begin(); itr != signal.end();) {
|
||||
// Each channel is processed with the sliding window
|
||||
size_t s1 = idx;
|
||||
size_t s2 = idx + lag;
|
||||
size_t frame_count = 0;
|
||||
T s1_norm = static_cast<T>(0);
|
||||
T s2_norm = static_cast<T>(0);
|
||||
T ncc_umerator = static_cast<T>(0);
|
||||
T ncc = static_cast<T>(0);
|
||||
while (idx < waveform_length_p * channel_idx) {
|
||||
// Sliding window
|
||||
if (frame_count == num_of_frames) {
|
||||
++itr;
|
||||
++idx;
|
||||
continue;
|
||||
}
|
||||
if (win_idx < frame_size) {
|
||||
ncc_umerator += signal[s1] * signal[s2];
|
||||
s1_norm += signal[s1] * signal[s1];
|
||||
s2_norm += signal[s2] * signal[s2];
|
||||
++win_idx;
|
||||
++s1;
|
||||
++s2;
|
||||
}
|
||||
if (win_idx == frame_size) {
|
||||
if (s1_norm != static_cast<T>(0.0) && s2_norm != static_cast<T>(0.0)) {
|
||||
ncc = ncc_umerator / s1_norm / s2_norm;
|
||||
} else {
|
||||
ncc = static_cast<T>(0.0);
|
||||
}
|
||||
out.push_back(ncc);
|
||||
ncc_umerator = static_cast<T>(0.0);
|
||||
s1_norm = static_cast<T>(0.0);
|
||||
s2_norm = static_cast<T>(0.0);
|
||||
++frame_count;
|
||||
win_idx = 0;
|
||||
}
|
||||
++itr;
|
||||
++idx;
|
||||
}
|
||||
if (idx % waveform_length_p == 0) {
|
||||
++channel_idx;
|
||||
}
|
||||
} // compute one ncc
|
||||
// cat tensor
|
||||
auto itr_out = out.begin();
|
||||
for (dsize_t row_idx = 0; row_idx < channel; ++row_idx) {
|
||||
for (dsize_t frame_idx = 0; frame_idx < num_of_frames; ++frame_idx) {
|
||||
RETURN_IF_NOT_OK((*output)->SetItemAt({row_idx, frame_idx, lag - 1}, *itr_out));
|
||||
++itr_out;
|
||||
}
|
||||
}
|
||||
} // compute ncc
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief For each frame, take the highest value of NCCF.
|
||||
/// \param input: Tensor of shape <channel, num_of_frames, lags>.
|
||||
/// \param output: Tensor of shape <channel, num_of_frames>.
|
||||
/// \param sample_rate: The sample rate of the waveform (Hz).
|
||||
/// \param freq_high: Highest frequency that can be detected (Hz).
|
||||
/// \return Status code.
|
||||
template <typename T>
|
||||
Status FindMaxPerFrame(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
|
||||
int32_t freq_high) {
|
||||
std::vector<T> signal;
|
||||
std::vector<int> out;
|
||||
auto channel = input->shape()[0];
|
||||
auto num_of_frames = input->shape()[1];
|
||||
auto lags = input->shape()[2];
|
||||
int32_t lag_min = static_cast<int32_t>(ceil(static_cast<float>(sample_rate) / freq_high));
|
||||
TensorShape out_shape({channel, num_of_frames});
|
||||
// pack batch
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); ++itr) {
|
||||
signal.push_back(*itr);
|
||||
}
|
||||
// find the best nccf
|
||||
T best_max_value = static_cast<T>(0.0);
|
||||
T half_max_value = static_cast<T>(0.0);
|
||||
int32_t best_max_indices = 0;
|
||||
int32_t half_max_indices = 0;
|
||||
auto thresh = static_cast<T>(0.99);
|
||||
auto lags_half = lags / 2;
|
||||
for (dsize_t channel_idx = 0; channel_idx < channel; ++channel_idx) {
|
||||
for (dsize_t frame_idx = 0; frame_idx < num_of_frames; ++frame_idx) {
|
||||
auto index_01 = channel_idx * num_of_frames * lags + frame_idx * lags + lag_min;
|
||||
best_max_value = signal[index_01];
|
||||
half_max_value = signal[index_01];
|
||||
best_max_indices = lag_min;
|
||||
half_max_indices = lag_min;
|
||||
for (dsize_t lag_idx = 0; lag_idx < lags; ++lag_idx) {
|
||||
if (lag_idx > lag_min) {
|
||||
auto index_02 = channel_idx * num_of_frames * lags + frame_idx * lags + lag_idx;
|
||||
if (signal[index_02] > best_max_value) {
|
||||
best_max_value = signal[index_02];
|
||||
best_max_indices = lag_idx;
|
||||
if (lag_idx < lags_half) {
|
||||
half_max_value = signal[index_02];
|
||||
half_max_indices = lag_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add back minimal lag
|
||||
// Add 1 empirical calibration offset
|
||||
if (half_max_value > best_max_value * thresh) {
|
||||
out.push_back(half_max_indices + 1);
|
||||
} else {
|
||||
out.push_back(best_max_indices + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
// unpack batch
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out, out_shape, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Apply median smoothing to the 1D tensor over the given window.
|
||||
/// \param input: Tensor of shape<channel, num_of_frames>.
|
||||
/// \param output: Tensor of shape <channel, num_of_window>.
|
||||
/// \param win_length: The window length for median smoothing (in number of frames).
|
||||
/// \return Status code.
|
||||
Status MedianSmoothing(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length);
|
||||
|
||||
/// \brief Detect pitch frequency.
|
||||
/// \param input: Tensor of shape <channel,waveform_length>.
|
||||
/// \param output: Tensor of shape <channel, num_of_frames, lags>.
|
||||
/// \param sample_rate: The sample rate of the waveform (Hz).
|
||||
/// \param frame_time: Duration of a frame.
|
||||
/// \param win_length: The window length for median smoothing (in number of frames).
|
||||
/// \param freq_low: Lowest frequency that can be detected (Hz).
|
||||
/// \param freq_high: Highest frequency that can be detected (Hz).
|
||||
/// \return Status code.
|
||||
Status DetectPitchFrequency(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate,
|
||||
float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/detect_pitch_frequency_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status DetectPitchFrequencyOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
TensorShape input_shape = input->shape();
|
||||
// check input tensor dimension, it should be greater than 0.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0,
|
||||
"DetectPitchFrequency: input tensor is not in shape of <..., time>.");
|
||||
// check input type, it should be DE_FLOAT16, DE_FLOAT32 or DE_FLOAT64
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type() == DataType(DataType::DE_FLOAT32) || input->type() == DataType(DataType::DE_FLOAT64) ||
|
||||
input->type() == DataType(DataType::DE_FLOAT16),
|
||||
"DetectPitchFrequency: input tensor type should be float or double, but got: " + input->type().ToString());
|
||||
return DetectPitchFrequency(input, output, sample_rate_, frame_time_, win_length_, freq_low_, freq_high_);
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DETECT_PITCH_FREQUENCY_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DETECT_PITCH_FREQUENCY_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DetectPitchFrequencyOp : public TensorOp {
|
||||
public:
|
||||
DetectPitchFrequencyOp(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low, int32_t freq_high)
|
||||
: sample_rate_(sample_rate),
|
||||
frame_time_(frame_time),
|
||||
win_length_(win_length),
|
||||
freq_low_(freq_low),
|
||||
freq_high_(freq_high) {}
|
||||
|
||||
~DetectPitchFrequencyOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
out << Name() << ": sample_rate: " << sample_rate_ << ", frame_time: " << frame_time_
|
||||
<< ", win_length: " << win_length_ << ", freq_low: " << freq_low_ << ", freq_high: " << freq_high_ << std::endl;
|
||||
}
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kDetectPitchFrequencyOp; }
|
||||
|
||||
private:
|
||||
int32_t sample_rate_;
|
||||
float frame_time_;
|
||||
int32_t win_length_;
|
||||
int32_t freq_low_;
|
||||
int32_t freq_high_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DETECT_PITCH_FREQUENCY_OP_H_
|
|
@ -300,6 +300,33 @@ class DeemphBiquad final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Detect pitch frequency.
|
||||
class DetectPitchFrequency final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
/// \param[in] frame_time Duration of a frame, the value must be greater than zero (default=0.02).
|
||||
/// \param[in] win_length The window length for median smoothing (in number of frames), the value must
|
||||
/// be greater than zero (default=30).
|
||||
/// \param[in] freq_low Lowest frequency that can be detected (Hz), the value must be greater than zero (default=85).
|
||||
/// \param[in] freq_high Highest frequency that can be detected (Hz), the value must be greater than
|
||||
/// zero (default=3400).
|
||||
explicit DetectPitchFrequency(int32_t sample_rate, float frame_time = 0.01, int32_t win_length = 30,
|
||||
int32_t freq_low = 85, int32_t freq_high = 3400);
|
||||
|
||||
/// \brief Destructor.
|
||||
~DetectPitchFrequency() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio.
|
||||
class EqualizerBiquad final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -155,6 +155,7 @@ constexpr char kComplexNormOp[] = "ComplexNormOp";
|
|||
constexpr char kContrastOp[] = "ContrastOp";
|
||||
constexpr char kDCShiftOp[] = "DCShiftOp";
|
||||
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
|
||||
constexpr char kDetectPitchFrequencyOp[] = "DetectPitchFrequencyOp";
|
||||
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
|
||||
constexpr char kFadeOp[] = "FadeOp";
|
||||
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
|
||||
|
|
|
@ -26,8 +26,8 @@ from ..transforms.c_transforms import TensorOperation
|
|||
from .utils import FadeShape, GainType, ScaleType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \
|
||||
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_riaa_biquad, \
|
||||
check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_highpass_biquad, \
|
||||
check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_riaa_biquad, \
|
||||
check_time_stretch, check_treble_biquad, check_vol
|
||||
|
||||
|
||||
|
@ -379,6 +379,45 @@ class DeemphBiquad(AudioTensorOperation):
|
|||
return cde.DeemphBiquadOperation(self.sample_rate)
|
||||
|
||||
|
||||
class DetectPitchFrequency(AudioTensorOperation):
|
||||
"""
|
||||
Detect pitch frequency.
|
||||
|
||||
It is implemented using normalized cross-correlation function and median smoothing.
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
frame_time (float, optional): Duration of a frame, the value must be greater than zero (default=0.01).
|
||||
win_length (int, optional): The window length for median smoothing (in number of frames), the value must be
|
||||
greater than zero (default=30).
|
||||
freq_low (int, optional): Lowest frequency that can be detected (Hz), the value must be greater than zero
|
||||
(default=85).
|
||||
freq_high (int, optional): Highest frequency that can be detected (Hz), the value must be greater than zero
|
||||
(default=3400).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[0.716064e-03, 5.347656e-03, 6.246826e-03, 2.089477e-02, 7.138305e-02],
|
||||
... [4.156616e-02, 1.394653e-02, 3.550292e-02, 0.614379e-02, 3.840209e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.DetectPitchFrequency(30, 0.1, 3, 5, 25)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_detect_pitch_frequency
|
||||
def __init__(self, sample_rate, frame_time=0.01, win_length=30, freq_low=85, freq_high=3400):
|
||||
self.sample_rate = sample_rate
|
||||
self.frame_time = frame_time
|
||||
self.win_length = win_length
|
||||
self.freq_low = freq_low
|
||||
self.freq_high = freq_high
|
||||
|
||||
def parse(self):
|
||||
return cde.DetectPitchFrequencyOperation(self.sample_rate, self.frame_time,
|
||||
self.win_length, self.freq_low, self.freq_high)
|
||||
|
||||
|
||||
class EqualizerBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design biquad equalizer filter and perform filtering. Similar to SoX implementation.
|
||||
|
|
|
@ -453,3 +453,25 @@ def check_vol(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_detect_pitch_frequency(method):
|
||||
"""Wrapper method to check the parameters of DetectPitchFrequency."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[sample_rate, frame_time, win_length, freq_low, freq_high], _ = parse_user_args(
|
||||
method, *args, **kwargs)
|
||||
type_check(sample_rate, (int,), "sample_rate")
|
||||
check_int32_not_zero(sample_rate, "sample_rate")
|
||||
type_check(frame_time, (float, int), "frame_time")
|
||||
check_pos_float32(frame_time, "frame_time")
|
||||
type_check(win_length, (int,), "win_length")
|
||||
check_pos_int32(win_length, "win_length")
|
||||
type_check(freq_low, (int, float), "freq_low")
|
||||
check_pos_float32(freq_low, "freq_low")
|
||||
type_check(freq_high, (int, float), "freq_high")
|
||||
check_pos_float32(freq_high, "freq_high")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -1375,7 +1375,7 @@ TEST_F(MindDataTestPipeline, TestMagphaseWrongArgs) {
|
|||
std::shared_ptr<TensorTransform> magphase(new audio::Magphase(power_wrong));
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
|
||||
//Magphase: power must be greater than or equal to 0.
|
||||
// Magphase: power must be greater than or equal to 0.
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||
|
@ -1385,3 +1385,100 @@ TEST_F(MindDataTestPipeline, TestMagphaseWrongArgs) {
|
|||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDetectPitchFrequencyBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDetectPitchFrequencyBasic.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 10000}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto DetectPitchFrequencyOp = audio::DetectPitchFrequency(44100);
|
||||
|
||||
ds = ds->Map({DetectPitchFrequencyOp});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Detect pitch frequency
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 8};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDetectPitchFrequencyParamCheck) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDetectPitchFrequencyParamCheck.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
// Original waveform
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
std::shared_ptr<Dataset> ds01;
|
||||
std::shared_ptr<Dataset> ds02;
|
||||
std::shared_ptr<Dataset> ds03;
|
||||
std::shared_ptr<Dataset> ds04;
|
||||
std::shared_ptr<Dataset> ds05;
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Check sample_rate
|
||||
MS_LOG(INFO) << "sample_rate is zero.";
|
||||
auto detect_pitch_frequency_op_01 = audio::DetectPitchFrequency(0);
|
||||
ds01 = ds->Map({detect_pitch_frequency_op_01});
|
||||
EXPECT_NE(ds01, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
|
||||
EXPECT_EQ(iter01, nullptr);
|
||||
|
||||
// Check frame_time
|
||||
MS_LOG(INFO) << "frame_time is zero.";
|
||||
auto detect_pitch_frequency_op_02 = audio::DetectPitchFrequency(30, 0);
|
||||
ds02 = ds->Map({detect_pitch_frequency_op_02});
|
||||
EXPECT_NE(ds02, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
|
||||
EXPECT_EQ(iter02, nullptr);
|
||||
|
||||
// Check win_length
|
||||
MS_LOG(INFO) << "win_length is zero.";
|
||||
auto detect_pitch_frequency_op_03 = audio::DetectPitchFrequency(30, 0.1, 0);
|
||||
ds03 = ds->Map({detect_pitch_frequency_op_03});
|
||||
EXPECT_NE(ds03, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter03 = ds03->CreateIterator();
|
||||
EXPECT_EQ(iter03, nullptr);
|
||||
|
||||
// Check freq_low
|
||||
MS_LOG(INFO) << "freq_low is zero.";
|
||||
auto detect_pitch_frequency_op_04 = audio::DetectPitchFrequency(30, 0.1, 3, 0);
|
||||
ds04 = ds->Map({detect_pitch_frequency_op_04});
|
||||
EXPECT_NE(ds04, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter04 = ds04->CreateIterator();
|
||||
EXPECT_EQ(iter04, nullptr);
|
||||
|
||||
// Check freq_high
|
||||
MS_LOG(INFO) << "freq_high is zero.";
|
||||
auto detect_pitch_frequency_op_05 = audio::DetectPitchFrequency(30, 0.1, 3, 5, 0);
|
||||
ds05 = ds->Map({detect_pitch_frequency_op_05});
|
||||
EXPECT_NE(ds05, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter05 = ds05->CreateIterator();
|
||||
EXPECT_EQ(iter05, nullptr);
|
||||
}
|
||||
|
|
|
@ -897,12 +897,8 @@ TEST_F(MindDataTestExecute, TestRiaaBiquadWithEager) {
|
|||
|
||||
TEST_F(MindDataTestExecute, TestRiaaBiquadWithWrongArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRiaaBiquadWithWrongArg.";
|
||||
std::vector<float> labels = {
|
||||
3.156, 5.690, 1.362, 1.093,
|
||||
5.782, 6.381, 5.982, 3.098,
|
||||
1.222, 6.027, 3.909, 7.993,
|
||||
4.324, 1.092, 5.093, 0.991,
|
||||
1.099, 4.092, 8.111, 6.666};
|
||||
std::vector<float> labels = {3.156, 5.690, 1.362, 1.093, 5.782, 6.381, 5.982, 3.098, 1.222, 6.027,
|
||||
3.909, 7.993, 4.324, 1.092, 5.093, 0.991, 1.099, 4.092, 8.111, 6.666};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input));
|
||||
auto input01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
|
@ -917,12 +913,8 @@ TEST_F(MindDataTestExecute, TestRiaaBiquadWithWrongArg) {
|
|||
TEST_F(MindDataTestExecute, TestTrebleBiquadWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTrebleBiquadWithEager.";
|
||||
// Original waveform
|
||||
std::vector<float> labels = {
|
||||
3.156, 5.690, 1.362, 1.093,
|
||||
5.782, 6.381, 5.982, 3.098,
|
||||
1.222, 6.027, 3.909, 7.993,
|
||||
4.324, 1.092, 5.093, 0.991,
|
||||
1.099, 4.092, 8.111, 6.666};
|
||||
std::vector<float> labels = {3.156, 5.690, 1.362, 1.093, 5.782, 6.381, 5.982, 3.098, 1.222, 6.027,
|
||||
3.909, 7.993, 4.324, 1.092, 5.093, 0.991, 1.099, 4.092, 8.111, 6.666};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
|
@ -949,9 +941,10 @@ TEST_F(MindDataTestExecute, TestTrebleBiquadWithWrongArg) {
|
|||
std::shared_ptr<TensorTransform> treble_biquad_op01 = std::make_shared<audio::TrebleBiquad>(0.0, 200.0);
|
||||
mindspore::dataset::Execute Transform01({treble_biquad_op01});
|
||||
EXPECT_ERROR(Transform01(input01, &input01));
|
||||
//Check Q
|
||||
// Check Q
|
||||
MS_LOG(INFO) << "Q is zero.";
|
||||
std::shared_ptr<TensorTransform> treble_biquad_op02 = std::make_shared<audio::TrebleBiquad>(44100, 200.0, 3000.0, 0.0);
|
||||
std::shared_ptr<TensorTransform> treble_biquad_op02 =
|
||||
std::make_shared<audio::TrebleBiquad>(44100, 200.0, 3000.0, 0.0);
|
||||
mindspore::dataset::Execute Transform02({treble_biquad_op02});
|
||||
EXPECT_ERROR(Transform02(input02, &input02));
|
||||
}
|
||||
|
@ -1169,8 +1162,7 @@ TEST_F(MindDataTestExecute, TestMagphaseEager) {
|
|||
float power = 1.0;
|
||||
std::vector<mindspore::MSTensor> output_tensor;
|
||||
std::shared_ptr<Tensor> test;
|
||||
std::vector<float> test_vector = {3, 4, -3, 4, 3, -4, -3, -4,
|
||||
5, 12, -5, 12, 5, -12, -5, -12};
|
||||
std::vector<float> test_vector = {3, 4, -3, 4, 3, -4, -3, -4, 5, 12, -5, 12, 5, -12, -5, -12};
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({2, 4, 2}), &test);
|
||||
auto input_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
|
||||
std::shared_ptr<TensorTransform> magphase(new audio::Magphase({power}));
|
||||
|
@ -1234,3 +1226,68 @@ TEST_F(MindDataTestExecute, TestRandomAdjustSharpnessEager) {
|
|||
Status rc = transform(image, &image);
|
||||
EXPECT_EQ(rc, Status::OK());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestDetectPitchFrequencyWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDetectPitchFrequencyWithEager.";
|
||||
// Original waveform
|
||||
std::vector<double> labels = {
|
||||
3.716064453125000000e-03, 2.347656250000000000e-03, 9.246826171875000000e-03, 4.089477539062500000e-02,
|
||||
3.138305664062500000e-02, 1.156616210937500000e-02, 0.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 0.840209960937500000e-02, 1.718139648437500000e-02, 2.599121093750000000e-02,
|
||||
5.647949218750000000e-02, 1.510620117187500000e-02, 2.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 3.284790039062500000e-02, 9.052856445312500000e-02, 2.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> detect_pitch_frequency_01 =
|
||||
std::make_shared<audio::DetectPitchFrequency>(30, 0.1, 3, 5, 25);
|
||||
mindspore::dataset::Execute Transform01({detect_pitch_frequency_01});
|
||||
// Detect pitch frequence
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestDetectPitchFrequencyWithWrongArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDetectPitchFrequencyWithWrongArg.";
|
||||
std::vector<float> labels = {
|
||||
0.716064e-03, 5.347656e-03, 6.246826e-03, 2.089477e-02, 7.138305e-02,
|
||||
4.156616e-02, 1.394653e-02, 3.550292e-02, 0.614379e-02, 3.840209e-02,
|
||||
};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 5}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
// Check frame_time
|
||||
MS_LOG(INFO) << "frame_time is zero.";
|
||||
std::shared_ptr<TensorTransform> detect_pitch_frequency_01 =
|
||||
std::make_shared<audio::DetectPitchFrequency>(40, 0, 3, 3, 20);
|
||||
mindspore::dataset::Execute Transform01({detect_pitch_frequency_01});
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
// Check win_length
|
||||
MS_LOG(INFO) << "win_length is zero.";
|
||||
std::shared_ptr<TensorTransform> detect_pitch_frequency_02 =
|
||||
std::make_shared<audio::DetectPitchFrequency>(40, 0.1, 0, 3, 20);
|
||||
mindspore::dataset::Execute Transform02({detect_pitch_frequency_02});
|
||||
Status s02 = Transform02(input_02, &input_02);
|
||||
EXPECT_FALSE(s02.IsOk());
|
||||
// Check freq_low
|
||||
MS_LOG(INFO) << "freq_low is zero.";
|
||||
std::shared_ptr<TensorTransform> detect_pitch_frequency_03 =
|
||||
std::make_shared<audio::DetectPitchFrequency>(40, 0.1, 3, 0, 20);
|
||||
mindspore::dataset::Execute Transform03({detect_pitch_frequency_03});
|
||||
Status s03 = Transform03(input_02, &input_02);
|
||||
EXPECT_FALSE(s03.IsOk());
|
||||
// Check freq_high
|
||||
MS_LOG(INFO) << "freq_high is zero.";
|
||||
std::shared_ptr<TensorTransform> detect_pitch_frequency_04 =
|
||||
std::make_shared<audio::DetectPitchFrequency>(40, 0.1, 3, 3, 0);
|
||||
mindspore::dataset::Execute Transform04({detect_pitch_frequency_04});
|
||||
Status s04 = Transform04(input_02, &input_02);
|
||||
EXPECT_FALSE(s04.IsOk());
|
||||
// Check sample_rate
|
||||
MS_LOG(INFO) << "sample_rate is zero.";
|
||||
std::shared_ptr<TensorTransform> detect_pitch_frequency_05 = std::make_shared<audio::DetectPitchFrequency>(0);
|
||||
mindspore::dataset::Execute Transform05({detect_pitch_frequency_05});
|
||||
Status s05 = Transform05(input_02, &input_02);
|
||||
EXPECT_FALSE(s05.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_detect_pitch_frequency_eager():
|
||||
""" mindspore eager mode normal testcase:detect_pitch_frequency op"""
|
||||
# Original waveform
|
||||
waveform = np.array([[2.716064453125e-03, 6.34765625e-03, 9.246826171875e-03, 1.0894775390625e-02,
|
||||
1.1383056640625e-02, 1.1566162109375e-02, 1.3946533203125e-02, 1.55029296875e-02,
|
||||
1.6143798828125e-02, 1.8402099609375e-02],
|
||||
[1.7181396484375e-02, 1.59912109375e-02, 1.64794921875e-02, 1.5106201171875e-02,
|
||||
1.385498046875e-02, 1.3458251953125e-02, 1.4190673828125e-02, 1.2847900390625e-02,
|
||||
1.0528564453125e-02, 9.368896484375e-03]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array(
|
||||
[[10., 10., 10.], [5., 5., 10.]], dtype=np.float64)
|
||||
detect_pitch_frequency_op = audio.DetectPitchFrequency(30, 0.1, 3, 5, 25)
|
||||
# Detect pitch frequence
|
||||
output = detect_pitch_frequency_op(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_detect_pitch_frequency_pipeline():
|
||||
""" mindspore pipeline mode normal testcase:detect_pitch_frequency op"""
|
||||
# Original waveform
|
||||
waveform = np.array([[0.716064453125e-03, 5.34765625e-03, 6.246826171875e-03, 2.0894775390625e-02,
|
||||
7.1383056640625e-02], [4.1566162109375e-02, 1.3946533203125e-02, 3.55029296875e-02,
|
||||
0.6143798828125e-02, 3.8402099609375e-02]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[10.0000], [7.5000]], dtype=np.float64)
|
||||
dataset = ds.NumpySlicesDataset(waveform, ["audio"], shuffle=False)
|
||||
detect_pitch_frequency_op = audio.DetectPitchFrequency(30, 0.1, 3, 5, 25)
|
||||
# Detect pitch frequence
|
||||
dataset = dataset.map(input_columns=["audio"],
|
||||
operations=detect_pitch_frequency_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['audio'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
||||
def test_detect_pitch_frequency_invalid_input():
|
||||
def test_invalid_input(test_name, sample_rate, frame_time, win_length, freq_low, freq_high, error, error_msg):
|
||||
logger.info(
|
||||
"Test DetectPitchFrequency with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.DetectPitchFrequency(
|
||||
sample_rate, frame_time, win_length, freq_low, freq_high)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 0.01, 30, 85, 3400, TypeError,
|
||||
"Argument sample_rate with value 44100.5 is not of type [<class 'int'>],"
|
||||
" but got <class 'float'>.")
|
||||
test_invalid_input("invalid sample_rate parameter type as a String", "44100", 0.01, 30, 85, 3400, TypeError,
|
||||
"Argument sample_rate with value 44100 is not of type [<class 'int'>], but got <class 'str'>.")
|
||||
test_invalid_input("invalid frame_time parameter type as a String", 44100, "0.01", 30, 85, 3400, TypeError,
|
||||
"Argument frame_time with value 0.01 is not of type [<class 'float'>, <class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid win_length parameter type as a float", 44100, 0.01, 30.1, 85, 3400, TypeError,
|
||||
"Argument win_length with value 30.1 is not of type [<class 'int'>], but got <class 'float'>.")
|
||||
test_invalid_input("invalid win_length parameter type as a String", 44100, 0.01, "30", 85, 3400, TypeError,
|
||||
"Argument win_length with value 30 is not of type [<class 'int'>], but got <class 'str'>.")
|
||||
test_invalid_input("invalid freq_low parameter type as a String", 44100, 0.01, 30, "85", 3400, TypeError,
|
||||
"Argument freq_low with value 85 is not of type [<class 'int'>, <class 'float'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid freq_high parameter type as a String", 44100, 0.01, 30, 85, "3400", TypeError,
|
||||
"Argument freq_high with value 3400 is not of type [<class 'int'>, <class 'float'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid sample_rate parameter value", 0, 0.01, 30, 85, 3400, ValueError,
|
||||
"Input sample_rate is not within the required interval of [-2147483648, 0) and (0, 2147483647].")
|
||||
test_invalid_input("invalid frame_time parameter value", 44100, 0, 30, 85, 3400, ValueError,
|
||||
"Input frame_time is not within the required interval of (0, 16777216].")
|
||||
test_invalid_input("invalid win_length parameter value", 44100, 0.01, 0, 85, 3400, ValueError,
|
||||
"Input win_length is not within the required interval of [1, 2147483647].")
|
||||
test_invalid_input("invalid freq_low parameter value", 44100, 0.01, 30, 0, 3400, ValueError,
|
||||
"Input freq_low is not within the required interval of (0, 16777216].")
|
||||
test_invalid_input("invalid freq_high parameter value", 44100, 0.01, 30, 85, 0, ValueError,
|
||||
"Input freq_high is not within the required interval of (0, 16777216].")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_detect_pitch_frequency_eager()
|
||||
test_detect_pitch_frequency_pipeline()
|
||||
test_detect_pitch_frequency_invalid_input()
|
Loading…
Reference in New Issue