!22686 [assistant][ops][Dither]
Merge pull request !22686 from Isaac/Dither
This commit is contained in:
commit
7331629ceb
|
@ -31,6 +31,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dither_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/flanger_ir.h"
|
||||
|
@ -292,6 +293,21 @@ std::shared_ptr<TensorOperation> DetectPitchFrequency::Parse() {
|
|||
data_->freq_low_, data_->freq_high_);
|
||||
}
|
||||
|
||||
// Dither Transform Operation.
|
||||
struct Dither::Data {
|
||||
Data(DensityFunction density_function, bool noise_shaping)
|
||||
: density_function_(density_function), noise_shaping_(noise_shaping) {}
|
||||
DensityFunction density_function_;
|
||||
bool noise_shaping_;
|
||||
};
|
||||
|
||||
Dither::Dither(DensityFunction density_function, bool noise_shaping)
|
||||
: data_(std::make_shared<Data>(density_function, noise_shaping)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Dither::Parse() {
|
||||
return std::make_shared<DitherOperation>(data_->density_function_, data_->noise_shaping_);
|
||||
}
|
||||
|
||||
// EqualizerBiquad Transform Operation.
|
||||
struct EqualizerBiquad::Data {
|
||||
Data(int32_t sample_rate, float center_freq, float gain, float Q)
|
||||
|
|
|
@ -35,6 +35,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dither_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/equalizer_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/fade_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/flanger_ir.h"
|
||||
|
@ -231,6 +232,24 @@ PYBIND_REGISTER(DetectPitchFrequencyOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DensityFunction, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<DensityFunction>(*m, "DensityFunction", py::arithmetic())
|
||||
.value("DE_DENSITYFUNCTION_TPDF", DensityFunction::kTPDF)
|
||||
.value("DE_DENSITYFUNCTION_RPDF", DensityFunction::kRPDF)
|
||||
.value("DE_DENSITYFUNCTION_GPDF", DensityFunction::kGPDF)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DitherOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::DitherOperation, TensorOperation, std::shared_ptr<audio::DitherOperation>>(
|
||||
*m, "DitherOperation")
|
||||
.def(py::init([](DensityFunction density_function, bool noise_shaping) {
|
||||
auto dither = std::make_shared<audio::DitherOperation>(density_function, noise_shaping);
|
||||
THROW_IF_ERROR(dither->ValidateParams());
|
||||
return dither;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::EqualizerBiquadOperation, TensorOperation,
|
||||
std::shared_ptr<audio::EqualizerBiquadOperation>>(*m, "EqualizerBiquadOperation")
|
||||
|
|
|
@ -17,6 +17,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
dc_shift_ir.cc
|
||||
deemph_biquad_ir.cc
|
||||
detect_pitch_frequency_ir.cc
|
||||
dither_ir.cc
|
||||
equalizer_biquad_ir.cc
|
||||
fade_ir.cc
|
||||
flanger_ir.cc
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/dither_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/dither_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
// DitherOperation
|
||||
DitherOperation::DitherOperation(DensityFunction density_function, bool noise_shaping)
|
||||
: density_function_(density_function), noise_shaping_(noise_shaping) {
|
||||
random_op_ = true;
|
||||
}
|
||||
|
||||
Status DitherOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> DitherOperation::Build() {
|
||||
std::shared_ptr<DitherOp> tensor_op = std::make_shared<DitherOp>(density_function_, noise_shaping_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status DitherOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["density_function"] = density_function_;
|
||||
args["noise_shaping"] = noise_shaping_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DITHER_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DITHER_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kDitherOperation[] = "Dither";
|
||||
|
||||
class DitherOperation : public TensorOperation {
|
||||
public:
|
||||
DitherOperation(DensityFunction density_function, bool noise_shaping);
|
||||
|
||||
~DitherOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kDitherOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
DensityFunction density_function_;
|
||||
bool noise_shaping_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DITHER_IR_H_
|
|
@ -18,6 +18,7 @@ add_library(audio-kernels OBJECT
|
|||
dc_shift_op.cc
|
||||
deemph_biquad_op.cc
|
||||
detect_pitch_frequency_op.cc
|
||||
dither_op.cc
|
||||
equalizer_biquad_op.cc
|
||||
fade_op.cc
|
||||
flanger_op.cc
|
||||
|
|
|
@ -38,7 +38,6 @@ constexpr int TWO = 2;
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \brief Turn a tensor from the power/amplitude scale to the decibel scale.
|
||||
/// \param input/output: Tensor of shape <..., freq, time>.
|
||||
/// \param multiplier: power - 10, amplitude - 20.
|
||||
|
@ -110,6 +109,8 @@ Status Angle(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outp
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Bartlett(std::shared_ptr<Tensor> *output, int len);
|
||||
|
||||
/// \brief Perform a biquad filter of input tensor.
|
||||
/// \param input/output: Tensor of shape <..., time>.
|
||||
/// \param a0: denominator coefficient of current output y[n], typically 1.
|
||||
|
@ -1077,6 +1078,275 @@ Status SlidingWindowCmn(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
|||
/// \return Status code.
|
||||
Status ComputeDeltas(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t win_length,
|
||||
const BorderType &mode);
|
||||
|
||||
template <typename T>
|
||||
Status Mul(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, T value) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
|
||||
auto iter_in = input->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = (*iter_in) * value;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Div(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, T value) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(value != 0, "Div: invalid parameter, 'value' can not be zero.");
|
||||
auto iter_in = input->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = (*iter_in) / value;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Add(const std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, T value) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
|
||||
auto iter_in = input->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = (*iter_in) + value;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status SubTensor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int len) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({len}), input->type(), output));
|
||||
RETURN_IF_NOT_OK(
|
||||
ValidateNoGreaterThan("SubTensor", "len", len, "size of input tensor", static_cast<int>(input->Size())));
|
||||
auto iter_in = input->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_out != (*output)->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = *iter_in;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TensorAdd(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
|
||||
std::shared_ptr<Tensor> *output) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->shape() == other->shape(), "TensorAdd: input tensor shape must be the same.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == other->type(), "TensorAdd: input tensor type must be the same.");
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
|
||||
auto iter_in1 = input->begin<T>();
|
||||
auto iter_in2 = other->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_out != (*output)->end<T>(); ++iter_in1, ++iter_in2, ++iter_out) {
|
||||
*iter_out = (*iter_in1) + (*iter_in2);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TensorSub(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
|
||||
std::shared_ptr<Tensor> *output) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->shape() == other->shape(), "TensorSub: input tensor shape must be the same.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == other->type(), "TensorSub: input tensor type must be the same.");
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
|
||||
auto iter_in1 = input->begin<T>();
|
||||
auto iter_in2 = other->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_out != (*output)->end<T>(); ++iter_in1, ++iter_in2, ++iter_out) {
|
||||
*iter_out = (*iter_in1) - (*iter_in2);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TensorCat(const std::shared_ptr<Tensor> &input, const std::shared_ptr<Tensor> &other,
|
||||
std::shared_ptr<Tensor> *output) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == other->type(), "TensorCat: input tensor type must be the same.");
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({input->shape()[-1] + other->shape()[-1]}), input->type(), output));
|
||||
auto iter_in1 = input->begin<T>();
|
||||
auto iter_in2 = other->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_in1 != input->end<T>(); ++iter_in1, ++iter_out) {
|
||||
*iter_out = *iter_in1;
|
||||
}
|
||||
for (; iter_in2 != other->end<T>(); ++iter_in2, ++iter_out) {
|
||||
*iter_out = *iter_in2;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TensorRepeat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int rank_repeat) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({rank_repeat, (input->shape()[-1])}), input->type(), output));
|
||||
auto iter_in = input->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (int i = 0; i < rank_repeat; i++) {
|
||||
auto iter_in = input->begin<T>();
|
||||
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = *iter_in;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TensorRowReplace(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int row) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
auto iter_in = input->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>() + (*output)->shape()[-1] * row;
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(iter_out <= (*output)->end<T>(), "TensorRowReplace: pointer out of bounds");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->Size() <= (*output)->shape()[-1], "TensorRowReplace: pointer out of bounds");
|
||||
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = *iter_in;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TensorRowAt(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int rank_index) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({input->shape()[-1]}), input->type(), output));
|
||||
auto iter_in = input->begin<T>() + input->shape()[-1] * rank_index;
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(iter_in <= input->end<T>(), "TensorRowAt: pointer out of bounds");
|
||||
for (; iter_out != (*output)->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = *iter_in;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status TensorRound(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), output));
|
||||
auto iter_in = input->begin<T>();
|
||||
auto iter_out = (*output)->begin<T>();
|
||||
for (; iter_in != input->end<T>(); ++iter_in, ++iter_out) {
|
||||
*iter_out = round(*iter_in);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ApplyProbabilityDistribution(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||
DensityFunction density_function, std::mt19937 rnd) {
|
||||
int channel_size = input->shape()[0] - 1;
|
||||
int time_size = input->shape()[-1] - 1;
|
||||
std::uniform_int_distribution<> dis_channel(0, channel_size);
|
||||
int random_channel = channel_size > 0 ? dis_channel(rnd) : 0;
|
||||
std::uniform_int_distribution<> dis_time(0, time_size);
|
||||
int random_time = time_size > 0 ? dis_time(rnd) : 0;
|
||||
int number_of_bits = 16;
|
||||
int up_scaling = static_cast<int>(pow(2, number_of_bits - 1) - 2);
|
||||
int down_scaling = static_cast<int>(pow(2, number_of_bits - 1));
|
||||
|
||||
std::shared_ptr<Tensor> signal_scaled;
|
||||
RETURN_IF_NOT_OK(Mul<T>(input, &signal_scaled, up_scaling));
|
||||
|
||||
std::shared_ptr<Tensor> signal_scaled_dis;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input, &signal_scaled_dis));
|
||||
|
||||
if (density_function == DensityFunction::kRPDF) {
|
||||
auto iter_in = input->begin<T>();
|
||||
iter_in += (time_size + 1) * random_channel + random_time;
|
||||
auto RPDF = *(iter_in);
|
||||
RETURN_IF_NOT_OK(Add<T>(signal_scaled, &signal_scaled_dis, RPDF));
|
||||
} else if (density_function == DensityFunction::kGPDF) {
|
||||
int num_rand_variables = 6;
|
||||
auto iter_in = input->begin<T>();
|
||||
iter_in += (time_size + 1) * random_channel + random_time;
|
||||
auto gaussian = *(iter_in);
|
||||
for (int i = 0; i < num_rand_variables; i++) {
|
||||
int rand_channel = channel_size > 0 ? dis_channel(rnd) : 0;
|
||||
int rand_time = time_size > 0 ? dis_time(rnd) : 0;
|
||||
|
||||
auto iter_in_rand = input->begin<T>();
|
||||
iter_in_rand += (time_size + 1) * rand_channel + rand_time;
|
||||
gaussian += *(iter_in_rand);
|
||||
*(iter_in_rand) = gaussian;
|
||||
}
|
||||
RETURN_IF_NOT_OK(Add<T>(signal_scaled, &signal_scaled_dis, gaussian));
|
||||
} else {
|
||||
int window_length = time_size + 1;
|
||||
std::shared_ptr<Tensor> float_bartlett;
|
||||
RETURN_IF_NOT_OK(Bartlett(&float_bartlett, window_length));
|
||||
std::shared_ptr<Tensor> type_convert_bartlett;
|
||||
RETURN_IF_NOT_OK(TypeCast(float_bartlett, &type_convert_bartlett, input->type()));
|
||||
|
||||
int rank_repeat = channel_size + 1;
|
||||
std::shared_ptr<Tensor> TPDF;
|
||||
RETURN_IF_NOT_OK(TensorRepeat<T>(type_convert_bartlett, &TPDF, rank_repeat));
|
||||
RETURN_IF_NOT_OK(TensorAdd<T>(signal_scaled, TPDF, &signal_scaled_dis));
|
||||
}
|
||||
std::shared_ptr<Tensor> quantised_signal_scaled;
|
||||
RETURN_IF_NOT_OK(TensorRound<T>(signal_scaled_dis, &quantised_signal_scaled));
|
||||
|
||||
std::shared_ptr<Tensor> quantised_signal;
|
||||
RETURN_IF_NOT_OK(Div<T>(quantised_signal_scaled, &quantised_signal, down_scaling));
|
||||
*output = quantised_signal;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status AddNoiseShaping(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::shared_ptr<Tensor> dithered_waveform;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(*output, &dithered_waveform));
|
||||
std::shared_ptr<Tensor> waveform;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input, &waveform));
|
||||
|
||||
std::shared_ptr<Tensor> error;
|
||||
RETURN_IF_NOT_OK(TensorSub<T>(dithered_waveform, waveform, &error));
|
||||
for (int i = 0; i < error->shape()[0]; i++) {
|
||||
std::shared_ptr<Tensor> err;
|
||||
RETURN_IF_NOT_OK(TensorRowAt<T>(error, &err, i));
|
||||
std::shared_ptr<Tensor> tensor_zero;
|
||||
std::vector<T> vector_zero(1, 0);
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(vector_zero, TensorShape({1}), &tensor_zero));
|
||||
std::shared_ptr<Tensor> error_offset;
|
||||
RETURN_IF_NOT_OK(TensorCat<T>(tensor_zero, err, &error_offset));
|
||||
int k = error->shape()[-1];
|
||||
std::shared_ptr<Tensor> fresh_error_offset;
|
||||
RETURN_IF_NOT_OK(SubTensor<T>(error_offset, &fresh_error_offset, k));
|
||||
RETURN_IF_NOT_OK(TensorRowReplace<T>(fresh_error_offset, &error, i));
|
||||
}
|
||||
std::shared_ptr<Tensor> noise_shaped;
|
||||
RETURN_IF_NOT_OK(TensorAdd<T>(dithered_waveform, error, &noise_shaped));
|
||||
*output = noise_shaped;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Apply dither effect.
|
||||
/// \param input/output: Tensor of shape <..., time>.
|
||||
/// \param density_function: The density function of a continuous random variable.
|
||||
/// \param noise_shaing: A filtering process that shapes the spectral energy of quantisation error.
|
||||
/// \return Status code.
|
||||
template <typename T>
|
||||
Status Dither(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, DensityFunction density_function,
|
||||
bool noise_shaping, std::mt19937 rnd) {
|
||||
TensorShape shape = input->shape();
|
||||
TensorShape new_shape({input->Size() / shape[-1], shape[-1]});
|
||||
RETURN_IF_NOT_OK(input->Reshape(new_shape));
|
||||
|
||||
RETURN_IF_NOT_OK(ApplyProbabilityDistribution<T>(input, output, density_function, rnd));
|
||||
if (noise_shaping) {
|
||||
RETURN_IF_NOT_OK(AddNoiseShaping<T>(input, output));
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(shape));
|
||||
RETURN_IF_NOT_OK(input->Reshape(shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/kernels/dither_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status DitherOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
// check input dimension, it should be greater than 0
|
||||
RETURN_IF_NOT_OK(ValidateLowRank("Dither", input, kMinAudioDim, "<..., time>"));
|
||||
|
||||
// check input type, it should be [int, float, double]
|
||||
RETURN_IF_NOT_OK(ValidateTensorNumeric("Dither", input));
|
||||
|
||||
if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
return Dither<double>(input, output, density_function_, noise_shaping_, rnd_);
|
||||
} else {
|
||||
std::shared_ptr<Tensor> float_input;
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &float_input, DataType(DataType::DE_FLOAT32)));
|
||||
return Dither<float>(float_input, output, density_function_, noise_shaping_, rnd_);
|
||||
}
|
||||
}
|
||||
|
||||
Status DitherOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
RETURN_IF_NOT_OK(ValidateTensorType("Dither", inputs[0].IsNumeric(), "[int, float, double]", inputs[0].ToString()));
|
||||
if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DITHER_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DITHER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DitherOp : public TensorOp {
|
||||
public:
|
||||
DitherOp(DensityFunction density_function, bool noise_shaping)
|
||||
: density_function_(density_function), noise_shaping_(noise_shaping) {
|
||||
rnd_.seed(GetSeed());
|
||||
}
|
||||
|
||||
~DitherOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
out << Name() << " density_function: " << density_function_ << ", noise_shaping: " << noise_shaping_ << std::endl;
|
||||
}
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kDitherOp; }
|
||||
|
||||
private:
|
||||
DensityFunction density_function_;
|
||||
bool noise_shaping_;
|
||||
std::mt19937 rnd_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DITHER_OP_H_
|
|
@ -377,6 +377,32 @@ class MS_API DetectPitchFrequency final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Dither increases the perceived dynamic range of audio stored at a
|
||||
/// particular bit-depth by eliminating nonlinear truncation distortion.
|
||||
class MS_API Dither final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] density_function The density function of a continuous random variable.
|
||||
/// Can be one of DensityFunction::kTPDF (Triangular Probability Density Function),
|
||||
/// DensityFunction::kRPDF (Rectangular Probability Density Function) or
|
||||
/// DensityFunction::kGPDF (Gaussian Probability Density Function) (Default: DensityFunction::kTPDF).
|
||||
/// \param[in] noise_shaping A filtering process that shapes the spectral energy of
|
||||
/// quantisation error (Default: false).
|
||||
explicit Dither(DensityFunction density_function = DensityFunction::kTPDF, bool noise_shaping = false);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Dither() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief EqualizerBiquad TensorTransform. Apply highpass biquad filter on audio.
|
||||
class MS_API EqualizerBiquad final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -71,6 +71,13 @@ enum class MS_API ConvertMode {
|
|||
COLOR_RGBA2GRAY = 11 ///< Convert RGBA image to GRAY image.
|
||||
};
|
||||
|
||||
// \brief Possible density function in Dither.
|
||||
enum MS_API DensityFunction {
|
||||
kTPDF = 0, ///< Use triangular probability density function.
|
||||
kRPDF = 1, ///< Use rectangular probability density function.
|
||||
kGPDF = 2 ///< Use gaussian probability density function.
|
||||
};
|
||||
|
||||
/// \brief Values of norm in CreateDct.
|
||||
enum class MS_API NormMode {
|
||||
kNone = 0, ///< None type norm.
|
||||
|
|
|
@ -161,6 +161,7 @@ constexpr char kDBToAmplitudeOp[] = " DBToAmplitudeOp";
|
|||
constexpr char kDCShiftOp[] = "DCShiftOp";
|
||||
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
|
||||
constexpr char kDetectPitchFrequencyOp[] = "DetectPitchFrequencyOp";
|
||||
constexpr char kDitherOp[] = "DitherOp";
|
||||
constexpr char kEqualizerBiquadOp[] = "EqualizerBiquadOp";
|
||||
constexpr char kFadeOp[] = "FadeOp";
|
||||
constexpr char kFlangerOp[] = "FlangerOp";
|
||||
|
|
|
@ -23,14 +23,14 @@ import numpy as np
|
|||
|
||||
import mindspore._c_dataengine as cde
|
||||
from ..transforms.c_transforms import TensorOperation
|
||||
from .utils import BorderType, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
|
||||
from .utils import BorderType, DensityFunction, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_compute_deltas, \
|
||||
check_contrast, check_db_to_amplitude, check_dc_shift, check_deemph_biquad, check_detect_pitch_frequency, \
|
||||
check_equalizer_biquad, check_fade, check_flanger, check_gain, check_highpass_biquad, check_lfilter, \
|
||||
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, check_overdrive, check_phaser, \
|
||||
check_riaa_biquad, check_sliding_window_cmn, check_spectrogram, check_time_stretch, check_treble_biquad, \
|
||||
check_vol
|
||||
check_dither, check_equalizer_biquad, check_fade, check_flanger, check_gain, check_highpass_biquad, \
|
||||
check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, check_overdrive, \
|
||||
check_phaser, check_riaa_biquad, check_sliding_window_cmn, check_spectrogram, check_time_stretch, \
|
||||
check_treble_biquad, check_vol
|
||||
|
||||
|
||||
|
||||
|
@ -495,6 +495,43 @@ class DetectPitchFrequency(AudioTensorOperation):
|
|||
self.win_length, self.freq_low, self.freq_high)
|
||||
|
||||
|
||||
DE_C_DENSITYFUNCTION_TYPE = {DensityFunction.TPDF: cde.DensityFunction.DE_DENSITYFUNCTION_TPDF,
|
||||
DensityFunction.RPDF: cde.DensityFunction.DE_DENSITYFUNCTION_RPDF,
|
||||
DensityFunction.GPDF: cde.DensityFunction.DE_DENSITYFUNCTION_GPDF}
|
||||
|
||||
|
||||
class Dither(AudioTensorOperation):
|
||||
"""
|
||||
Dither increases the perceived dynamic range of audio stored at a
|
||||
particular bit-depth by eliminating nonlinear truncation distortion.
|
||||
|
||||
Args:
|
||||
density_function (DensityFunction, optional): The density function of a continuous
|
||||
random variable. Can be one of DensityFunction.TPDF (Triangular Probability Density Function),
|
||||
DensityFunction.RPDF (Rectangular Probability Density Function) or
|
||||
DensityFunction.GPDF (Gaussian Probability Density Function)
|
||||
(default=DensityFunction.TPDF).
|
||||
noise_shaping (bool, optional): A filtering process that shapes the spectral
|
||||
energy of quantisation error (default=False).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Dither()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_dither
|
||||
def __init__(self, density_function=DensityFunction.TPDF, noise_shaping=False):
|
||||
self.density_function = density_function
|
||||
self.noise_shaping = noise_shaping
|
||||
|
||||
def parse(self):
|
||||
return cde.DitherOperation(DE_C_DENSITYFUNCTION_TYPE[self.density_function], self.noise_shaping)
|
||||
|
||||
|
||||
class EqualizerBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design biquad equalizer filter and perform filtering. Similar to SoX implementation.
|
||||
|
|
|
@ -20,6 +20,22 @@ from enum import Enum
|
|||
import mindspore._c_dataengine as cde
|
||||
|
||||
|
||||
class DensityFunction(str, Enum):
|
||||
"""
|
||||
Density Functions.
|
||||
|
||||
Possible enumeration values are: DensityFunction.TPDF, DensityFunction.GPDF,
|
||||
DensityFunction.RPDF.
|
||||
|
||||
- DensityFunction.TPDF: means triangular probability density function.
|
||||
- DensityFunction.GPDF: means gaussian probability density function.
|
||||
- DensityFunction.RPDF: means rectangular probability density function.
|
||||
"""
|
||||
TPDF: str = "TPDF"
|
||||
RPDF: str = "RPDF"
|
||||
GPDF: str = "GPDF"
|
||||
|
||||
|
||||
class FadeShape(str, Enum):
|
||||
"""
|
||||
Fade Shapes.
|
||||
|
|
|
@ -21,7 +21,7 @@ from functools import wraps
|
|||
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32, \
|
||||
check_int32_not_zero, check_list_same_size, check_non_negative_float32, check_non_negative_int32, \
|
||||
check_pos_float32, check_pos_int32, check_value, INT32_MAX, parse_user_args, type_check
|
||||
from .utils import BorderType, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
|
||||
from .utils import BorderType, DensityFunction, FadeShape, GainType, Interpolation, Modulation, ScaleType, WindowType
|
||||
|
||||
|
||||
def check_amplitude_to_db(method):
|
||||
|
@ -239,6 +239,22 @@ def check_deemph_biquad(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_dither(method):
|
||||
"""Wrapper method to check the parameters of Dither."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[density_function, noise_shaping], _ = parse_user_args(
|
||||
method, *args, **kwargs)
|
||||
|
||||
type_check(density_function, (DensityFunction), "density_function")
|
||||
type_check(noise_shaping, (bool,), "noise_shaping")
|
||||
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_equalizer_biquad(method):
|
||||
"""Wrapper method to check the parameters of EqualizerBiquad."""
|
||||
|
||||
|
|
|
@ -841,6 +841,46 @@ TEST_F(MindDataTestPipeline, TestDeemphBiquadWrongArgs) {
|
|||
EXPECT_EQ(iter01, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Dither
|
||||
/// Description: test basic usage of Dither in pipeline mode
|
||||
/// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestDitherBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDitherBasic.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto DitherOp = audio::Dither();
|
||||
|
||||
ds = ds->Map({DitherOp});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Filtered waveform by Dither
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestHighpassBiquadSuccess) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestHighpassBiquadSuccess.";
|
||||
|
||||
|
|
|
@ -1526,6 +1526,28 @@ TEST_F(MindDataTestExecute, TestDetectPitchFrequencyWithWrongArg) {
|
|||
EXPECT_FALSE(s05.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: Dither
|
||||
/// Description: test Dither in eager mode
|
||||
/// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestExecute, TestDitherWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDitherWithEager.";
|
||||
// Original waveform
|
||||
std::vector<float> labels = {
|
||||
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
|
||||
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
|
||||
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> dither_01 = std::make_shared<audio::Dither>();
|
||||
mindspore::dataset::Execute Transform01({dither_01});
|
||||
// Filtered waveform by Dither
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestFlangerWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestFlangerWithEager.";
|
||||
// Original waveform
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
from mindspore.dataset.audio.utils import DensityFunction
|
||||
from util import visualize_audio, diff_mse
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_dither_eager_noise_shaping_false():
|
||||
"""
|
||||
Feature: Dither
|
||||
Description: test Dither in eager mode
|
||||
Expectation: the result is as expected
|
||||
"""
|
||||
logger.info("test Dither in eager mode")
|
||||
|
||||
# Original waveform
|
||||
waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.99993896, 1.99990845, 2.99984741],
|
||||
[3.99975586, 4.99972534, 5.99966431]], dtype=np.float64)
|
||||
dither_op = audio.Dither(DensityFunction.TPDF, False)
|
||||
# Filtered waveform by Dither
|
||||
output = dither_op(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_dither_eager_noise_shaping_true():
|
||||
"""
|
||||
Feature: Dither
|
||||
Description: test Dither in eager mode
|
||||
Expectation: the result is as expected
|
||||
"""
|
||||
logger.info("test Dither in eager mode")
|
||||
|
||||
# Original waveform
|
||||
waveform = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.9999, 1.9998, 2.9998],
|
||||
[3.9998, 4.9995, 5.9994],
|
||||
[6.9996, 7.9991, 8.9990]], dtype=np.float64)
|
||||
dither_op = audio.Dither(DensityFunction.TPDF, True)
|
||||
# Filtered waveform by Dither
|
||||
output = dither_op(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_dither_pipeline(plot=False):
|
||||
"""
|
||||
Feature: Dither
|
||||
Description: test Dither in pipeline mode
|
||||
Expectation: the result is as expected
|
||||
"""
|
||||
logger.info("test Dither in pipeline mode")
|
||||
|
||||
# Original waveform
|
||||
waveform_tpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
|
||||
[0.23657, 0.087965, 0.43579]], dtype=np.float64)
|
||||
waveform_rpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
|
||||
[0.23657, 0.087965, 0.43579]], dtype=np.float64)
|
||||
waveform_gpdf = np.array([[0.4941969, 0.53911686, 0.4846254], [0.10841596, 0.029320478, 0.52353495],
|
||||
[0.23657, 0.087965, 0.43579]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_tpdf = np.array([[0.49417114, 0.53909302, 0.48461914],
|
||||
[0.10839844, 0.02932739, 0.52352905],
|
||||
[0.23654175, 0.08798218, 0.43579102]], dtype=np.float64)
|
||||
expect_rpdf = np.array([[0.4941, 0.5391, 0.4846],
|
||||
[0.1084, 0.0293, 0.5235],
|
||||
[0.2365, 0.0880, 0.4358]], dtype=np.float64)
|
||||
expect_gpdf = np.array([[0.4944, 0.5393, 0.4848],
|
||||
[0.1086, 0.0295, 0.5237],
|
||||
[0.2368, 0.0882, 0.4360]], dtype=np.float64)
|
||||
dataset_tpdf = ds.NumpySlicesDataset(waveform_tpdf, ["audio"], shuffle=False)
|
||||
dataset_rpdf = ds.NumpySlicesDataset(waveform_rpdf, ["audio"], shuffle=False)
|
||||
dataset_gpdf = ds.NumpySlicesDataset(waveform_gpdf, ["audio"], shuffle=False)
|
||||
|
||||
# Filtered waveform by Dither of TPDF
|
||||
dither_tpdf = audio.Dither()
|
||||
dataset_tpdf = dataset_tpdf.map(input_columns=["audio"], operations=dither_tpdf, num_parallel_workers=2)
|
||||
|
||||
# Filtered waveform by Dither of RPDF
|
||||
dither_rpdf = audio.Dither(DensityFunction.RPDF, False)
|
||||
dataset_rpdf = dataset_rpdf.map(input_columns=["audio"], operations=dither_rpdf, num_parallel_workers=2)
|
||||
|
||||
# Filtered waveform by Dither of GPDF
|
||||
dither_gpdf = audio.Dither(DensityFunction.GPDF, False)
|
||||
dataset_gpdf = dataset_gpdf.map(input_columns=["audio"], operations=dither_gpdf, num_parallel_workers=2)
|
||||
|
||||
i = 0
|
||||
for data1, data2, data3 in zip(dataset_tpdf.create_dict_iterator(output_numpy=True),
|
||||
dataset_rpdf.create_dict_iterator(output_numpy=True),
|
||||
dataset_gpdf.create_dict_iterator(output_numpy=True)):
|
||||
count_unequal_element(expect_tpdf[i, :], data1['audio'], 0.0001, 0.0001)
|
||||
dither_rpdf = data2['audio']
|
||||
dither_gpdf = data3['audio']
|
||||
mse_rpdf = diff_mse(dither_rpdf, expect_rpdf[i, :])
|
||||
logger.info("dither_rpdf_{}, mse: {}".format(i + 1, mse_rpdf))
|
||||
mse_gpdf = diff_mse(dither_gpdf, expect_gpdf[i, :])
|
||||
logger.info("dither_gpdf_{}, mse: {}".format(i + 1, mse_gpdf))
|
||||
i += 1
|
||||
if plot:
|
||||
visualize_audio(dither_rpdf, expect_rpdf[i, :])
|
||||
visualize_audio(dither_gpdf, expect_gpdf[i, :])
|
||||
|
||||
|
||||
def test_invalid_dither_input():
|
||||
"""
|
||||
Feature: Dither
|
||||
Description: test param check of Dither
|
||||
Expectation: throw correct error and message
|
||||
"""
|
||||
logger.info("test param check of Dither")
|
||||
|
||||
def test_invalid_input(test_name, density_function, noise_shaping, error, error_msg):
|
||||
logger.info("Test Dither with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.Dither(density_function, noise_shaping)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid density function parameter value", "TPDF", False, TypeError,
|
||||
"Argument density_function with value TPDF is not of type"
|
||||
+ " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
|
||||
+ ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'str'>.")
|
||||
|
||||
test_invalid_input("invalid density_function parameter value", 6, False, TypeError,
|
||||
"Argument density_function with value 6 is not of type"
|
||||
+ " [<DensityFunction.TPDF: 'TPDF'>, <DensityFunction.RPDF: 'RPDF'>"
|
||||
+ ", <DensityFunction.GPDF: 'GPDF'>], but got <class 'int'>.")
|
||||
|
||||
test_invalid_input("invalid noise_shaping parameter value", DensityFunction.GPDF, 1, TypeError,
|
||||
"Argument noise_shaping with value 1 is not of type [<class 'bool'>], but got <class 'int'>.")
|
||||
|
||||
test_invalid_input("invalid noise_shaping parameter value", DensityFunction.RPDF, "true", TypeError,
|
||||
"Argument noise_shaping with value true is not of type [<class 'bool'>], but got <class 'str'>")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dither_eager_noise_shaping_false()
|
||||
test_dither_eager_noise_shaping_true()
|
||||
test_dither_pipeline(plot=False)
|
||||
test_invalid_dither_input()
|
|
@ -207,6 +207,26 @@ def diff_me(in1, in2):
|
|||
return mse / 255 * 100
|
||||
|
||||
|
||||
def visualize_audio(waveform, expect_waveform):
|
||||
"""
|
||||
Visualizes audio waveform.
|
||||
"""
|
||||
plt.figure(1)
|
||||
plt.subplot(1, 3, 1)
|
||||
plt.imshow(waveform)
|
||||
plt.title("waveform")
|
||||
|
||||
plt.subplot(1, 3, 2)
|
||||
plt.imshow(expect_waveform)
|
||||
plt.title("expect waveform")
|
||||
|
||||
plt.subplot(1, 3, 3)
|
||||
plt.imshow(waveform - expect_waveform)
|
||||
plt.title("difference")
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def visualize_one_channel_dataset(images_original, images_transformed, labels):
|
||||
"""
|
||||
Helper function to visualize one channel grayscale images
|
||||
|
|
Loading…
Reference in New Issue