forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3CEGC] add op timestretch
This commit is contained in:
parent
3693625d6f
commit
04705e5b0d
|
@ -23,6 +23,7 @@
|
||||||
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
|
||||||
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
|
||||||
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
||||||
|
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
@ -132,6 +133,23 @@ BassBiquad::BassBiquad(int32_t sample_rate, float gain, float central_freq, floa
|
||||||
std::shared_ptr<TensorOperation> BassBiquad::Parse() {
|
std::shared_ptr<TensorOperation> BassBiquad::Parse() {
|
||||||
return std::make_shared<BassBiquadOperation>(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_);
|
return std::make_shared<BassBiquadOperation>(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TimeStretch Operation.
|
||||||
|
struct TimeStretch::Data {
|
||||||
|
explicit Data(float hop_length, int n_freq, float fixed_rate)
|
||||||
|
: hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {}
|
||||||
|
float hop_length_;
|
||||||
|
int n_freq_;
|
||||||
|
float fixed_rate_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TimeStretch::TimeStretch(float hop_length, int n_freq, float fixed_rate)
|
||||||
|
: data_(std::make_shared<Data>(hop_length, n_freq, fixed_rate)) {}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOperation> TimeStretch::Parse() {
|
||||||
|
return std::make_shared<TimeStretchOperation>(data_->hop_length_, data_->n_freq_, data_->fixed_rate_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -24,6 +24,7 @@
|
||||||
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
|
||||||
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
|
||||||
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
|
||||||
|
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||||
#include "minddata/dataset/include/dataset/transforms.h"
|
#include "minddata/dataset/include/dataset/transforms.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -113,5 +114,15 @@ PYBIND_REGISTER(
|
||||||
}));
|
}));
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
PYBIND_REGISTER(
|
||||||
|
TimeStretchOperation, 1, ([](const py::module *m) {
|
||||||
|
(void)py::class_<audio::TimeStretchOperation, TensorOperation, std::shared_ptr<audio::TimeStretchOperation>>(
|
||||||
|
*m, "TimeStretchOperation")
|
||||||
|
.def(py::init([](float hop_length, int n_freq, float fixed_rate) {
|
||||||
|
auto timestretch = std::make_shared<audio::TimeStretchOperation>(hop_length, n_freq, fixed_rate);
|
||||||
|
THROW_IF_ERROR(timestretch->ValidateParams());
|
||||||
|
return timestretch;
|
||||||
|
}));
|
||||||
|
}));
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -9,4 +9,5 @@ add_library(audio-ir-kernels OBJECT
|
||||||
bandpass_biquad_ir.cc
|
bandpass_biquad_ir.cc
|
||||||
bandreject_biquad_ir.cc
|
bandreject_biquad_ir.cc
|
||||||
bass_biquad_ir.cc
|
bass_biquad_ir.cc
|
||||||
|
time_stretch_ir.cc
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||||
|
#include "minddata/dataset/audio/kernels/time_stretch_op.h"
|
||||||
|
|
||||||
|
#include "minddata/dataset/audio/ir/validators.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace audio {
|
||||||
|
|
||||||
|
// TimeStretch
|
||||||
|
TimeStretchOperation::TimeStretchOperation(float hop_length, int n_freq, float fixed_rate)
|
||||||
|
: hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {}
|
||||||
|
|
||||||
|
TimeStretchOperation::~TimeStretchOperation() = default;
|
||||||
|
|
||||||
|
std::string TimeStretchOperation::Name() const { return kTimeStretchOperation; }
|
||||||
|
|
||||||
|
Status TimeStretchOperation::ValidateParams() {
|
||||||
|
// param check
|
||||||
|
RETURN_IF_NOT_OK(CheckFloatScalarPositive("TimeStretch", "hop_length", hop_length_));
|
||||||
|
RETURN_IF_NOT_OK(CheckIntScalarPositive("TimeStretch", "n_freq", n_freq_));
|
||||||
|
RETURN_IF_NOT_OK(CheckFloatScalarNotNan("TimeStretch", "fixed_rate", fixed_rate_));
|
||||||
|
RETURN_IF_NOT_OK(CheckFloatScalarPositive("TimeStretch", "fixed_rate", fixed_rate_));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> TimeStretchOperation::Build() {
|
||||||
|
std::shared_ptr<TimeStretchOp> tensor_op = std::make_shared<TimeStretchOp>(hop_length_, n_freq_, fixed_rate_);
|
||||||
|
return tensor_op;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TimeStretchOperation::to_json(nlohmann::json *out_json) {
|
||||||
|
nlohmann::json args;
|
||||||
|
args["hop_length"] = hop_length_;
|
||||||
|
args["n_freq"] = n_freq_;
|
||||||
|
args["fixed_rate"] = fixed_rate_;
|
||||||
|
*out_json = args;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace audio
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TIME_STRETCH_IR_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TIME_STRETCH_IR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "include/api/status.h"
|
||||||
|
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
namespace audio {
|
||||||
|
|
||||||
|
constexpr char kTimeStretchOperation[] = "TimeStretch";
|
||||||
|
|
||||||
|
class TimeStretchOperation : public TensorOperation {
|
||||||
|
public:
|
||||||
|
TimeStretchOperation(float hop_length, int n_freq, float fixed_rate);
|
||||||
|
|
||||||
|
~TimeStretchOperation();
|
||||||
|
|
||||||
|
std::shared_ptr<TensorOp> Build() override;
|
||||||
|
|
||||||
|
Status ValidateParams() override;
|
||||||
|
|
||||||
|
std::string Name() const override;
|
||||||
|
|
||||||
|
Status to_json(nlohmann::json *out_json) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
float hop_length_;
|
||||||
|
int n_freq_;
|
||||||
|
float fixed_rate_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace audio
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TIME_STRETCH_IR_H_
|
|
@ -23,11 +23,25 @@ Status CheckFloatScalarPositive(const std::string &op_name, const std::string &s
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CheckFloatScalarNotNan(const std::string &op_name, const std::string &scalar_name, float scalar) {
|
||||||
|
if (std::isnan(scalar)) {
|
||||||
|
std::string err_msg = op_name + ":" + scalar_name + " should be specified, got: Nan.";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status CheckFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
|
Status CheckFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
|
||||||
RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, false));
|
RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, false));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status CheckIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
|
||||||
|
RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, true));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status CheckStringScalarInList(const std::string &op_name, const std::string &scalar_name, const std::string &scalar,
|
Status CheckStringScalarInList(const std::string &op_name, const std::string &scalar_name, const std::string &scalar,
|
||||||
const std::vector<std::string> &str_vec) {
|
const std::vector<std::string> &str_vec) {
|
||||||
auto ret = std::find(str_vec.begin(), str_vec.end(), scalar);
|
auto ret = std::find(str_vec.begin(), str_vec.end(), scalar);
|
||||||
|
@ -78,5 +92,7 @@ Status CheckScalar(const std::string &op_name, const std::string &scalar_name, c
|
||||||
template Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const float scalar,
|
template Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const float scalar,
|
||||||
const std::vector<float> &range, bool left_open_interval, bool right_open_interval);
|
const std::vector<float> &range, bool left_open_interval, bool right_open_interval);
|
||||||
|
|
||||||
|
template Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const int32_t scalar,
|
||||||
|
const std::vector<int32_t> &range, bool left_open_interval, bool right_open_interval);
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -28,6 +28,15 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace dataset {
|
namespace dataset {
|
||||||
|
|
||||||
|
// Helper function to non-nan float scalar
|
||||||
|
Status CheckFloatScalarNotNan(const std::string &op_name, const std::string &scalar_name, float scalar);
|
||||||
|
|
||||||
|
// Helper function to positive float scalar
|
||||||
|
Status CheckFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
|
||||||
|
|
||||||
|
// Helper function to positive int scalar
|
||||||
|
Status CheckIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
// Helper function to check scalar is not equal to zero
|
// Helper function to check scalar is not equal to zero
|
||||||
Status CheckScalarNotZero(const std::string &op_name, const std::string &scalar_name, const T scalar) {
|
Status CheckScalarNotZero(const std::string &op_name, const std::string &scalar_name, const T scalar) {
|
||||||
|
|
|
@ -10,4 +10,5 @@ add_library(audio-kernels OBJECT
|
||||||
bandpass_biquad_op.cc
|
bandpass_biquad_op.cc
|
||||||
bandreject_biquad_op.cc
|
bandreject_biquad_op.cc
|
||||||
bass_biquad_op.cc
|
bass_biquad_op.cc
|
||||||
|
time_stretch_op.cc
|
||||||
)
|
)
|
||||||
|
|
|
@ -61,5 +61,325 @@ template Status AmplitudeToDB<float>(const std::shared_ptr<Tensor> &input, std::
|
||||||
float multiplier, float amin, float db_multiplier, float top_db);
|
float multiplier, float amin, float db_multiplier, float top_db);
|
||||||
template Status AmplitudeToDB<double>(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
template Status AmplitudeToDB<double>(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
|
||||||
double multiplier, double amin, double db_multiplier, double top_db);
|
double multiplier, double amin, double db_multiplier, double top_db);
|
||||||
|
|
||||||
|
/// \brief Generate linearly spaced vector
|
||||||
|
/// \param[in] start - Value of the startpoint.
|
||||||
|
/// \param[in] end - Value of the endpoint.
|
||||||
|
/// \param[in] n - N points in the output tensor.
|
||||||
|
/// \param[out] output - Tensor has n points with linearly space. The spacing between the points is (end-start)/(n-1).
|
||||||
|
/// \return Status return code
|
||||||
|
template <typename T>
|
||||||
|
Status Linespace(std::shared_ptr<Tensor> *output, T start, T end, int n) {
|
||||||
|
if (start > end) {
|
||||||
|
std::string err = "Linespace: input param end must be greater than start.";
|
||||||
|
RETURN_STATUS_UNEXPECTED(err);
|
||||||
|
}
|
||||||
|
n = std::isnan(n) ? 100 : n;
|
||||||
|
TensorShape out_shape({n});
|
||||||
|
std::vector<T> linear_vect(n);
|
||||||
|
T interval = (end - start) / (n - 1);
|
||||||
|
for (int i = 0; i < linear_vect.size(); ++i) {
|
||||||
|
linear_vect[i] = start + i * interval;
|
||||||
|
}
|
||||||
|
std::shared_ptr<Tensor> out_t;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(linear_vect, out_shape, &out_t));
|
||||||
|
linear_vect.clear();
|
||||||
|
linear_vect.shrink_to_fit();
|
||||||
|
*output = out_t;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Calculate complex tensor angle
|
||||||
|
/// \param[in] input - Input tensor, must be complex, <channel, freq, time, complex=2>.
|
||||||
|
/// \param[out] output - Complex tensor angle.
|
||||||
|
/// \return Status return code
|
||||||
|
template <typename T>
|
||||||
|
Status ComplexAngle(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
// check complex
|
||||||
|
if (!input->IsComplex()) {
|
||||||
|
std::string err_msg = "ComplexAngle: input tensor is not in shape of <..., 2>.";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
TensorShape input_shape = input->shape();
|
||||||
|
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2]});
|
||||||
|
std::vector<T> phase(input_shape[0] * input_shape[1] * input_shape[2]);
|
||||||
|
int ind = 0;
|
||||||
|
|
||||||
|
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
|
||||||
|
auto x = (*itr);
|
||||||
|
itr++;
|
||||||
|
auto y = (*itr);
|
||||||
|
phase[ind] = atan2(y, x);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> out_t;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(phase, out_shape, &out_t));
|
||||||
|
phase.clear();
|
||||||
|
phase.shrink_to_fit();
|
||||||
|
*output = out_t;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Calculate complex tensor abs
|
||||||
|
/// \param[in] input - Input tensor, must be complex, <channel, freq, time, complex=2>.
|
||||||
|
/// \param[out] output - Complex tensor abs.
|
||||||
|
/// \return Status return code
|
||||||
|
template <typename T>
|
||||||
|
Status ComplexAbs(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
// check complex
|
||||||
|
if (!input->IsComplex()) {
|
||||||
|
std::string err_msg = "ComplexAngle: input tensor is not in shape of <..., 2>.";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
TensorShape input_shape = input->shape();
|
||||||
|
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2]});
|
||||||
|
std::vector<T> abs(input_shape[0] * input_shape[1] * input_shape[2]);
|
||||||
|
int ind = 0;
|
||||||
|
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
|
||||||
|
T x = (*itr);
|
||||||
|
itr++;
|
||||||
|
T y = (*itr);
|
||||||
|
abs[ind] = sqrt(pow(y, 2) + pow(x, 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> out_t;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(abs, out_shape, &out_t));
|
||||||
|
*output = out_t;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Reconstruct complex tensor from norm and angle
|
||||||
|
/// \param[in] abs - The absolute value of the complex tensor.
|
||||||
|
/// \param[in] angle - The angle of the complex tensor.
|
||||||
|
/// \param[out] output - Complex tensor, <channel, freq, time, complex=2>.
|
||||||
|
/// \return Status return code
|
||||||
|
template <typename T>
|
||||||
|
Status Polar(const std::shared_ptr<Tensor> &abs, const std::shared_ptr<Tensor> &angle,
|
||||||
|
std::shared_ptr<Tensor> *output) {
|
||||||
|
// check shape
|
||||||
|
if (abs->shape() != angle->shape()) {
|
||||||
|
std::string err_msg = "Polar: input shape of abs and angle must be same.";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape input_shape = abs->shape();
|
||||||
|
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2], 2});
|
||||||
|
std::vector<T> complex_vec(input_shape[0] * input_shape[1] * input_shape[2] * 2);
|
||||||
|
int ind = 0;
|
||||||
|
auto itr_abs = abs->begin<T>();
|
||||||
|
auto itr_angle = angle->begin<T>();
|
||||||
|
|
||||||
|
for (; itr_abs != abs->end<T>(); itr_abs++, itr_angle++) {
|
||||||
|
complex_vec[ind++] = cos(*itr_angle) * (*itr_abs);
|
||||||
|
complex_vec[ind++] = sin(*itr_angle) * (*itr_abs);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> out_t;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(complex_vec, out_shape, &out_t));
|
||||||
|
*output = out_t;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Pad complex tensor
|
||||||
|
/// \param[in] input - The complex tensor.
|
||||||
|
/// \param[in] length - The length of padding.
|
||||||
|
/// \param[in] dim - The dim index for padding.
|
||||||
|
/// \param[out] output - Complex tensor, <channel, freq, time, complex=2>.
|
||||||
|
/// \return Status return code
|
||||||
|
template <typename T>
|
||||||
|
Status PadComplexTensor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int length, int dim) {
|
||||||
|
TensorShape input_shape = input->shape();
|
||||||
|
std::vector<int64_t> pad_shape_vec = {input_shape[0], input_shape[1], input_shape[2], input_shape[3]};
|
||||||
|
pad_shape_vec[dim] += length;
|
||||||
|
TensorShape input_shape_with_pad(pad_shape_vec);
|
||||||
|
std::vector<T> in_vect(input_shape_with_pad[0] * input_shape_with_pad[1] * input_shape_with_pad[2] *
|
||||||
|
input_shape_with_pad[3]);
|
||||||
|
auto itr_input = input->begin<T>();
|
||||||
|
int input_cnt = 0;
|
||||||
|
for (int ind = 0; ind < in_vect.size(); ind++) {
|
||||||
|
in_vect[ind] = (*itr_input);
|
||||||
|
input_cnt = (input_cnt + 1) % (input_shape[2] * input_shape[3]);
|
||||||
|
itr_input++;
|
||||||
|
// complex tensor last dim equals 2, fill zero count equals 2*width
|
||||||
|
if (input_cnt == 0 && ind != 0) {
|
||||||
|
for (int c = 0; c < length * 2; c++) {
|
||||||
|
in_vect[++ind] = 0.0f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::shared_ptr<Tensor> out_t;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(in_vect, input_shape_with_pad, &out_t));
|
||||||
|
*output = out_t;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Calculate phase
|
||||||
|
/// \param[in] angle_0 - The angle.
|
||||||
|
/// \param[in] angle_1 - The angle.
|
||||||
|
/// \param[in] phase_advance - The phase advance.
|
||||||
|
/// \param[in] phase_time0 - The phase at time 0.
|
||||||
|
/// \param[out] output - Phase tensor.
|
||||||
|
/// \return Status return code
|
||||||
|
template <typename T>
|
||||||
|
Status Phase(const std::shared_ptr<Tensor> &angle_0, const std::shared_ptr<Tensor> &angle_1,
|
||||||
|
const std::shared_ptr<Tensor> &phase_advance, const std::shared_ptr<Tensor> &phase_time0,
|
||||||
|
std::shared_ptr<Tensor> *output) {
|
||||||
|
TensorShape phase_shape = angle_0->shape();
|
||||||
|
std::vector<T> phase(phase_shape[0] * phase_shape[1] * phase_shape[2]);
|
||||||
|
auto itr_angle_0 = angle_0->begin<T>();
|
||||||
|
auto itr_angle_1 = angle_1->begin<T>();
|
||||||
|
auto itr_pa = phase_advance->begin<T>();
|
||||||
|
for (int ind = 0, input_cnt = 0; itr_angle_0 != angle_0->end<T>(); itr_angle_0++, itr_angle_1++, ind++) {
|
||||||
|
if (ind != 0 && ind % phase_shape[2] == 0) {
|
||||||
|
itr_pa++;
|
||||||
|
if (itr_pa == phase_advance->end<T>()) {
|
||||||
|
itr_pa = phase_advance->begin<T>();
|
||||||
|
}
|
||||||
|
input_cnt++;
|
||||||
|
}
|
||||||
|
phase[ind] = (*itr_angle_1) - (*itr_angle_0) - (*itr_pa);
|
||||||
|
phase[ind] = phase[ind] - 2 * PI * round(phase[ind] / (2 * PI)) + (*itr_pa);
|
||||||
|
}
|
||||||
|
|
||||||
|
// concat phase time 0
|
||||||
|
int ind = 0;
|
||||||
|
auto itr_p0 = phase_time0->begin<T>();
|
||||||
|
phase.insert(phase.begin(), (*itr_p0));
|
||||||
|
while (itr_p0 != phase_time0->end<T>()) {
|
||||||
|
itr_p0++;
|
||||||
|
ind += phase_shape[2];
|
||||||
|
phase[ind] = (*itr_p0);
|
||||||
|
}
|
||||||
|
phase.erase(phase.begin() + static_cast<int>(angle_0->Size()), phase.end());
|
||||||
|
|
||||||
|
// cal phase accum
|
||||||
|
for (ind = 0; ind < phase.size(); ind++) {
|
||||||
|
if (ind % phase_shape[2] != 0) {
|
||||||
|
phase[ind] = phase[ind] + phase[ind - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::shared_ptr<Tensor> phase_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(phase, phase_shape, &phase_tensor));
|
||||||
|
*output = phase_tensor;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// \brief Calculate magnitude
|
||||||
|
/// \param[in] alphas - The alphas.
|
||||||
|
/// \param[in] abs_0 - The norm.
|
||||||
|
/// \param[in] abs_1 - The norm.
|
||||||
|
/// \param[out] output - Magnitude tensor.
|
||||||
|
/// \return Status return code
|
||||||
|
template <typename T>
|
||||||
|
Status Mag(const std::shared_ptr<Tensor> &abs_0, const std::shared_ptr<Tensor> &abs_1, std::shared_ptr<Tensor> *output,
|
||||||
|
const std::vector<T> &alphas) {
|
||||||
|
TensorShape mag_shape = abs_0->shape();
|
||||||
|
std::vector<T> mag(mag_shape[0] * mag_shape[1] * mag_shape[2]);
|
||||||
|
auto itr_abs_0 = abs_0->begin<T>();
|
||||||
|
auto itr_abs_1 = abs_1->begin<T>();
|
||||||
|
for (int ind = 0; itr_abs_0 != abs_0->end<T>(); itr_abs_0++, itr_abs_1++, ind++) {
|
||||||
|
mag[ind] = alphas[ind % mag_shape[2]] * (*itr_abs_1) + (1 - alphas[ind % mag_shape[2]]) * (*itr_abs_0);
|
||||||
|
}
|
||||||
|
std::shared_ptr<Tensor> mag_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateFromVector(mag, mag_shape, &mag_tensor));
|
||||||
|
*output = mag_tensor;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, float rate,
|
||||||
|
std::shared_ptr<Tensor> phase_advance) {
|
||||||
|
// pack <..., freq, time, complex>
|
||||||
|
TensorShape input_shape = input->shape();
|
||||||
|
TensorShape toShape({input->Size() / (input_shape[-1] * input_shape[-2] * input_shape[-3]), input_shape[-3],
|
||||||
|
input_shape[-2], input_shape[-1]});
|
||||||
|
RETURN_IF_NOT_OK(input->Reshape(toShape));
|
||||||
|
if (rate == 1.0) {
|
||||||
|
*output = input;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
// calculate time step and alphas
|
||||||
|
int ind = 0;
|
||||||
|
std::vector<dsize_t> time_steps_0, time_steps_1;
|
||||||
|
std::vector<T> alphas;
|
||||||
|
for (T val = 0;; ind++) {
|
||||||
|
val = ind * rate;
|
||||||
|
if (val >= input_shape[-2]) break;
|
||||||
|
int val_int = static_cast<int>(val);
|
||||||
|
time_steps_0.push_back(val_int);
|
||||||
|
time_steps_1.push_back(val_int + 1);
|
||||||
|
alphas.push_back(fmod(val, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
// calculate phase on time 0
|
||||||
|
std::shared_ptr<Tensor> spec_time0, phase_time0;
|
||||||
|
RETURN_IF_NOT_OK(
|
||||||
|
input->Slice(&spec_time0, std::vector<SliceOption>({SliceOption(true), SliceOption(true),
|
||||||
|
SliceOption(std::vector<dsize_t>{0}), SliceOption(true)})));
|
||||||
|
RETURN_IF_NOT_OK(ComplexAngle<T>(spec_time0, &phase_time0));
|
||||||
|
|
||||||
|
// time pad: add zero to time dim
|
||||||
|
RETURN_IF_NOT_OK(PadComplexTensor<T>(input, &input, 2, 2));
|
||||||
|
|
||||||
|
// slice
|
||||||
|
std::shared_ptr<Tensor> spec_0;
|
||||||
|
RETURN_IF_NOT_OK(input->Slice(&spec_0, std::vector<SliceOption>({SliceOption(true), SliceOption(true),
|
||||||
|
SliceOption(time_steps_0), SliceOption(true)})));
|
||||||
|
std::shared_ptr<Tensor> spec_1;
|
||||||
|
RETURN_IF_NOT_OK(input->Slice(&spec_1, std::vector<SliceOption>({SliceOption(true), SliceOption(true),
|
||||||
|
SliceOption(time_steps_1), SliceOption(true)})));
|
||||||
|
|
||||||
|
// new slices angle and abs <channel, freq, time>
|
||||||
|
std::shared_ptr<Tensor> angle_0, angle_1, abs_0, abs_1;
|
||||||
|
RETURN_IF_NOT_OK(ComplexAngle<T>(spec_0, &angle_0));
|
||||||
|
RETURN_IF_NOT_OK(ComplexAbs<T>(spec_0, &abs_0));
|
||||||
|
RETURN_IF_NOT_OK(ComplexAngle<T>(spec_1, &angle_1));
|
||||||
|
RETURN_IF_NOT_OK(ComplexAbs<T>(spec_1, &abs_1));
|
||||||
|
|
||||||
|
// cal phase, there exists precision loss between mindspore and pytorch
|
||||||
|
std::shared_ptr<Tensor> phase_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Phase<T>(angle_0, angle_1, phase_advance, phase_time0, &phase_tensor));
|
||||||
|
|
||||||
|
// calculate magnitude
|
||||||
|
std::shared_ptr<Tensor> mag_tensor;
|
||||||
|
RETURN_IF_NOT_OK(Mag<T>(abs_0, abs_1, &mag_tensor, alphas));
|
||||||
|
|
||||||
|
// reconstruct complex from norm and angle
|
||||||
|
std::shared_ptr<Tensor> complex_spec_stretch;
|
||||||
|
RETURN_IF_NOT_OK(Polar<T>(mag_tensor, phase_tensor, &complex_spec_stretch));
|
||||||
|
|
||||||
|
// unpack
|
||||||
|
auto output_shape_vec = input_shape.AsVector();
|
||||||
|
output_shape_vec.pop_back();
|
||||||
|
output_shape_vec.pop_back();
|
||||||
|
output_shape_vec.push_back(complex_spec_stretch->shape()[-2]);
|
||||||
|
output_shape_vec.push_back(input_shape[-1]);
|
||||||
|
RETURN_IF_NOT_OK(complex_spec_stretch->Reshape(TensorShape(output_shape_vec)));
|
||||||
|
*output = complex_spec_stretch;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, float rate, float hop_length,
|
||||||
|
float n_freq) {
|
||||||
|
std::shared_ptr<Tensor> phase_advance;
|
||||||
|
switch (input->type().value()) {
|
||||||
|
case DataType::DE_FLOAT32:
|
||||||
|
RETURN_IF_NOT_OK(Linespace<float>(&phase_advance, 0, PI * hop_length, n_freq));
|
||||||
|
RETURN_IF_NOT_OK(TimeStretch<float>(input, output, rate, phase_advance));
|
||||||
|
break;
|
||||||
|
case DataType::DE_FLOAT64:
|
||||||
|
RETURN_IF_NOT_OK(Linespace<double>(&phase_advance, 0, PI * hop_length, n_freq));
|
||||||
|
RETURN_IF_NOT_OK(TimeStretch<double>(input, output, rate, phase_advance));
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
RETURN_STATUS_UNEXPECTED(
|
||||||
|
"TimeStretch: unsupported type, currently supported types include "
|
||||||
|
"[float, double].");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -176,6 +176,16 @@ Status LFilter(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *ou
|
||||||
delete m_py;
|
delete m_py;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// \brief Stretch STFT in time at a given rate, without changing the pitch.
|
||||||
|
/// \param[in] input - Tensor of shape <...,freq,time>.
|
||||||
|
/// \param[in] rate - Stretch factor.
|
||||||
|
/// \param[in] phase_advance - Expected phase advance in each bin.
|
||||||
|
/// \param[out] output - Tensor after stretch in time domain.
|
||||||
|
/// \return Status return code
|
||||||
|
Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, float rate, float hop_length,
|
||||||
|
float n_freq);
|
||||||
|
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||||
|
|
|
@ -0,0 +1,80 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "minddata/dataset/audio/kernels/time_stretch_op.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||||
|
#include "minddata/dataset/kernels/data/data_utils.h"
|
||||||
|
#include "minddata/dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
const float TimeStretchOp::kHopLength = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
const int TimeStretchOp::kNFreq = 201;
|
||||||
|
const float TimeStretchOp::kFixedRate = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
|
||||||
|
Status TimeStretchOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
// check and init
|
||||||
|
IO_CHECK(input, output);
|
||||||
|
|
||||||
|
// check shape
|
||||||
|
if (input->shape().Rank() < 3) {
|
||||||
|
std::string err_msg = "TimeStretch: input tensor shape is not <..., freq, num_frame, complex=2>.";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// check complex
|
||||||
|
if (!input->IsComplex()) {
|
||||||
|
std::string err_msg = "TimeStretch: input tensor is not in shape of <..., 2>.";
|
||||||
|
MS_LOG(ERROR) << err_msg;
|
||||||
|
RETURN_STATUS_SYNTAX_ERROR(err_msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<Tensor> input_tensor;
|
||||||
|
// std::shared_ptr<Tensor> phase_advance;
|
||||||
|
float hop_length = std::isnan(hop_length_) ? (n_freq_ - 1) : hop_length_;
|
||||||
|
// typecast
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING,
|
||||||
|
"TimeStretch: input tensor type should be [int, float, double], but got string.");
|
||||||
|
if (input->type() != DataType::DE_FLOAT64) {
|
||||||
|
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
|
||||||
|
} else {
|
||||||
|
input_tensor = input;
|
||||||
|
}
|
||||||
|
|
||||||
|
return TimeStretch(input_tensor, output, fixed_rate_, hop_length, n_freq_);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status TimeStretchOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||||
|
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||||
|
outputs.clear();
|
||||||
|
for (auto s : inputs) {
|
||||||
|
std::vector<dsize_t> s_vec = s.AsVector();
|
||||||
|
s_vec.pop_back();
|
||||||
|
s_vec.pop_back();
|
||||||
|
s_vec.push_back(std::ceil(s[-2] / fixed_rate_));
|
||||||
|
// push back complex
|
||||||
|
s_vec.push_back(2);
|
||||||
|
outputs.emplace_back(TensorShape(s_vec));
|
||||||
|
}
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "TimeStretch: invalid input shape.");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TIME_STRETCH_OP_H_
|
||||||
|
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TIME_STRETCH_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/kernels/tensor_op.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
class TimeStretchOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
/// Default value
|
||||||
|
static const float kHopLength;
|
||||||
|
static const int kNFreq;
|
||||||
|
static const float kFixedRate;
|
||||||
|
|
||||||
|
explicit TimeStretchOp(float hop_length = kHopLength, int n_freq = kNFreq, float fixed_rate = kFixedRate)
|
||||||
|
: hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {}
|
||||||
|
|
||||||
|
~TimeStretchOp() override = default;
|
||||||
|
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
std::string Name() const override { return kTimeStretchOp; }
|
||||||
|
|
||||||
|
/// \param[in] inputs
|
||||||
|
/// \param[out] outputs
|
||||||
|
/// \return Status code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
float hop_length_;
|
||||||
|
int n_freq_;
|
||||||
|
float fixed_rate_;
|
||||||
|
};
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TIME_STRETCH_OP_H_
|
|
@ -306,6 +306,13 @@ class Tensor {
|
||||||
/// \return bool - true if tensor is not empty
|
/// \return bool - true if tensor is not empty
|
||||||
bool HasData() const { return data_ != nullptr; }
|
bool HasData() const { return data_ != nullptr; }
|
||||||
|
|
||||||
|
/// Check if tensor is complex
|
||||||
|
/// \return bool - true if tensor is complex
|
||||||
|
bool IsComplex() const {
|
||||||
|
// check the last dim all be 2
|
||||||
|
return shape_[-1] == 2;
|
||||||
|
}
|
||||||
|
|
||||||
/// Reshape the tensor. The given shape should have the same number of elements in the Tensor
|
/// Reshape the tensor. The given shape should have the same number of elements in the Tensor
|
||||||
/// \param shape
|
/// \param shape
|
||||||
virtual Status Reshape(const TensorShape &shape);
|
virtual Status Reshape(const TensorShape &shape);
|
||||||
|
|
|
@ -186,6 +186,30 @@ class BassBiquad final : public TensorTransform {
|
||||||
struct Data;
|
struct Data;
|
||||||
std::shared_ptr<Data> data_;
|
std::shared_ptr<Data> data_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// \brief TimeStretch TensorTransform
|
||||||
|
/// \notes Stretch STFT in time at a given rate, without changing the pitch.
|
||||||
|
class TimeStretch final : public TensorTransform {
|
||||||
|
public:
|
||||||
|
/// \brief Constructor.
|
||||||
|
/// \param[in] hop_length Length of hop between STFT windows. Default: None.
|
||||||
|
/// \param[in] n_freq Number of filter banks form STFT. Default: 201.
|
||||||
|
/// \param[in] fixed_rate Rate to speed up or slow down the input in time. Default: None.
|
||||||
|
explicit TimeStretch(float hop_length = std::numeric_limits<float>::quiet_NaN(), int n_freq = 201,
|
||||||
|
float fixed_rate = std::numeric_limits<float>::quiet_NaN());
|
||||||
|
|
||||||
|
/// \brief Destructor.
|
||||||
|
~TimeStretch() = default;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||||
|
/// \return Shared pointer to TensorOperation object.
|
||||||
|
std::shared_ptr<TensorOperation> Parse() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct Data;
|
||||||
|
std::shared_ptr<Data> data_;
|
||||||
|
};
|
||||||
} // namespace audio
|
} // namespace audio
|
||||||
} // namespace dataset
|
} // namespace dataset
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -145,6 +145,7 @@ constexpr char kBandBiquadOp[] = "BandBiquadOp";
|
||||||
constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp";
|
constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp";
|
||||||
constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp";
|
constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp";
|
||||||
constexpr char kBassBiquadOp[] = "BassBiquadOp";
|
constexpr char kBassBiquadOp[] = "BassBiquadOp";
|
||||||
|
constexpr char kTimeStretchOp[] = "TimeStretchOp";
|
||||||
|
|
||||||
// data
|
// data
|
||||||
constexpr char kConcatenateOp[] = "ConcatenateOp";
|
constexpr char kConcatenateOp[] = "ConcatenateOp";
|
||||||
|
|
|
@ -22,7 +22,7 @@ import numpy as np
|
||||||
from ..transforms.c_transforms import TensorOperation
|
from ..transforms.c_transforms import TensorOperation
|
||||||
from .utils import ScaleType
|
from .utils import ScaleType
|
||||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||||
check_bandreject_biquad, check_bass_biquad
|
check_bandreject_biquad, check_bass_biquad, check_time_stretch
|
||||||
|
|
||||||
|
|
||||||
class AudioTensorOperation(TensorOperation):
|
class AudioTensorOperation(TensorOperation):
|
||||||
|
@ -249,3 +249,36 @@ class BassBiquad(AudioTensorOperation):
|
||||||
|
|
||||||
def parse(self):
|
def parse(self):
|
||||||
return cde.BassBiquadOperation(self.sample_rate, self.gain, self.central_freq, self.Q)
|
return cde.BassBiquadOperation(self.sample_rate, self.gain, self.central_freq, self.Q)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeStretch(AudioTensorOperation):
|
||||||
|
"""
|
||||||
|
Stretch STFT in time at a given rate, without changing the pitch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hop_length (int, optional): Length of hop between STFT windows (default=None).
|
||||||
|
n_freq (int, optional): Number of filter banks form STFT (default=201).
|
||||||
|
fixed_rate (float, optional): Rate to speed up or slow down the input in time (default=None).
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> freq = 44100
|
||||||
|
>>> num_frame = 30
|
||||||
|
>>> def gen():
|
||||||
|
... np.random.seed(0)
|
||||||
|
... data = np.random.random([freq, num_frame])
|
||||||
|
... yield (np.array(data, dtype=np.float32), )
|
||||||
|
>>> data1 = ds.GeneratorDataset(source=gen, column_names=["multi_dimensional_data"])
|
||||||
|
>>> transforms = [py_audio.TimeStretch()]
|
||||||
|
>>> data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
|
||||||
|
"""
|
||||||
|
@check_time_stretch
|
||||||
|
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
|
||||||
|
self.n_freq = n_freq
|
||||||
|
self.fixed_rate = fixed_rate
|
||||||
|
|
||||||
|
n_fft = (n_freq - 1) * 2
|
||||||
|
self.hop_length = hop_length if hop_length is not None else n_fft // 2
|
||||||
|
self.fixed_rate = fixed_rate if fixed_rate is not None else np.nan
|
||||||
|
|
||||||
|
def parse(self):
|
||||||
|
return cde.TimeStretchOperation(self.hop_length, self.n_freq, self.fixed_rate)
|
||||||
|
|
|
@ -16,8 +16,8 @@
|
||||||
Validators for TensorOps.
|
Validators for TensorOps.
|
||||||
"""
|
"""
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, \
|
from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, check_value, \
|
||||||
check_value_normalize_std, check_value_ratio, FLOAT_MAX_INTEGER, parse_user_args, type_check
|
check_value_normalize_std, check_value_ratio, FLOAT_MAX_INTEGER, INT64_MAX, parse_user_args, type_check
|
||||||
from .utils import ScaleType
|
from .utils import ScaleType
|
||||||
|
|
||||||
|
|
||||||
|
@ -164,3 +164,25 @@ def check_bass_biquad(method):
|
||||||
return method(self, *args, **kwargs)
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
return new_method
|
return new_method
|
||||||
|
|
||||||
|
|
||||||
|
def check_time_stretch(method):
|
||||||
|
"""Wrapper method to check the parameters of time_stretch."""
|
||||||
|
@wraps(method)
|
||||||
|
def new_method(self, *args, **kwargs):
|
||||||
|
[hop_length, n_freq, fixed_rate], _ = parse_user_args(method, *args, **kwargs)
|
||||||
|
# type check
|
||||||
|
type_check(hop_length, (int, type(None)), "hop_length")
|
||||||
|
type_check(n_freq, (int,), "n_freq")
|
||||||
|
type_check(fixed_rate, (int, float, type(None)), "fixed_rate")
|
||||||
|
|
||||||
|
# value check
|
||||||
|
if hop_length is not None:
|
||||||
|
check_value(hop_length, (1, INT64_MAX), "hop_length")
|
||||||
|
check_value(n_freq, (1, INT64_MAX), "n_freq")
|
||||||
|
if fixed_rate is not None:
|
||||||
|
check_value_ratio(fixed_rate, (0, FLOAT_MAX_INTEGER), "fixed_rate")
|
||||||
|
|
||||||
|
return method(self, *args, **kwargs)
|
||||||
|
|
||||||
|
return new_method
|
||||||
|
|
|
@ -13,6 +13,7 @@ SET(DE_UT_SRCS
|
||||||
buddy_test.cc
|
buddy_test.cc
|
||||||
build_vocab_test.cc
|
build_vocab_test.cc
|
||||||
c_api_audio_a_to_q_test.cc
|
c_api_audio_a_to_q_test.cc
|
||||||
|
c_api_audio_r_to_z_test.cc
|
||||||
c_api_cache_test.cc
|
c_api_cache_test.cc
|
||||||
c_api_dataset_album_test.cc
|
c_api_dataset_album_test.cc
|
||||||
c_api_audio_a_to_q_test.cc
|
c_api_audio_a_to_q_test.cc
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "common/common.h"
|
||||||
|
#include "minddata/dataset/core/tensor.h"
|
||||||
|
#include "minddata/dataset/include/dataset/datasets.h"
|
||||||
|
#include "minddata/dataset/include/dataset/audio.h"
|
||||||
|
|
||||||
|
using namespace mindspore::dataset;
|
||||||
|
using mindspore::LogStream;
|
||||||
|
using mindspore::ExceptionType::NoExceptionType;
|
||||||
|
using mindspore::MsLogLevel::INFO;
|
||||||
|
|
||||||
|
class MindDataTestPipeline : public UT::Common {
|
||||||
|
public:
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestTimeStretchPipeline) {
|
||||||
|
MS_LOG(INFO) << "Doing test TimeStretchOp with custom param value. Pipeline.";
|
||||||
|
// op param
|
||||||
|
int freq = 1025;
|
||||||
|
int hop_length = 512;
|
||||||
|
float rate = 1.2;
|
||||||
|
// Original waveform
|
||||||
|
std::shared_ptr<SchemaObj> schema = Schema();
|
||||||
|
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, freq, 400, 2}));
|
||||||
|
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
ds = ds->SetNumWorkers(4);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
auto TimeStretchOp = audio::TimeStretch(hop_length, freq, rate);
|
||||||
|
|
||||||
|
ds = ds->Map({TimeStretchOp});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// apply timestretch
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
|
||||||
|
std::vector<int64_t> expected = {2, freq, int(std::ceil(400 / rate)), 2};
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
while (row.size() != 0) {
|
||||||
|
auto col = row["inputData"];
|
||||||
|
ASSERT_EQ(col.Shape(), expected);
|
||||||
|
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||||
|
ASSERT_OK(iter->GetNextRow(&row));
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
EXPECT_EQ(i, 50);
|
||||||
|
|
||||||
|
iter->Stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestPipeline, TestTimeStretchPipelineWrongArgs) {
|
||||||
|
MS_LOG(INFO) << "Doing test TimeStretchOp with wrong param value. Pipeline.";
|
||||||
|
// op param
|
||||||
|
int freq = 1025;
|
||||||
|
int hop_length = 512;
|
||||||
|
float rate = -2;
|
||||||
|
// Original waveform
|
||||||
|
std::shared_ptr<SchemaObj> schema = Schema();
|
||||||
|
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, freq, 400, 2}));
|
||||||
|
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
ds = ds->SetNumWorkers(4);
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
auto TimeStretchOp = audio::TimeStretch(hop_length, freq, rate);
|
||||||
|
|
||||||
|
ds = ds->Map({TimeStretchOp});
|
||||||
|
EXPECT_NE(ds, nullptr);
|
||||||
|
|
||||||
|
// apply timestretch
|
||||||
|
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||||
|
// Expect failure
|
||||||
|
EXPECT_EQ(iter, nullptr);
|
||||||
|
}
|
|
@ -19,6 +19,7 @@
|
||||||
#include "minddata/dataset/include/dataset/audio.h"
|
#include "minddata/dataset/include/dataset/audio.h"
|
||||||
#include "minddata/dataset/include/dataset/execute.h"
|
#include "minddata/dataset/include/dataset/execute.h"
|
||||||
#include "minddata/dataset/include/dataset/transforms.h"
|
#include "minddata/dataset/include/dataset/transforms.h"
|
||||||
|
#include "minddata/dataset/include/dataset/audio.h"
|
||||||
#include "minddata/dataset/include/dataset/vision.h"
|
#include "minddata/dataset/include/dataset/vision.h"
|
||||||
#include "minddata/dataset/include/dataset/audio.h"
|
#include "minddata/dataset/include/dataset/audio.h"
|
||||||
#include "minddata/dataset/include/dataset/text.h"
|
#include "minddata/dataset/include/dataset/text.h"
|
||||||
|
@ -196,6 +197,65 @@ TEST_F(MindDataTestExecute, TestCrop) {
|
||||||
EXPECT_EQ(image.Shape()[1], 15);
|
EXPECT_EQ(image.Shape()[1], 15);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestExecute, TestTimeStretchEager) {
|
||||||
|
MS_LOG(INFO) << "Doing test TimeStretchOp with custom param value. Eager.";
|
||||||
|
std::shared_ptr<Tensor> input_tensor_;
|
||||||
|
// op param
|
||||||
|
int freq = 4;
|
||||||
|
int hop_length = 20;
|
||||||
|
float rate = 1.3;
|
||||||
|
int frame_num = 10;
|
||||||
|
// create tensor
|
||||||
|
TensorShape s = TensorShape({2, freq, frame_num, 2});
|
||||||
|
// init input vec
|
||||||
|
std::vector<float> input_vec(2 * freq * frame_num * 2);
|
||||||
|
for (int ind = 0; ind < input_vec.size(); ind++) {
|
||||||
|
input_vec[ind] = std::rand() % (1000) / (1000.0f);
|
||||||
|
}
|
||||||
|
ASSERT_OK(Tensor::CreateFromVector(input_vec, s, &input_tensor_));
|
||||||
|
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||||
|
std::shared_ptr<TensorTransform> time_stretch_op = std::make_shared<audio::TimeStretch>(hop_length, freq, rate);
|
||||||
|
|
||||||
|
// apply timestretch
|
||||||
|
mindspore::dataset::Execute Transform({time_stretch_op});
|
||||||
|
Status status = Transform(input_ms, &input_ms);
|
||||||
|
EXPECT_TRUE(status.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestExecute, TestTimeStretchParamCheck1) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestTimeStretch-TestTimeStretchParamCheck with invalid parameters.";
|
||||||
|
// Create an input
|
||||||
|
std::shared_ptr<Tensor> input_tensor_;
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
TensorShape s = TensorShape({1, 4, 3, 2});
|
||||||
|
ASSERT_OK(Tensor::CreateFromVector(
|
||||||
|
std::vector<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f,
|
||||||
|
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}),
|
||||||
|
s, &input_tensor_));
|
||||||
|
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||||
|
std::shared_ptr<TensorTransform> timestretch = std::make_shared<audio::TimeStretch>(4, 512, -2);
|
||||||
|
mindspore::dataset::Execute Transform({timestretch});
|
||||||
|
Status status = Transform(input_ms, &input_ms);
|
||||||
|
EXPECT_FALSE(status.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MindDataTestExecute, TestTimeStretchParamCheck2) {
|
||||||
|
MS_LOG(INFO) << "Doing MindDataTestTimeStretch-TestTimeStretchParamCheck with invalid parameters.";
|
||||||
|
// Create an input
|
||||||
|
std::shared_ptr<Tensor> input_tensor_;
|
||||||
|
std::shared_ptr<Tensor> output_tensor;
|
||||||
|
TensorShape s = TensorShape({1, 4, 3, 2});
|
||||||
|
ASSERT_OK(Tensor::CreateFromVector(
|
||||||
|
std::vector<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f,
|
||||||
|
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}),
|
||||||
|
s, &input_tensor_));
|
||||||
|
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
|
||||||
|
std::shared_ptr<TensorTransform> timestretch = std::make_shared<audio::TimeStretch>(4, -512, 2);
|
||||||
|
mindspore::dataset::Execute Transform({timestretch});
|
||||||
|
Status status = Transform(input_ms, &input_ms);
|
||||||
|
EXPECT_FALSE(status.IsOk());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(MindDataTestExecute, TestTransformInput1) {
|
TEST_F(MindDataTestExecute, TestTransformInput1) {
|
||||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput1.";
|
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput1.";
|
||||||
// Test Execute with transform op input using API constructors, with std::shared_ptr<TensorTransform pointers,
|
// Test Execute with transform op input using API constructors, with std::shared_ptr<TensorTransform pointers,
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Testing TimeStretch op in DE
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.audio.transforms as c_audio
|
||||||
|
from mindspore import log as logger
|
||||||
|
|
||||||
|
CHANNEL_NUM = 2
|
||||||
|
FREQ = 1025
|
||||||
|
FRAME_NUM = 300
|
||||||
|
COMPLEX = 2
|
||||||
|
|
||||||
|
|
||||||
|
def gen(shape):
|
||||||
|
np.random.seed(0)
|
||||||
|
data = np.random.random(shape)
|
||||||
|
yield(np.array(data, dtype=np.float32),)
|
||||||
|
|
||||||
|
|
||||||
|
def _count_unequal_element(data_expected, data_me, rtol, atol):
|
||||||
|
assert data_expected.shape == data_me.shape
|
||||||
|
total_count = len(data_expected.flatten())
|
||||||
|
error = np.abs(data_expected - data_me)
|
||||||
|
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||||
|
loss_count = np.count_nonzero(greater)
|
||||||
|
assert (loss_count / total_count) < rtol, \
|
||||||
|
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||||
|
format(data_expected[greater], data_me[greater], error[greater])
|
||||||
|
|
||||||
|
|
||||||
|
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
|
||||||
|
if np.any(np.isnan(data_expected)):
|
||||||
|
assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
|
||||||
|
elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
|
||||||
|
_count_unequal_element(data_expected, data_me, rtol, atol)
|
||||||
|
else:
|
||||||
|
assert True
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_stretch_pipeline():
|
||||||
|
"""
|
||||||
|
Test TimeStretch op. Pipeline.
|
||||||
|
"""
|
||||||
|
logger.info("test TimeStretch op")
|
||||||
|
generator = gen([CHANNEL_NUM, FREQ, FRAME_NUM, COMPLEX])
|
||||||
|
data1 = ds.GeneratorDataset(source=generator, column_names=[
|
||||||
|
"multi_dimensional_data"])
|
||||||
|
|
||||||
|
transforms = [
|
||||||
|
c_audio.TimeStretch(512, FREQ, 1.3)
|
||||||
|
]
|
||||||
|
data1 = data1.map(operations=transforms, input_columns=[
|
||||||
|
"multi_dimensional_data"])
|
||||||
|
|
||||||
|
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
out_put = item["multi_dimensional_data"]
|
||||||
|
assert out_put.shape == (CHANNEL_NUM, FREQ, np.ceil(FRAME_NUM/1.3), COMPLEX)
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_stretch_pipeline_invalid_param():
|
||||||
|
"""
|
||||||
|
Test TimeStretch op. Set invalid param. Pipeline.
|
||||||
|
"""
|
||||||
|
logger.info("test TimeStretch op with invalid values")
|
||||||
|
generator = gen([CHANNEL_NUM, FREQ, FRAME_NUM, COMPLEX])
|
||||||
|
data1 = ds.GeneratorDataset(source=generator, column_names=[
|
||||||
|
"multi_dimensional_data"])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=r"Input fixed_rate is not within the required interval of \(0, 16777216\]."):
|
||||||
|
transforms = [
|
||||||
|
c_audio.TimeStretch(512, FREQ, -1.3)
|
||||||
|
]
|
||||||
|
data1 = data1.map(operations=transforms, input_columns=[
|
||||||
|
"multi_dimensional_data"])
|
||||||
|
|
||||||
|
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||||
|
out_put = item["multi_dimensional_data"]
|
||||||
|
assert out_put.shape == (CHANNEL_NUM, FREQ, np.ceil(FRAME_NUM/1.3), COMPLEX)
|
||||||
|
|
||||||
|
|
||||||
|
def test_time_stretch_eager():
|
||||||
|
"""
|
||||||
|
Test TimeStretch op. Set param. Eager.
|
||||||
|
"""
|
||||||
|
logger.info("test TimeStretch op with customized parameter values")
|
||||||
|
spectrogram = next(gen([CHANNEL_NUM, FREQ, FRAME_NUM, COMPLEX]))[0]
|
||||||
|
out_put = c_audio.TimeStretch(512, FREQ, 1.3)(spectrogram)
|
||||||
|
assert out_put.shape == (CHANNEL_NUM, FREQ, np.ceil(FRAME_NUM/1.3), COMPLEX)
|
||||||
|
|
||||||
|
|
||||||
|
def test_percision_time_stretch_eager():
|
||||||
|
"""
|
||||||
|
Test TimeStretch op. Compare precision. Eager.
|
||||||
|
"""
|
||||||
|
logger.info("test TimeStretch op with default values")
|
||||||
|
spectrogram = np.array([[[[1.0402449369430542, 0.3807601034641266],
|
||||||
|
[-1.120057225227356, -0.12819576263427734],
|
||||||
|
[1.4303032159805298, -0.08839055150747299]],
|
||||||
|
[[1.4198592901229858, 0.6900091767311096],
|
||||||
|
[-1.8593409061431885, 0.16363371908664703],
|
||||||
|
[-2.3349387645721436, -1.4366451501846313]]],
|
||||||
|
[[[-0.7083967328071594, 0.9325454831123352],
|
||||||
|
[-1.9133838415145874, 0.011225821450352669],
|
||||||
|
[1.477278232574463, -1.0551637411117554]],
|
||||||
|
[[-0.6668586134910583, -0.23143270611763],
|
||||||
|
[-2.4390718936920166, 0.17638640105724335],
|
||||||
|
[-0.4795735776424408, 0.1345423310995102]]]]).astype(np.float64)
|
||||||
|
out_expect = np.array([[[[1.0402449369430542, 0.3807601034641266],
|
||||||
|
[-1.302264928817749, -0.1490504890680313]],
|
||||||
|
[[1.4198592901229858, 0.6900091767311096],
|
||||||
|
[-2.382312774658203, 0.2096325159072876]]],
|
||||||
|
[[[-0.7083966732025146, 0.9325454831123352],
|
||||||
|
[-1.8545820713043213, 0.010880803689360619]],
|
||||||
|
[[-0.6668586134910583, -0.23143276572227478],
|
||||||
|
[-1.2737033367156982, 0.09211209416389465]]]]).astype(np.float64)
|
||||||
|
out_ms = c_audio.TimeStretch(64, 2, 1.6)(spectrogram)
|
||||||
|
|
||||||
|
allclose_nparray(out_ms, out_expect, 0.001, 0.001)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_time_stretch_pipeline()
|
||||||
|
test_time_stretch_pipeline_invalid_param()
|
||||||
|
test_time_stretch_eager()
|
||||||
|
test_percision_time_stretch_eager()
|
Loading…
Reference in New Issue