!22686 [assistant][ops][Dither]

Merge pull request !22686 from Isaac/Dither
This commit is contained in:
i-robot 2021-12-20 11:16:20 +00:00 committed by Gitee
commit 7331629ceb
19 changed files with 879 additions and 7 deletions

View File

@ -31,6 +31,7 @@
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
#include "minddata/dataset/audio/ir/kernels/dither_ir.h"
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
#include "minddata/dataset/audio/ir/kernels/flanger_ir.h"
@ -292,6 +293,21 @@ std::shared_ptr<TensorOperation> DetectPitchFrequency::Parse() {
data_->freq_low_, data_->freq_high_);
}
// Dither Transform Operation.
struct Dither::Data {
Data(DensityFunction density_function, bool noise_shaping)
: density_function_(density_function), noise_shaping_(noise_shaping) {}
DensityFunction density_function_;
bool noise_shaping_;
};
Dither::Dither(DensityFunction density_function, bool noise_shaping)
: data_(std::make_shared<Data>(density_function, noise_shaping)) {}
std::shared_ptr<TensorOperation> Dither::Parse() {
return std::make_shared<DitherOperation>(data_->density_function_, data_->noise_shaping_);
}
// EqualizerBiquad Transform Operation.
struct EqualizerBiquad::Data {
Data(int32_t sample_rate, float center_freq, float gain, float Q)

View File

@ -35,6 +35,7 @@
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
#include "minddata/dataset/audio/ir/kernels/dither_ir.h"
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
#include "minddata/dataset/audio/ir/kernels/flanger_ir.h"
@ -231,6 +232,24 @@ PYBIND_REGISTER(DetectPitchFrequencyOperation, 1, ([](const py::module *m) {
}));
}));
PYBIND_REGISTER(DensityFunction, 0, ([](const py::module *m) {
(void)py::enum_<DensityFunction>(*m, "DensityFunction", py::arithmetic())
.value("DE_DENSITYFUNCTION_TPDF", DensityFunction::kTPDF)
.value("DE_DENSITYFUNCTION_RPDF", DensityFunction::kRPDF)
.value("DE_DENSITYFUNCTION_GPDF", DensityFunction::kGPDF)
.export_values();
}));
PYBIND_REGISTER(DitherOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::DitherOperation, TensorOperation, std::shared_ptr<audio::DitherOperation>>(
*m, "DitherOperation")
.def(py::init([](DensityFunction density_function, bool noise_shaping) {
auto dither = std::make_shared<audio::DitherOperation>(density_function, noise_shaping);
THROW_IF_ERROR(dither->ValidateParams());
return dither;
}));
}));
PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::EqualizerBiquadOperation, TensorOperation,
std::shared_ptr<audio::EqualizerBiquadOperation>>(*m, "EqualizerBiquadOperation")

View File

@ -17,6 +17,7 @@ add_library(audio-ir-kernels OBJECT
dc_shift_ir.cc
deemph_biquad_ir.cc
detect_pitch_frequency_ir.cc
dither_ir.cc
equalizer_biquad_ir.cc
fade_ir.cc
flanger_ir.cc

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/dither_ir.h"
#include "minddata/dataset/audio/kernels/dither_op.h"
namespace mindspore {
namespace dataset {
namespace audio {
// DitherOperation
DitherOperation::DitherOperation(DensityFunction density_function, bool noise_shaping)
: density_function_(density_function), noise_shaping_(noise_shaping) {
random_op_ = true;
}
Status DitherOperation::ValidateParams() { return Status::OK(); }
std::shared_ptr<TensorOp> DitherOperation::Build() {
std::shared_ptr<DitherOp> tensor_op = std::make_shared<DitherOp>(density_function_, noise_shaping_);
return tensor_op;
}
Status DitherOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["density_function"] = density_function_;
args["noise_shaping"] = noise_shaping_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DITHER_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DITHER_IR_H_
#include <memory>
#include <string>
#include <vector>
#include "include/api/status.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kDitherOperation[] = "Dither";
class DitherOperation : public TensorOperation {
public:
DitherOperation(DensityFunction density_function, bool noise_shaping);
~DitherOperation() = default;
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override { return kDitherOperation; }
Status to_json(nlohmann::json *out_json) override;
private:
DensityFunction density_function_;
bool noise_shaping_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DITHER_IR_H_

View File

@ -18,6 +18,7 @@ add_library(audio-kernels OBJECT
dc_shift_op.cc
deemph_biquad_op.cc
detect_pitch_frequency_op.cc
dither_op.cc
equalizer_biquad_op.cc
fade_op.cc
flanger_op.cc

View File

@ -38,7 +38,6 @@ constexpr int TWO = 2;
namespace mindspore {
namespace dataset {
/// \brief Turn a tensor from the power/amplitude scale to the decibel scale.
/// \param input/output: Tensor of shape <..., freq, time>.
/// \param multiplier: power - 10, amplitude - 20.
@ -110,6 +109,8 @@ Status Angle(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp
return Status::OK();
}
Status Bartlett(std::shared_ptr<Tensor> *output, int len);
/// \brief Perform a biquad filter of input tensor.
/// \param input/output: Tensor of shape <..., time>.
/// \param a0: denominator coefficient of current output y[n], typically 1.
@ -1077,6 +1078,275 @@ Status SlidingWindowCmn(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
/// \return Status code.
Status ComputeDeltas(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length,
const BorderType &mode);
template <typename T>
Status Mul(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, T value) {
RETURN_UNEXPECTED_IF_NULL(output);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
auto iter_in = input->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
*iter_out = (*iter_in) * value;
}
return Status::OK();
}
template <typename T>
Status Div(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, T value) {
RETURN_UNEXPECTED_IF_NULL(output);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
CHECK_FAIL_RETURN_UNEXPECTED(value != 0, "Div: invalid parameter, 'value' can not be zero.");
auto iter_in = input->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
*iter_out = (*iter_in) / value;
}
return Status::OK();
}
template <typename T>
Status Add(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, T value) {
RETURN_UNEXPECTED_IF_NULL(output);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
auto iter_in = input->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
*iter_out = (*iter_in) + value;
}
return Status::OK();
}
template <typename T>
Status SubTensor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int len) {
RETURN_UNEXPECTED_IF_NULL(output);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({len}), input->type(), output));
RETURN_IF_NOT_OK(
ValidateNoGreaterThan("SubTensor", "len", len, "size of input tensor", static_cast<int>(input->Size())));
auto iter_in = input->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_out != (*output)->end<T>(); ++iter_in, ++iter_out) {
*iter_out = *iter_in;
}
return Status::OK();
}
template <typename T>
Status TensorAdd(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
std::shared_ptr<Tensor> *output) {
RETURN_UNEXPECTED_IF_NULL(output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape() == other->shape(), "TensorAdd: input tensor shape must be the same.");
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == other->type(), "TensorAdd: input tensor type must be the same.");
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
auto iter_in1 = input->begin<T>();
auto iter_in2 = other->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_out != (*output)->end<T>(); ++iter_in1, ++iter_in2, ++iter_out) {
*iter_out = (*iter_in1) + (*iter_in2);
}
return Status::OK();
}
template <typename T>
Status TensorSub(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
std::shared_ptr<Tensor> *output) {
RETURN_UNEXPECTED_IF_NULL(output);
CHECK_FAIL_RETURN_UNEXPECTED(input->shape() == other->shape(), "TensorSub: input tensor shape must be the same.");
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == other->type(), "TensorSub: input tensor type must be the same.");
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
auto iter_in1 = input->begin<T>();
auto iter_in2 = other->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_out != (*output)->end<T>(); ++iter_in1, ++iter_in2, ++iter_out) {
*iter_out = (*iter_in1) - (*iter_in2);
}
return Status::OK();
}
template <typename T>
Status TensorCat(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
std::shared_ptr<Tensor> *output) {
RETURN_UNEXPECTED_IF_NULL(output);
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == other->type(), "TensorCat: input tensor type must be the same.");
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({input->shape()[-1] + other->shape()[-1]}), input->type(), output));
auto iter_in1 = input->begin<T>();
auto iter_in2 = other->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_in1 != input->end<T>(); ++iter_in1, ++iter_out) {
*iter_out = *iter_in1;
}
for (; iter_in2 != other->end<T>(); ++iter_in2, ++iter_out) {
*iter_out = *iter_in2;
}
return Status::OK();
}
template <typename T>
Status TensorRepeat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int rank_repeat) {
RETURN_UNEXPECTED_IF_NULL(output);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({rank_repeat, (input->shape()[-1])}), input->type(), output));
auto iter_in = input->begin<T>();
auto iter_out = (*output)->begin<T>();
for (int i = 0; i < rank_repeat; i++) {
auto iter_in = input->begin<T>();
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
*iter_out = *iter_in;
}
}
return Status::OK();
}
template <typename T>
Status TensorRowReplace(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int row) {
RETURN_UNEXPECTED_IF_NULL(output);
auto iter_in = input->begin<T>();
auto iter_out = (*output)->begin<T>() + (*output)->shape()[-1] * row;
CHECK_FAIL_RETURN_UNEXPECTED(iter_out <= (*output)->end<T>(), "TensorRowReplace: pointer out of bounds");
CHECK_FAIL_RETURN_UNEXPECTED(input->Size() <= (*output)->shape()[-1], "TensorRowReplace: pointer out of bounds");
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
*iter_out = *iter_in;
}
return Status::OK();
}
template <typename T>
Status TensorRowAt(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int rank_index) {
RETURN_UNEXPECTED_IF_NULL(output);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({input->shape()[-1]}), input->type(), output));
auto iter_in = input->begin<T>() + input->shape()[-1] * rank_index;
auto iter_out = (*output)->begin<T>();
CHECK_FAIL_RETURN_UNEXPECTED(iter_in <= input->end<T>(), "TensorRowAt: pointer out of bounds");
for (; iter_out != (*output)->end<T>(); ++iter_in, ++iter_out) {
*iter_out = *iter_in;
}
return Status::OK();
}
template <typename T>
Status TensorRound(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
RETURN_UNEXPECTED_IF_NULL(output);
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
auto iter_in = input->begin<T>();
auto iter_out = (*output)->begin<T>();
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
*iter_out = round(*iter_in);
}
return Status::OK();
}
template <typename T>
Status ApplyProbabilityDistribution(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
DensityFunction density_function, std::mt19937 rnd) {
int channel_size = input->shape()[0] - 1;
int time_size = input->shape()[-1] - 1;
std::uniform_int_distribution<> dis_channel(0, channel_size);
int random_channel = channel_size > 0 ? dis_channel(rnd) : 0;
std::uniform_int_distribution<> dis_time(0, time_size);
int random_time = time_size > 0 ? dis_time(rnd) : 0;
int number_of_bits = 16;
int up_scaling = static_cast<int>(pow(2, number_of_bits - 1) - 2);
int down_scaling = static_cast<int>(pow(2, number_of_bits - 1));
std::shared_ptr<Tensor> signal_scaled;
RETURN_IF_NOT_OK(Mul<T>(input, &signal_scaled, up_scaling));
std::shared_ptr<Tensor> signal_scaled_dis;
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input, &signal_scaled_dis));
if (density_function == DensityFunction::kRPDF) {
auto iter_in = input->begin<T>();
iter_in += (time_size + 1) * random_channel + random_time;
auto RPDF = *(iter_in);
RETURN_IF_NOT_OK(Add<T>(signal_scaled, &signal_scaled_dis, RPDF));
} else if (density_function == DensityFunction::kGPDF) {
int num_rand_variables = 6;
auto iter_in = input->begin<T>();
iter_in += (time_size + 1) * random_channel + random_time;
auto gaussian = *(iter_in);
for (int i = 0; i < num_rand_variables; i++) {
int rand_channel = channel_size > 0 ? dis_channel(rnd) : 0;
int rand_time = time_size > 0 ? dis_time(rnd) : 0;
auto iter_in_rand = input->begin<T>();
iter_in_rand += (time_size + 1) * rand_channel + rand_time;
gaussian += *(iter_in_rand);
*(iter_in_rand) = gaussian;
}
RETURN_IF_NOT_OK(Add<T>(signal_scaled, &signal_scaled_dis, gaussian));
} else {
int window_length = time_size + 1;
std::shared_ptr<Tensor> float_bartlett;
RETURN_IF_NOT_OK(Bartlett(&float_bartlett, window_length));
std::shared_ptr<Tensor> type_convert_bartlett;
RETURN_IF_NOT_OK(TypeCast(float_bartlett, &type_convert_bartlett, input->type()));
int rank_repeat = channel_size + 1;
std::shared_ptr<Tensor> TPDF;
RETURN_IF_NOT_OK(TensorRepeat<T>(type_convert_bartlett, &TPDF, rank_repeat));
RETURN_IF_NOT_OK(TensorAdd<T>(signal_scaled, TPDF, &signal_scaled_dis));
}
std::shared_ptr<Tensor> quantised_signal_scaled;
RETURN_IF_NOT_OK(TensorRound<T>(signal_scaled_dis, &quantised_signal_scaled));
std::shared_ptr<Tensor> quantised_signal;
RETURN_IF_NOT_OK(Div<T>(quantised_signal_scaled, &quantised_signal, down_scaling));
*output = quantised_signal;
return Status::OK();
}
template <typename T>
Status AddNoiseShaping(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
std::shared_ptr<Tensor> dithered_waveform;
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(*output, &dithered_waveform));
std::shared_ptr<Tensor> waveform;
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input, &waveform));
std::shared_ptr<Tensor> error;
RETURN_IF_NOT_OK(TensorSub<T>(dithered_waveform, waveform, &error));
for (int i = 0; i < error->shape()[0]; i++) {
std::shared_ptr<Tensor> err;
RETURN_IF_NOT_OK(TensorRowAt<T>(error, &err, i));
std::shared_ptr<Tensor> tensor_zero;
std::vector<T> vector_zero(1, 0);
RETURN_IF_NOT_OK(Tensor::CreateFromVector(vector_zero, TensorShape({1}), &tensor_zero));
std::shared_ptr<Tensor> error_offset;
RETURN_IF_NOT_OK(TensorCat<T>(tensor_zero, err, &error_offset));
int k = error->shape()[-1];
std::shared_ptr<Tensor> fresh_error_offset;
RETURN_IF_NOT_OK(SubTensor<T>(error_offset, &fresh_error_offset, k));
RETURN_IF_NOT_OK(TensorRowReplace<T>(fresh_error_offset, &error, i));
}
std::shared_ptr<Tensor> noise_shaped;
RETURN_IF_NOT_OK(TensorAdd<T>(dithered_waveform, error, &noise_shaped));
*output = noise_shaped;
return Status::OK();
}
/// \brief Apply dither effect.
/// \param input/output: Tensor of shape <..., time>.
/// \param density_function: The density function of a continuous random variable.
/// \param noise_shaing: A filtering process that shapes the spectral energy of quantisation error.
/// \return Status code.
template <typename T>
Status Dither(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, DensityFunction density_function,
bool noise_shaping, std::mt19937 rnd) {
TensorShape shape = input->shape();
TensorShape new_shape({input->Size() / shape[-1], shape[-1]});
RETURN_IF_NOT_OK(input->Reshape(new_shape));
RETURN_IF_NOT_OK(ApplyProbabilityDistribution<T>(input, output, density_function, rnd));
if (noise_shaping) {
RETURN_IF_NOT_OK(AddNoiseShaping<T>(input, output));
}
RETURN_IF_NOT_OK((*output)->Reshape(shape));
RETURN_IF_NOT_OK(input->Reshape(shape));
return Status::OK();
}
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_

View File

@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/dither_op.h"
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
Status DitherOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
IO_CHECK(input, output);
// check input dimension, it should be greater than 0
RETURN_IF_NOT_OK(ValidateLowRank("Dither", input, kMinAudioDim, "<..., time>"));
// check input type, it should be [int, float, double]
RETURN_IF_NOT_OK(ValidateTensorNumeric("Dither", input));
if (input->type() == DataType(DataType::DE_FLOAT64)) {
return Dither<double>(input, output, density_function_, noise_shaping_, rnd_);
} else {
std::shared_ptr<Tensor> float_input;
RETURN_IF_NOT_OK(TypeCast(input, &float_input, DataType(DataType::DE_FLOAT32)));
return Dither<float>(float_input, output, density_function_, noise_shaping_, rnd_);
}
}
Status DitherOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
RETURN_IF_NOT_OK(ValidateTensorType("Dither", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString()));
if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
outputs[0] = DataType(DataType::DE_FLOAT64);
} else {
outputs[0] = DataType(DataType::DE_FLOAT32);
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DITHER_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DITHER_OP_H_
#include <memory>
#include <string>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/kernels/tensor_op.h"
#include "minddata/dataset/util/random.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
class DitherOp : public TensorOp {
public:
DitherOp(DensityFunction density_function, bool noise_shaping)
: density_function_(density_function), noise_shaping_(noise_shaping) {
rnd_.seed(GetSeed());
}
~DitherOp() override = default;
void Print(std::ostream &out) const override {
out << Name() << " density_function: " << density_function_ << ", noise_shaping: " << noise_shaping_ << std::endl;
}
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
std::string Name() const override { return kDitherOp; }
private:
DensityFunction density_function_;
bool noise_shaping_;
std::mt19937 rnd_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DITHER_OP_H_

View File

@ -377,6 +377,32 @@ class MS_API DetectPitchFrequency final : public TensorTransform {
std::shared_ptr<Data> data_;
};
/// \brief Dither increases the perceived dynamic range of audio stored at a
/// particular bit-depth by eliminating nonlinear truncation distortion.
class MS_API Dither final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] density_function The density function of a continuous random variable.
/// Can be one of DensityFunction::kTPDF (Triangular Probability Density Function),
/// DensityFunction::kRPDF (Rectangular Probability Density Function) or
/// DensityFunction::kGPDF (Gaussian Probability Density Function) (Default: DensityFunction::kTPDF).
/// \param[in] noise_shaping A filtering process that shapes the spectral energy of
/// quantisation error (Default: false).
explicit Dither(DensityFunction density_function = DensityFunction::kTPDF, bool noise_shaping = false);
/// \brief Destructor.
~Dither() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio.
class MS_API EqualizerBiquad final : public TensorTransform {
public:

View File

@ -71,6 +71,13 @@ enum class MS_API ConvertMode {
COLOR_RGBA2GRAY = 11 ///< Convert RGBA image to GRAY image.
};
// \brief Possible density function in Dither.
enum MS_API DensityFunction {
kTPDF = 0, ///< Use triangular probability density function.
kRPDF = 1, ///< Use rectangular probability density function.
kGPDF = 2 ///< Use gaussian probability density function.
};
/// \brief Values of norm in CreateDct.
enum class MS_API NormMode {
kNone = 0, ///< None type norm.

View File

@ -161,6 +161,7 @@ constexpr char kDBToAmplitudeOp[] = " DBToAmplitudeOp";
constexpr char kDCShiftOp[] = "DCShiftOp";
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
constexpr char kDetectPitchFrequencyOp[] = "DetectPitchFrequencyOp";
constexpr char kDitherOp[] = "DitherOp";
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
constexpr char kFadeOp[] = "FadeOp";
constexpr char kFlangerOp[] = "FlangerOp";

View File

@ -23,14 +23,14 @@ import numpy as np
import mindspore._c_dataengine as cde
from ..transforms.c_transforms import TensorOperation
from .utils import BorderType, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
from .utils import BorderType, DensityFunction, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_compute_deltas, \
check_contrast, check_db_to_amplitude, check_dc_shift, check_deemph_biquad, check_detect_pitch_frequency, \
check_equalizer_biquad, check_fade, check_flanger, check_gain, check_highpass_biquad, check_lfilter, \
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, check_overdrive, check_phaser, \
check_riaa_biquad, check_sliding_window_cmn, check_spectrogram, check_time_stretch, check_treble_biquad, \
check_vol
check_dither, check_equalizer_biquad, check_fade, check_flanger, check_gain, check_highpass_biquad, \
check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, check_overdrive, \
check_phaser, check_riaa_biquad, check_sliding_window_cmn, check_spectrogram, check_time_stretch, \
check_treble_biquad, check_vol
@ -495,6 +495,43 @@ class DetectPitchFrequency(AudioTensorOperation):
self.win_length, self.freq_low, self.freq_high)
DE_C_DENSITYFUNCTION_TYPE = {DensityFunction.TPDF: cde.DensityFunction.DE_DENSITYFUNCTION_TPDF,
DensityFunction.RPDF: cde.DensityFunction.DE_DENSITYFUNCTION_RPDF,
DensityFunction.GPDF: cde.DensityFunction.DE_DENSITYFUNCTION_GPDF}
class Dither(AudioTensorOperation):
"""
Dither increases the perceived dynamic range of audio stored at a
particular bit-depth by eliminating nonlinear truncation distortion.
Args:
density_function (DensityFunction, optional): The density function of a continuous
random variable. Can be one of DensityFunction.TPDF (Triangular Probability Density Function),
DensityFunction.RPDF (Rectangular Probability Density Function) or
DensityFunction.GPDF (Gaussian Probability Density Function)
(default=DensityFunction.TPDF).
noise_shaping (bool, optional): A filtering process that shapes the spectral
energy of quantisation error (default=False).
Examples:
>>> import numpy as np
>>>
>>> waveform = np.array([[1, 2, 3], [4, 5, 6]])
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
>>> transforms = [audio.Dither()]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_dither
def __init__(self, density_function=DensityFunction.TPDF, noise_shaping=False):
self.density_function = density_function
self.noise_shaping = noise_shaping
def parse(self):
return cde.DitherOperation(DE_C_DENSITYFUNCTION_TYPE[self.density_function], self.noise_shaping)
class EqualizerBiquad(AudioTensorOperation):
"""
Design biquad equalizer filter and perform filtering. Similar to SoX implementation.

View File

@ -20,6 +20,22 @@ from enum import Enum
import mindspore._c_dataengine as cde
class DensityFunction(str, Enum):
"""
Density Functions.
Possible enumeration values are: DensityFunction.TPDF, DensityFunction.GPDF,
DensityFunction.RPDF.
- DensityFunction.TPDF: means triangular probability density function.
- DensityFunction.GPDF: means gaussian probability density function.
- DensityFunction.RPDF: means rectangular probability density function.
"""
TPDF: str = "TPDF"
RPDF: str = "RPDF"
GPDF: str = "GPDF"
class FadeShape(str, Enum):
"""
Fade Shapes.

View File

@ -21,7 +21,7 @@ from functools import wraps
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32, \
check_int32_not_zero, check_list_same_size, check_non_negative_float32, check_non_negative_int32, \
check_pos_float32, check_pos_int32, check_value, INT32_MAX, parse_user_args, type_check
from .utils import BorderType, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
from .utils import BorderType, DensityFunction, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
def check_amplitude_to_db(method):
@ -239,6 +239,22 @@ def check_deemph_biquad(method):
return new_method
def check_dither(method):
"""Wrapper method to check the parameters of Dither."""
@wraps(method)
def new_method(self, *args, **kwargs):
[density_function, noise_shaping], _ = parse_user_args(
method, *args, **kwargs)
type_check(density_function, (DensityFunction), "density_function")
type_check(noise_shaping, (bool,), "noise_shaping")
return method(self, *args, **kwargs)
return new_method
def check_equalizer_biquad(method):
"""Wrapper method to check the parameters of EqualizerBiquad."""

View File

@ -841,6 +841,46 @@ TEST_F(MindDataTestPipeline, TestDeemphBiquadWrongArgs) {
EXPECT_EQ(iter01, nullptr);
}
/// Feature: Dither
/// Description: test basic usage of Dither in pipeline mode
/// Expectation: the data is processed successfully
TEST_F(MindDataTestPipeline, TestDitherBasic) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDitherBasic.";
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(2);
EXPECT_NE(ds, nullptr);
auto DitherOp = audio::Dither();
ds = ds->Map({DitherOp});
EXPECT_NE(ds, nullptr);
// Filtered waveform by Dither
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, 200};
int i = 0;
while (row.size() != 0) {
auto col = row["waveform"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.Shape().size(), 2);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestHighpassBiquadSuccess) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestHighpassBiquadSuccess.";

View File

@ -1526,6 +1526,28 @@ TEST_F(MindDataTestExecute, TestDetectPitchFrequencyWithWrongArg) {
EXPECT_FALSE(s05.IsOk());
}
/// Feature: Dither
/// Description: test Dither in eager mode
/// Expectation: the data is processed successfully
TEST_F(MindDataTestExecute, TestDitherWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDitherWithEager.";
// Original waveform
std::vector<float> labels = {
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
std::shared_ptr<Tensor> input;
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
std::shared_ptr<TensorTransform> dither_01 = std::make_shared<audio::Dither>();
mindspore::dataset::Execute Transform01({dither_01});
// Filtered waveform by Dither
Status s01 = Transform01(input_02, &input_02);
EXPECT_TRUE(s01.IsOk());
}
TEST_F(MindDataTestExecute, TestFlangerWithEager) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestFlangerWithEager.";
// Original waveform

View File

@ -0,0 +1,168 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as audio
from mindspore import log as logger
from mindspore.dataset.audio.utils import DensityFunction
from util import visualize_audio, diff_mse
def count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def test_dither_eager_noise_shaping_false():
"""
Feature: Dither
Description: test Dither in eager mode
Expectation: the result is as expected
"""
logger.info("test Dither in eager mode")
# Original waveform
waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[0.99993896, 1.99990845, 2.99984741],
[3.99975586, 4.99972534, 5.99966431]], dtype=np.float64)
dither_op = audio.Dither(DensityFunction.TPDF, False)
# Filtered waveform by Dither
output = dither_op(waveform)
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_dither_eager_noise_shaping_true():
"""
Feature: Dither
Description: test Dither in eager mode
Expectation: the result is as expected
"""
logger.info("test Dither in eager mode")
# Original waveform
waveform = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
# Expect waveform
expect_waveform = np.array([[0.9999, 1.9998, 2.9998],
[3.9998, 4.9995, 5.9994],
[6.9996, 7.9991, 8.9990]], dtype=np.float64)
dither_op = audio.Dither(DensityFunction.TPDF, True)
# Filtered waveform by Dither
output = dither_op(waveform)
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_dither_pipeline(plot=False):
"""
Feature: Dither
Description: test Dither in pipeline mode
Expectation: the result is as expected
"""
logger.info("test Dither in pipeline mode")
# Original waveform
waveform_tpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
[0.23657, 0.087965, 0.43579]], dtype=np.float64)
waveform_rpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
[0.23657, 0.087965, 0.43579]], dtype=np.float64)
waveform_gpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
[0.23657, 0.087965, 0.43579]], dtype=np.float64)
# Expect waveform
expect_tpdf = np.array([[0.49417114, 0.53909302, 0.48461914],
[0.10839844, 0.02932739, 0.52352905],
[0.23654175, 0.08798218, 0.43579102]], dtype=np.float64)
expect_rpdf = np.array([[0.4941, 0.5391, 0.4846],
[0.1084, 0.0293, 0.5235],
[0.2365, 0.0880, 0.4358]], dtype=np.float64)
expect_gpdf = np.array([[0.4944, 0.5393, 0.4848],
[0.1086, 0.0295, 0.5237],
[0.2368, 0.0882, 0.4360]], dtype=np.float64)
dataset_tpdf = ds.NumpySlicesDataset(waveform_tpdf, ["audio"], shuffle=False)
dataset_rpdf = ds.NumpySlicesDataset(waveform_rpdf, ["audio"], shuffle=False)
dataset_gpdf = ds.NumpySlicesDataset(waveform_gpdf, ["audio"], shuffle=False)
# Filtered waveform by Dither of TPDF
dither_tpdf = audio.Dither()
dataset_tpdf = dataset_tpdf.map(input_columns=["audio"], operations=dither_tpdf, num_parallel_workers=2)
# Filtered waveform by Dither of RPDF
dither_rpdf = audio.Dither(DensityFunction.RPDF, False)
dataset_rpdf = dataset_rpdf.map(input_columns=["audio"], operations=dither_rpdf, num_parallel_workers=2)
# Filtered waveform by Dither of GPDF
dither_gpdf = audio.Dither(DensityFunction.GPDF, False)
dataset_gpdf = dataset_gpdf.map(input_columns=["audio"], operations=dither_gpdf, num_parallel_workers=2)
i = 0
for data1, data2, data3 in zip(dataset_tpdf.create_dict_iterator(output_numpy=True),
dataset_rpdf.create_dict_iterator(output_numpy=True),
dataset_gpdf.create_dict_iterator(output_numpy=True)):
count_unequal_element(expect_tpdf[i, :], data1['audio'], 0.0001, 0.0001)
dither_rpdf = data2['audio']
dither_gpdf = data3['audio']
mse_rpdf = diff_mse(dither_rpdf, expect_rpdf[i, :])
logger.info("dither_rpdf_{}, mse: {}".format(i + 1, mse_rpdf))
mse_gpdf = diff_mse(dither_gpdf, expect_gpdf[i, :])
logger.info("dither_gpdf_{}, mse: {}".format(i + 1, mse_gpdf))
i += 1
if plot:
visualize_audio(dither_rpdf, expect_rpdf[i, :])
visualize_audio(dither_gpdf, expect_gpdf[i, :])
def test_invalid_dither_input():
"""
Feature: Dither
Description: test param check of Dither
Expectation: throw correct error and message
"""
logger.info("test param check of Dither")
def test_invalid_input(test_name, density_function, noise_shaping, error, error_msg):
logger.info("Test Dither with bad input: {0}".format(test_name))
with pytest.raises(error) as error_info:
audio.Dither(density_function, noise_shaping)
assert error_msg in str(error_info.value)
test_invalid_input("invalid density function parameter value", "TPDF", False, TypeError,
"Argument density_function with value TPDF is not of type"
+ " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
+ ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'str'>.")
test_invalid_input("invalid density_function parameter value", 6, False, TypeError,
"Argument density_function with value 6 is not of type"
+ " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
+ ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'int'>.")
test_invalid_input("invalid noise_shaping parameter value", DensityFunction.GPDF, 1, TypeError,
"Argument noise_shaping with value 1 is not of type [<class 'bool'>], but got <class 'int'>.")
test_invalid_input("invalid noise_shaping parameter value", DensityFunction.RPDF, "true", TypeError,
"Argument noise_shaping with value true is not of type [<class 'bool'>], but got <class 'str'>")
if __name__ == '__main__':
test_dither_eager_noise_shaping_false()
test_dither_eager_noise_shaping_true()
test_dither_pipeline(plot=False)
test_invalid_dither_input()

View File

@ -207,6 +207,26 @@ def diff_me(in1, in2):
return mse / 255 * 100
def visualize_audio(waveform, expect_waveform):
"""
Visualizes audio waveform.
"""
plt.figure(1)
plt.subplot(1, 3, 1)
plt.imshow(waveform)
plt.title("waveform")
plt.subplot(1, 3, 2)
plt.imshow(expect_waveform)
plt.title("expect waveform")
plt.subplot(1, 3, 3)
plt.imshow(waveform - expect_waveform)
plt.title("difference")
plt.show()
def visualize_one_channel_dataset(images_original, images_transformed, labels):
"""
Helper function to visualize one channel grayscale images