[feat][assistant][I3CEGC] add op timestretch

This commit is contained in:
chenx2ovo 2021-08-07 22:53:41 +08:00
parent 3693625d6f
commit 04705e5b0d
21 changed files with 1024 additions and 3 deletions

View File

@ -23,6 +23,7 @@
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
namespace mindspore {
namespace dataset {
@ -132,6 +133,23 @@ BassBiquad::BassBiquad(int32_t sample_rate, float gain, float central_freq, floa
std::shared_ptr<TensorOperation> BassBiquad::Parse() {
return std::make_shared<BassBiquadOperation>(data_->sample_rate_, data_->gain_, data_->central_freq_, data_->Q_);
}
// TimeStretch Operation.
struct TimeStretch::Data {
explicit Data(float hop_length, int n_freq, float fixed_rate)
: hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {}
float hop_length_;
int n_freq_;
float fixed_rate_;
};
TimeStretch::TimeStretch(float hop_length, int n_freq, float fixed_rate)
: data_(std::make_shared<Data>(hop_length, n_freq, fixed_rate)) {}
std::shared_ptr<TensorOperation> TimeStretch::Parse() {
return std::make_shared<TimeStretchOperation>(data_->hop_length_, data_->n_freq_, data_->fixed_rate_);
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -24,6 +24,7 @@
#include "minddata/dataset/audio/ir/kernels/bandpass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bandreject_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/bass_biquad_ir.h"
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
#include "minddata/dataset/include/dataset/transforms.h"
namespace mindspore {
@ -113,5 +114,15 @@ PYBIND_REGISTER(
}));
}));
PYBIND_REGISTER(
TimeStretchOperation, 1, ([](const py::module *m) {
(void)py::class_<audio::TimeStretchOperation, TensorOperation, std::shared_ptr<audio::TimeStretchOperation>>(
*m, "TimeStretchOperation")
.def(py::init([](float hop_length, int n_freq, float fixed_rate) {
auto timestretch = std::make_shared<audio::TimeStretchOperation>(hop_length, n_freq, fixed_rate);
THROW_IF_ERROR(timestretch->ValidateParams());
return timestretch;
}));
}));
} // namespace dataset
} // namespace mindspore

View File

@ -9,4 +9,5 @@ add_library(audio-ir-kernels OBJECT
bandpass_biquad_ir.cc
bandreject_biquad_ir.cc
bass_biquad_ir.cc
time_stretch_ir.cc
)

View File

@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
#include "minddata/dataset/audio/kernels/time_stretch_op.h"
#include "minddata/dataset/audio/ir/validators.h"
namespace mindspore {
namespace dataset {
namespace audio {
// TimeStretch
TimeStretchOperation::TimeStretchOperation(float hop_length, int n_freq, float fixed_rate)
: hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {}
TimeStretchOperation::~TimeStretchOperation() = default;
std::string TimeStretchOperation::Name() const { return kTimeStretchOperation; }
Status TimeStretchOperation::ValidateParams() {
// param check
RETURN_IF_NOT_OK(CheckFloatScalarPositive("TimeStretch", "hop_length", hop_length_));
RETURN_IF_NOT_OK(CheckIntScalarPositive("TimeStretch", "n_freq", n_freq_));
RETURN_IF_NOT_OK(CheckFloatScalarNotNan("TimeStretch", "fixed_rate", fixed_rate_));
RETURN_IF_NOT_OK(CheckFloatScalarPositive("TimeStretch", "fixed_rate", fixed_rate_));
return Status::OK();
}
std::shared_ptr<TensorOp> TimeStretchOperation::Build() {
std::shared_ptr<TimeStretchOp> tensor_op = std::make_shared<TimeStretchOp>(hop_length_, n_freq_, fixed_rate_);
return tensor_op;
}
Status TimeStretchOperation::to_json(nlohmann::json *out_json) {
nlohmann::json args;
args["hop_length"] = hop_length_;
args["n_freq"] = n_freq_;
args["fixed_rate"] = fixed_rate_;
*out_json = args;
return Status::OK();
}
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,55 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TIME_STRETCH_IR_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TIME_STRETCH_IR_H_
#include <memory>
#include <string>
#include "include/api/status.h"
#include "minddata/dataset/kernels/ir/tensor_operation.h"
namespace mindspore {
namespace dataset {
namespace audio {
constexpr char kTimeStretchOperation[] = "TimeStretch";
class TimeStretchOperation : public TensorOperation {
public:
TimeStretchOperation(float hop_length, int n_freq, float fixed_rate);
~TimeStretchOperation();
std::shared_ptr<TensorOp> Build() override;
Status ValidateParams() override;
std::string Name() const override;
Status to_json(nlohmann::json *out_json) override;
private:
float hop_length_;
int n_freq_;
float fixed_rate_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_TIME_STRETCH_IR_H_

View File

@ -23,11 +23,25 @@ Status CheckFloatScalarPositive(const std::string &op_name, const std::string &s
return Status::OK();
}
Status CheckFloatScalarNotNan(const std::string &op_name, const std::string &scalar_name, float scalar) {
if (std::isnan(scalar)) {
std::string err_msg = op_name + ":" + scalar_name + " should be specified, got: Nan.";
MS_LOG(ERROR) << err_msg;
return Status(StatusCode::kMDSyntaxError, __LINE__, __FILE__, err_msg);
}
return Status::OK();
}
Status CheckFloatScalarNonNegative(const std::string &op_name, const std::string &scalar_name, float scalar) {
RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, false));
return Status::OK();
}
Status CheckIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar) {
RETURN_IF_NOT_OK(CheckScalar(op_name, scalar_name, scalar, {0}, true));
return Status::OK();
}
Status CheckStringScalarInList(const std::string &op_name, const std::string &scalar_name, const std::string &scalar,
const std::vector<std::string> &str_vec) {
auto ret = std::find(str_vec.begin(), str_vec.end(), scalar);
@ -78,5 +92,7 @@ Status CheckScalar(const std::string &op_name, const std::string &scalar_name, c
template Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const float scalar,
const std::vector<float> &range, bool left_open_interval, bool right_open_interval);
template Status CheckScalar(const std::string &op_name, const std::string &scalar_name, const int32_t scalar,
const std::vector<int32_t> &range, bool left_open_interval, bool right_open_interval);
} // namespace dataset
} // namespace mindspore

View File

@ -28,6 +28,15 @@
namespace mindspore {
namespace dataset {
// Helper function to non-nan float scalar
Status CheckFloatScalarNotNan(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to positive float scalar
Status CheckFloatScalarPositive(const std::string &op_name, const std::string &scalar_name, float scalar);
// Helper function to positive int scalar
Status CheckIntScalarPositive(const std::string &op_name, const std::string &scalar_name, int32_t scalar);
template <typename T>
// Helper function to check scalar is not equal to zero
Status CheckScalarNotZero(const std::string &op_name, const std::string &scalar_name, const T scalar) {

View File

@ -10,4 +10,5 @@ add_library(audio-kernels OBJECT
bandpass_biquad_op.cc
bandreject_biquad_op.cc
bass_biquad_op.cc
time_stretch_op.cc
)

View File

@ -61,5 +61,325 @@ template Status AmplitudeToDB<float>(const std::shared_ptr<Tensor> &input, std::
float multiplier, float amin, float db_multiplier, float top_db);
template Status AmplitudeToDB<double>(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output,
double multiplier, double amin, double db_multiplier, double top_db);
/// \brief Generate linearly spaced vector
/// \param[in] start - Value of the startpoint.
/// \param[in] end - Value of the endpoint.
/// \param[in] n - N points in the output tensor.
/// \param[out] output - Tensor has n points with linearly space. The spacing between the points is (end-start)/(n-1).
/// \return Status return code
template <typename T>
Status Linespace(std::shared_ptr<Tensor> *output, T start, T end, int n) {
if (start > end) {
std::string err = "Linespace: input param end must be greater than start.";
RETURN_STATUS_UNEXPECTED(err);
}
n = std::isnan(n) ? 100 : n;
TensorShape out_shape({n});
std::vector<T> linear_vect(n);
T interval = (end - start) / (n - 1);
for (int i = 0; i < linear_vect.size(); ++i) {
linear_vect[i] = start + i * interval;
}
std::shared_ptr<Tensor> out_t;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(linear_vect, out_shape, &out_t));
linear_vect.clear();
linear_vect.shrink_to_fit();
*output = out_t;
return Status::OK();
}
/// \brief Calculate complex tensor angle
/// \param[in] input - Input tensor, must be complex, <channel, freq, time, complex=2>.
/// \param[out] output - Complex tensor angle.
/// \return Status return code
template <typename T>
Status ComplexAngle(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
// check complex
if (!input->IsComplex()) {
std::string err_msg = "ComplexAngle: input tensor is not in shape of <..., 2>.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
TensorShape input_shape = input->shape();
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2]});
std::vector<T> phase(input_shape[0] * input_shape[1] * input_shape[2]);
int ind = 0;
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
auto x = (*itr);
itr++;
auto y = (*itr);
phase[ind] = atan2(y, x);
}
std::shared_ptr<Tensor> out_t;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(phase, out_shape, &out_t));
phase.clear();
phase.shrink_to_fit();
*output = out_t;
return Status::OK();
}
/// \brief Calculate complex tensor abs
/// \param[in] input - Input tensor, must be complex, <channel, freq, time, complex=2>.
/// \param[out] output - Complex tensor abs.
/// \return Status return code
template <typename T>
Status ComplexAbs(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
// check complex
if (!input->IsComplex()) {
std::string err_msg = "ComplexAngle: input tensor is not in shape of <..., 2>.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
TensorShape input_shape = input->shape();
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2]});
std::vector<T> abs(input_shape[0] * input_shape[1] * input_shape[2]);
int ind = 0;
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
T x = (*itr);
itr++;
T y = (*itr);
abs[ind] = sqrt(pow(y, 2) + pow(x, 2));
}
std::shared_ptr<Tensor> out_t;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(abs, out_shape, &out_t));
*output = out_t;
return Status::OK();
}
/// \brief Reconstruct complex tensor from norm and angle
/// \param[in] abs - The absolute value of the complex tensor.
/// \param[in] angle - The angle of the complex tensor.
/// \param[out] output - Complex tensor, <channel, freq, time, complex=2>.
/// \return Status return code
template <typename T>
Status Polar(const std::shared_ptr<Tensor> &abs, const std::shared_ptr<Tensor> &angle,
std::shared_ptr<Tensor> *output) {
// check shape
if (abs->shape() != angle->shape()) {
std::string err_msg = "Polar: input shape of abs and angle must be same.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
TensorShape input_shape = abs->shape();
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2], 2});
std::vector<T> complex_vec(input_shape[0] * input_shape[1] * input_shape[2] * 2);
int ind = 0;
auto itr_abs = abs->begin<T>();
auto itr_angle = angle->begin<T>();
for (; itr_abs != abs->end<T>(); itr_abs++, itr_angle++) {
complex_vec[ind++] = cos(*itr_angle) * (*itr_abs);
complex_vec[ind++] = sin(*itr_angle) * (*itr_abs);
}
std::shared_ptr<Tensor> out_t;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(complex_vec, out_shape, &out_t));
*output = out_t;
return Status::OK();
}
/// \brief Pad complex tensor
/// \param[in] input - The complex tensor.
/// \param[in] length - The length of padding.
/// \param[in] dim - The dim index for padding.
/// \param[out] output - Complex tensor, <channel, freq, time, complex=2>.
/// \return Status return code
template <typename T>
Status PadComplexTensor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int length, int dim) {
TensorShape input_shape = input->shape();
std::vector<int64_t> pad_shape_vec = {input_shape[0], input_shape[1], input_shape[2], input_shape[3]};
pad_shape_vec[dim] += length;
TensorShape input_shape_with_pad(pad_shape_vec);
std::vector<T> in_vect(input_shape_with_pad[0] * input_shape_with_pad[1] * input_shape_with_pad[2] *
input_shape_with_pad[3]);
auto itr_input = input->begin<T>();
int input_cnt = 0;
for (int ind = 0; ind < in_vect.size(); ind++) {
in_vect[ind] = (*itr_input);
input_cnt = (input_cnt + 1) % (input_shape[2] * input_shape[3]);
itr_input++;
// complex tensor last dim equals 2, fill zero count equals 2*width
if (input_cnt == 0 && ind != 0) {
for (int c = 0; c < length * 2; c++) {
in_vect[++ind] = 0.0f;
}
}
}
std::shared_ptr<Tensor> out_t;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(in_vect, input_shape_with_pad, &out_t));
*output = out_t;
return Status::OK();
}
/// \brief Calculate phase
/// \param[in] angle_0 - The angle.
/// \param[in] angle_1 - The angle.
/// \param[in] phase_advance - The phase advance.
/// \param[in] phase_time0 - The phase at time 0.
/// \param[out] output - Phase tensor.
/// \return Status return code
template <typename T>
Status Phase(const std::shared_ptr<Tensor> &angle_0, const std::shared_ptr<Tensor> &angle_1,
const std::shared_ptr<Tensor> &phase_advance, const std::shared_ptr<Tensor> &phase_time0,
std::shared_ptr<Tensor> *output) {
TensorShape phase_shape = angle_0->shape();
std::vector<T> phase(phase_shape[0] * phase_shape[1] * phase_shape[2]);
auto itr_angle_0 = angle_0->begin<T>();
auto itr_angle_1 = angle_1->begin<T>();
auto itr_pa = phase_advance->begin<T>();
for (int ind = 0, input_cnt = 0; itr_angle_0 != angle_0->end<T>(); itr_angle_0++, itr_angle_1++, ind++) {
if (ind != 0 && ind % phase_shape[2] == 0) {
itr_pa++;
if (itr_pa == phase_advance->end<T>()) {
itr_pa = phase_advance->begin<T>();
}
input_cnt++;
}
phase[ind] = (*itr_angle_1) - (*itr_angle_0) - (*itr_pa);
phase[ind] = phase[ind] - 2 * PI * round(phase[ind] / (2 * PI)) + (*itr_pa);
}
// concat phase time 0
int ind = 0;
auto itr_p0 = phase_time0->begin<T>();
phase.insert(phase.begin(), (*itr_p0));
while (itr_p0 != phase_time0->end<T>()) {
itr_p0++;
ind += phase_shape[2];
phase[ind] = (*itr_p0);
}
phase.erase(phase.begin() + static_cast<int>(angle_0->Size()), phase.end());
// cal phase accum
for (ind = 0; ind < phase.size(); ind++) {
if (ind % phase_shape[2] != 0) {
phase[ind] = phase[ind] + phase[ind - 1];
}
}
std::shared_ptr<Tensor> phase_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(phase, phase_shape, &phase_tensor));
*output = phase_tensor;
return Status::OK();
}
/// \brief Calculate magnitude
/// \param[in] alphas - The alphas.
/// \param[in] abs_0 - The norm.
/// \param[in] abs_1 - The norm.
/// \param[out] output - Magnitude tensor.
/// \return Status return code
template <typename T>
Status Mag(const std::shared_ptr<Tensor> &abs_0, const std::shared_ptr<Tensor> &abs_1, std::shared_ptr<Tensor> *output,
const std::vector<T> &alphas) {
TensorShape mag_shape = abs_0->shape();
std::vector<T> mag(mag_shape[0] * mag_shape[1] * mag_shape[2]);
auto itr_abs_0 = abs_0->begin<T>();
auto itr_abs_1 = abs_1->begin<T>();
for (int ind = 0; itr_abs_0 != abs_0->end<T>(); itr_abs_0++, itr_abs_1++, ind++) {
mag[ind] = alphas[ind % mag_shape[2]] * (*itr_abs_1) + (1 - alphas[ind % mag_shape[2]]) * (*itr_abs_0);
}
std::shared_ptr<Tensor> mag_tensor;
RETURN_IF_NOT_OK(Tensor::CreateFromVector(mag, mag_shape, &mag_tensor));
*output = mag_tensor;
return Status::OK();
}
template <typename T>
Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, float rate,
std::shared_ptr<Tensor> phase_advance) {
// pack <..., freq, time, complex>
TensorShape input_shape = input->shape();
TensorShape toShape({input->Size() / (input_shape[-1] * input_shape[-2] * input_shape[-3]), input_shape[-3],
input_shape[-2], input_shape[-1]});
RETURN_IF_NOT_OK(input->Reshape(toShape));
if (rate == 1.0) {
*output = input;
return Status::OK();
}
// calculate time step and alphas
int ind = 0;
std::vector<dsize_t> time_steps_0, time_steps_1;
std::vector<T> alphas;
for (T val = 0;; ind++) {
val = ind * rate;
if (val >= input_shape[-2]) break;
int val_int = static_cast<int>(val);
time_steps_0.push_back(val_int);
time_steps_1.push_back(val_int + 1);
alphas.push_back(fmod(val, 1));
}
// calculate phase on time 0
std::shared_ptr<Tensor> spec_time0, phase_time0;
RETURN_IF_NOT_OK(
input->Slice(&spec_time0, std::vector<SliceOption>({SliceOption(true), SliceOption(true),
SliceOption(std::vector<dsize_t>{0}), SliceOption(true)})));
RETURN_IF_NOT_OK(ComplexAngle<T>(spec_time0, &phase_time0));
// time pad: add zero to time dim
RETURN_IF_NOT_OK(PadComplexTensor<T>(input, &input, 2, 2));
// slice
std::shared_ptr<Tensor> spec_0;
RETURN_IF_NOT_OK(input->Slice(&spec_0, std::vector<SliceOption>({SliceOption(true), SliceOption(true),
SliceOption(time_steps_0), SliceOption(true)})));
std::shared_ptr<Tensor> spec_1;
RETURN_IF_NOT_OK(input->Slice(&spec_1, std::vector<SliceOption>({SliceOption(true), SliceOption(true),
SliceOption(time_steps_1), SliceOption(true)})));
// new slices angle and abs <channel, freq, time>
std::shared_ptr<Tensor> angle_0, angle_1, abs_0, abs_1;
RETURN_IF_NOT_OK(ComplexAngle<T>(spec_0, &angle_0));
RETURN_IF_NOT_OK(ComplexAbs<T>(spec_0, &abs_0));
RETURN_IF_NOT_OK(ComplexAngle<T>(spec_1, &angle_1));
RETURN_IF_NOT_OK(ComplexAbs<T>(spec_1, &abs_1));
// cal phase, there exists precision loss between mindspore and pytorch
std::shared_ptr<Tensor> phase_tensor;
RETURN_IF_NOT_OK(Phase<T>(angle_0, angle_1, phase_advance, phase_time0, &phase_tensor));
// calculate magnitude
std::shared_ptr<Tensor> mag_tensor;
RETURN_IF_NOT_OK(Mag<T>(abs_0, abs_1, &mag_tensor, alphas));
// reconstruct complex from norm and angle
std::shared_ptr<Tensor> complex_spec_stretch;
RETURN_IF_NOT_OK(Polar<T>(mag_tensor, phase_tensor, &complex_spec_stretch));
// unpack
auto output_shape_vec = input_shape.AsVector();
output_shape_vec.pop_back();
output_shape_vec.pop_back();
output_shape_vec.push_back(complex_spec_stretch->shape()[-2]);
output_shape_vec.push_back(input_shape[-1]);
RETURN_IF_NOT_OK(complex_spec_stretch->Reshape(TensorShape(output_shape_vec)));
*output = complex_spec_stretch;
return Status::OK();
}
Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, float rate, float hop_length,
float n_freq) {
std::shared_ptr<Tensor> phase_advance;
switch (input->type().value()) {
case DataType::DE_FLOAT32:
RETURN_IF_NOT_OK(Linespace<float>(&phase_advance, 0, PI * hop_length, n_freq));
RETURN_IF_NOT_OK(TimeStretch<float>(input, output, rate, phase_advance));
break;
case DataType::DE_FLOAT64:
RETURN_IF_NOT_OK(Linespace<double>(&phase_advance, 0, PI * hop_length, n_freq));
RETURN_IF_NOT_OK(TimeStretch<double>(input, output, rate, phase_advance));
break;
default:
RETURN_STATUS_UNEXPECTED(
"TimeStretch: unsupported type, currently supported types include "
"[float, double].");
}
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -176,6 +176,16 @@ Status LFilter(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *ou
delete m_py;
return Status::OK();
}
/// \brief Stretch STFT in time at a given rate, without changing the pitch.
/// \param[in] input - Tensor of shape <...,freq,time>.
/// \param[in] rate - Stretch factor.
/// \param[in] phase_advance - Expected phase advance in each bin.
/// \param[out] output - Tensor after stretch in time domain.
/// \return Status return code
Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, float rate, float hop_length,
float n_freq);
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_

View File

@ -0,0 +1,80 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "minddata/dataset/audio/kernels/time_stretch_op.h"
#include <limits>
#include "minddata/dataset/audio/kernels/audio_utils.h"
#include "minddata/dataset/kernels/data/data_utils.h"
#include "minddata/dataset/util/status.h"
namespace mindspore {
namespace dataset {
const float TimeStretchOp::kHopLength = std::numeric_limits<float>::quiet_NaN();
const int TimeStretchOp::kNFreq = 201;
const float TimeStretchOp::kFixedRate = std::numeric_limits<float>::quiet_NaN();
Status TimeStretchOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
// check and init
IO_CHECK(input, output);
// check shape
if (input->shape().Rank() < 3) {
std::string err_msg = "TimeStretch: input tensor shape is not <..., freq, num_frame, complex=2>.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
// check complex
if (!input->IsComplex()) {
std::string err_msg = "TimeStretch: input tensor is not in shape of <..., 2>.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
std::shared_ptr<Tensor> input_tensor;
// std::shared_ptr<Tensor> phase_advance;
float hop_length = std::isnan(hop_length_) ? (n_freq_ - 1) : hop_length_;
// typecast
CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING,
"TimeStretch: input tensor type should be [int, float, double], but got string.");
if (input->type() != DataType::DE_FLOAT64) {
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
} else {
input_tensor = input;
}
return TimeStretch(input_tensor, output, fixed_rate_, hop_length, n_freq_);
}
Status TimeStretchOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
outputs.clear();
for (auto s : inputs) {
std::vector<dsize_t> s_vec = s.AsVector();
s_vec.pop_back();
s_vec.pop_back();
s_vec.push_back(std::ceil(s[-2] / fixed_rate_));
// push back complex
s_vec.push_back(2);
outputs.emplace_back(TensorShape(s_vec));
}
CHECK_FAIL_RETURN_UNEXPECTED(!outputs.empty(), "TimeStretch: invalid input shape.");
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -0,0 +1,58 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TIME_STRETCH_OP_H_
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TIME_STRETCH_OP_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/kernels/tensor_op.h"
namespace mindspore {
namespace dataset {
class TimeStretchOp : public TensorOp {
public:
/// Default value
static const float kHopLength;
static const int kNFreq;
static const float kFixedRate;
explicit TimeStretchOp(float hop_length = kHopLength, int n_freq = kNFreq, float fixed_rate = kFixedRate)
: hop_length_(hop_length), n_freq_(n_freq), fixed_rate_(fixed_rate) {}
~TimeStretchOp() override = default;
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
std::string Name() const override { return kTimeStretchOp; }
/// \param[in] inputs
/// \param[out] outputs
/// \return Status code
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
private:
float hop_length_;
int n_freq_;
float fixed_rate_;
};
} // namespace dataset
} // namespace mindspore
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_TIME_STRETCH_OP_H_

View File

@ -306,6 +306,13 @@ class Tensor {
/// \return bool - true if tensor is not empty
bool HasData() const { return data_ != nullptr; }
/// Check if tensor is complex
/// \return bool - true if tensor is complex
bool IsComplex() const {
// check the last dim all be 2
return shape_[-1] == 2;
}
/// Reshape the tensor. The given shape should have the same number of elements in the Tensor
/// \param shape
virtual Status Reshape(const TensorShape &shape);

View File

@ -186,6 +186,30 @@ class BassBiquad final : public TensorTransform {
struct Data;
std::shared_ptr<Data> data_;
};
/// \brief TimeStretch TensorTransform
/// \notes Stretch STFT in time at a given rate, without changing the pitch.
class TimeStretch final : public TensorTransform {
public:
/// \brief Constructor.
/// \param[in] hop_length Length of hop between STFT windows. Default: None.
/// \param[in] n_freq Number of filter banks form STFT. Default: 201.
/// \param[in] fixed_rate Rate to speed up or slow down the input in time. Default: None.
explicit TimeStretch(float hop_length = std::numeric_limits<float>::quiet_NaN(), int n_freq = 201,
float fixed_rate = std::numeric_limits<float>::quiet_NaN());
/// \brief Destructor.
~TimeStretch() = default;
protected:
/// \brief Function to convert TensorTransform object into a TensorOperation object.
/// \return Shared pointer to TensorOperation object.
std::shared_ptr<TensorOperation> Parse() override;
private:
struct Data;
std::shared_ptr<Data> data_;
};
} // namespace audio
} // namespace dataset
} // namespace mindspore

View File

@ -145,6 +145,7 @@ constexpr char kBandBiquadOp[] = "BandBiquadOp";
constexpr char kBandpassBiquadOp[] = "BandpassBiquadOp";
constexpr char kBandrejectBiquadOp[] = "BandrejectBiquadOp";
constexpr char kBassBiquadOp[] = "BassBiquadOp";
constexpr char kTimeStretchOp[] = "TimeStretchOp";
// data
constexpr char kConcatenateOp[] = "ConcatenateOp";

View File

@ -22,7 +22,7 @@ import numpy as np
from ..transforms.c_transforms import TensorOperation
from .utils import ScaleType
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
check_bandreject_biquad, check_bass_biquad
check_bandreject_biquad, check_bass_biquad, check_time_stretch
class AudioTensorOperation(TensorOperation):
@ -249,3 +249,36 @@ class BassBiquad(AudioTensorOperation):
def parse(self):
return cde.BassBiquadOperation(self.sample_rate, self.gain, self.central_freq, self.Q)
class TimeStretch(AudioTensorOperation):
"""
Stretch STFT in time at a given rate, without changing the pitch.
Args:
hop_length (int, optional): Length of hop between STFT windows (default=None).
n_freq (int, optional): Number of filter banks form STFT (default=201).
fixed_rate (float, optional): Rate to speed up or slow down the input in time (default=None).
Examples:
>>> freq = 44100
>>> num_frame = 30
>>> def gen():
... np.random.seed(0)
... data = np.random.random([freq, num_frame])
... yield (np.array(data, dtype=np.float32), )
>>> data1 = ds.GeneratorDataset(source=gen, column_names=["multi_dimensional_data"])
>>> transforms = [py_audio.TimeStretch()]
>>> data1 = data1.map(operations=transforms, input_columns=["multi_dimensional_data"])
"""
@check_time_stretch
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
self.n_freq = n_freq
self.fixed_rate = fixed_rate
n_fft = (n_freq - 1) * 2
self.hop_length = hop_length if hop_length is not None else n_fft // 2
self.fixed_rate = fixed_rate if fixed_rate is not None else np.nan
def parse(self):
return cde.TimeStretchOperation(self.hop_length, self.n_freq, self.fixed_rate)

View File

@ -16,8 +16,8 @@
Validators for TensorOps.
"""
from functools import wraps
from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, \
check_value_normalize_std, check_value_ratio, FLOAT_MAX_INTEGER, parse_user_args, type_check
from mindspore.dataset.core.validator_helpers import check_not_zero, check_int32, check_float32, check_value, \
check_value_normalize_std, check_value_ratio, FLOAT_MAX_INTEGER, INT64_MAX, parse_user_args, type_check
from .utils import ScaleType
@ -164,3 +164,25 @@ def check_bass_biquad(method):
return method(self, *args, **kwargs)
return new_method
def check_time_stretch(method):
"""Wrapper method to check the parameters of time_stretch."""
@wraps(method)
def new_method(self, *args, **kwargs):
[hop_length, n_freq, fixed_rate], _ = parse_user_args(method, *args, **kwargs)
# type check
type_check(hop_length, (int, type(None)), "hop_length")
type_check(n_freq, (int,), "n_freq")
type_check(fixed_rate, (int, float, type(None)), "fixed_rate")
# value check
if hop_length is not None:
check_value(hop_length, (1, INT64_MAX), "hop_length")
check_value(n_freq, (1, INT64_MAX), "n_freq")
if fixed_rate is not None:
check_value_ratio(fixed_rate, (0, FLOAT_MAX_INTEGER), "fixed_rate")
return method(self, *args, **kwargs)
return new_method

View File

@ -13,6 +13,7 @@ SET(DE_UT_SRCS
buddy_test.cc
build_vocab_test.cc
c_api_audio_a_to_q_test.cc
c_api_audio_r_to_z_test.cc
c_api_cache_test.cc
c_api_dataset_album_test.cc
c_api_audio_a_to_q_test.cc

View File

@ -0,0 +1,96 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common.h"
#include "minddata/dataset/core/tensor.h"
#include "minddata/dataset/include/dataset/datasets.h"
#include "minddata/dataset/include/dataset/audio.h"
using namespace mindspore::dataset;
using mindspore::LogStream;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::MsLogLevel::INFO;
class MindDataTestPipeline : public UT::Common {
public:
};
TEST_F(MindDataTestPipeline, TestTimeStretchPipeline) {
MS_LOG(INFO) << "Doing test TimeStretchOp with custom param value. Pipeline.";
// op param
int freq = 1025;
int hop_length = 512;
float rate = 1.2;
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, freq, 400, 2}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto TimeStretchOp = audio::TimeStretch(hop_length, freq, rate);
ds = ds->Map({TimeStretchOp});
EXPECT_NE(ds, nullptr);
// apply timestretch
std::shared_ptr<Iterator> iter = ds->CreateIterator();
EXPECT_NE(ds, nullptr);
std::unordered_map<std::string, mindspore::MSTensor> row;
ASSERT_OK(iter->GetNextRow(&row));
std::vector<int64_t> expected = {2, freq, int(std::ceil(400 / rate)), 2};
int i = 0;
while (row.size() != 0) {
auto col = row["inputData"];
ASSERT_EQ(col.Shape(), expected);
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
ASSERT_OK(iter->GetNextRow(&row));
i++;
}
EXPECT_EQ(i, 50);
iter->Stop();
}
TEST_F(MindDataTestPipeline, TestTimeStretchPipelineWrongArgs) {
MS_LOG(INFO) << "Doing test TimeStretchOp with wrong param value. Pipeline.";
// op param
int freq = 1025;
int hop_length = 512;
float rate = -2;
// Original waveform
std::shared_ptr<SchemaObj> schema = Schema();
ASSERT_OK(schema->add_column("inputData", mindspore::DataType::kNumberTypeFloat32, {2, freq, 400, 2}));
std::shared_ptr<Dataset> ds = RandomData(50, schema);
EXPECT_NE(ds, nullptr);
ds = ds->SetNumWorkers(4);
EXPECT_NE(ds, nullptr);
auto TimeStretchOp = audio::TimeStretch(hop_length, freq, rate);
ds = ds->Map({TimeStretchOp});
EXPECT_NE(ds, nullptr);
// apply timestretch
std::shared_ptr<Iterator> iter = ds->CreateIterator();
// Expect failure
EXPECT_EQ(iter, nullptr);
}

View File

@ -19,6 +19,7 @@
#include "minddata/dataset/include/dataset/audio.h"
#include "minddata/dataset/include/dataset/execute.h"
#include "minddata/dataset/include/dataset/transforms.h"
#include "minddata/dataset/include/dataset/audio.h"
#include "minddata/dataset/include/dataset/vision.h"
#include "minddata/dataset/include/dataset/audio.h"
#include "minddata/dataset/include/dataset/text.h"
@ -196,6 +197,65 @@ TEST_F(MindDataTestExecute, TestCrop) {
EXPECT_EQ(image.Shape()[1], 15);
}
TEST_F(MindDataTestExecute, TestTimeStretchEager) {
MS_LOG(INFO) << "Doing test TimeStretchOp with custom param value. Eager.";
std::shared_ptr<Tensor> input_tensor_;
// op param
int freq = 4;
int hop_length = 20;
float rate = 1.3;
int frame_num = 10;
// create tensor
TensorShape s = TensorShape({2, freq, frame_num, 2});
// init input vec
std::vector<float> input_vec(2 * freq * frame_num * 2);
for (int ind = 0; ind < input_vec.size(); ind++) {
input_vec[ind] = std::rand() % (1000) / (1000.0f);
}
ASSERT_OK(Tensor::CreateFromVector(input_vec, s, &input_tensor_));
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
std::shared_ptr<TensorTransform> time_stretch_op = std::make_shared<audio::TimeStretch>(hop_length, freq, rate);
// apply timestretch
mindspore::dataset::Execute Transform({time_stretch_op});
Status status = Transform(input_ms, &input_ms);
EXPECT_TRUE(status.IsOk());
}
TEST_F(MindDataTestExecute, TestTimeStretchParamCheck1) {
MS_LOG(INFO) << "Doing MindDataTestTimeStretch-TestTimeStretchParamCheck with invalid parameters.";
// Create an input
std::shared_ptr<Tensor> input_tensor_;
std::shared_ptr<Tensor> output_tensor;
TensorShape s = TensorShape({1, 4, 3, 2});
ASSERT_OK(Tensor::CreateFromVector(
std::vector<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f,
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}),
s, &input_tensor_));
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
std::shared_ptr<TensorTransform> timestretch = std::make_shared<audio::TimeStretch>(4, 512, -2);
mindspore::dataset::Execute Transform({timestretch});
Status status = Transform(input_ms, &input_ms);
EXPECT_FALSE(status.IsOk());
}
TEST_F(MindDataTestExecute, TestTimeStretchParamCheck2) {
MS_LOG(INFO) << "Doing MindDataTestTimeStretch-TestTimeStretchParamCheck with invalid parameters.";
// Create an input
std::shared_ptr<Tensor> input_tensor_;
std::shared_ptr<Tensor> output_tensor;
TensorShape s = TensorShape({1, 4, 3, 2});
ASSERT_OK(Tensor::CreateFromVector(
std::vector<float>({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f,
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}),
s, &input_tensor_));
auto input_ms = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input_tensor_));
std::shared_ptr<TensorTransform> timestretch = std::make_shared<audio::TimeStretch>(4, -512, 2);
mindspore::dataset::Execute Transform({timestretch});
Status status = Transform(input_ms, &input_ms);
EXPECT_FALSE(status.IsOk());
}
TEST_F(MindDataTestExecute, TestTransformInput1) {
MS_LOG(INFO) << "Doing MindDataTestExecute-TestTransformInput1.";
// Test Execute with transform op input using API constructors, with std::shared_ptr<TensorTransform pointers,

View File

@ -0,0 +1,142 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Testing TimeStretch op in DE
"""
import numpy as np
import pytest
import mindspore.dataset as ds
import mindspore.dataset.audio.transforms as c_audio
from mindspore import log as logger
CHANNEL_NUM = 2
FREQ = 1025
FRAME_NUM = 300
COMPLEX = 2
def gen(shape):
np.random.seed(0)
data = np.random.random(shape)
yield(np.array(data, dtype=np.float32),)
def _count_unequal_element(data_expected, data_me, rtol, atol):
assert data_expected.shape == data_me.shape
total_count = len(data_expected.flatten())
error = np.abs(data_expected - data_me)
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
loss_count = np.count_nonzero(greater)
assert (loss_count / total_count) < rtol, \
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
format(data_expected[greater], data_me[greater], error[greater])
def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True):
if np.any(np.isnan(data_expected)):
assert np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan)
elif not np.allclose(data_me, data_expected, rtol, atol, equal_nan=equal_nan):
_count_unequal_element(data_expected, data_me, rtol, atol)
else:
assert True
def test_time_stretch_pipeline():
"""
Test TimeStretch op. Pipeline.
"""
logger.info("test TimeStretch op")
generator = gen([CHANNEL_NUM, FREQ, FRAME_NUM, COMPLEX])
data1 = ds.GeneratorDataset(source=generator, column_names=[
"multi_dimensional_data"])
transforms = [
c_audio.TimeStretch(512, FREQ, 1.3)
]
data1 = data1.map(operations=transforms, input_columns=[
"multi_dimensional_data"])
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
out_put = item["multi_dimensional_data"]
assert out_put.shape == (CHANNEL_NUM, FREQ, np.ceil(FRAME_NUM/1.3), COMPLEX)
def test_time_stretch_pipeline_invalid_param():
"""
Test TimeStretch op. Set invalid param. Pipeline.
"""
logger.info("test TimeStretch op with invalid values")
generator = gen([CHANNEL_NUM, FREQ, FRAME_NUM, COMPLEX])
data1 = ds.GeneratorDataset(source=generator, column_names=[
"multi_dimensional_data"])
with pytest.raises(ValueError, match=r"Input fixed_rate is not within the required interval of \(0, 16777216\]."):
transforms = [
c_audio.TimeStretch(512, FREQ, -1.3)
]
data1 = data1.map(operations=transforms, input_columns=[
"multi_dimensional_data"])
for item in data1.create_dict_iterator(num_epochs=1, output_numpy=True):
out_put = item["multi_dimensional_data"]
assert out_put.shape == (CHANNEL_NUM, FREQ, np.ceil(FRAME_NUM/1.3), COMPLEX)
def test_time_stretch_eager():
"""
Test TimeStretch op. Set param. Eager.
"""
logger.info("test TimeStretch op with customized parameter values")
spectrogram = next(gen([CHANNEL_NUM, FREQ, FRAME_NUM, COMPLEX]))[0]
out_put = c_audio.TimeStretch(512, FREQ, 1.3)(spectrogram)
assert out_put.shape == (CHANNEL_NUM, FREQ, np.ceil(FRAME_NUM/1.3), COMPLEX)
def test_percision_time_stretch_eager():
"""
Test TimeStretch op. Compare precision. Eager.
"""
logger.info("test TimeStretch op with default values")
spectrogram = np.array([[[[1.0402449369430542, 0.3807601034641266],
[-1.120057225227356, -0.12819576263427734],
[1.4303032159805298, -0.08839055150747299]],
[[1.4198592901229858, 0.6900091767311096],
[-1.8593409061431885, 0.16363371908664703],
[-2.3349387645721436, -1.4366451501846313]]],
[[[-0.7083967328071594, 0.9325454831123352],
[-1.9133838415145874, 0.011225821450352669],
[1.477278232574463, -1.0551637411117554]],
[[-0.6668586134910583, -0.23143270611763],
[-2.4390718936920166, 0.17638640105724335],
[-0.4795735776424408, 0.1345423310995102]]]]).astype(np.float64)
out_expect = np.array([[[[1.0402449369430542, 0.3807601034641266],
[-1.302264928817749, -0.1490504890680313]],
[[1.4198592901229858, 0.6900091767311096],
[-2.382312774658203, 0.2096325159072876]]],
[[[-0.7083966732025146, 0.9325454831123352],
[-1.8545820713043213, 0.010880803689360619]],
[[-0.6668586134910583, -0.23143276572227478],
[-1.2737033367156982, 0.09211209416389465]]]]).astype(np.float64)
out_ms = c_audio.TimeStretch(64, 2, 1.6)(spectrogram)
allclose_nparray(out_ms, out_expect, 0.001, 0.001)
if __name__ == '__main__':
test_time_stretch_pipeline()
test_time_stretch_pipeline_invalid_param()
test_time_stretch_eager()
test_percision_time_stretch_eager()