!24263 [assistant][TimeStretch] fix out-of-bounds ptr

Merge pull request !24263 from QingfengLi/TimeStretch
This commit is contained in:
i-robot 2021-09-30 03:46:04 +00:00 committed by Gitee
commit a2c3a34ee7
5 changed files with 12 additions and 10 deletions

View File

@ -33,7 +33,6 @@ Status TimeStretchOperation::ValidateParams() {
// param check // param check
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("TimeStretch", "hop_length", hop_length_)); RETURN_IF_NOT_OK(ValidateFloatScalarPositive("TimeStretch", "hop_length", hop_length_));
RETURN_IF_NOT_OK(ValidateIntScalarPositive("TimeStretch", "n_freq", n_freq_)); RETURN_IF_NOT_OK(ValidateIntScalarPositive("TimeStretch", "n_freq", n_freq_));
RETURN_IF_NOT_OK(ValidateFloatScalarNotNan("TimeStretch", "fixed_rate", fixed_rate_));
RETURN_IF_NOT_OK(ValidateFloatScalarPositive("TimeStretch", "fixed_rate", fixed_rate_)); RETURN_IF_NOT_OK(ValidateFloatScalarPositive("TimeStretch", "fixed_rate", fixed_rate_));
return Status::OK(); return Status::OK();
} }

View File

@ -213,10 +213,11 @@ Status Phase(const std::shared_ptr<Tensor> &angle_0, const std::shared_ptr<Tenso
int64_t ind = 0; int64_t ind = 0;
auto itr_p0 = phase_time0->begin<T>(); auto itr_p0 = phase_time0->begin<T>();
(void)phase.insert(phase.begin(), (*itr_p0)); (void)phase.insert(phase.begin(), (*itr_p0));
itr_p0++;
while (itr_p0 != phase_time0->end<T>()) { while (itr_p0 != phase_time0->end<T>()) {
itr_p0++;
ind += phase_shape[2]; ind += phase_shape[2];
phase[ind] = (*itr_p0); phase[ind] = (*itr_p0);
itr_p0++;
} }
(void)phase.erase(phase.begin() + static_cast<int>(angle_0->Size()), phase.end()); (void)phase.erase(phase.begin() + static_cast<int>(angle_0->Size()), phase.end());

View File

@ -39,8 +39,8 @@ Status TimeStretchOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_
} }
std::shared_ptr<Tensor> input_tensor; std::shared_ptr<Tensor> input_tensor;
// std::shared_ptr<Tensor> phase_advance;
float hop_length = std::isnan(hop_length_) ? (n_freq_ - 1) : hop_length_; float hop_length = std::isnan(hop_length_) ? (n_freq_ - 1) : hop_length_;
float fixed_rate = std::isnan(fixed_rate_) ? 1 : fixed_rate_;
// typecast // typecast
CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING, CHECK_FAIL_RETURN_UNEXPECTED(input->type() != DataType::DE_STRING,
"TimeStretch: input tensor type should be int, float or double, but got: string."); "TimeStretch: input tensor type should be int, float or double, but got: string.");
@ -50,7 +50,7 @@ Status TimeStretchOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_
input_tensor = input; input_tensor = input;
} }
return TimeStretch(input_tensor, output, fixed_rate_, hop_length, n_freq_); return TimeStretch(input_tensor, output, fixed_rate, hop_length, n_freq_);
} }
Status TimeStretchOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) { Status TimeStretchOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {

View File

@ -537,9 +537,10 @@ class TimeMasking final : public TensorTransform {
class TimeStretch final : public TensorTransform { class TimeStretch final : public TensorTransform {
public: public:
/// \brief Constructor. /// \brief Constructor.
/// \param[in] hop_length Length of hop between STFT windows. Default: None. /// \param[in] hop_length Length of hop between STFT windows (Default: None, will use ((n_freq - 1) * 2) // 2).
/// \param[in] n_freq Number of filter banks form STFT. Default: 201. /// \param[in] n_freq Number of filter banks form STFT (Default: 201).
/// \param[in] fixed_rate Rate to speed up or slow down the input in time. Default: None. /// \param[in] fixed_rate Rate to speed up or slow down the input in time
/// (Default: std::numeric_limits<float>::quiet_NaN(), will keep the original rate).
explicit TimeStretch(float hop_length = std::numeric_limits<float>::quiet_NaN(), int n_freq = 201, explicit TimeStretch(float hop_length = std::numeric_limits<float>::quiet_NaN(), int n_freq = 201,
float fixed_rate = std::numeric_limits<float>::quiet_NaN()); float fixed_rate = std::numeric_limits<float>::quiet_NaN());

View File

@ -689,9 +689,10 @@ class TimeStretch(AudioTensorOperation):
Stretch STFT in time at a given rate, without changing the pitch. Stretch STFT in time at a given rate, without changing the pitch.
Args: Args:
hop_length (int, optional): Length of hop between STFT windows (default=None). hop_length (int, optional): Length of hop between STFT windows (default=None, will use ((n_freq - 1) * 2) // 2).
n_freq (int, optional): Number of filter banks form STFT (default=201). n_freq (int, optional): Number of filter banks form STFT (default=201).
fixed_rate (float, optional): Rate to speed up or slow down the input in time (default=None). fixed_rate (float, optional): Rate to speed up or slow down the input in time
(default=None, will keep the original rate).
Examples: Examples:
>>> import numpy as np >>> import numpy as np
@ -708,7 +709,7 @@ class TimeStretch(AudioTensorOperation):
n_fft = (n_freq - 1) * 2 n_fft = (n_freq - 1) * 2
self.hop_length = hop_length if hop_length is not None else n_fft // 2 self.hop_length = hop_length if hop_length is not None else n_fft // 2
self.fixed_rate = fixed_rate if fixed_rate is not None else np.nan self.fixed_rate = fixed_rate if fixed_rate is not None else 1
def parse(self): def parse(self):
return cde.TimeStretchOperation(self.hop_length, self.n_freq, self.fixed_rate) return cde.TimeStretchOperation(self.hop_length, self.n_freq, self.fixed_rate)