forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3CEG3] add new data OP MuLawEncoding
This commit is contained in:
parent
ec981124d0
commit
75e2c3041c
|
@ -38,6 +38,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
|
@ -401,6 +402,18 @@ std::shared_ptr<TensorOperation> MuLawDecoding::Parse() {
|
|||
return std::make_shared<MuLawDecodingOperation>(data_->quantization_channels_);
|
||||
}
|
||||
|
||||
// MuLawEncoding Transform Operation.
|
||||
struct MuLawEncoding::Data {
|
||||
explicit Data(int32_t quantization_channels) : quantization_channels_(quantization_channels) {}
|
||||
int32_t quantization_channels_;
|
||||
};
|
||||
|
||||
MuLawEncoding::MuLawEncoding(int32_t quantization_channels) : data_(std::make_shared<Data>(quantization_channels)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> MuLawEncoding::Parse() {
|
||||
return std::make_shared<MuLawEncodingOperation>(data_->quantization_channels_);
|
||||
}
|
||||
|
||||
// Overdrive Transform Operation.
|
||||
struct Overdrive::Data {
|
||||
Data(float gain, float color) : gain_(gain), color_(color) {}
|
||||
|
|
|
@ -42,6 +42,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
|
@ -319,13 +320,24 @@ PYBIND_REGISTER(
|
|||
MuLawDecodingOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::MuLawDecodingOperation, TensorOperation, std::shared_ptr<audio::MuLawDecodingOperation>>(
|
||||
*m, "MuLawDecodingOperation")
|
||||
.def(py::init([](int quantization_channels) {
|
||||
.def(py::init([](int32_t quantization_channels) {
|
||||
auto mu_law_decoding = std::make_shared<audio::MuLawDecodingOperation>(quantization_channels);
|
||||
THROW_IF_ERROR(mu_law_decoding->ValidateParams());
|
||||
return mu_law_decoding;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
MuLawEncodingOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::MuLawEncodingOperation, TensorOperation, std::shared_ptr<audio::MuLawEncodingOperation>>(
|
||||
*m, "MuLawEncodingOperation")
|
||||
.def(py::init([](int32_t quantization_channels) {
|
||||
auto mu_law_encoding = std::make_shared<audio::MuLawEncodingOperation>(quantization_channels);
|
||||
THROW_IF_ERROR(mu_law_encoding->ValidateParams());
|
||||
return mu_law_encoding;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(OverdriveOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<audio::OverdriveOperation, TensorOperation, std::shared_ptr<audio::OverdriveOperation>>(
|
||||
|
|
|
@ -24,6 +24,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
lowpass_biquad_ir.cc
|
||||
magphase_ir.cc
|
||||
mu_law_decoding_ir.cc
|
||||
mu_law_encoding_ir.cc
|
||||
overdrive_ir.cc
|
||||
riaa_biquad_ir.cc
|
||||
time_masking_ir.cc
|
||||
|
|
|
@ -21,13 +21,13 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
MuLawDecodingOperation::MuLawDecodingOperation(int quantization_channels)
|
||||
MuLawDecodingOperation::MuLawDecodingOperation(int32_t quantization_channels)
|
||||
: quantization_channels_(quantization_channels) {}
|
||||
|
||||
MuLawDecodingOperation::~MuLawDecodingOperation() = default;
|
||||
|
||||
Status MuLawDecodingOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("MuLawEncoding", "quantization_channels", quantization_channels_));
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("MuLawDecoding", "quantization_channels", quantization_channels_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ constexpr char kMuLawDecodingOperation[] = "MuLawDecoding";
|
|||
|
||||
class MuLawDecodingOperation : public TensorOperation {
|
||||
public:
|
||||
explicit MuLawDecodingOperation(int quantization_channels);
|
||||
explicit MuLawDecodingOperation(int32_t quantization_channels);
|
||||
|
||||
~MuLawDecodingOperation();
|
||||
|
||||
|
@ -44,7 +44,7 @@ class MuLawDecodingOperation : public TensorOperation {
|
|||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int quantization_channels_;
|
||||
int32_t quantization_channels_;
|
||||
}; // class MuLawDecodingOperation
|
||||
|
||||
} // namespace audio
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/mu_law_encoding_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
MuLawEncodingOperation::MuLawEncodingOperation(int32_t quantization_channels)
|
||||
: quantization_channels_(quantization_channels) {}
|
||||
|
||||
MuLawEncodingOperation::~MuLawEncodingOperation() = default;
|
||||
|
||||
Status MuLawEncodingOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateIntScalarPositive("MuLawEncoding", "quantization_channels", quantization_channels_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MuLawEncodingOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["quantization_channels"] = quantization_channels_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> MuLawEncodingOperation::Build() {
|
||||
std::shared_ptr<MuLawEncodingOp> tensor_op = std::make_shared<MuLawEncodingOp>(quantization_channels_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
std::string MuLawEncodingOperation::Name() const { return kMuLawEncodingOperation; }
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,53 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_ENCODING_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_ENCODING_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
|
||||
constexpr char kMuLawEncodingOperation[] = "MuLawEncoding";
|
||||
|
||||
class MuLawEncodingOperation : public TensorOperation {
|
||||
public:
|
||||
explicit MuLawEncodingOperation(int32_t quantization_channels);
|
||||
|
||||
~MuLawEncodingOperation();
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t quantization_channels_;
|
||||
}; // class MuLawEncodingOperation
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MU_LAW_ENCODING_IR_H_
|
|
@ -25,6 +25,7 @@ add_library(audio-kernels OBJECT
|
|||
lowpass_biquad_op.cc
|
||||
magphase_op.cc
|
||||
mu_law_decoding_op.cc
|
||||
mu_law_encoding_op.cc
|
||||
overdrive_op.cc
|
||||
riaa_biquad_op.cc
|
||||
time_masking_op.cc
|
||||
|
|
|
@ -16,13 +16,10 @@
|
|||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
||||
#include <complex>
|
||||
|
||||
#include "mindspore/core/base/float16.h"
|
||||
#include "minddata/dataset/core/type_id.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
#include "minddata/dataset/util/random.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -493,8 +490,10 @@ Status Decoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int quantization_channels) {
|
||||
if (input->type().value() >= DataType::DE_INT8 && input->type().value() <= DataType::DE_FLOAT32) {
|
||||
Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||
int32_t quantization_channels) {
|
||||
if (input->type().IsInt() || input->type() == DataType(DataType::DE_FLOAT16) ||
|
||||
input->type() == DataType(DataType::DE_FLOAT32)) {
|
||||
float f_mu = static_cast<float>(quantization_channels) - 1;
|
||||
|
||||
// convert the data type to float
|
||||
|
@ -502,7 +501,7 @@ Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
|||
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
|
||||
|
||||
RETURN_IF_NOT_OK(Decoding<float>(input_tensor, output, f_mu));
|
||||
} else if (input->type().value() == DataType::DE_FLOAT64) {
|
||||
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
double f_mu = static_cast<double>(quantization_channels) - 1;
|
||||
|
||||
RETURN_IF_NOT_OK(Decoding<double>(input, output, f_mu));
|
||||
|
@ -513,6 +512,49 @@ Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Encoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, T mu) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), DataType(DataType::DE_INT32), output));
|
||||
auto itr_out = (*output)->begin<int32_t>();
|
||||
auto itr = input->begin<T>();
|
||||
auto end = input->end<T>();
|
||||
|
||||
while (itr != end) {
|
||||
auto x = *itr;
|
||||
x = sgn(x) * log1p(mu * fabs(x)) / log1p(mu);
|
||||
x = (x + 1) / 2 * mu + 0.5;
|
||||
*itr_out = static_cast<int32_t>(x);
|
||||
++itr_out;
|
||||
++itr;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MuLawEncoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||
int32_t quantization_channels) {
|
||||
if (input->type().IsInt() || input->type() == DataType(DataType::DE_FLOAT16)) {
|
||||
float f_mu = static_cast<float>(quantization_channels) - 1;
|
||||
|
||||
// convert the data type to float
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
|
||||
|
||||
RETURN_IF_NOT_OK(Encoding<float>(input_tensor, output, f_mu));
|
||||
} else if (input->type() == DataType(DataType::DE_FLOAT32)) {
|
||||
float f_mu = static_cast<float>(quantization_channels) - 1;
|
||||
|
||||
RETURN_IF_NOT_OK(Encoding<float>(input, output, f_mu));
|
||||
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
double f_mu = static_cast<double>(quantization_channels) - 1;
|
||||
|
||||
RETURN_IF_NOT_OK(Encoding<double>(input, output, f_mu));
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("MuLawEncoding: input tensor type should be int, float or double, but got: " +
|
||||
input->type().ToString());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status FadeIn(std::shared_ptr<Tensor> *output, int32_t fade_in_len, FadeShape fade_shape) {
|
||||
T start = 0;
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
@ -316,7 +317,16 @@ Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
|
|||
/// \param output Tensor of shape <..., time>.
|
||||
/// \param quantization_channels Number of channels.
|
||||
/// \return Status code.
|
||||
Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int quantization_channels);
|
||||
Status MuLawDecoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||
int32_t quantization_channels);
|
||||
|
||||
/// \brief Encode signal based on mu-law companding.
|
||||
/// \param input Tensor of shape <..., time>.
|
||||
/// \param output Tensor of shape <..., time>.
|
||||
/// \param quantization_channels Number of channels.
|
||||
/// \return Status code.
|
||||
Status MuLawEncoding(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||
int32_t quantization_channels);
|
||||
|
||||
/// \brief Apply a overdrive effect to the audio.
|
||||
/// \param input Tensor of shape <..., time>.
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
// constructor
|
||||
MuLawDecodingOp::MuLawDecodingOp(int quantization_channels) : quantization_channels_(quantization_channels) {}
|
||||
MuLawDecodingOp::MuLawDecodingOp(int32_t quantization_channels) : quantization_channels_(quantization_channels) {}
|
||||
|
||||
// main function
|
||||
Status MuLawDecodingOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
|
@ -28,7 +28,7 @@ Status MuLawDecodingOp::Compute(const std::shared_ptr<Tensor> &input, std::share
|
|||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->Rank() >= 1, "MuLawDecoding: input tensor is not in shape of <..., time>.");
|
||||
|
||||
if (input->type().value() >= DataType::DE_INT8 && input->type().value() <= DataType::DE_FLOAT64) {
|
||||
if (input->type().IsNumeric()) {
|
||||
return MuLawDecoding(input, output, quantization_channels_);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("MuLawDecoding: input tensor type should be int, float or double, but got: " +
|
||||
|
@ -40,7 +40,8 @@ Status MuLawDecodingOp::OutputType(const std::vector<DataType> &inputs, std::vec
|
|||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else if (inputs[0] >= DataType(DataType::DE_INT8) || inputs[0] <= DataType(DataType::DE_FLOAT32)) {
|
||||
} else if (inputs[0].IsInt() || inputs[0] == DataType(DataType::DE_FLOAT16) ||
|
||||
inputs[0] == DataType(DataType::DE_FLOAT32)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("MuLawDecoding: input tensor type should be int, float or double, but got: " +
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace dataset {
|
|||
|
||||
class MuLawDecodingOp : public TensorOp {
|
||||
public:
|
||||
explicit MuLawDecodingOp(int quantization_channels = 256);
|
||||
explicit MuLawDecodingOp(int32_t quantization_channels = 256);
|
||||
|
||||
~MuLawDecodingOp() override = default;
|
||||
|
||||
|
@ -39,7 +39,7 @@ class MuLawDecodingOp : public TensorOp {
|
|||
std::string Name() const override { return kMuLawDecodingOp; }
|
||||
|
||||
private:
|
||||
int quantization_channels_;
|
||||
int32_t quantization_channels_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/mu_law_encoding_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// constructor
|
||||
MuLawEncodingOp::MuLawEncodingOp(int32_t quantization_channels) : quantization_channels_(quantization_channels) {}
|
||||
|
||||
// main function
|
||||
Status MuLawEncodingOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->Rank() >= 1, "MuLawEncoding: input tensor is not in shape of <..., time>.");
|
||||
|
||||
if (input->type().IsNumeric()) {
|
||||
return MuLawEncoding(input, output, quantization_channels_);
|
||||
} else {
|
||||
RETURN_STATUS_UNEXPECTED("MuLawEncoding: input tensor type should be int, float or double, but got: " +
|
||||
input->type().ToString());
|
||||
}
|
||||
}
|
||||
|
||||
Status MuLawEncodingOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
outputs[0] = DataType(DataType::DE_INT32);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MU_LAW_ENCODING_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MU_LAW_ENCODING_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class MuLawEncodingOp : public TensorOp {
|
||||
public:
|
||||
/// \brief Constructor for MuLawEncoding.
|
||||
/// \param[in] quantization_channels Number of channels.
|
||||
explicit MuLawEncodingOp(int32_t quantization_channels = 256);
|
||||
|
||||
~MuLawEncodingOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kMuLawEncodingOp; }
|
||||
|
||||
private:
|
||||
int32_t quantization_channels_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MU_LAW_ENCODING_OP_H_
|
|
@ -528,7 +528,7 @@ class MuLawDecoding final : public TensorTransform {
|
|||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] quantization_channels Number of channels, which must be positive (Default: 256).
|
||||
explicit MuLawDecoding(int quantization_channels = 256);
|
||||
explicit MuLawDecoding(int32_t quantization_channels = 256);
|
||||
|
||||
/// \brief Destructor.
|
||||
~MuLawDecoding() = default;
|
||||
|
@ -543,6 +543,27 @@ class MuLawDecoding final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief MuLawEncoding TensorTransform.
|
||||
/// \note Encode signal based on mu-law companding.
|
||||
class MuLawEncoding final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] quantization_channels Number of channels, which must be positive (Default: 256).
|
||||
explicit MuLawEncoding(int32_t quantization_channels = 256);
|
||||
|
||||
/// \brief Destructor.
|
||||
~MuLawEncoding() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Overdrive TensorTransform.
|
||||
class Overdrive final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -165,6 +165,7 @@ constexpr char kLFilterOp[] = "LFilterOp";
|
|||
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
|
||||
constexpr char kMagphaseOp[] = "MagphaseOp";
|
||||
constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
|
||||
constexpr char kMuLawEncodingOp[] = "MuLawEncodingOp";
|
||||
constexpr char kOverdriveOp[] = "OverdriveOp";
|
||||
constexpr char kRiaaBiquadOp[] = "RiaaBiquadOp";
|
||||
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
|
||||
|
|
|
@ -27,7 +27,7 @@ from .utils import FadeShape, GainType, Interpolation, Modulation, ScaleType
|
|||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_flanger, \
|
||||
check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, \
|
||||
check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, \
|
||||
check_overdrive, check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol
|
||||
|
||||
|
||||
|
@ -711,7 +711,8 @@ class MuLawDecoding(AudioTensorOperation):
|
|||
>>> transforms = [audio.MuLawDecoding()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
@check_mu_law_decoding
|
||||
|
||||
@check_mu_law_coding
|
||||
def __init__(self, quantization_channels=256):
|
||||
self.quantization_channels = quantization_channels
|
||||
|
||||
|
@ -719,6 +720,30 @@ class MuLawDecoding(AudioTensorOperation):
|
|||
return cde.MuLawDecodingOperation(self.quantization_channels)
|
||||
|
||||
|
||||
class MuLawEncoding(AudioTensorOperation):
|
||||
"""
|
||||
Encode signal based on mu-law companding.
|
||||
|
||||
Args:
|
||||
quantization_channels (int): Number of channels, which must be positive (Default: 256).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([0.1, 0.3, 0.4])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.MuLawEncoding()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_mu_law_coding
|
||||
def __init__(self, quantization_channels=256):
|
||||
self.quantization_channels = quantization_channels
|
||||
|
||||
def parse(self):
|
||||
return cde.MuLawEncodingOperation(self.quantization_channels)
|
||||
|
||||
|
||||
class Overdrive(AudioTensorOperation):
|
||||
"""
|
||||
Apply overdrive on input audio.
|
||||
|
|
|
@ -280,8 +280,8 @@ def check_lowpass_biquad(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_mu_law_decoding(method):
|
||||
"""Wrapper method to check the parameters of MuLawDecoding"""
|
||||
def check_mu_law_coding(method):
|
||||
"""Wrapper method to check the parameters of MuLawDecoding and MuLawEncoding"""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
|
|
|
@ -902,16 +902,16 @@ TEST_F(MindDataTestPipeline, TestMuLawDecodingBasic) {
|
|||
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeInt64, {1, 100}));
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeInt32, {1, 100}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto MuLawDecodingOp = audio::MuLawDecoding();
|
||||
auto mu_law_decoding_op = audio::MuLawDecoding();
|
||||
|
||||
ds = ds->Map({MuLawDecodingOp});
|
||||
ds = ds->Map({mu_law_decoding_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Filtered waveform by MuLawDecoding
|
||||
|
@ -925,7 +925,7 @@ TEST_F(MindDataTestPipeline, TestMuLawDecodingBasic) {
|
|||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["inputData"];
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
@ -941,18 +941,98 @@ TEST_F(MindDataTestPipeline, TestMuLawDecodingWrongArgs) {
|
|||
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeInt64, {1, 100}));
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeInt32, {1, 100}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto MuLawDecodingOp = audio::MuLawDecoding(-10);
|
||||
// quantization_channels is negative
|
||||
auto mu_law_decoding_op1 = audio::MuLawDecoding(-10);
|
||||
|
||||
ds = ds->Map({MuLawDecodingOp});
|
||||
ds = ds->Map({mu_law_decoding_op1});
|
||||
std::shared_ptr<Iterator> iter1 = ds->CreateIterator();
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
|
||||
// quantization_channels is 0
|
||||
auto mu_law_decoding_op2 = audio::MuLawDecoding(0);
|
||||
|
||||
ds = ds->Map({mu_law_decoding_op2});
|
||||
std::shared_ptr<Iterator> iter2 = ds->CreateIterator();
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: MuLawEncoding
|
||||
/// Description: test MuLawEncoding in pipeline mode
|
||||
/// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestMuLawEncodingBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMuLawEncodingBasic.";
|
||||
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {1, 100}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto mu_law_encoding_op = audio::MuLawEncoding();
|
||||
|
||||
ds = ds->Map({mu_law_encoding_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Filtered waveform by MuLawEncoding
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {1, 100};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeInt32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: MuLawEncoding
|
||||
/// Description: test invalid parameter of MuLawEncoding
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestPipeline, TestMuLawEncodingWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMuLawEncodingWrongArgs.";
|
||||
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {1, 100}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// quantization_channels is negative
|
||||
auto mu_law_encoding_op1 = audio::MuLawEncoding(-10);
|
||||
|
||||
ds = ds->Map({mu_law_encoding_op1});
|
||||
std::shared_ptr<Iterator> iter1 = ds->CreateIterator();
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
|
||||
// quantization_channels is 0
|
||||
auto mu_law_encoding_op2 = audio::MuLawEncoding(0);
|
||||
|
||||
ds = ds->Map({mu_law_encoding_op2});
|
||||
std::shared_ptr<Iterator> iter2 = ds->CreateIterator();
|
||||
EXPECT_EQ(iter1, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Overdrive
|
||||
|
|
|
@ -864,15 +864,33 @@ TEST_F(MindDataTestExecute, TestHighpassBiquadParamCheckSampleRate) {
|
|||
TEST_F(MindDataTestExecute, TestMuLawDecodingEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestMuLawDecodingEager.";
|
||||
// testing
|
||||
std::shared_ptr<Tensor> input_tensor_;
|
||||
Tensor::CreateFromVector(std::vector<float>({1, 254, 231, 155, 101, 77}), TensorShape({1, 6}), &input_tensor_);
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateFromVector(std::vector<float>({1, 254, 231, 155, 101, 77}), TensorShape({1, 6}), &input_tensor);
|
||||
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor));
|
||||
std::shared_ptr<TensorTransform> mu_law_encoding_01 = std::make_shared<audio::MuLawDecoding>(255);
|
||||
|
||||
// Filtered waveform by mulawencoding
|
||||
mindspore::dataset::Execute Transform01({mu_law_encoding_01});
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
Status s01 = Transform01(input_01, &input_01);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: MuLawEncoding
|
||||
/// Description: test MuLawEncoding in eager mode
|
||||
/// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestExecute, TestMuLawEncodingEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestMuLawEncodingEager.";
|
||||
// testing
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
Tensor::CreateFromVector(std::vector<float>({0.1, 0.2, 0.3, 0.4, 0.5, 0.6}), TensorShape({1, 6}), &input_tensor);
|
||||
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor));
|
||||
std::shared_ptr<TensorTransform> mu_law_encoding_01 = std::make_shared<audio::MuLawEncoding>(255);
|
||||
|
||||
// Filtered waveform by mulawencoding
|
||||
mindspore::dataset::Execute Transform01({mu_law_encoding_01});
|
||||
Status s01 = Transform01(input_01, &input_01);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
|
|
|
@ -52,9 +52,9 @@ def test_mu_law_decoding_eager():
|
|||
logger.info("Test MuLawDecoding callable.")
|
||||
|
||||
input_t = np.array([70, 170])
|
||||
output_t = audio.MuLawDecoding()(input_t)
|
||||
output_t = audio.MuLawDecoding(128)(input_t)
|
||||
assert output_t.shape == (2,)
|
||||
excepted = np.array([-0.04388953000307083, 0.02097884565591812])
|
||||
excepted = np.array([0.00506480922922492, 26.928272247314453])
|
||||
assert np.array_equal(output_t, excepted)
|
||||
|
||||
logger.info("Finish testing MuLawDecoding.")
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing MuLawEncoding op in DE.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def test_mu_law_encoding():
|
||||
"""
|
||||
Feature: MuLawEncoding
|
||||
Description: test MuLawEncoding in pipeline mode
|
||||
Expectation: the data is processed successfully
|
||||
"""
|
||||
logger.info("Test MuLawEncoding.")
|
||||
|
||||
def gen():
|
||||
data = np.array([[0.1, 0.2, 0.3, 0.4]])
|
||||
yield (np.array(data, dtype=np.float32),)
|
||||
|
||||
dataset = ds.GeneratorDataset(source=gen, column_names=["multi_dim_data"])
|
||||
|
||||
dataset = dataset.map(operations=audio.MuLawEncoding(), input_columns=["multi_dim_data"])
|
||||
|
||||
for i in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
assert i["multi_dim_data"].shape == (1, 4)
|
||||
expected = np.array([[203, 218, 228, 234]])
|
||||
assert np.array_equal(i["multi_dim_data"], expected)
|
||||
|
||||
logger.info("Finish testing MuLawEncoding.")
|
||||
|
||||
|
||||
def test_mu_law_encoding_eager():
|
||||
"""
|
||||
Feature: MuLawEncoding
|
||||
Description: test MuLawEncoding in eager mode
|
||||
Expectation: the data is processed successfully
|
||||
"""
|
||||
logger.info("Test MuLawEncoding callable.")
|
||||
|
||||
input_t = np.array([[0.1, 0.2, 0.3, 0.4]])
|
||||
output_t = audio.MuLawEncoding(128)(input_t)
|
||||
assert output_t.shape == (1, 4)
|
||||
expected = np.array([[98, 106, 111, 115]])
|
||||
assert np.array_equal(output_t, expected)
|
||||
|
||||
logger.info("Finish testing MuLawEncoding.")
|
||||
|
||||
|
||||
def test_mu_law_encoding_uncallable():
|
||||
"""
|
||||
Feature: MuLawEncoding
|
||||
Description: test param check of MuLawEncoding
|
||||
Expectation: throw correct error and message
|
||||
"""
|
||||
logger.info("Test MuLawEncoding not callable.")
|
||||
|
||||
try:
|
||||
input_t = np.random.rand(2, 4)
|
||||
output_t = audio.MuLawEncoding(-3)(input_t)
|
||||
assert output_t.shape == (2, 4)
|
||||
except ValueError as e:
|
||||
assert 'Input quantization_channels is not within the required interval of [1, 2147483647].' in str(e)
|
||||
|
||||
logger.info("Finish testing MuLawEncoding.")
|
||||
|
||||
|
||||
def test_mu_law_encoding_and_decoding():
|
||||
"""
|
||||
Feature: MuLawEncoding and MuLawDecoding
|
||||
Description: test MuLawEncoding and MuLawDecoding in eager mode
|
||||
Expectation: the data is processed successfully
|
||||
"""
|
||||
logger.info("Test MuLawEncoding and MuLawDecoding callable.")
|
||||
|
||||
input_t = np.array([[98, 106, 111, 115]])
|
||||
output_decoding = audio.MuLawDecoding(128)(input_t)
|
||||
output_encoding = audio.MuLawEncoding(128)(output_decoding)
|
||||
assert np.array_equal(input_t, output_encoding)
|
||||
|
||||
logger.info("Finish testing MuLawEncoding and MuLawDecoding callable.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_mu_law_encoding()
|
||||
test_mu_law_encoding_eager()
|
||||
test_mu_law_encoding_uncallable()
|
||||
test_mu_law_encoding_and_decoding()
|
Loading…
Reference in New Issue