forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3T96B]add new data operator LFilter
This commit is contained in:
parent
6b0e8fef6b
commit
53fc99b914
|
@ -29,6 +29,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
|
@ -230,6 +231,22 @@ std::shared_ptr<TensorOperation> HighpassBiquad::Parse() {
|
|||
return std::make_shared<HighpassBiquadOperation>(data_->sample_rate_, data_->cutoff_freq_, data_->Q_);
|
||||
}
|
||||
|
||||
// LFilter Transform Operation.
|
||||
struct LFilter::Data {
|
||||
Data(const std::vector<float> &a_coeffs, const std::vector<float> &b_coeffs, bool clamp)
|
||||
: a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {}
|
||||
std::vector<float> a_coeffs_;
|
||||
std::vector<float> b_coeffs_;
|
||||
bool clamp_;
|
||||
};
|
||||
|
||||
LFilter::LFilter(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp)
|
||||
: data_(std::make_shared<Data>(a_coeffs, b_coeffs, clamp)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> LFilter::Parse() {
|
||||
return std::make_shared<LFilterOperation>(data_->a_coeffs_, data_->b_coeffs_, data_->clamp_);
|
||||
}
|
||||
|
||||
// LowpassBiquad Transform Operation.
|
||||
struct LowpassBiquad::Data {
|
||||
Data(int32_t sample_rate, float cutoff_freq, float Q) : sample_rate_(sample_rate), cutoff_freq_(cutoff_freq), Q_(Q) {}
|
||||
|
@ -290,7 +307,6 @@ TimeStretch::TimeStretch(float hop_length, int n_freq, float fixed_rate)
|
|||
std::shared_ptr<TensorOperation> TimeStretch::Parse() {
|
||||
return std::make_shared<TimeStretchOperation>(data_->hop_length_, data_->n_freq_, data_->fixed_rate_);
|
||||
}
|
||||
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
|
@ -193,6 +194,16 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(LFilterOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::LFilterOperation, TensorOperation, std::shared_ptr<audio::LFilterOperation>>(
|
||||
*m, "LFilterOperation")
|
||||
.def(py::init([](std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp) {
|
||||
auto lfilter = std::make_shared<audio::LFilterOperation>(a_coeffs, b_coeffs, clamp);
|
||||
THROW_IF_ERROR(lfilter->ValidateParams());
|
||||
return lfilter;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
LowpassBiquadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::LowpassBiquadOperation, TensorOperation, std::shared_ptr<audio::LowpassBiquadOperation>>(
|
||||
|
|
|
@ -15,6 +15,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
equalizer_biquad_ir.cc
|
||||
frequency_masking_ir.cc
|
||||
highpass_biquad_ir.cc
|
||||
lfilter_ir.cc
|
||||
lowpass_biquad_ir.cc
|
||||
mu_law_decoding_ir.cc
|
||||
time_masking_ir.cc
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/lfilter_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace audio {
|
||||
// LFilterOperation
|
||||
LFilterOperation::LFilterOperation(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp)
|
||||
: a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {}
|
||||
|
||||
Status LFilterOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateVectorNotEmpty("lfilter", "a_coeffs", a_coeffs_));
|
||||
RETURN_IF_NOT_OK(ValidateVectorNotEmpty("lfilter", "b_coeffs", b_coeffs_));
|
||||
RETURN_IF_NOT_OK(ValidateVectorSameSize("lfilter", "a_coeffs", a_coeffs_, "b_coeffs", b_coeffs_));
|
||||
RETURN_IF_NOT_OK(ValidateScalarNotZero("lfilter", "a_coeffs[0]", a_coeffs_[0]));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> LFilterOperation::Build() {
|
||||
std::shared_ptr<LFilterOp> tensor_op = std::make_shared<LFilterOp>(a_coeffs_, b_coeffs_, clamp_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status LFilterOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["a_coeffs"] = a_coeffs_;
|
||||
args["b_coeffs"] = b_coeffs_;
|
||||
args["clamp"] = clamp_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,59 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace audio {
|
||||
// Char arrays storing name of corresponding classes (in alphabetical order)
|
||||
constexpr char kLFilterOperation[] = "LFilter";
|
||||
|
||||
class LFilterOperation : public TensorOperation {
|
||||
public:
|
||||
LFilterOperation(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp);
|
||||
|
||||
~LFilterOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kLFilterOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
std::vector<float> a_coeffs_;
|
||||
std::vector<float> b_coeffs_;
|
||||
bool clamp_;
|
||||
};
|
||||
} // namespace audio
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_LFILTER_IR_H_
|
|
@ -34,8 +34,8 @@ Status ValidateIntScalarNonNegative(const std::string &op_name, const std::strin
|
|||
// Helper function to non-nan float scalar
|
||||
Status ValidateFloatScalarNotNan(const std::string &op_name, const std::string &scalar_name, float scalar);
|
||||
|
||||
template <typename T>
|
||||
// Helper function to check scalar is not equal to zero
|
||||
template <typename T>
|
||||
Status ValidateScalarNotZero(const std::string &op_name, const std::string &scalar_name, const T scalar) {
|
||||
if (scalar == 0) {
|
||||
std::string err_msg = op_name + ": " + scalar_name + " can't be zero, got: " + std::to_string(scalar);
|
||||
|
@ -45,6 +45,29 @@ Status ValidateScalarNotZero(const std::string &op_name, const std::string &scal
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to check vector is not empty
|
||||
template <typename T>
|
||||
Status ValidateVectorNotEmpty(const std::string &op_name, const std::string &vec_name, const std::vector<T> &vec) {
|
||||
if (vec.empty()) {
|
||||
std::string err_msg = op_name + ": " + vec_name + " can't be empty.";
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Helper function to check two vector size equal
|
||||
template <typename T>
|
||||
Status ValidateVectorSameSize(const std::string &op_name, const std::string &vec1_name, const std::vector<T> &vec1,
|
||||
const std::string &vec2_name, const std::vector<T> &vec2) {
|
||||
if (vec1.size() != vec2.size()) {
|
||||
std::string err_msg = op_name + ": the size of " + vec1_name + " should be the same as that of " + vec2_name;
|
||||
MS_LOG(ERROR) << err_msg;
|
||||
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ADUIO_IR_VALIDATORS_H_
|
||||
|
|
|
@ -16,8 +16,10 @@ add_library(audio-kernels OBJECT
|
|||
equalizer_biquad_op.cc
|
||||
frequency_masking_op.cc
|
||||
highpass_biquad_op.cc
|
||||
lfilter_op.cc
|
||||
lowpass_biquad_op.cc
|
||||
mu_law_decoding_op.cc
|
||||
time_masking_op.cc
|
||||
time_stretch_op.cc
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/lfilter_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status LFilterOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
TensorShape input_shape = input->shape();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "LFilter: input tensor is not in shape of <..., time>.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType(DataType::DE_FLOAT32) ||
|
||||
input->type() == DataType(DataType::DE_FLOAT16) ||
|
||||
input->type() == DataType(DataType::DE_FLOAT64),
|
||||
"LFilter: input tensor type should be float, but got: " + input->type().ToString());
|
||||
if (input->type() == DataType(DataType::DE_FLOAT32)) {
|
||||
return LFilter(input, output, a_coeffs_, b_coeffs_, clamp_);
|
||||
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
std::vector<double> a_coeffs_double;
|
||||
std::vector<double> b_coeffs_double;
|
||||
for (int i = 0; i < a_coeffs_.size(); i++) {
|
||||
a_coeffs_double.push_back(static_cast<double>(a_coeffs_[i]));
|
||||
}
|
||||
for (int i = 0; i < b_coeffs_.size(); i++) {
|
||||
b_coeffs_double.push_back(static_cast<double>(b_coeffs_[i]));
|
||||
}
|
||||
return LFilter(input, output, a_coeffs_double, b_coeffs_double, clamp_);
|
||||
} else {
|
||||
std::vector<float16> a_coeffs_float16;
|
||||
std::vector<float16> b_coeffs_float16;
|
||||
for (int i = 0; i < a_coeffs_.size(); i++) {
|
||||
a_coeffs_float16.push_back(static_cast<float16>(a_coeffs_[i]));
|
||||
}
|
||||
for (int i = 0; i < b_coeffs_.size(); i++) {
|
||||
b_coeffs_float16.push_back(static_cast<float16>(b_coeffs_[i]));
|
||||
}
|
||||
return LFilter(input, output, a_coeffs_float16, b_coeffs_float16, clamp_);
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LFILTER_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_LFILTER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class LFilterOp : public TensorOp {
|
||||
public:
|
||||
LFilterOp(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp)
|
||||
: a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {}
|
||||
|
||||
~LFilterOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
out << Name() << ": a_coeffs: ";
|
||||
for (int i = 0; i < a_coeffs_.size(); i++) {
|
||||
out << a_coeffs_[i] << " ";
|
||||
}
|
||||
out << "b_coeffs: ";
|
||||
for (int i = 0; i < b_coeffs_.size(); i++) {
|
||||
out << b_coeffs_[i] << " ";
|
||||
}
|
||||
out << "clamp: " << clamp_ << std::endl;
|
||||
}
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kLFilterOp; }
|
||||
|
||||
private:
|
||||
std::vector<float> a_coeffs_;
|
||||
std::vector<float> b_coeffs_;
|
||||
bool clamp_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_KERNELS_LFILTER_OP_H_
|
|
@ -321,6 +321,31 @@ class HighpassBiquad final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Design filter. Similar to SoX implementation.
|
||||
class LFilter final : public TensorTransform {
|
||||
public:
|
||||
/// \param[in] a_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1).
|
||||
/// Lower delays coefficients are first, e.g. [a0, a1, a2, ...].
|
||||
/// Must be same size as b_coeffs (pad with 0's as necessary).
|
||||
/// \param[in] b_coeffs Numerator coefficients of difference equation of dimension of (n_order + 1).
|
||||
/// Lower delays coefficients are first, e.g. [b0, b1, b2, ...].
|
||||
/// Must be same size as a_coeffs (pad with 0's as necessary).
|
||||
/// \param[in] clamp If True, clamp the output signal to be in the range [-1, 1] (Default: True).
|
||||
explicit LFilter(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp = true);
|
||||
|
||||
/// \brief Destructor.
|
||||
~LFilter() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
|
||||
class LowpassBiquad final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -152,6 +152,7 @@ constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
|
|||
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
|
||||
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
|
||||
constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp";
|
||||
constexpr char kLFilterOp[] = "LFilterOp";
|
||||
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
|
||||
constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
|
||||
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
|
||||
|
|
|
@ -26,8 +26,8 @@ from ..transforms.c_transforms import TensorOperation
|
|||
from .utils import ScaleType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_deemph_biquad, \
|
||||
check_equalizer_biquad, check_highpass_biquad, check_lowpass_biquad, check_masking, check_mu_law_decoding,\
|
||||
check_time_stretch
|
||||
check_equalizer_biquad, check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_masking, \
|
||||
check_mu_law_decoding, check_time_stretch
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -409,6 +409,39 @@ class HighpassBiquad(AudioTensorOperation):
|
|||
return cde.HighpassBiquadOperation(self.sample_rate, self.cutoff_freq, self.Q)
|
||||
|
||||
|
||||
class LFilter(AudioTensorOperation):
|
||||
"""
|
||||
Design two-pole filter for audio waveform of dimension of (..., time).
|
||||
|
||||
Args:
|
||||
a_coeffs (sequence): denominator coefficients of difference equation of dimension of (n_order + 1).
|
||||
Lower delays coefficients are first, e.g. [a0, a1, a2, ...].
|
||||
Must be same size as b_coeffs (pad with 0's as necessary).
|
||||
b_coeffs (sequence): numerator coefficients of difference equation of dimension of (n_order + 1).
|
||||
Lower delays coefficients are first, e.g. [b0, b1, b2, ...].
|
||||
Must be same size as a_coeffs (pad with 0's as necessary).
|
||||
clamp (bool, optional): If True, clamp the output signal to be in the range [-1, 1] (default=True).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> a_coeffs = [0.1, 0.2, 0.3]
|
||||
>>> b_coeffs = [0.1, 0.2, 0.3]
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.LFilter(a_coeffs, b_coeffs)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
@check_lfilter
|
||||
def __init__(self, a_coeffs, b_coeffs, clamp=True):
|
||||
self.a_coeffs = a_coeffs
|
||||
self.b_coeffs = b_coeffs
|
||||
self.clamp = clamp
|
||||
|
||||
def parse(self):
|
||||
return cde.LFilterOperation(self.a_coeffs, self.b_coeffs, self.clamp)
|
||||
|
||||
|
||||
class LowpassBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
|
||||
|
|
|
@ -18,8 +18,8 @@ Validators for TensorOps.
|
|||
|
||||
from functools import wraps
|
||||
|
||||
from mindspore.dataset.core.validator_helpers import check_float32, check_int32_not_zero, check_non_negative_float32, \
|
||||
check_pos_float32, check_pos_int32, check_value, parse_user_args, type_check
|
||||
from mindspore.dataset.core.validator_helpers import check_float32, check_int32_not_zero, check_list_same_size, \
|
||||
check_non_negative_float32, check_pos_float32, check_pos_int32, check_value, parse_user_args, type_check
|
||||
from .utils import ScaleType
|
||||
|
||||
|
||||
|
@ -230,6 +230,27 @@ def check_equalizer_biquad(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_lfilter(method):
|
||||
"""Wrapper method to check the parameters of lfilter."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[a_coeffs, b_coeffs, clamp], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(a_coeffs, (list, tuple), "a_coeffs")
|
||||
type_check(b_coeffs, (list, tuple), "b_coeffs")
|
||||
for i, value in enumerate(a_coeffs):
|
||||
type_check(value, (float, int), "a_coeffs[{0}]".format(i))
|
||||
check_float32(value, "a_coeffs[{0}]".format(i))
|
||||
for i, value in enumerate(b_coeffs):
|
||||
type_check(value, (float, int), "b_coeffs[{0}]".format(i))
|
||||
check_float32(value, "b_coeffs[{0}]".format(i))
|
||||
check_list_same_size(a_coeffs, b_coeffs, "a_coeffs", "b_coeffs")
|
||||
type_check(clamp, (bool,), "clamp")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_lowpass_biquad(method):
|
||||
"""Wrapper method to check the parameters of LowpassBiquad."""
|
||||
|
||||
|
|
|
@ -532,6 +532,20 @@ def check_dir(dataset_dir):
|
|||
raise ValueError("The folder {} does not exist or is not a directory or permission denied!".format(dataset_dir))
|
||||
|
||||
|
||||
def check_list_same_size(list1, list2, list1_name="", list2_name=""):
|
||||
"""
|
||||
Validates the two lists as the same size.
|
||||
|
||||
:param list1: the first list to be validated
|
||||
:param list2: the secend list to be validated
|
||||
:param list1_name: name of the list1
|
||||
:param list2_name: name of the list2
|
||||
:return: Exception: when the two list no same size, nothing otherwise.
|
||||
"""
|
||||
if len(list1) != len(list2):
|
||||
raise ValueError("The size of {0} should be the same as that of {1}.".format(list1_name, list2_name))
|
||||
|
||||
|
||||
def check_file(dataset_file):
|
||||
"""
|
||||
Validates if the argument is a valid file name.
|
||||
|
|
|
@ -954,3 +954,65 @@ TEST_F(MindDataTestPipeline, TestMuLawDecodingWrongArgs) {
|
|||
std::shared_ptr<Iterator> iter1 = ds->CreateIterator();
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestLfilterPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLfilterPipeline.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::vector<float> a_coeffs = {0.1, 0.2, 0.3};
|
||||
std::vector<float> b_coeffs = {0.1, 0.2, 0.3};
|
||||
auto LFilterOp = audio::LFilter(a_coeffs, b_coeffs);
|
||||
|
||||
ds = ds->Map({LFilterOp});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Filtered waveform by lfilter
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestLfilterWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLfilterWrongArgs.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
// Original waveform
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
std::shared_ptr<Dataset> ds01;
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Check sample_rate
|
||||
MS_LOG(INFO) << "a_coeffs size not equal to b_coeffs";
|
||||
std::vector<float> a_coeffs = {0.1, 0.2, 0.3};
|
||||
std::vector<float> b_coeffs = {0.1, 0.2};
|
||||
auto LFilterOp = audio::LFilter(a_coeffs, b_coeffs);
|
||||
ds01 = ds->Map({LFilterOp});
|
||||
EXPECT_NE(ds01, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
|
||||
EXPECT_EQ(iter01, nullptr);
|
||||
}
|
||||
|
|
|
@ -875,3 +875,41 @@ TEST_F(MindDataTestExecute, TestMuLawDecodingEager) {
|
|||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestLFilterWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestLFilterWithEager.";
|
||||
// Original waveform
|
||||
std::vector<float> labels = {
|
||||
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
|
||||
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
|
||||
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::vector<float> a_coeffs = {0.1, 0.2, 0.3};
|
||||
std::vector<float> b_coeffs = {0.1, 0.2, 0.3};
|
||||
std::shared_ptr<TensorTransform> lfilter_01 = std::make_shared<audio::LFilter>(a_coeffs, b_coeffs);
|
||||
mindspore::dataset::Execute Transform01({lfilter_01});
|
||||
// Filtered waveform by lfilter
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestLFilterWithWrongArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestLFilterWithWrongArg.";
|
||||
std::vector<double> labels = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({1, 6}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
|
||||
// Check a_coeffs size equal to b_coeffs
|
||||
MS_LOG(INFO) << "a_coeffs size not equal to b_coeffs";
|
||||
std::vector<float> a_coeffs = {0.1, 0.2, 0.3};
|
||||
std::vector<float> b_coeffs = {0.1, 0.2};
|
||||
std::shared_ptr<TensorTransform> lfilter_op = std::make_shared<audio::LFilter>(a_coeffs, b_coeffs);
|
||||
mindspore::dataset::Execute Transform01({lfilter_op});
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_func_lfilter_eager():
|
||||
""" mindspore eager mode normal testcase:deemph_biquad op"""
|
||||
# Original waveform
|
||||
waveform = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.25, 0.45, 0.425],
|
||||
[1., 1., 0.35]], dtype=np.float64)
|
||||
a_coeffs = [0.2, 0.2, 0.3]
|
||||
b_coeffs = [0.5, 0.4, 0.2]
|
||||
lfilter_op = audio.LFilter(a_coeffs, b_coeffs, True)
|
||||
output = lfilter_op(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_func_lfilter_pipeline():
|
||||
""" mindspore pipeline mode normal testcase:lfilter op"""
|
||||
|
||||
# Original waveform
|
||||
waveform = np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.5, 0.6, 0.7]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.4, 0.5, 0.6, 1.],
|
||||
[1., 0.8, 0.9, 1.]], dtype=np.float64)
|
||||
data = (waveform, waveform.shape)
|
||||
a_coeffs = [0.1, 0.2, 0.3]
|
||||
b_coeffs = [0.4, 0.5, 0.6]
|
||||
dataset = ds.NumpySlicesDataset(data, ["channel", "sample"], shuffle=False)
|
||||
lfilter_op = audio.LFilter(a_coeffs, b_coeffs)
|
||||
# Filtered waveform by lfilter
|
||||
dataset = dataset.map(input_columns=["channel"], operations=lfilter_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for data in dataset.create_dict_iterator(output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :], data['channel'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
||||
def test_invalid_input_all():
|
||||
waveform = np.random.rand(2, 1000)
|
||||
|
||||
def test_invalid_input(test_name, a_coeffs, b_coeffs, clamp, error, error_msg):
|
||||
logger.info("Test LFilter with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.LFilter(a_coeffs, b_coeffs, clamp)(waveform)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
a_coeffs = ['0.1', '0.2', '0.3']
|
||||
b_coeffs = [0.1, 0.2, 0.3]
|
||||
test_invalid_input("invalid a_coeffs parameter type as a string", a_coeffs, b_coeffs, True, TypeError,
|
||||
"Argument a_coeffs[0] with value 0.1 is not of type [<class 'float'>, <class 'int'>], "
|
||||
+ "but got <class 'str'>.")
|
||||
a_coeffs = [234322354352353453651, 0.2, 0.3]
|
||||
b_coeffs = [0.1, 0.2, 0.3]
|
||||
test_invalid_input("invalid a_coeffs parameter value", a_coeffs, b_coeffs, True, ValueError,
|
||||
"Input a_coeffs[0] is not within the required interval of [-16777216, 16777216].")
|
||||
a_coeffs = [0.1, 0.2, 0.3]
|
||||
b_coeffs = [0.1, 0.2, 0.3]
|
||||
test_invalid_input("invalid clamp parameter type as a String", a_coeffs, b_coeffs, "True", TypeError,
|
||||
"Argument clamp with value True is not of type [<class 'bool'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_func_lfilter_eager()
|
||||
test_func_lfilter_pipeline()
|
||||
test_invalid_input_all()
|
||||
|
Loading…
Reference in New Issue