forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6UU] add new audio operator Overdrive
This commit is contained in:
parent
58a2a02400
commit
9b7d7be65c
|
@ -38,6 +38,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
|
@ -400,6 +401,19 @@ std::shared_ptr<TensorOperation> MuLawDecoding::Parse() {
|
|||
return std::make_shared<MuLawDecodingOperation>(data_->quantization_channels_);
|
||||
}
|
||||
|
||||
// Overdrive Transform Operation.
|
||||
struct Overdrive::Data {
|
||||
Data(float gain, float color) : gain_(gain), color_(color) {}
|
||||
float gain_;
|
||||
float color_;
|
||||
};
|
||||
|
||||
Overdrive::Overdrive(float gain, float color) : data_(std::make_shared<Data>(gain, color)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Overdrive::Parse() {
|
||||
return std::make_shared<OverdriveOperation>(data_->gain_, data_->color_);
|
||||
}
|
||||
|
||||
// RiaaBiquad Transform Operation.
|
||||
struct RiaaBiquad::Data {
|
||||
explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {}
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
|
@ -325,6 +326,17 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(OverdriveOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<audio::OverdriveOperation, TensorOperation, std::shared_ptr<audio::OverdriveOperation>>(
|
||||
*m, "OverdriveOperation")
|
||||
.def(py::init([](float gain, float color) {
|
||||
auto overdrive = std::make_shared<audio::OverdriveOperation>(gain, color);
|
||||
THROW_IF_ERROR(overdrive->ValidateParams());
|
||||
return overdrive;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RiaaBiquadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::RiaaBiquadOperation, TensorOperation, std::shared_ptr<audio::RiaaBiquadOperation>>(
|
||||
|
|
|
@ -24,6 +24,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
lowpass_biquad_ir.cc
|
||||
magphase_ir.cc
|
||||
mu_law_decoding_ir.cc
|
||||
overdrive_ir.cc
|
||||
riaa_biquad_ir.cc
|
||||
time_masking_ir.cc
|
||||
time_stretch_ir.cc
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/overdrive_op.h"
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
OverdriveOperation::OverdriveOperation(float gain, float color) : gain_(gain), color_(color) {}
|
||||
|
||||
Status OverdriveOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Overdrive", "gain", gain_, {0.0f, 100.0f}, false, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Overdrive", "color", color_, {0.0f, 100.0f}, false, false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> OverdriveOperation::Build() {
|
||||
std::shared_ptr<OverdriveOp> tensor_op = std::make_shared<OverdriveOp>(gain_, color_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status OverdriveOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["gain"] = gain_;
|
||||
args["color"] = color_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,57 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_OVERDRIVE_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_OVERDRIVE_IR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kOverdriveOperation[] = "Overdrive";
|
||||
|
||||
class OverdriveOperation : public TensorOperation {
|
||||
public:
|
||||
explicit OverdriveOperation(float gain, float color);
|
||||
|
||||
~OverdriveOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kOverdriveOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
float gain_;
|
||||
float color_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_OVERDRIVE_IR_H_
|
|
@ -25,6 +25,7 @@ add_library(audio-kernels OBJECT
|
|||
lowpass_biquad_op.cc
|
||||
magphase_op.cc
|
||||
mu_law_decoding_op.cc
|
||||
overdrive_op.cc
|
||||
riaa_biquad_op.cc
|
||||
time_masking_op.cc
|
||||
time_stretch_op.cc
|
||||
|
|
|
@ -318,6 +318,80 @@ Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
|
|||
/// \return Status code.
|
||||
Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int quantization_channels);
|
||||
|
||||
/// \brief Apply a overdrive effect to the audio.
|
||||
/// \param input Tensor of shape <..., time>.
|
||||
/// \param output Tensor of shape <..., time>.
|
||||
/// \param gain Coefficient of overload in dB.
|
||||
/// \param color Coefficient of translation.
|
||||
/// \return Status code.
|
||||
template <typename T>
|
||||
Status Overdrive(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float gain, float color) {
|
||||
TensorShape input_shape = input->shape();
|
||||
// input->2D.
|
||||
auto rows = input->Size() / input_shape[-1];
|
||||
auto cols = input_shape[-1];
|
||||
TensorShape to_shape({rows, cols});
|
||||
RETURN_IF_NOT_OK(input->Reshape(to_shape));
|
||||
// apply dB2Linear on gain, 20dB is expect to gain.
|
||||
float gain_ex = exp(gain * log(10) / 20.0);
|
||||
color = color / 200;
|
||||
// declare the array used to store the input.
|
||||
std::vector<T> input_vec;
|
||||
// out_vec is used to save the result of applying overdrive.
|
||||
std::vector<T> out_vec;
|
||||
// store intermediate results of input.
|
||||
std::vector<T> temp;
|
||||
// scale and pan the input two-dimensional sound wave array to a certain extent.
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++) {
|
||||
// store the value of traverse the input.
|
||||
T temp_fp = *itr;
|
||||
input_vec.push_back(temp_fp);
|
||||
// use 0 to initialize out_vec.
|
||||
out_vec.push_back(0);
|
||||
T temp_fp2 = temp_fp * gain_ex + color;
|
||||
// 0.5 + 2/3 * 0.75 = 1, zoom and shift the sound.
|
||||
if (temp_fp2 < -1) {
|
||||
temp.push_back(-2.0 / 3.0);
|
||||
} else if (temp_fp2 > 1) {
|
||||
temp.push_back(2.0 / 3.0);
|
||||
} else {
|
||||
temp.push_back(temp_fp2 - temp_fp2 * temp_fp2 * temp_fp2 / 3.0);
|
||||
}
|
||||
}
|
||||
// last_in and last_out are the intermediate values for processing each moment.
|
||||
std::vector<T> last_in;
|
||||
std::vector<T> last_out;
|
||||
for (size_t i = 0; i < cols; i++) {
|
||||
last_in.push_back(0.0);
|
||||
last_out.push_back(0.0);
|
||||
}
|
||||
// overdrive core loop.
|
||||
for (size_t i = 0; i < cols; i++) {
|
||||
size_t index = 0;
|
||||
// calculate the value of each moment according to the rules of overdrive.
|
||||
for (size_t j = i; j < rows * cols; j += cols, index++) {
|
||||
// 0.995 is the preservation ratio of sound waves.
|
||||
last_out[index] = temp[j] - last_in[index] + last_out[index] * 0.995;
|
||||
last_in[index] = temp[j];
|
||||
// 0.5 + 2/3 * 0.75 = 1, zoom and shift the sound.
|
||||
T temp_fp = input_vec[j] * 0.5 + last_out[index] * 0.75;
|
||||
// clamp min=-1, max=1.
|
||||
if (temp_fp < -1) {
|
||||
out_vec[j] = -1.0;
|
||||
} else if (temp_fp > 1) {
|
||||
out_vec[j] = 1.0;
|
||||
} else {
|
||||
out_vec[j] = temp_fp;
|
||||
}
|
||||
}
|
||||
}
|
||||
// move data to output tensor.
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out_vec, input_shape, &out));
|
||||
*output = out;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Add a fade in and/or fade out to an input.
|
||||
/// \param[in] input: The input tensor.
|
||||
/// \param[out] output: Added fade in and/or fade out audio with the same shape.
|
||||
|
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/kernels/overdrive_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
OverdriveOp::OverdriveOp(float gain, float color) : gain_(gain), color_(color) {}
|
||||
|
||||
Status OverdriveOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
TensorShape input_shape = input->shape();
|
||||
// check input tensor dimension, it should be greater than 0.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "Overdrive: input tensor is not in shape of <..., time>.");
|
||||
// check input type, it should be numeric.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type().IsNumeric(),
|
||||
"Overdrive: input tensor type should be int, float or double, but got: " + input->type().ToString());
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
if (input->type() != DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
|
||||
return Overdrive<float>(input_tensor, output, gain_, color_);
|
||||
} else {
|
||||
input_tensor = input;
|
||||
return Overdrive<double>(input_tensor, output, gain_, color_);
|
||||
}
|
||||
}
|
||||
|
||||
Status OverdriveOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
if (!inputs[0].IsNumeric()) {
|
||||
RETURN_STATUS_UNEXPECTED("Overdrive: input tensor type should be int, float or double, but got: " +
|
||||
inputs[0].ToString());
|
||||
} else if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_OVERDRIVE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_OVERDRIVE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class OverdriveOp : public TensorOp {
|
||||
public:
|
||||
explicit OverdriveOp(float gain, float color);
|
||||
|
||||
~OverdriveOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kOverdriveOp; }
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
private:
|
||||
float gain_;
|
||||
float color_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_OVERDRIVE_OP_H_
|
|
@ -543,6 +543,27 @@ class MuLawDecoding final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Overdrive TensorTransform.
|
||||
class Overdrive final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] gain Coefficient of overload in dB, in range of [0, 100] (Default: 20.0).
|
||||
/// \param[in] color Coefficient of translation, in range of [0, 100] (Default: 20.0).
|
||||
explicit Overdrive(float gain = 20.0f, float color = 20.0f);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Overdrive() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply RIAA vinyl playback equalization.
|
||||
class RiaaBiquad final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -165,6 +165,7 @@ constexpr char kLFilterOp[] = "LFilterOp";
|
|||
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
|
||||
constexpr char kMagphaseOp[] = "MagphaseOp";
|
||||
constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
|
||||
constexpr char kOverdriveOp[] = "OverdriveOp";
|
||||
constexpr char kRiaaBiquadOp[] = "RiaaBiquadOp";
|
||||
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
|
||||
constexpr char kTimeStretchOp[] = "TimeStretchOp";
|
||||
|
|
|
@ -28,7 +28,7 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_
|
|||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_flanger, \
|
||||
check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, \
|
||||
check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol
|
||||
check_overdrive, check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -719,6 +719,33 @@ class MuLawDecoding(AudioTensorOperation):
|
|||
return cde.MuLawDecodingOperation(self.quantization_channels)
|
||||
|
||||
|
||||
class Overdrive(AudioTensorOperation):
|
||||
"""
|
||||
Apply overdrive on input audio.
|
||||
|
||||
Args:
|
||||
gain (float): Desired gain at the boost (or attenuation) in dB, in range of [0, 100] (default=20.0).
|
||||
color (float): Controls the amount of even harmonic content in the over-driven output,
|
||||
in range of [0, 100] (default=20.0).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Overdrive()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_overdrive
|
||||
def __init__(self, gain=20.0, color=20.0):
|
||||
self.gain = gain
|
||||
self.color = color
|
||||
|
||||
def parse(self):
|
||||
return cde.OverdriveOperation(self.gain, self.color)
|
||||
|
||||
|
||||
class RiaaBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Apply RIAA vinyl playback equalization. Similar to SoX implementation.
|
||||
|
|
|
@ -292,6 +292,21 @@ def check_mu_law_decoding(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_overdrive(method):
|
||||
"""Wrapper method to check the parameters of Overdrive."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[gain, color], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(gain, (float, int), "gain")
|
||||
check_value(gain, [0, 100], "gain")
|
||||
type_check(color, (float, int), "color")
|
||||
check_value(color, [0, 100], "color")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_riaa_biquad(method):
|
||||
"""Wrapper method to check the parameters of RiaaBiquad."""
|
||||
|
||||
|
|
|
@ -955,6 +955,77 @@ TEST_F(MindDataTestPipeline, TestMuLawDecodingWrongArgs) {
|
|||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Overdrive
|
||||
/// Description: test basic usage of Overdrive
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestPipeline, TestOverdriveBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOverdriveBasic.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto OverdriveOp = audio::Overdrive();
|
||||
|
||||
ds = ds->Map({OverdriveOp});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Apply a phasing effect to the audio
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Overdrive
|
||||
/// Description: test invalid parameter of Overdrive
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestPipeline, TestOverdriveWrongArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestOverdriveWrongArg.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
// Original waveform
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
std::shared_ptr<Dataset> ds01;
|
||||
std::shared_ptr<Dataset> ds02;
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Check gain out of range [0,100]
|
||||
MS_LOG(INFO) << "gain is less than 0.";
|
||||
auto overdrive_op_01 = audio::Overdrive(-0.2, 20.0);
|
||||
ds01 = ds->Map({overdrive_op_01});
|
||||
EXPECT_NE(ds01, nullptr);
|
||||
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
|
||||
EXPECT_EQ(iter01, nullptr);
|
||||
|
||||
// Check color out of range [0,100]
|
||||
MS_LOG(INFO) << "color is greater than 100.";
|
||||
auto overdrive_op_02 = audio::Overdrive(20.0, 102.3);
|
||||
ds02 = ds->Map({overdrive_op_02});
|
||||
EXPECT_NE(ds02, nullptr);
|
||||
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
|
||||
EXPECT_EQ(iter02, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestLfilterPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLfilterPipeline.";
|
||||
// Original waveform
|
||||
|
|
|
@ -876,6 +876,56 @@ TEST_F(MindDataTestExecute, TestMuLawDecodingEager) {
|
|||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: Overdrive
|
||||
/// Description: test basic usage of Overdrive
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestExecute, TestOverdriveBasicWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestOverdriveBasicWithEager.";
|
||||
// Original waveform
|
||||
std::vector<float> labels = {
|
||||
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
|
||||
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
|
||||
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> phaser_op_01 = std::make_shared<audio::Overdrive>(5.0, 3.0);
|
||||
mindspore::dataset::Execute Transform01({phaser_op_01});
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: Overdrive
|
||||
/// Description: test invalid parameter of Overdrive
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestExecute, TestOverdriveWrongArgWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestOverdriveWrongArgWithEager";
|
||||
std::vector<double> labels = {
|
||||
0.271, 1.634, 9.246, 0.108,
|
||||
1.138, 1.156, 3.394, 1.55,
|
||||
3.614, 1.8402, 0.718, 4.599,
|
||||
5.64, 2.510620117187500000e-02, 1.38, 5.825,
|
||||
4.1906, 5.28, 1.052, 9.36};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input));
|
||||
|
||||
// verify the gain range from 0 to 100
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> overdrive_op1 = std::make_shared<audio::Overdrive>(100.1);
|
||||
mindspore::dataset::Execute Transform01({overdrive_op1});
|
||||
Status s01 = Transform01(input_01, &input_01);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
|
||||
// verify the color range from 0 to 100
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> overdrive_op2 = std::make_shared<audio::Overdrive>(5.0, 100.1);
|
||||
mindspore::dataset::Execute Transform02({overdrive_op2});
|
||||
Status s02 = Transform02(input_02, &input_02);
|
||||
EXPECT_FALSE(s02.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestRiaaBiquadWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestRiaaBiquadWithEager.";
|
||||
// Original waveform
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
|
||||
data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_overdrive_eager():
|
||||
"""
|
||||
Feature: Overdrive
|
||||
Description: test Overdrive in eager mode
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
# Original waveform
|
||||
waveform = np.array([[1.47, 4.722, 5.863], [0.492, 0.235, 0.56]], dtype=np.float32)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[1., 1., 1.],
|
||||
[0.74600005, 0.615, 0.77501255]], dtype=np.float32)
|
||||
overdrive_op = audio.Overdrive()
|
||||
# Filtered waveform by overdrive
|
||||
output = overdrive_op(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_overdrive_pipeline():
|
||||
"""
|
||||
Feature: Overdrive
|
||||
Description: test Overdrive in pipeline mode
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
# Original waveform
|
||||
waveform = np.array([[0.1, 0.2], [0.4, 2.6]], dtype=np.float32)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.29598799, 0.52081579],
|
||||
[0.7, 1.]], dtype=np.float32)
|
||||
dataset = ds.NumpySlicesDataset(waveform, ["waveform"], shuffle=False)
|
||||
overdrive_op = audio.Overdrive(10.0, 5.0)
|
||||
# Filtered waveform by overdrive
|
||||
dataset = dataset.map(
|
||||
input_columns=["waveform"], operations=overdrive_op)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['waveform'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
||||
def test_overdrive_invalid_input():
|
||||
"""
|
||||
Feature: Overdrive
|
||||
Description: test invalid parameter of Overdrive
|
||||
Expectation: catch exceptions correctly
|
||||
"""
|
||||
def test_invalid_input(test_name, gain, color, error, error_msg):
|
||||
logger.info("Test Overdrive with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.Overdrive(gain, color)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid gain parameter type as a str", "20", 20, TypeError,
|
||||
"Argument gain with value 20 is not of type [<class 'float'>, <class 'int'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid color parameter type as a str", 10, "5", TypeError,
|
||||
"Argument color with value 5 is not of type [<class 'float'>, <class 'int'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid gain out of range [0, 100]", 100.23, 5.0, ValueError,
|
||||
"Input gain is not within the required interval of [0, 100].")
|
||||
test_invalid_input("invalid color out of range [0, 100]", 30, -0.333, ValueError,
|
||||
"Input color is not within the required interval of [0, 100].")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_overdrive_eager()
|
||||
test_overdrive_pipeline()
|
||||
test_overdrive_invalid_input()
|
Loading…
Reference in New Issue