forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3CEGF] add op fade
This commit is contained in:
parent
c0e821dc98
commit
1a3196b052
|
@ -29,6 +29,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h"
|
||||
|
@ -231,6 +232,22 @@ std::shared_ptr<TensorOperation> EqualizerBiquad::Parse() {
|
|||
return std::make_shared<EqualizerBiquadOperation>(data_->sample_rate_, data_->center_freq_, data_->gain_, data_->Q_);
|
||||
}
|
||||
|
||||
// Fade Transform Operation.
|
||||
struct Fade::Data {
|
||||
Data(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape)
|
||||
: fade_in_len_(fade_in_len), fade_out_len_(fade_out_len), fade_shape_(fade_shape) {}
|
||||
int32_t fade_in_len_;
|
||||
int32_t fade_out_len_;
|
||||
FadeShape fade_shape_;
|
||||
};
|
||||
|
||||
Fade::Fade(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape)
|
||||
: data_(std::make_shared<Data>(fade_in_len, fade_out_len, fade_shape)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Fade::Parse() {
|
||||
return std::make_shared<FadeOperation>(data_->fade_in_len_, data_->fade_out_len_, data_->fade_shape_);
|
||||
}
|
||||
|
||||
// FrequencyMasking Transform Operation.
|
||||
struct FrequencyMasking::Data {
|
||||
Data(bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value)
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/frequency_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h"
|
||||
|
@ -192,6 +193,26 @@ PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(FadeShape, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<FadeShape>(*m, "FadeShape", py::arithmetic())
|
||||
.value("DE_FADESHAPE_LINEAR", FadeShape::kLinear)
|
||||
.value("DE_FADESHAPE_EXPONENTIAL", FadeShape::kExponential)
|
||||
.value("DE_FADESHAPE_LOGARITHMIC", FadeShape::kLogarithmic)
|
||||
.value("DE_FADESHAPE_QUARTERSINE", FadeShape::kQuarterSine)
|
||||
.value("DE_FADESHAPE_HALFSINE", FadeShape::kHalfSine)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(FadeOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::FadeOperation, TensorOperation, std::shared_ptr<audio::FadeOperation>>(
|
||||
*m, "FadeOperation")
|
||||
.def(py::init([](int fade_in_len, int fade_out_len, FadeShape fade_shape) {
|
||||
auto fade = std::make_shared<audio::FadeOperation>(fade_in_len, fade_out_len, fade_shape);
|
||||
THROW_IF_ERROR(fade->ValidateParams());
|
||||
return fade;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
FrequencyMaskingOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
|
|
|
@ -15,6 +15,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
dc_shift_ir.cc
|
||||
deemph_biquad_ir.cc
|
||||
equalizer_biquad_ir.cc
|
||||
fade_ir.cc
|
||||
frequency_masking_ir.cc
|
||||
highpass_biquad_ir.cc
|
||||
lfilter_ir.cc
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/fade_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
FadeOperation::FadeOperation(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape)
|
||||
: fade_in_len_(fade_in_len), fade_out_len_(fade_out_len), fade_shape_(fade_shape) {}
|
||||
|
||||
Status FadeOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("Fade", "fade_in_len", fade_in_len_));
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarNonNegative("Fade", "fade_out_len", fade_out_len_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> FadeOperation::Build() {
|
||||
std::shared_ptr<FadeOp> tensor_op = std::make_shared<FadeOp>(fade_in_len_, fade_out_len_, fade_shape_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status FadeOperation::to_json(nlohmann::json *const out_json) {
|
||||
nlohmann::json args;
|
||||
args["fade_in_len"] = fade_in_len_;
|
||||
args["fade_out_len"] = fade_out_len_;
|
||||
args["fade_shape"] = fade_shape_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FADE_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FADE_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
|
||||
constexpr char kFadeOperation[] = "Fade";
|
||||
|
||||
class FadeOperation : public TensorOperation {
|
||||
public:
|
||||
explicit FadeOperation(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape);
|
||||
|
||||
~FadeOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kFadeOperation; }
|
||||
|
||||
/// \brief Get the arguments of node
|
||||
/// \param[out] out_json JSON string of all attributes
|
||||
/// \return Status of the function
|
||||
Status to_json(nlohmann::json *const out_json) override;
|
||||
|
||||
private:
|
||||
int32_t fade_in_len_;
|
||||
int32_t fade_out_len_;
|
||||
FadeShape fade_shape_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_FADE_IR_H_
|
|
@ -16,6 +16,7 @@ add_library(audio-kernels OBJECT
|
|||
dc_shift_op.cc
|
||||
deemph_biquad_op.cc
|
||||
equalizer_biquad_op.cc
|
||||
fade_op.cc
|
||||
frequency_masking_op.cc
|
||||
highpass_biquad_op.cc
|
||||
lfilter_op.cc
|
||||
|
|
|
@ -42,7 +42,7 @@ Status Linspace(std::shared_ptr<Tensor> *output, T start, T end, int n) {
|
|||
n = std::isnan(n) ? 100 : n;
|
||||
TensorShape out_shape({n});
|
||||
std::vector<T> linear_vect(n);
|
||||
T interval = (end - start) / (n - 1);
|
||||
T interval = (n == 1) ? 0 : ((end - start) / (n - 1));
|
||||
for (int i = 0; i < linear_vect.size(); ++i) {
|
||||
linear_vect[i] = start + i * interval;
|
||||
}
|
||||
|
@ -509,5 +509,126 @@ Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status FadeIn(std::shared_ptr<Tensor> *output, int32_t fade_in_len, FadeShape fade_shape) {
|
||||
T start = 0;
|
||||
T end = 1;
|
||||
RETURN_IF_NOT_OK(Linspace<T>(output, start, end, fade_in_len));
|
||||
for (auto iter = (*output)->begin<T>(); iter != (*output)->end<T>(); iter++) {
|
||||
switch (fade_shape) {
|
||||
case FadeShape::kLinear:
|
||||
break;
|
||||
case FadeShape::kExponential:
|
||||
// Compute the scale factor of the exponential function, pow(2.0, *in_ter - 1.0) * (*in_ter)
|
||||
*iter = static_cast<T>(std::pow(2.0, *iter - 1.0) * (*iter));
|
||||
break;
|
||||
case FadeShape::kLogarithmic:
|
||||
// Compute the scale factor of the logarithmic function, log(*in_iter + 0.1) + 1.0
|
||||
*iter = static_cast<T>(std::log10(*iter + 0.1) + 1.0);
|
||||
break;
|
||||
case FadeShape::kQuarterSine:
|
||||
// Compute the scale factor of the quarter_sine function, sin((*in_iter - 1.0) * PI / 2.0)
|
||||
*iter = static_cast<T>(std::sin((*iter) * PI / 2.0));
|
||||
break;
|
||||
case FadeShape::kHalfSine:
|
||||
// Compute the scale factor of the half_sine function, sin((*in_iter) * PI - PI / 2.0) / 2.0 + 0.5
|
||||
*iter = static_cast<T>(std::sin((*iter) * PI - PI / 2.0) / 2.0 + 0.5);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status FadeOut(std::shared_ptr<Tensor> *output, int32_t fade_out_len, FadeShape fade_shape) {
|
||||
T start = 0;
|
||||
T end = 1;
|
||||
RETURN_IF_NOT_OK(Linspace<T>(output, start, end, fade_out_len));
|
||||
for (auto iter = (*output)->begin<T>(); iter != (*output)->end<T>(); iter++) {
|
||||
switch (fade_shape) {
|
||||
case FadeShape::kLinear:
|
||||
// In fade out, invert *out_iter
|
||||
*iter = static_cast<T>(1.0 - *iter);
|
||||
break;
|
||||
case FadeShape::kExponential:
|
||||
// Compute the scale factor of the exponential function
|
||||
*iter = static_cast<T>(std::pow(2.0, -*iter) * (1.0 - *iter));
|
||||
break;
|
||||
case FadeShape::kLogarithmic:
|
||||
// Compute the scale factor of the logarithmic function
|
||||
*iter = static_cast<T>(std::log10(1.1 - *iter) + 1.0);
|
||||
break;
|
||||
case FadeShape::kQuarterSine:
|
||||
// Compute the scale factor of the quarter_sine function
|
||||
*iter = static_cast<T>(std::sin((*iter) * PI / 2.0 + PI / 2.0));
|
||||
break;
|
||||
case FadeShape::kHalfSine:
|
||||
// Compute the scale factor of the half_sine function
|
||||
*iter = static_cast<T>(std::sin((*iter) * PI + PI / 2.0) / 2.0 + 0.5);
|
||||
break;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Fade(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t fade_in_len,
|
||||
int32_t fade_out_len, FadeShape fade_shape) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input, output));
|
||||
const TensorShape input_shape = input->shape();
|
||||
int32_t waveform_length = static_cast<int32_t>(input_shape[-1]);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fade_in_len <= waveform_length, "Fade: fade_in_len exceeds waveform length.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(fade_out_len <= waveform_length, "Fade: fade_out_len exceeds waveform length.");
|
||||
int32_t num_waveform = static_cast<int32_t>(input->Size() / waveform_length);
|
||||
TensorShape toShape = TensorShape({num_waveform, waveform_length});
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(toShape));
|
||||
TensorPtr fade_in;
|
||||
RETURN_IF_NOT_OK(FadeIn<T>(&fade_in, fade_in_len, fade_shape));
|
||||
TensorPtr fade_out;
|
||||
RETURN_IF_NOT_OK(FadeOut<T>(&fade_out, fade_out_len, fade_shape));
|
||||
|
||||
// Add fade in to input tensor
|
||||
auto output_iter = (*output)->begin<T>();
|
||||
for (auto fade_in_iter = fade_in->begin<T>(); fade_in_iter != fade_in->end<T>(); fade_in_iter++) {
|
||||
*output_iter = (*output_iter) * (*fade_in_iter);
|
||||
for (int32_t j = 1; j < num_waveform; j++) {
|
||||
output_iter += waveform_length;
|
||||
*output_iter = (*output_iter) * (*fade_in_iter);
|
||||
}
|
||||
output_iter -= ((num_waveform - 1) * waveform_length);
|
||||
++output_iter;
|
||||
}
|
||||
|
||||
// Add fade out to input tensor
|
||||
output_iter = (*output)->begin<T>();
|
||||
output_iter += (waveform_length - fade_out_len);
|
||||
for (auto fade_out_iter = fade_out->begin<T>(); fade_out_iter != fade_out->end<T>(); fade_out_iter++) {
|
||||
*output_iter = (*output_iter) * (*fade_out_iter);
|
||||
for (int32_t j = 1; j < num_waveform; j++) {
|
||||
output_iter += waveform_length;
|
||||
*output_iter = (*output_iter) * (*fade_out_iter);
|
||||
}
|
||||
output_iter -= ((num_waveform - 1) * waveform_length);
|
||||
++output_iter;
|
||||
}
|
||||
(*output)->Reshape(input_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Fade(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t fade_in_len,
|
||||
int32_t fade_out_len, FadeShape fade_shape) {
|
||||
if (DataType::DE_INT8 <= input->type().value() && input->type().value() <= DataType::DE_FLOAT32) {
|
||||
std::shared_ptr<Tensor> waveform;
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &waveform, DataType(DataType::DE_FLOAT32)));
|
||||
RETURN_IF_NOT_OK(Fade<float>(waveform, output, fade_in_len, fade_out_len, fade_shape));
|
||||
} else if (input->type().value() == DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK(Fade<double>(input, output, fade_in_len, fade_out_len, fade_shape));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Fade: input tensor type should be int, float or double, but got: " +
|
||||
input->type().ToString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -317,6 +317,14 @@ Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
|
|||
/// \return Status code.
|
||||
Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int quantization_channels);
|
||||
|
||||
/// \brief Add a fade in and/or fade out to an input.
|
||||
/// \param[in] input: The input tensor.
|
||||
/// \param[out] output: Added fade in and/or fade out audio with the same shape.
|
||||
/// \param[in] fade_in_len: Length of fade-in (time frames).
|
||||
/// \param[in] fade_out_len: Length of fade-out (time frames).
|
||||
/// \param[in] fade_shape: Shape of fade.
|
||||
Status Fade(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t fade_in_len,
|
||||
int32_t fade_out_len, FadeShape fade_shape);
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/fade_op.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
constexpr int32_t FadeOp::kFadeInLen = 0;
|
||||
constexpr int32_t FadeOp::kFadeOutLen = 0;
|
||||
constexpr FadeShape FadeOp::kFadeShape = FadeShape::kLinear;
|
||||
|
||||
Status FadeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "Fade: input tensor is not in shape of <..., time>.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
DataType::DE_INT8 <= input->type().value() && input->type().value() <= DataType::DE_FLOAT64,
|
||||
"Fade: input tensor type should be int, float or double, but got: " + input->type().ToString());
|
||||
if (fade_in_len_ == 0 && fade_out_len_ == 0) {
|
||||
*output = input;
|
||||
} else {
|
||||
RETURN_IF_NOT_OK(Fade(input, output, fade_in_len_, fade_out_len_, fade_shape_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FadeOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
if (inputs[0] >= DataType::DE_INT8 && inputs[0] <= DataType::DE_FLOAT32) {
|
||||
outputs[0] == DataType(DataType::DE_FLOAT32);
|
||||
} else if (inputs[0] == DataType::DE_FLOAT64) {
|
||||
outputs[0] == DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("Fade: input tensor type should be int, float or double, but got: " +
|
||||
inputs[0].ToString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FADE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FADE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class FadeOp : public TensorOp {
|
||||
public:
|
||||
/// Default fade in len to be used
|
||||
static const int32_t kFadeInLen;
|
||||
/// Default fade out len to be used
|
||||
static const int32_t kFadeOutLen;
|
||||
/// Default fade shape to be used
|
||||
static const FadeShape kFadeShape;
|
||||
|
||||
explicit FadeOp(int32_t fade_in_len = kFadeInLen, int32_t fade_out_len = kFadeOutLen,
|
||||
FadeShape fade_shape = kFadeShape)
|
||||
: fade_in_len_(fade_in_len), fade_out_len_(fade_out_len), fade_shape_(fade_shape) {}
|
||||
|
||||
~FadeOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kFadeOp; }
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
private:
|
||||
int32_t fade_in_len_;
|
||||
int32_t fade_out_len_;
|
||||
FadeShape fade_shape_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_FADE_OP_H_
|
|
@ -323,6 +323,30 @@ class EqualizerBiquad final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Add fade in or/and fade out on the input audio.
|
||||
class Fade final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] fade_in_len Length of fade-in (time frames), which must be non-negative
|
||||
/// and no more than the length of waveform (Default: 0).
|
||||
/// \param[in] fade_out_len Length of fade-out (time frames), which must be non-negative
|
||||
/// and no more than the length of waveform (Default: 0).
|
||||
/// \param[in] fade_shape An enum for the fade shape (Default: FadeShape::kLinear).
|
||||
explicit Fade(int32_t fade_in_len = 0, int32_t fade_out_len = 0, FadeShape fade_shape = FadeShape::kLinear);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Fade() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief FrequencyMasking TensorTransform.
|
||||
/// \notes Apply masking to a spectrogram in the frequency domain.
|
||||
class FrequencyMasking final : public TensorTransform {
|
||||
|
|
|
@ -186,6 +186,15 @@ enum class OutputFormat {
|
|||
kCsr = 2 ///< CSR format.
|
||||
};
|
||||
|
||||
/// \brief Possible options for fade shape.
|
||||
enum class FadeShape {
|
||||
kLinear = 0, ///< Fade shape is linear mode.
|
||||
kExponential = 1, ///< Fade shape is exponential mode.
|
||||
kLogarithmic = 2, ///< Fade shape is logarithmic mode.
|
||||
kQuarterSine = 3, ///< Fade shape is quarter_sine mode.
|
||||
kHalfSine = 4, ///< Fade shape is half_sine mode.
|
||||
};
|
||||
|
||||
/// \brief Convenience function to check bitmask for a 32bit int
|
||||
/// \param[in] bits a 32bit int to be tested
|
||||
/// \param[in] bitMask a 32bit int representing bit mask
|
||||
|
|
|
@ -152,6 +152,7 @@ constexpr char kContrastOp[] = "ContrastOp";
|
|||
constexpr char kDCShiftOp[] = "DCShiftOp";
|
||||
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
|
||||
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
|
||||
constexpr char kFadeOp[] = "FadeOp";
|
||||
constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
|
||||
constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp";
|
||||
constexpr char kLFilterOp[] = "LFilterOp";
|
||||
|
|
|
@ -23,11 +23,11 @@ import numpy as np
|
|||
|
||||
import mindspore._c_dataengine as cde
|
||||
from ..transforms.c_transforms import TensorOperation
|
||||
from .utils import ScaleType
|
||||
from .utils import FadeShape, ScaleType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_equalizer_biquad, check_highpass_biquad, check_lfilter, check_lowpass_biquad, \
|
||||
check_masking, check_mu_law_decoding, check_time_stretch
|
||||
check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \
|
||||
check_lowpass_biquad, check_masking, check_mu_law_decoding, check_time_stretch
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -408,6 +408,56 @@ class EqualizerBiquad(AudioTensorOperation):
|
|||
return cde.EqualizerBiquadOperation(self.sample_rate, self.center_freq, self.gain, self.Q)
|
||||
|
||||
|
||||
DE_C_FADESHAPE_TYPE = {FadeShape.LINEAR: cde.FadeShape.DE_FADESHAPE_LINEAR,
|
||||
FadeShape.EXPONENTIAL: cde.FadeShape.DE_FADESHAPE_EXPONENTIAL,
|
||||
FadeShape.LOGARITHMIC: cde.FadeShape.DE_FADESHAPE_LOGARITHMIC,
|
||||
FadeShape.QUARTERSINE: cde.FadeShape.DE_FADESHAPE_QUARTERSINE,
|
||||
FadeShape.HALFSINE: cde.FadeShape.DE_FADESHAPE_HALFSINE}
|
||||
|
||||
|
||||
class Fade(AudioTensorOperation):
|
||||
"""
|
||||
Add a fade in and/or fade out to an waveform.
|
||||
|
||||
Args:
|
||||
fade_in_len (int, optional): Length of fade-in (time frames), which must be non-negative (default=0).
|
||||
fade_out_len (int, optional): Length of fade-out (time frames), which must be non-negative (default=0).
|
||||
fade_shape (FadeShape, optional): Shape of fade (default=FadeShape.LINEAR). Can be one of
|
||||
[FadeShape.LINEAR, FadeShape.EXPONENTIAL, FadeShape.LOGARITHMIC, FadeShape.QUARTERSINC, FadeShape.HALFSINC].
|
||||
|
||||
-FadeShape.LINEAR, means it linear to 0.
|
||||
|
||||
-FadeShape.EXPONENTIAL, means it tend to 0 in an exponential function.
|
||||
|
||||
-FadeShape.LOGARITHMIC, means it tend to 0 in an logrithmic function.
|
||||
|
||||
-FadeShape.QUARTERSINE, means it tend to 0 in an quarter sin function.
|
||||
|
||||
-FadeShape.HALFSINE, means it tend to 0 in an half sin function.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If fade_in_len exceeds waveform length.
|
||||
RuntimeError: If fade_out_len exceeds waveform length.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03, 9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Fade(fade_in_len=3, fade_out_len=2, fade_shape=FadeShape.LINEAR)]
|
||||
>>> dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_fade
|
||||
def __init__(self, fade_in_len=0, fade_out_len=0, fade_shape=FadeShape.LINEAR):
|
||||
self.fade_in_len = fade_in_len
|
||||
self.fade_out_len = fade_out_len
|
||||
self.fade_shape = fade_shape
|
||||
|
||||
def parse(self):
|
||||
return cde.FadeOperation(self.fade_in_len, self.fade_out_len, DE_C_FADESHAPE_TYPE[self.fade_shape])
|
||||
|
||||
|
||||
class FrequencyMasking(AudioTensorOperation):
|
||||
"""
|
||||
Apply masking to a spectrogram in the frequency domain.
|
||||
|
|
|
@ -23,3 +23,12 @@ class ScaleType(str, Enum):
|
|||
"""Scale Type"""
|
||||
POWER: str = "power"
|
||||
MAGNITUDE: str = "magnitude"
|
||||
|
||||
|
||||
class FadeShape(str, Enum):
|
||||
"""Fade Shape"""
|
||||
LINEAR: str = "linear"
|
||||
EXPONENTIAL: str = "exponential"
|
||||
LOGARITHMIC: str = "logarithmic"
|
||||
QUARTERSINE: str = "quarter_sine"
|
||||
HALFSINE: str = "half_sine"
|
||||
|
|
|
@ -18,10 +18,10 @@ Validators for TensorOps.
|
|||
|
||||
from functools import wraps
|
||||
|
||||
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, \
|
||||
check_int32_not_zero, check_list_same_size, check_non_negative_float32, check_pos_float32, \
|
||||
check_pos_int32, check_value, parse_user_args, type_check
|
||||
from .utils import ScaleType
|
||||
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32_not_zero, \
|
||||
check_list_same_size, check_non_negative_float32, check_non_negative_int32, check_pos_float32, check_pos_int32, \
|
||||
check_value, parse_user_args, type_check
|
||||
from .utils import FadeShape, ScaleType
|
||||
|
||||
|
||||
def check_amplitude_to_db(method):
|
||||
|
@ -368,3 +368,19 @@ def check_biquad(method):
|
|||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_fade(method):
|
||||
"""Wrapper method to check the parameters of Fade."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[fade_in_len, fade_out_len, fade_shape], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(fade_in_len, (int,), "fade_in_len")
|
||||
check_non_negative_int32(fade_in_len, "fade_in_len")
|
||||
type_check(fade_out_len, (int,), "fade_out_len")
|
||||
check_non_negative_int32(fade_out_len, "fade_out_len")
|
||||
type_check(fade_shape, (FadeShape,), "fade_shape")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -302,6 +302,17 @@ def check_pos_int64(value, arg_name=""):
|
|||
check_value(value, [POS_INT_MIN, INT64_MAX])
|
||||
|
||||
|
||||
def check_non_negative_int32(value, arg_name=""):
|
||||
"""
|
||||
Validates the value of a variable is within the range of non negative int32.
|
||||
|
||||
:param value: the value of the variable.
|
||||
:param arg_name: name of the variable to be validated.
|
||||
:return: Exception: when the validation fails, nothing otherwise.
|
||||
"""
|
||||
check_value(value, [UINT32_MIN, INT32_MAX], arg_name)
|
||||
|
||||
|
||||
def check_float32(value, arg_name=""):
|
||||
"""
|
||||
Validates the value of a variable is within the range of float32.
|
||||
|
|
|
@ -1125,3 +1125,221 @@ TEST_F(MindDataTestPipeline, TestBiquadParamCheck) {
|
|||
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
|
||||
EXPECT_EQ(iter01, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFadeWithPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFadeWithPipeline.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {1, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto fade_op = audio::Fade(20, 30, FadeShape::kExponential);
|
||||
|
||||
ds = ds->Map({fade_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {1, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFadeWithLinear) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFadeWithLinear.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, 10}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(10, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto fade_op = audio::Fade(5, 5, FadeShape::kLinear);
|
||||
|
||||
ds = ds->Map({fade_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 10};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 10);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFadeWithLogarithmic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFadeWithLogarithmic.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat64, {1, 150}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(30, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto fade_op = audio::Fade(80, 100, FadeShape::kLogarithmic);
|
||||
|
||||
ds = ds->Map({fade_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {1, 150};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat64);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 30);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFadeWithQuarterSine) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFadeWithQuarterSine.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeInt32, {20, 20000}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(40, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto fade_op = audio::Fade(1000, 1000, FadeShape::kQuarterSine);
|
||||
|
||||
ds = ds->Map({fade_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {20, 20000};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 40);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFadeWithHalfSine) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFadeWithHalfSine.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeInt16, {1, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(40, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto fade_op = audio::Fade(100, 100, FadeShape::kHalfSine);
|
||||
|
||||
ds = ds->Map({fade_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {1, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 40);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestFadeWithInvalidArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestFadeWithInvalidArg.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {1, 200}));
|
||||
std::shared_ptr<Dataset> ds_01 = RandomData(50, schema);
|
||||
EXPECT_NE(ds_01, nullptr);
|
||||
|
||||
ds_01 = ds_01->SetNumWorkers(4);
|
||||
EXPECT_NE(ds_01, nullptr);
|
||||
|
||||
auto fade_op_01 = audio::Fade(-20, 30, FadeShape::kLogarithmic);
|
||||
|
||||
ds_01 = ds_01->Map({fade_op_01});
|
||||
EXPECT_NE(ds_01, nullptr);
|
||||
// Expect failure, fade in length less than zero
|
||||
std::shared_ptr<Iterator> iter_01 = ds_01->CreateIterator();
|
||||
EXPECT_EQ(iter_01, nullptr);
|
||||
|
||||
std::shared_ptr<Dataset> ds_02 = RandomData(50, schema);
|
||||
EXPECT_NE(ds_02, nullptr);
|
||||
|
||||
ds_02 = ds_02->SetNumWorkers(4);
|
||||
EXPECT_NE(ds_02, nullptr);
|
||||
|
||||
auto fade_op_02 = audio::Fade(5, -3, FadeShape::kExponential);
|
||||
|
||||
ds_02 = ds_02->Map({fade_op_02});
|
||||
EXPECT_NE(ds_02, nullptr);
|
||||
// Expect failure, fade out length less than zero
|
||||
std::shared_ptr<Iterator> iter_02 = ds_02->CreateIterator();
|
||||
EXPECT_EQ(iter_02, nullptr);
|
||||
}
|
||||
|
|
|
@ -961,3 +961,99 @@ TEST_F(MindDataTestExecute, TestBiquadWithWrongArg) {
|
|||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestFade) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestFade.";
|
||||
std::vector<float> waveform = {
|
||||
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
|
||||
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
|
||||
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(waveform, TensorShape({1, 20}), &input));
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade01 = std::make_shared<audio::Fade>(5, 6, FadeShape::kLinear);
|
||||
mindspore::dataset::Execute Transform01({fade01});
|
||||
Status s01 = Transform01(input_01, &input_01);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade02 = std::make_shared<audio::Fade>(5, 6, FadeShape::kQuarterSine);
|
||||
mindspore::dataset::Execute Transform02({fade02});
|
||||
Status s02 = Transform02(input_02, &input_02);
|
||||
EXPECT_TRUE(s02.IsOk());
|
||||
auto input_03 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade03 = std::make_shared<audio::Fade>(5, 6, FadeShape::kExponential);
|
||||
mindspore::dataset::Execute Transform03({fade03});
|
||||
Status s03 = Transform03(input_03, &input_03);
|
||||
EXPECT_TRUE(s03.IsOk());
|
||||
auto input_04 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade04 = std::make_shared<audio::Fade>(5, 6, FadeShape::kHalfSine);
|
||||
mindspore::dataset::Execute Transform04({fade04});
|
||||
Status s04 = Transform01(input_04, &input_04);
|
||||
EXPECT_TRUE(s04.IsOk());
|
||||
auto input_05 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade05 = std::make_shared<audio::Fade>(5, 6, FadeShape::kLogarithmic);
|
||||
mindspore::dataset::Execute Transform05({fade05});
|
||||
Status s05 = Transform01(input_05, &input_05);
|
||||
EXPECT_TRUE(s05.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestFadeDefaultArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestFadeDefaultArg.";
|
||||
std::vector<double> waveform = {
|
||||
1.573897564868000000e-03, 5.462374385400000000e-03, 3.584989689205400000e-03, 2.035667767462500000e-02,
|
||||
2.353543454062500000e-02, 1.256616210937500000e-02, 2.394653320312500000e-02, 5.243553968750000000e-02,
|
||||
2.434554533002500000e-02, 3.454566960937500000e-02, 2.343545454437500000e-02, 2.534343093750000000e-02,
|
||||
2.354465654550000000e-02, 1.453545517187500000e-02, 1.454645535875000000e-02, 1.433243195312500000e-02,
|
||||
1.434354554812500000e-02, 3.343435276865400000e-02, 1.234257687312500000e-02, 5.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(waveform, TensorShape({2, 10}), &input));
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade01 = std::make_shared<audio::Fade>();
|
||||
mindspore::dataset::Execute Transform01({fade01});
|
||||
Status s01 = Transform01(input_01, &input_01);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade02 = std::make_shared<audio::Fade>(5);
|
||||
mindspore::dataset::Execute Transform02({fade02});
|
||||
Status s02 = Transform02(input_02, &input_02);
|
||||
EXPECT_TRUE(s02.IsOk());
|
||||
auto input_03 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade03 = std::make_shared<audio::Fade>(5, 6);
|
||||
mindspore::dataset::Execute Transform03({fade03});
|
||||
Status s03 = Transform03(input_03, &input_03);
|
||||
EXPECT_TRUE(s03.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestFadeWithInvalidArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestFadeWithInvalidArg.";
|
||||
std::vector<float> waveform = {
|
||||
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
|
||||
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
|
||||
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(waveform, TensorShape({1, 20}), &input));
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade1 = std::make_shared<audio::Fade>(-5, 6);
|
||||
mindspore::dataset::Execute Transform01({fade1});
|
||||
Status s01 = Transform01(input_01, &input_01);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade2 = std::make_shared<audio::Fade>(0, -1);
|
||||
mindspore::dataset::Execute Transform02({fade2});
|
||||
Status s02 = Transform02(input_02, &input_02);
|
||||
EXPECT_FALSE(s02.IsOk());
|
||||
auto input_03 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade3 = std::make_shared<audio::Fade>(30, 10);
|
||||
mindspore::dataset::Execute Transform03({fade3});
|
||||
Status s03 = Transform03(input_03, &input_03);
|
||||
EXPECT_FALSE(s03.IsOk());
|
||||
auto input_04 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> fade4 = std::make_shared<audio::Fade>(10, 30);
|
||||
mindspore::dataset::Execute Transform04({fade4});
|
||||
Status s04 = Transform04(input_04, &input_04);
|
||||
EXPECT_FALSE(s04.IsOk());
|
||||
}
|
|
@ -0,0 +1,189 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing fade op in DE
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore.dataset.audio.utils import FadeShape
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def test_fade_linear():
|
||||
"""
|
||||
Test Fade, fade shape is linear.
|
||||
"""
|
||||
logger.info("test fade, fade shape is 'linear'")
|
||||
|
||||
waveform = [[[9.1553e-05, 6.1035e-05, 6.1035e-05, 6.1035e-05, 1.2207e-04, 1.2207e-04,
|
||||
9.1553e-05, 9.1553e-05, 9.1553e-05, 9.1553e-05, 9.1553e-05, 6.1035e-05,
|
||||
1.2207e-04, 1.2207e-04, 1.2207e-04, 9.1553e-05, 9.1553e-05, 9.1553e-05,
|
||||
6.1035e-05, 9.1553e-05]]]
|
||||
dataset = ds.NumpySlicesDataset(data=waveform, column_names='audio', shuffle=False)
|
||||
transforms = [audio.Fade(fade_in_len=10, fade_out_len=5, fade_shape=FadeShape.LINEAR)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
out_put = item["audio"]
|
||||
# The result of the reference operator
|
||||
expected_output = np.array([[0.0000000000000000000, 6.781666797905927e-06, 1.356333359581185e-05,
|
||||
2.034499993897043e-05, 5.425333438324742e-05, 6.781666888855398e-05,
|
||||
6.103533087298274e-05, 7.120789086911827e-05, 8.138045086525380e-05,
|
||||
9.155300358543172e-05, 9.155300358543172e-05, 6.103499981691129e-05,
|
||||
0.0001220699996338225, 0.0001220699996338225, 0.0001220699996338225,
|
||||
9.155300358543172e-05, 6.866475450806320e-05, 4.577650179271586e-05,
|
||||
1.525874995422782e-05, 0.0000000000000000000]], dtype=np.float32)
|
||||
assert np.mean(out_put - expected_output) < 0.0001
|
||||
|
||||
|
||||
def test_fade_exponential():
|
||||
"""
|
||||
Test Fade, fade shape is exponential.
|
||||
"""
|
||||
logger.info("test fade, fade shape is 'exponential'")
|
||||
|
||||
waveform = [[[1, 2, 3, 4, 5, 6],
|
||||
[5, 7, 3, 78, 8, 4]]]
|
||||
dataset = ds.NumpySlicesDataset(data=waveform, column_names='audio', shuffle=False)
|
||||
transforms = [audio.Fade(fade_in_len=5, fade_out_len=6, fade_shape=FadeShape.EXPONENTIAL)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
out_put = item["audio"]
|
||||
# The result of the reference operator
|
||||
expected_output = np.array([[0.0000, 0.2071, 0.4823, 0.6657, 0.5743, 0.0000],
|
||||
[0.0000, 0.7247, 0.4823, 12.9820, 0.9190, 0.0000]], dtype=np.float32)
|
||||
assert np.mean(out_put - expected_output) < 0.0001
|
||||
|
||||
|
||||
def test_fade_logarithmic():
|
||||
"""
|
||||
Test Fade, fade shape is logarithmic.
|
||||
"""
|
||||
logger.info("test fade, fade shape is 'logarithmic'")
|
||||
|
||||
waveform = [[[0.03424072265625, 0.01476832226565, 0.04995727590625,
|
||||
-0.0205993652375, -0.0356467868775, 0.01290893546875]]]
|
||||
dataset = ds.NumpySlicesDataset(data=waveform, column_names='audio', shuffle=False)
|
||||
transforms = [audio.Fade(fade_in_len=4, fade_out_len=2, fade_shape=FadeShape.LOGARITHMIC)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
out_put = item["audio"]
|
||||
# The result of the reference operator
|
||||
expected_output = np.array([[0.0000e+00, 9.4048e-03, 4.4193e-02,
|
||||
-2.0599e-02, -3.5647e-02, 1.5389e-09]],
|
||||
dtype=np.float32)
|
||||
assert np.mean(out_put - expected_output) < 0.0001
|
||||
|
||||
|
||||
def test_fade_quarter_sine():
|
||||
"""
|
||||
Test Fade, fade shape is quarter_sine.
|
||||
"""
|
||||
logger.info("test fade, fade shape is 'quarter sine'")
|
||||
|
||||
waveform = np.array([[[1, 2, 3, 4, 5, 6],
|
||||
[5, 7, 3, 78, 8, 4],
|
||||
[1, 2, 3, 4, 5, 6]]], dtype=np.float64)
|
||||
dataset = ds.NumpySlicesDataset(data=waveform, column_names='audio', shuffle=False)
|
||||
transforms = [audio.Fade(fade_in_len=6, fade_out_len=6, fade_shape=FadeShape.QUARTERSINE)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
out_put = item["audio"]
|
||||
# The result of the reference operator
|
||||
expected_output = np.array([[0.0000, 0.5878, 1.4266, 1.9021, 1.4695, 0.0000],
|
||||
[0.0000, 2.0572, 1.4266, 37.091, 2.3511, 0.0000],
|
||||
[0.0000, 0.5878, 1.4266, 1.9021, 1.4695, 0.0000]], dtype=np.float64)
|
||||
assert np.mean(out_put - expected_output) < 0.0001
|
||||
|
||||
|
||||
def test_fade_half_sine():
|
||||
"""
|
||||
Test Fade, fade shape is half_sine.
|
||||
"""
|
||||
logger.info("test fade, fade shape is 'half sine'")
|
||||
|
||||
waveform = [[[0.03424072265625, 0.013580322265625, -0.011871337890625,
|
||||
-0.0205993652343, -0.01049804687500, 0.0129089355468750],
|
||||
[0.04125976562500, 0.060577392578125, 0.0499572753906250,
|
||||
0.01306152343750, -0.019683837890625, -0.018829345703125]]]
|
||||
dataset = ds.NumpySlicesDataset(data=waveform, column_names='audio', shuffle=False)
|
||||
transforms = [audio.Fade(fade_in_len=3, fade_out_len=3, fade_shape=FadeShape.HALFSINE)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
out_put = item["audio"]
|
||||
# The result of the reference operator
|
||||
expected_output = np.array([[0.0000, 0.0068, -0.0119, -0.0206, -0.0052, 0.0000],
|
||||
[0.0000, 0.0303, 0.0500, 0.0131, -0.0098, -0.0000]], dtype=np.float32)
|
||||
assert np.mean(out_put - expected_output) < 0.0001
|
||||
|
||||
|
||||
def test_fade_wrong_arguments():
|
||||
"""
|
||||
Test Fade with invalid arguments
|
||||
"""
|
||||
logger.info("test fade with invalid arguments")
|
||||
try:
|
||||
_ = audio.Fade(-1, 0)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in Fade: {}".format(str(e)))
|
||||
assert "fade_in_len is not within the required interval of [0, 2147483647]" in str(e)
|
||||
try:
|
||||
_ = audio.Fade(0, -1)
|
||||
except ValueError as e:
|
||||
logger.info("Got an exception in Fade: {}".format(str(e)))
|
||||
assert "fade_out_len is not within the required interval of [0, 2147483647]" in str(e)
|
||||
try:
|
||||
_ = audio.Fade(fade_shape='123')
|
||||
except TypeError as e:
|
||||
logger.info("Got an exception in Fade: {}".format(str(e)))
|
||||
assert "is not of type [<enum 'FadeShape'>]" in str(e)
|
||||
|
||||
|
||||
def test_fade_eager():
|
||||
"""
|
||||
Test Fade eager.
|
||||
"""
|
||||
logger.info("test fade eager")
|
||||
|
||||
data = np.array([[9.1553e-05, 6.1035e-05, 6.1035e-05, 6.1035e-05, 1.2207e-04, 1.2207e-04,
|
||||
9.1553e-05, 9.1553e-05, 9.1553e-05, 9.1553e-05, 9.1553e-05, 6.1035e-05,
|
||||
1.2207e-04, 1.2207e-04, 1.2207e-04, 9.1553e-05, 9.1553e-05, 9.1553e-05,
|
||||
6.1035e-05, 9.1553e-05]]).astype(np.float32)
|
||||
expected_output = np.array([0.0000000000000000000, 6.781666797905927e-06, 1.356333359581185e-05,
|
||||
2.034499993897043e-05, 5.425333438324742e-05, 6.781666888855398e-05,
|
||||
6.103533087298274e-05, 7.120789086911827e-05, 8.138045086525380e-05,
|
||||
9.155300358543172e-05, 9.155300358543172e-05, 6.103499981691129e-05,
|
||||
0.0001220699996338225, 0.0001220699996338225, 0.0001220699996338225,
|
||||
9.155300358543172e-05, 6.866475450806320e-05, 4.577650179271586e-05,
|
||||
1.525874995422782e-05, 0.0000000000000000000], dtype=np.float32)
|
||||
fade = audio.Fade(10, 5, fade_shape=FadeShape.LINEAR)
|
||||
out_put = fade(data)
|
||||
assert np.mean(out_put - expected_output) < 0.0001
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fade_linear()
|
||||
test_fade_exponential()
|
||||
test_fade_logarithmic()
|
||||
test_fade_quarter_sine()
|
||||
test_fade_half_sine()
|
||||
test_fade_wrong_arguments()
|
||||
test_fade_eager()
|
Loading…
Reference in New Issue