forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6U3] add new audio operator DCShift
This commit is contained in:
parent
8a51311e57
commit
0ba84345b0
|
@ -25,6 +25,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
|
@ -166,6 +167,21 @@ std::shared_ptr<TensorOperation> Contrast::Parse() {
|
|||
return std::make_shared<ContrastOperation>(data_->enhancement_amount_);
|
||||
}
|
||||
|
||||
// DCShift Transform Operation.
|
||||
struct DCShift::Data {
|
||||
Data(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {}
|
||||
float limiter_gain_;
|
||||
float shift_;
|
||||
};
|
||||
|
||||
DCShift::DCShift(float shift) : data_(std::make_shared<Data>(shift, shift)) {}
|
||||
|
||||
DCShift::DCShift(float shift, float limiter_gain) : data_(std::make_shared<Data>(shift, limiter_gain)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> DCShift::Parse() {
|
||||
return std::make_shared<DCShiftOperation>(data_->shift_, data_->limiter_gain_);
|
||||
}
|
||||
|
||||
// DeemphBiquad Transform Operation.
|
||||
struct DeemphBiquad::Data {
|
||||
explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {}
|
||||
|
|
|
@ -29,6 +29,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
|
@ -148,6 +149,16 @@ PYBIND_REGISTER(ContrastOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DCShiftOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::DCShiftOperation, TensorOperation, std::shared_ptr<audio::DCShiftOperation>>(
|
||||
*m, "DCShiftOperation")
|
||||
.def(py::init([](float shift, float limiter_gain) {
|
||||
auto dc_shift = std::make_shared<audio::DCShiftOperation>(shift, limiter_gain);
|
||||
THROW_IF_ERROR(dc_shift->ValidateParams());
|
||||
return dc_shift;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
DeemphBiquadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::DeemphBiquadOperation, TensorOperation, std::shared_ptr<audio::DeemphBiquadOperation>>(
|
||||
|
|
|
@ -11,6 +11,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
bass_biquad_ir.cc
|
||||
complex_norm_ir.cc
|
||||
contrast_ir.cc
|
||||
dc_shift_ir.cc
|
||||
deemph_biquad_ir.cc
|
||||
equalizer_biquad_ir.cc
|
||||
frequency_masking_ir.cc
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/dc_shift_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
|
||||
// DCShiftOperation
|
||||
DCShiftOperation::DCShiftOperation(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {}
|
||||
|
||||
Status DCShiftOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateScalar("DCShift", "shift", shift_, {-2.0, 2.0}, false, false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> DCShiftOperation::Build() {
|
||||
std::shared_ptr<DCShiftOp> tensor_op = std::make_shared<DCShiftOp>(shift_, limiter_gain_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status DCShiftOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["shift"] = shift_;
|
||||
args["limiter_gain"] = limiter_gain_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DC_SHIFT_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DC_SHIFT_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
|
||||
constexpr char kDCShiftOperation[] = "DCShift";
|
||||
|
||||
class DCShiftOperation : public TensorOperation {
|
||||
public:
|
||||
DCShiftOperation(float shift, float limiter_gain);
|
||||
|
||||
~DCShiftOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kDCShiftOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
float shift_;
|
||||
float limiter_gain_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DC_SHIFT_IR_H_
|
|
@ -12,6 +12,7 @@ add_library(audio-kernels OBJECT
|
|||
bass_biquad_op.cc
|
||||
complex_norm_op.cc
|
||||
contrast_op.cc
|
||||
dc_shift_op.cc
|
||||
deemph_biquad_op.cc
|
||||
equalizer_biquad_op.cc
|
||||
frequency_masking_op.cc
|
||||
|
|
|
@ -150,6 +150,40 @@ Status Contrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Apply a DC shift to the audio.
|
||||
/// \param input/output: Tensor of shape <...,time>.
|
||||
/// \param shift: the amount to shift the audio.
|
||||
/// \param limiter_gain: used only on peaks to prevent clipping.
|
||||
/// \return Status code.
|
||||
template <typename T>
|
||||
Status DCShift(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float shift, float limiter_gain) {
|
||||
float limiter_threshold = 0.0;
|
||||
if (shift != limiter_gain && shift != 0) {
|
||||
limiter_threshold = 1.0 - (std::abs(shift) - limiter_gain);
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++) {
|
||||
if (*itr > limiter_threshold && shift > 0) {
|
||||
T peak = (*itr - limiter_threshold) * limiter_gain / (1 - limiter_threshold);
|
||||
T sample = (peak + limiter_threshold + shift);
|
||||
*itr = sample > limiter_threshold ? limiter_threshold : sample;
|
||||
} else if (*itr < -limiter_threshold && shift < 0) {
|
||||
T peak = (*itr + limiter_threshold) * limiter_gain / (1 - limiter_threshold);
|
||||
T sample = (peak + limiter_threshold + shift);
|
||||
*itr = sample < -limiter_threshold ? -limiter_threshold : sample;
|
||||
} else {
|
||||
T sample = (*itr + shift);
|
||||
*itr = (sample > 1 || sample < -1) ? (sample > 1 ? 1 : -1) : sample;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++) {
|
||||
T sample = (*itr + shift);
|
||||
*itr = sample > 1 || sample < -1 ? (sample > 1 ? 1 : -1) : sample;
|
||||
}
|
||||
}
|
||||
*output = input;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Perform an IIR filter by evaluating difference equation.
|
||||
/// \param input/output: Tensor of shape <..., time>
|
||||
/// \param a_coeffs: denominator coefficients of difference equation of dimension of (n_order + 1).
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/dc_shift_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
Status DCShiftOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
// input <..., time>.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->Rank() > 0, "ComplexNorm: input tensor is not in shape of <..., time>.");
|
||||
// If datatype is not a numeric type, then we cannot deal with the data.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type().IsNumeric(),
|
||||
"DCShift: input tensor type should be int, float or double, but got: " + input->type().ToString());
|
||||
if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
return DCShift<double>(input, output, shift_, limiter_gain_);
|
||||
} else {
|
||||
std::shared_ptr<Tensor> tmp;
|
||||
TypeCast(input, &tmp, DataType(DataType::DE_FLOAT32));
|
||||
return DCShift<float>(tmp, output, shift_, limiter_gain_);
|
||||
}
|
||||
}
|
||||
|
||||
Status DCShiftOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
inputs[0].IsNumeric(),
|
||||
"DCShift: input tensor type should be int, float or double, but got: " + inputs[0].ToString());
|
||||
if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DC_SHIFT_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DC_SHIFT_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DCShiftOp : public TensorOp {
|
||||
public:
|
||||
DCShiftOp(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {}
|
||||
|
||||
~DCShiftOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
out << Name() << ":: shift: " << shift_ << ", limiter_gain: " << limiter_gain_;
|
||||
}
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kDCShiftOp; }
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
protected:
|
||||
float shift_;
|
||||
float limiter_gain_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DC_SHIFT_OP_H_
|
|
@ -230,6 +230,32 @@ class Contrast final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply a DC shift to the audio.
|
||||
class DCShift : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] shift Indicates the amount to shift the audio, the value must be in the range [-2.0, 2.0].
|
||||
/// \param[in] limiter_gain Used only on peaks to prevent clipping.
|
||||
DCShift(float shift, float limiter_gain);
|
||||
|
||||
/// \brief Constructor
|
||||
/// \param[in] shift Indicates the amount to shift the audio.
|
||||
/// \note This constructor will use `shift` as `limiter_gain`.
|
||||
explicit DCShift(float shift);
|
||||
|
||||
/// \brief Destructor.
|
||||
~DCShift() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Design two-pole deemph filter. Similar to SoX implementation.
|
||||
class DeemphBiquad final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -148,6 +148,7 @@ constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp";
|
|||
constexpr char kBassBiquadOp[] = "BassBiquadOp";
|
||||
constexpr char kComplexNormOp[] = "ComplexNormOp";
|
||||
constexpr char kContrastOp[] = "ContrastOp";
|
||||
constexpr char kDCShiftOp[] = "DCShiftOp";
|
||||
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
|
||||
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
|
||||
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
|
||||
|
|
|
@ -25,9 +25,9 @@ import mindspore._c_dataengine as cde
|
|||
from ..transforms.c_transforms import TensorOperation
|
||||
from .utils import ScaleType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_deemph_biquad, \
|
||||
check_equalizer_biquad, check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_masking, \
|
||||
check_mu_law_decoding, check_time_stretch
|
||||
check_bandreject_biquad, check_bass_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_equalizer_biquad, check_highpass_biquad, check_lfilter, check_lowpass_biquad, \
|
||||
check_masking, check_mu_law_decoding, check_time_stretch
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -295,6 +295,33 @@ class Contrast(AudioTensorOperation):
|
|||
return cde.ContrastOperation(self.enhancement_amount)
|
||||
|
||||
|
||||
class DCShift(AudioTensorOperation):
|
||||
"""
|
||||
Apply a DC shift to the audio.
|
||||
|
||||
Args:
|
||||
shift (float): The amount to shift the audio, the value must be in the range [-2.0, 2.0].
|
||||
limiter_gain (float, optional): Used only on peaks to prevent clipping,
|
||||
the value should be much less than 1, such as 0.05 or 0.02.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.DCShift(0.5, 0.02)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operation=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_dc_shift
|
||||
def __init__(self, shift, limiter_gain=None):
|
||||
self.shift = shift
|
||||
self.limiter_gain = limiter_gain if limiter_gain else shift
|
||||
|
||||
def parse(self):
|
||||
return cde.DCShiftOperation(self.shift, self.limiter_gain)
|
||||
|
||||
|
||||
class DeemphBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design two-pole deemph filter for audio waveform of dimension of (..., time).
|
||||
|
|
|
@ -201,6 +201,20 @@ def check_contrast(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_dc_shift(method):
|
||||
"""Wrapper method to check the parameters of DCShift."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[shift, limiter_gain], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(shift, (float, int), "shift")
|
||||
check_value(shift, [-2.0, 2.0], "shift")
|
||||
if limiter_gain is not None:
|
||||
type_check(limiter_gain, (float, int), "limiter_gain")
|
||||
return method(self, *args, **kwargs)
|
||||
return new_method
|
||||
|
||||
|
||||
def check_deemph_biquad(method):
|
||||
"""Wrapper method to check the parameters of CutMixBatch."""
|
||||
|
||||
|
|
|
@ -1016,3 +1016,54 @@ TEST_F(MindDataTestPipeline, TestLfilterWrongArgs) {
|
|||
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
|
||||
EXPECT_EQ(iter01, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDCShiftPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDCShiftPipeline.";
|
||||
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {1, 2, 100}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto dc_shift_op = audio::DCShift(0.8, 0.02);
|
||||
|
||||
ds = ds->Map({dc_shift_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {1, 2, 100};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 3);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestDCShiftPipelineError) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDCShiftPipelineError.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {100}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(4, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto dc_shift_op = audio::DCShift(3, 0.02);
|
||||
|
||||
ds = ds->Map({dc_shift_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
|
|
@ -913,3 +913,17 @@ TEST_F(MindDataTestExecute, TestLFilterWithWrongArg) {
|
|||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestDCShiftEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDCShiftEager.";
|
||||
|
||||
std::vector<float> origin = {0.67443, 1.87523, 0.73465, -0.74553, -1.54346, 1.54093, -1.23453};
|
||||
std::shared_ptr<Tensor> de_tensor;
|
||||
Tensor::CreateFromVector(origin, &de_tensor);
|
||||
|
||||
std::shared_ptr<TensorTransform> dc_shift = std::make_shared<audio::DCShift>(0.5, 0.02);
|
||||
auto input = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(de_tensor));
|
||||
mindspore::dataset::Execute Transform({dc_shift});
|
||||
Status s = Transform(input, &input);
|
||||
ASSERT_TRUE(s.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as a_c_trans
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
|
||||
data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_func_dc_shift_eager():
|
||||
"""
|
||||
Eager Test
|
||||
"""
|
||||
arr = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93, 0.71, 0.61], dtype=np.double)
|
||||
expected = np.array([0.0400, 0.0400, -0.0400, -0.2600, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400, 0.0400],
|
||||
dtype=np.double)
|
||||
dcshift_op = a_c_trans.DCShift(1.0, 0.04)
|
||||
output = dcshift_op(arr)
|
||||
count_unequal_element(expected, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_func_dc_shift_pipeline():
|
||||
"""
|
||||
Pipeline Test
|
||||
"""
|
||||
arr = np.array([[1.14, -1.06, 0.94, 0.90], [-1.11, 1.40, -0.33, 1.43]], dtype=np.double)
|
||||
expected = np.array([[0.2300, -0.2600, 0.2300, 0.2300], [-0.3100, 0.2300, 0.4700, 0.2300]], dtype=np.double)
|
||||
dataset = ds.NumpySlicesDataset(arr, column_names=["col1"], shuffle=False)
|
||||
dcshift_op = a_c_trans.DCShift(0.8, 0.03)
|
||||
dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
|
||||
for item1, item2 in zip(dataset.create_dict_iterator(output_numpy=True), expected):
|
||||
count_unequal_element(item2, item1['col1'], 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_func_dc_shift_pipeline_error():
|
||||
"""
|
||||
Pipeline Error Test
|
||||
"""
|
||||
arr = np.random.uniform(-2, 2, size=(1000)).astype(np.float)
|
||||
label = np.random.sample((1000, 1))
|
||||
data = (arr, label)
|
||||
dataset = ds.NumpySlicesDataset(data, column_names=["col1", "col2"], shuffle=False)
|
||||
num_itr = 0
|
||||
with pytest.raises(ValueError, match=r"Input shift is not within the required interval of \[-2.0, 2.0\]."):
|
||||
dcshift_op = a_c_trans.DCShift(2.5, 0.03)
|
||||
dataset = dataset.map(operations=dcshift_op, input_columns=["col1"])
|
||||
for _ in dataset.create_dict_iterator(output_numpy=True):
|
||||
num_itr += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_func_dc_shift_eager()
|
||||
test_func_dc_shift_pipeline()
|
||||
test_func_dc_shift_pipeline_error()
|
Loading…
Reference in New Issue