forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6UV] add new audio operator Phaser
This commit is contained in:
parent
70cd1c77d5
commit
d13d2413f5
|
@ -40,6 +40,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/phaser_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
|
@ -427,6 +428,35 @@ std::shared_ptr<TensorOperation> Overdrive::Parse() {
|
|||
return std::make_shared<OverdriveOperation>(data_->gain_, data_->color_);
|
||||
}
|
||||
|
||||
// Phaser Transform Operation.
|
||||
struct Phaser::Data {
|
||||
Data(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed,
|
||||
bool sinusoidal)
|
||||
: sample_rate_(sample_rate),
|
||||
gain_in_(gain_in),
|
||||
gain_out_(gain_out),
|
||||
delay_ms_(delay_ms),
|
||||
decay_(decay),
|
||||
mod_speed_(mod_speed),
|
||||
sinusoidal_(sinusoidal) {}
|
||||
int32_t sample_rate_;
|
||||
float gain_in_;
|
||||
float gain_out_;
|
||||
float delay_ms_;
|
||||
float decay_;
|
||||
float mod_speed_;
|
||||
bool sinusoidal_;
|
||||
};
|
||||
|
||||
Phaser::Phaser(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed,
|
||||
bool sinusoidal)
|
||||
: data_(std::make_shared<Data>(sample_rate, gain_in, gain_out, delay_ms, decay, mod_speed, sinusoidal)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Phaser::Parse() {
|
||||
return std::make_shared<PhaserOperation>(data_->sample_rate_, data_->gain_in_, data_->gain_out_, data_->delay_ms_,
|
||||
data_->decay_, data_->mod_speed_, data_->sinusoidal_);
|
||||
}
|
||||
|
||||
// RiaaBiquad Transform Operation.
|
||||
struct RiaaBiquad::Data {
|
||||
explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {}
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_encoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/overdrive_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/phaser_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/riaa_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
|
@ -349,6 +350,18 @@ PYBIND_REGISTER(OverdriveOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(PhaserOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::PhaserOperation, TensorOperation, std::shared_ptr<audio::PhaserOperation>>(
|
||||
*m, "PhaserOperation")
|
||||
.def(py::init([](int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay,
|
||||
float mod_speed, bool sinusoidal) {
|
||||
auto phaser = std::make_shared<audio::PhaserOperation>(sample_rate, gain_in, gain_out, delay_ms,
|
||||
decay, mod_speed, sinusoidal);
|
||||
THROW_IF_ERROR(phaser->ValidateParams());
|
||||
return phaser;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
RiaaBiquadOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::RiaaBiquadOperation, TensorOperation, std::shared_ptr<audio::RiaaBiquadOperation>>(
|
||||
|
|
|
@ -26,6 +26,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
mu_law_decoding_ir.cc
|
||||
mu_law_encoding_ir.cc
|
||||
overdrive_ir.cc
|
||||
phaser_ir.cc
|
||||
riaa_biquad_ir.cc
|
||||
time_masking_ir.cc
|
||||
time_stretch_ir.cc
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/phaser_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/phaser_op.h"
|
||||
#include "minddata/dataset/kernels/ir/validators.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
PhaserOperation::PhaserOperation(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay,
|
||||
float mod_speed, bool sinusoidal)
|
||||
: sample_rate_(sample_rate),
|
||||
gain_in_(gain_in),
|
||||
gain_out_(gain_out),
|
||||
delay_ms_(delay_ms),
|
||||
decay_(decay),
|
||||
mod_speed_(mod_speed),
|
||||
sinusoidal_(sinusoidal) {}
|
||||
|
||||
Status PhaserOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Phaser", "gain_in", gain_in_, {0.0f, 1.0f}, false, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Phaser", "gain_out", gain_out_, {0.0f, 1e9f}, false, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Phaser", "delay_ms", delay_ms_, {0.0f, 5.0f}, false, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Phaser", "decay", decay_, {0.0f, 0.99f}, false, false));
|
||||
RETURN_IF_NOT_OK(ValidateScalar("Phaser", "mod_speed", mod_speed_, {0.1f, 2.0f}, false, false));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> PhaserOperation::Build() {
|
||||
std::shared_ptr<PhaserOp> tensor_op =
|
||||
std::make_shared<PhaserOp>(sample_rate_, gain_in_, gain_out_, delay_ms_, decay_, mod_speed_, sinusoidal_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status PhaserOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["sample_rate"] = sample_rate_;
|
||||
args["gain_in"] = gain_in_;
|
||||
args["gain_out"] = gain_out_;
|
||||
args["delay_ms"] = delay_ms_;
|
||||
args["decay"] = decay_;
|
||||
args["mod_speed"] = mod_speed_;
|
||||
args["sinusoidal"] = sinusoidal_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASER_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASER_IR_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kPhaserOperation[] = "Phaser";
|
||||
|
||||
class PhaserOperation : public TensorOperation {
|
||||
public:
|
||||
PhaserOperation(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed,
|
||||
bool sinusoidal);
|
||||
|
||||
~PhaserOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kPhaserOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
int32_t sample_rate_;
|
||||
float gain_in_;
|
||||
float gain_out_;
|
||||
float delay_ms_;
|
||||
float decay_;
|
||||
float mod_speed_;
|
||||
bool sinusoidal_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_PHASER_IR_H_
|
|
@ -27,6 +27,7 @@ add_library(audio-kernels OBJECT
|
|||
mu_law_decoding_op.cc
|
||||
mu_law_encoding_op.cc
|
||||
overdrive_op.cc
|
||||
phaser_op.cc
|
||||
riaa_biquad_op.cc
|
||||
time_masking_op.cc
|
||||
time_stretch_op.cc
|
||||
|
|
|
@ -655,6 +655,89 @@ Status DetectPitchFrequency(const std::shared_ptr<Tensor> &input, std::shared_pt
|
|||
Status GenerateWaveTable(std::shared_ptr<Tensor> *output, const DataType &type, Modulation modulation,
|
||||
int32_t table_size, float min, float max, float phase);
|
||||
|
||||
/// \brief Apply a phaser effect to the audio.
|
||||
/// \param input Tensor of shape <..., time>.
|
||||
/// \param output Tensor of shape <..., time>.
|
||||
/// \param sample_rate Sampling rate of the waveform.
|
||||
/// \param gain_in Desired input gain at the boost (or attenuation) in dB.
|
||||
/// \param gain_out Desired output gain at the boost (or attenuation) in dB.
|
||||
/// \param delay_ms Desired delay in milli seconds.
|
||||
/// \param decay Desired decay relative to gain-in.
|
||||
/// \param mod_speed Modulation speed in Hz.
|
||||
/// \param sinusoidal If true, use sinusoidal modulation. If false, use triangular modulation.
|
||||
/// \return Status code.
|
||||
template <typename T>
|
||||
Status Phaser(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t sample_rate, float gain_in,
|
||||
float gain_out, float delay_ms, float decay, float mod_speed, bool sinusoidal) {
|
||||
TensorShape input_shape = input->shape();
|
||||
// input convert to 2D (channels,time)
|
||||
auto channels = input->Size() / input_shape[-1];
|
||||
auto time = input_shape[-1];
|
||||
TensorShape to_shape({channels, time});
|
||||
RETURN_IF_NOT_OK(input->Reshape(to_shape));
|
||||
// input vector
|
||||
std::vector<std::vector<T>> input_vec(channels, std::vector<T>(time, 0));
|
||||
// output vector
|
||||
std::vector<std::vector<T>> out_vec(channels, std::vector<T>(time, 0));
|
||||
// input convert to vector
|
||||
auto input_itr = input->begin<T>();
|
||||
for (size_t i = 0; i < channels; i++) {
|
||||
for (size_t j = 0; j < time; j++) {
|
||||
input_vec[i][j] = *input_itr * gain_in;
|
||||
input_itr++;
|
||||
}
|
||||
}
|
||||
// compute
|
||||
// create delay buffer
|
||||
int delay_buf_nrow = channels;
|
||||
// calculate the length of the delay
|
||||
int delay_buf_len = static_cast<int>((delay_ms * 0.001 * sample_rate) + 0.5);
|
||||
std::vector<std::vector<T>> delay_buf(delay_buf_nrow, std::vector<T>(delay_buf_len, 0 * decay));
|
||||
// calculate the length after the momentum
|
||||
int mod_buf_len = static_cast<int>(sample_rate / mod_speed + 0.5);
|
||||
Modulation modulation = sinusoidal ? Modulation::kSinusoidal : Modulation::kTriangular;
|
||||
// create and compute mod buffer
|
||||
std::shared_ptr<Tensor> mod_buf_tensor;
|
||||
RETURN_IF_NOT_OK(GenerateWaveTable(&mod_buf_tensor, DataType(DataType::DE_INT32), modulation, mod_buf_len,
|
||||
static_cast<float>(1.0f), static_cast<float>(delay_buf_len),
|
||||
static_cast<float>(PI / 2)));
|
||||
// tensor mod_buf convert to vector
|
||||
std::vector<int> mod_buf;
|
||||
for (auto itr = mod_buf_tensor->begin<int>(); itr != mod_buf_tensor->end<int>(); itr++) {
|
||||
mod_buf.push_back(*itr);
|
||||
}
|
||||
dsize_t delay_pos = 0;
|
||||
dsize_t mod_pos = 0;
|
||||
// for every channal at the current time
|
||||
for (size_t i = 0; i < time; i++) {
|
||||
// calculate the delay data that should be added to each channal at this time
|
||||
int idx = static_cast<int>((delay_pos + mod_buf[mod_pos]) % delay_buf_len);
|
||||
mod_pos = (mod_pos + 1) % mod_buf_len;
|
||||
delay_pos = (delay_pos + 1) % delay_buf_len;
|
||||
// update the next delay data with the current result * decay
|
||||
for (size_t j = 0; j < channels; j++) {
|
||||
out_vec[j][i] = input_vec[j][i] + delay_buf[j][idx];
|
||||
delay_buf[j][delay_pos] = (input_vec[j][i] + delay_buf[j][idx]) * decay;
|
||||
}
|
||||
}
|
||||
std::vector<T> out_vec_one_d;
|
||||
for (size_t i = 0; i < channels; i++) {
|
||||
for (size_t j = 0; j < time; j++) {
|
||||
// gain_out on the output
|
||||
out_vec[i][j] *= gain_out;
|
||||
// clamp
|
||||
out_vec[i][j] = std::max<float>(-1.0f, std::min<float>(1.0f, out_vec[i][j]));
|
||||
// output vector is transformed from 2d to 1d
|
||||
out_vec_one_d.push_back(out_vec[i][j]);
|
||||
}
|
||||
}
|
||||
// move data to output tensor
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromVector(out_vec_one_d, input_shape, &out));
|
||||
*output = out;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Flanger about interpolation effect.
|
||||
/// \param input: Tensor of shape <batch, channel, time>.
|
||||
/// \param int_delay: A dimensional vector about integer delay, subscript representing delay.
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/phaser_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
PhaserOp::PhaserOp(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed,
|
||||
bool sinusoidal)
|
||||
: sample_rate_(sample_rate),
|
||||
gain_in_(gain_in),
|
||||
gain_out_(gain_out),
|
||||
delay_ms_(delay_ms),
|
||||
decay_(decay),
|
||||
mod_speed_(mod_speed),
|
||||
sinusoidal_(sinusoidal) {}
|
||||
|
||||
Status PhaserOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
TensorShape input_shape = input->shape();
|
||||
// check input tensor dimension, it should be greater than 0.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "Phaser: input tensor is not in shape of <..., time>.");
|
||||
// check input type, it should be DE_FLOAT
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type().IsNumeric(),
|
||||
"Phaser: input tensor type should be int, float or double, but got: " + input->type().ToString());
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
if (input->type() != DataType::DE_FLOAT64) {
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
|
||||
return Phaser<float>(input_tensor, output, sample_rate_, gain_in_, gain_out_, delay_ms_, decay_, mod_speed_,
|
||||
sinusoidal_);
|
||||
} else {
|
||||
input_tensor = input;
|
||||
return Phaser<double>(input_tensor, output, sample_rate_, gain_in_, gain_out_, delay_ms_, decay_, mod_speed_,
|
||||
sinusoidal_);
|
||||
}
|
||||
}
|
||||
|
||||
Status PhaserOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
if (!inputs[0].IsNumeric()) {
|
||||
RETURN_STATUS_UNEXPECTED("Phaser: input tensor type should be int, float or double, but got: " +
|
||||
inputs[0].ToString());
|
||||
} else if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASER_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class PhaserOp : public TensorOp {
|
||||
public:
|
||||
PhaserOp(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed,
|
||||
bool sinusoidal);
|
||||
|
||||
~PhaserOp() override = default;
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kPhaserOp; };
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
private:
|
||||
int32_t sample_rate_;
|
||||
float gain_in_;
|
||||
float gain_out_;
|
||||
float delay_ms_;
|
||||
float decay_;
|
||||
float mod_speed_;
|
||||
bool sinusoidal_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_PHASER_OP_H_
|
|
@ -585,6 +585,36 @@ class Overdrive final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Phaser TensorTransform.
|
||||
class Phaser final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] sample_rate Sampling rate of the waveform, e.g. 44100 (Hz).
|
||||
/// \param[in] gain_in Desired input gain at the boost (or attenuation) in dB.
|
||||
/// Allowed range of values is [0, 1] (Default=0.4).
|
||||
/// \param[in] gain_out Desired output gain at the boost (or attenuation) in dB.
|
||||
/// Allowed range of values is [0, 1e9] (Default=0.74).
|
||||
/// \param[in] delay_ms Desired delay in milli seconds. Allowed range of values is [0, 5] (Default=3.0).
|
||||
/// \param[in] decay Desired decay relative to gain-in. Allowed range of values is [0, 0.99] (Default=0.4).
|
||||
/// \param[in] mod_speed Modulation speed in Hz. Allowed range of values is [0.1, 2] (Default=0.5).
|
||||
/// \param[in] sinusoidal If true, use sinusoidal modulation (preferable for multiple instruments).
|
||||
/// If false, use triangular modulation (gives single instruments a sharper phasing effect) (Default=true).
|
||||
Phaser(int32_t sample_rate, float gain_in = 0.4f, float gain_out = 0.74f, float delay_ms = 3.0f, float decay = 0.4f,
|
||||
float mod_speed = 0.5f, bool sinusoidal = true);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Phaser() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply RIAA vinyl playback equalization.
|
||||
class RiaaBiquad final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -167,6 +167,7 @@ constexpr char kMagphaseOp[] = "MagphaseOp";
|
|||
constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
|
||||
constexpr char kMuLawEncodingOp[] = "MuLawEncodingOp";
|
||||
constexpr char kOverdriveOp[] = "OverdriveOp";
|
||||
constexpr char kPhaserOp[] = "PhaserOp";
|
||||
constexpr char kRiaaBiquadOp[] = "RiaaBiquadOp";
|
||||
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
|
||||
constexpr char kTimeStretchOp[] = "TimeStretchOp";
|
||||
|
|
|
@ -28,7 +28,7 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_
|
|||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_flanger, \
|
||||
check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, \
|
||||
check_overdrive, check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol
|
||||
check_overdrive, check_phaser, check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -771,6 +771,48 @@ class Overdrive(AudioTensorOperation):
|
|||
return cde.OverdriveOperation(self.gain, self.color)
|
||||
|
||||
|
||||
class Phaser(AudioTensorOperation):
|
||||
"""
|
||||
Apply a phasing effect to the audio.
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz).
|
||||
gain_in (float): Desired input gain at the boost (or attenuation) in dB.
|
||||
Allowed range of values is [0, 1] (default=0.4).
|
||||
gain_out (float): Desired output gain at the boost (or attenuation) in dB.
|
||||
Allowed range of values is [0, 1e9] (default=0.74).
|
||||
delay_ms (float): Desired delay in milli seconds. Allowed range of values is [0, 5] (default=3.0).
|
||||
decay (float): Desired decay relative to gain-in. Allowed range of values is [0, 0.99] (default=0.4).
|
||||
mod_speed (float): Modulation speed in Hz. Allowed range of values is [0.1, 2] (default=0.5).
|
||||
sinusoidal (bool): If True, use sinusoidal modulation (preferable for multiple instruments).
|
||||
If False, use triangular modulation (gives single instruments a sharper
|
||||
phasing effect) (default=True).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Phaser(44100)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_phaser
|
||||
def __init__(self, sample_rate, gain_in=0.4, gain_out=0.74,
|
||||
delay_ms=3.0, decay=0.4, mod_speed=0.5, sinusoidal=True):
|
||||
self.decay = decay
|
||||
self.delay_ms = delay_ms
|
||||
self.gain_in = gain_in
|
||||
self.gain_out = gain_out
|
||||
self.mod_speed = mod_speed
|
||||
self.sample_rate = sample_rate
|
||||
self.sinusoidal = sinusoidal
|
||||
|
||||
def parse(self):
|
||||
return cde.PhaserOperation(self.sample_rate, self.gain_in, self.gain_out,
|
||||
self.delay_ms, self.decay, self.mod_speed, self.sinusoidal)
|
||||
|
||||
|
||||
class RiaaBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Apply RIAA vinyl playback equalization. Similar to SoX implementation.
|
||||
|
|
|
@ -18,9 +18,9 @@ Validators for TensorOps.
|
|||
|
||||
from functools import wraps
|
||||
|
||||
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32_not_zero, \
|
||||
check_list_same_size, check_non_negative_float32, check_non_negative_int32, check_pos_float32, check_pos_int32, \
|
||||
check_value, parse_user_args, type_check
|
||||
from mindspore.dataset.core.validator_helpers import check_float32, check_float32_not_zero, check_int32,\
|
||||
check_int32_not_zero, check_list_same_size, check_non_negative_float32, check_non_negative_int32, \
|
||||
check_pos_float32, check_pos_int32, check_value, parse_user_args, type_check
|
||||
from .utils import FadeShape, GainType, Interpolation, Modulation, ScaleType
|
||||
|
||||
|
||||
|
@ -307,6 +307,31 @@ def check_overdrive(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_phaser(method):
|
||||
"""Wrapper method to check the parameters of Phaser."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[sample_rate, gain_in, gain_out, delay_ms, decay,
|
||||
mod_speed, sinusoidal], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(sample_rate, (int,), "sample_rate")
|
||||
check_int32(sample_rate, "sample_rate")
|
||||
type_check(gain_in, (float, int), "gain_in")
|
||||
check_value(gain_in, [0, 1], "gain_in")
|
||||
type_check(gain_out, (float, int), "gain_out")
|
||||
check_value(gain_out, [0, 1e9], "gain_out")
|
||||
type_check(delay_ms, (float, int), "delay_ms")
|
||||
check_value(delay_ms, [0, 5.0], "delay_ms")
|
||||
type_check(decay, (float, int), "decay")
|
||||
check_value(decay, [0, 0.99], "decay")
|
||||
type_check(mod_speed, (float, int), "mod_speed")
|
||||
check_value(mod_speed, [0.1, 2], "mod_speed")
|
||||
type_check(sinusoidal, (bool,), "sinusoidal")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_riaa_biquad(method):
|
||||
"""Wrapper method to check the parameters of RiaaBiquad."""
|
||||
|
||||
|
|
|
@ -1106,6 +1106,144 @@ TEST_F(MindDataTestPipeline, TestOverdriveWrongArg) {
|
|||
EXPECT_EQ(iter02, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: Phaser
|
||||
/// Description: test basic usage of Phaser
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestPipeline, TestPhaserBasic) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhaserBasic";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto PhaserOp = audio::Phaser(44100);
|
||||
|
||||
ds = ds->Map({PhaserOp});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Apply a phasing effect to the audio
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
/// Feature: Phaser
|
||||
/// Description: test invalid parameter of Phaser
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestPipeline, TestPhaserWrongArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestPhaserWrongArg.";
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
// Original waveform
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
std::shared_ptr<Dataset> ds01;
|
||||
std::shared_ptr<Dataset> ds02;
|
||||
std::shared_ptr<Dataset> ds03;
|
||||
std::shared_ptr<Dataset> ds04;
|
||||
std::shared_ptr<Dataset> ds05;
|
||||
std::shared_ptr<Dataset> ds06;
|
||||
std::shared_ptr<Dataset> ds07;
|
||||
std::shared_ptr<Dataset> ds08;
|
||||
std::shared_ptr<Dataset> ds09;
|
||||
std::shared_ptr<Dataset> ds10;
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Check gain_in out of range [0,1]
|
||||
MS_LOG(INFO) << "gain_in is less than 0.";
|
||||
auto phaser_op_01 = audio::Phaser(44100, -0.2);
|
||||
ds01 = ds->Map({phaser_op_01});
|
||||
EXPECT_NE(ds01, nullptr);
|
||||
std::shared_ptr<Iterator> iter01 = ds01->CreateIterator();
|
||||
EXPECT_EQ(iter01, nullptr);
|
||||
|
||||
MS_LOG(INFO) << "gain_in is greater than 1.";
|
||||
auto phaser_op_02 = audio::Phaser(44100, 1.2);
|
||||
ds02 = ds->Map({phaser_op_02});
|
||||
EXPECT_NE(ds02, nullptr);
|
||||
std::shared_ptr<Iterator> iter02 = ds02->CreateIterator();
|
||||
EXPECT_EQ(iter02, nullptr);
|
||||
|
||||
// Check gain_out out of range [0,1e9]
|
||||
MS_LOG(INFO) << "gain_out is less than 0.";
|
||||
auto phaser_op_03 = audio::Phaser(44100, 0.2, -1.3);
|
||||
ds03 = ds->Map({phaser_op_03});
|
||||
EXPECT_NE(ds03, nullptr);
|
||||
std::shared_ptr<Iterator> iter03 = ds03->CreateIterator();
|
||||
EXPECT_EQ(iter03, nullptr);
|
||||
|
||||
MS_LOG(INFO) << "gain_out is greater than 1e9.";
|
||||
auto phaser_op_04 = audio::Phaser(44100, 0.3, 1e10);
|
||||
ds04 = ds->Map({phaser_op_04});
|
||||
EXPECT_NE(ds04, nullptr);
|
||||
std::shared_ptr<Iterator> iter04 = ds04->CreateIterator();
|
||||
EXPECT_EQ(iter04, nullptr);
|
||||
|
||||
// Check delay_ms out of range [0,5.0]
|
||||
MS_LOG(INFO) << "delay_ms is less than 0.";
|
||||
auto phaser_op_05 = audio::Phaser(44100, 0.2, 2, -2.0);
|
||||
ds05 = ds->Map({phaser_op_05});
|
||||
EXPECT_NE(ds05, nullptr);
|
||||
std::shared_ptr<Iterator> iter05 = ds05->CreateIterator();
|
||||
EXPECT_EQ(iter05, nullptr);
|
||||
|
||||
MS_LOG(INFO) << "delay_ms is greater than 5.0.";
|
||||
auto phaser_op_06 = audio::Phaser(44100, 0.3, 2, 6.0);
|
||||
ds06 = ds->Map({phaser_op_06});
|
||||
EXPECT_NE(ds06, nullptr);
|
||||
std::shared_ptr<Iterator> iter06 = ds06->CreateIterator();
|
||||
EXPECT_EQ(iter06, nullptr);
|
||||
|
||||
// Check decay out of range [0,0.99]
|
||||
MS_LOG(INFO) << "decay is less than 0.";
|
||||
auto phaser_op_07 = audio::Phaser(44100, 0.2, 2, 2.0, -1.0);
|
||||
ds07 = ds->Map({phaser_op_07});
|
||||
EXPECT_NE(ds07, nullptr);
|
||||
std::shared_ptr<Iterator> iter07 = ds07->CreateIterator();
|
||||
EXPECT_EQ(iter07, nullptr);
|
||||
|
||||
MS_LOG(INFO) << "decay is greater than 0.99.";
|
||||
auto phaser_op_08 = audio::Phaser(44100, 0.3, 2, 2.0, 1.2);
|
||||
ds08 = ds->Map({phaser_op_08});
|
||||
EXPECT_NE(ds08, nullptr);
|
||||
std::shared_ptr<Iterator> iter08 = ds08->CreateIterator();
|
||||
EXPECT_EQ(iter08, nullptr);
|
||||
|
||||
// Check mod_speed out of range [0.1,10]
|
||||
MS_LOG(INFO) << "mod_speed is less than 0.1 .";
|
||||
auto phaser_op_09 = audio::Phaser(44100, 0.2, 2, 2.0, 0.5, 0.002);
|
||||
ds09 = ds->Map({phaser_op_09});
|
||||
EXPECT_NE(ds09, nullptr);
|
||||
std::shared_ptr<Iterator> iter09 = ds09->CreateIterator();
|
||||
EXPECT_EQ(iter09, nullptr);
|
||||
|
||||
MS_LOG(INFO) << "mod_speed is greater than 10.";
|
||||
auto phaser_op_10 = audio::Phaser(44100, 0.3, 2, 2.0, 0.5, 12.0);
|
||||
ds10 = ds->Map({phaser_op_10});
|
||||
EXPECT_NE(ds10, nullptr);
|
||||
std::shared_ptr<Iterator> iter10 = ds10->CreateIterator();
|
||||
EXPECT_EQ(iter10, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestLfilterPipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestLfilterPipeline.";
|
||||
// Original waveform
|
||||
|
|
|
@ -1055,6 +1055,77 @@ TEST_F(MindDataTestExecute, TestLFilterWithWrongArg) {
|
|||
EXPECT_FALSE(s01.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: Phaser
|
||||
/// Description: test basic usage of Phaser
|
||||
/// Expectation: get correct number of data
|
||||
TEST_F(MindDataTestExecute, TestPhaserBasicWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestPhaserBasicWithEager.";
|
||||
// Original waveform
|
||||
std::vector<float> labels = {
|
||||
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
|
||||
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
|
||||
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> phaser_op_01 = std::make_shared<audio::Phaser>(44100);
|
||||
mindspore::dataset::Execute Transform01({phaser_op_01});
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: Phaser
|
||||
/// Description: test invalid parameter of Phaser
|
||||
/// Expectation: throw exception correctly
|
||||
TEST_F(MindDataTestExecute, TestPhaserInputArgWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestPhaserInputArgWithEager";
|
||||
std::vector<double> labels = {
|
||||
0.271, 1.634, 9.246, 0.108,
|
||||
1.138, 1.156, 3.394, 1.55,
|
||||
3.614, 1.8402, 0.718, 4.599,
|
||||
5.64, 2.510620117187500000e-02, 1.38, 5.825,
|
||||
4.1906, 5.28, 1.052, 9.36};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({4, 5}), &input));
|
||||
|
||||
// check gain_in rang [0.0,1.0]
|
||||
auto input_01 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> phaser_op1 = std::make_shared<audio::Phaser>(44100, 2.0);
|
||||
mindspore::dataset::Execute Transform01({phaser_op1});
|
||||
Status s01 = Transform01(input_01, &input_01);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
|
||||
// check gain_out range [0.0,1e9]
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> phaser_op2 = std::make_shared<audio::Phaser>(44100, 0.2, -0.1);
|
||||
mindspore::dataset::Execute Transform02({phaser_op2});
|
||||
Status s02 = Transform02(input_02, &input_02);
|
||||
EXPECT_FALSE(s02.IsOk());
|
||||
|
||||
// check delay_ms range [0.0,5.0]
|
||||
auto input_03 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> phaser_op3 = std::make_shared<audio::Phaser>(44100, 0.2, 0.2, 6.0);
|
||||
mindspore::dataset::Execute Transform03({phaser_op3});
|
||||
Status s03 = Transform03(input_03, &input_03);
|
||||
EXPECT_FALSE(s03.IsOk());
|
||||
|
||||
// check decay range [0.0,0.99]
|
||||
auto input_04 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> phaser_op4 = std::make_shared<audio::Phaser>(44100, 0.2, 0.2, 4.0, 1.0);
|
||||
mindspore::dataset::Execute Transform04({phaser_op4});
|
||||
Status s04 = Transform04(input_04, &input_04);
|
||||
EXPECT_FALSE(s04.IsOk());
|
||||
|
||||
// check mod_speed range [0.1, 2]
|
||||
auto input_05 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> phaser_op5 = std::make_shared<audio::Phaser>(44100, 0.2, 0.2, 4.0, 0.8, 3.0);
|
||||
mindspore::dataset::Execute Transform05({phaser_op5});
|
||||
Status s05 = Transform05(input_05, &input_05);
|
||||
EXPECT_FALSE(s05.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestDCShiftEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDCShiftEager.";
|
||||
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".format(
|
||||
data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_phaser_eager():
|
||||
"""
|
||||
Feature: Phaser
|
||||
Description: test Phaser in eager mode
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
# Original waveform
|
||||
waveform = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.296, 0.71040004, 1.],
|
||||
[1., 1., 1.]], dtype=np.float32)
|
||||
sample_rate = 44100
|
||||
# Filtered waveform by phaser
|
||||
output = audio.Phaser(sample_rate=sample_rate)(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_phaser_pipeline():
|
||||
"""
|
||||
Feature: Phaser
|
||||
Description: test Phaser in pipline mode
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
# Original waveform
|
||||
waveform = np.array([[0.1, 1.2, 5.3], [0.4, 5.5, 1.6]], dtype=np.float32)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[0.0296, 0.36704, 1.],
|
||||
[0.11840001, 1., 1.]], dtype=np.float32)
|
||||
sample_rate = 44100
|
||||
dataset = ds.NumpySlicesDataset(waveform, ["waveform"], shuffle=False)
|
||||
phaser_op = audio.Phaser(sample_rate)
|
||||
# Filtered waveform by phaser
|
||||
dataset = dataset.map(
|
||||
input_columns=["waveform"], operations=phaser_op)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :],
|
||||
item['waveform'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
||||
def test_phaser_invalid_input():
|
||||
"""
|
||||
Feature: Phaser
|
||||
Description: test invalid parameter of Phaser
|
||||
Expectation: catch exceptions correctly
|
||||
"""
|
||||
def test_invalid_input(test_name, sample_rate, gain_in, gain_out, delay_ms, decay, mod_speed, sinusoidal, error,
|
||||
error_msg):
|
||||
logger.info("Test Phaser with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.Phaser(sample_rate, gain_in, gain_out, delay_ms, decay, mod_speed, sinusoidal)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid sample_rate parameter type as a float", 44100.5, 0.4, 0.74, 3.0, 0.4, 0.5, True,
|
||||
TypeError, "Argument sample_rate with value 44100.5 is not of type [<class 'int'>],"
|
||||
" but got <class 'float'>.")
|
||||
test_invalid_input("invalid gain_in parameter type as a str", 44100, "1", 0.74, 3.0, 0.4, 0.5, True,
|
||||
TypeError, "Argument gain_in with value 1 is not of type [<class 'float'>, <class 'int'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid gain_out parameter type as a str", 44100, 0.4, "10", 3.0, 0.4, 0.5, True, TypeError,
|
||||
"Argument gain_out with value 10 is not of type [<class 'float'>, <class 'int'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid delay_ms parameter type as a str", 44100, 0.4, 0.74, "2", 0.4, 0.5, True, TypeError,
|
||||
"Argument delay_ms with value 2 is not of type [<class 'float'>, <class 'int'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid decay parameter type as a str", 44100, 0.4, 0.74, 3.0, "0", 0.5, True, TypeError,
|
||||
"Argument decay with value 0 is not of type [<class 'float'>, <class 'int'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid mod_speed parameter type as a str", 44100, 0.4, 0.74, 3.0, 0.4, "3", True, TypeError,
|
||||
"Argument mod_speed with value 3 is not of type [<class 'float'>, <class 'int'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid sinusoidal parameter type as a str", 44100, 0.4, 0.74, 3.0, 0.4, 0.5, "True", TypeError,
|
||||
"Argument sinusoidal with value True is not of type [<class 'bool'>],"
|
||||
+ " but got <class 'str'>.")
|
||||
test_invalid_input("invalid sample_rate parameter value", 441324343243242342345300, 0.5, 0.74, 3.0, 0.4, 0.5, True,
|
||||
ValueError, "Input sample_rate is not within the required interval of "
|
||||
"[-2147483648, 2147483647].")
|
||||
test_invalid_input("invalid gain_in out of range [0, 1]", 44100, 2.0, 0.74, 3.0, 0.4, 0.5, True, ValueError,
|
||||
"Input gain_in is not within the required interval of [0, 1].")
|
||||
test_invalid_input("invalid gain_out out of range [0, 1e9]", 44100, 0.4, -2.0, 3.0, 0.4, 0.5, True, ValueError,
|
||||
"Input gain_out is not within the required interval of [0, 1000000000.0].")
|
||||
test_invalid_input("invalid delay_ms out of range [0, 5.0]", 44100, 0.4, 0.74, 6.0, 0.4, 0.5, True, ValueError,
|
||||
"Input delay_ms is not within the required interval of [0, 5.0].")
|
||||
test_invalid_input("invalid decay out of range [0, 0.99]", 44100, 0.4, 0.74, 3.0, 1.2, 0.5, True, ValueError,
|
||||
"Input decay is not within the required interval of [0, 0.99].")
|
||||
test_invalid_input("invalid mod_speed out of range [0.1, 2]", 44100, 0.4, 0.74, 3.0, 0.4, 0.003, True, ValueError,
|
||||
"Input mod_speed is not within the required interval of [0.1, 2].")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_phaser_eager()
|
||||
test_phaser_pipeline()
|
||||
test_phaser_invalid_input()
|
Loading…
Reference in New Issue