forked from mindspore-Ecosystem/mindspore
[feat][assistant][I40GXD]add new data operator DBToAmplitude
This commit is contained in:
parent
b8c50982c5
commit
69f57bc5b1
|
@ -26,6 +26,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
|
||||
|
@ -197,6 +198,19 @@ std::shared_ptr<TensorOperation> Contrast::Parse() {
|
|||
return std::make_shared<ContrastOperation>(data_->enhancement_amount_);
|
||||
}
|
||||
|
||||
// DBToAmplitude Transform Operation.
|
||||
struct DBToAmplitude::Data {
|
||||
explicit Data(float ref, float power) : ref_(ref), power_(power) {}
|
||||
float ref_;
|
||||
float power_;
|
||||
};
|
||||
|
||||
DBToAmplitude::DBToAmplitude(float ref, float power) : data_(std::make_shared<Data>(power, power)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> DBToAmplitude::Parse() {
|
||||
return std::make_shared<DBToAmplitudeOperation>(data_->ref_, data_->power_);
|
||||
}
|
||||
|
||||
// DCShift Transform Operation.
|
||||
struct DCShift::Data {
|
||||
Data(float shift, float limiter_gain) : shift_(shift), limiter_gain_(limiter_gain) {}
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/complex_norm_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/contrast_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/dc_shift_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/deemph_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/detect_pitch_frequency_ir.h"
|
||||
|
@ -170,6 +171,17 @@ PYBIND_REGISTER(ContrastOperation, 1, ([](const py::module *m) {
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
DBToAmplitudeOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::DBToAmplitudeOperation, TensorOperation, std::shared_ptr<audio::DBToAmplitudeOperation>>(
|
||||
*m, "DBToAmplitudeOperation")
|
||||
.def(py::init([](float ref, float power) {
|
||||
auto db_to_amplitude = std::make_shared<audio::DBToAmplitudeOperation>(ref, power);
|
||||
THROW_IF_ERROR(db_to_amplitude->ValidateParams());
|
||||
return db_to_amplitude;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(DCShiftOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::DCShiftOperation, TensorOperation, std::shared_ptr<audio::DCShiftOperation>>(
|
||||
*m, "DCShiftOperation")
|
||||
|
|
|
@ -12,6 +12,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
biquad_ir.cc
|
||||
complex_norm_ir.cc
|
||||
contrast_ir.cc
|
||||
db_to_amplitude_ir.cc
|
||||
dc_shift_ir.cc
|
||||
deemph_biquad_ir.cc
|
||||
detect_pitch_frequency_ir.cc
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/db_to_amplitude_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/db_to_amplitude_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
// DBToAmplitudeOperation
|
||||
DBToAmplitudeOperation::DBToAmplitudeOperation(float ref, float power) : ref_(ref), power_(power) {}
|
||||
|
||||
Status DBToAmplitudeOperation::ValidateParams() { return Status::OK(); }
|
||||
|
||||
std::shared_ptr<TensorOp> DBToAmplitudeOperation::Build() {
|
||||
std::shared_ptr<DBToAmplitudeOp> tensor_op = std::make_shared<DBToAmplitudeOp>(ref_, power_);
|
||||
return tensor_op;
|
||||
}
|
||||
|
||||
Status DBToAmplitudeOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["ref"] = ref_;
|
||||
args["power"] = power_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DB_TO_AMPLITUDE_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DB_TO_AMPLITUDE_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace audio {
|
||||
constexpr char kDBToAmplitudeOperation[] = "DBToAmplitude";
|
||||
|
||||
class DBToAmplitudeOperation : public TensorOperation {
|
||||
public:
|
||||
DBToAmplitudeOperation(float ref, float power);
|
||||
|
||||
~DBToAmplitudeOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
std::string Name() const override { return kDBToAmplitudeOperation; }
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
float ref_;
|
||||
float power_;
|
||||
};
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_DB_TO_AMPLITUDE_IR_H_
|
|
@ -13,6 +13,7 @@ add_library(audio-kernels OBJECT
|
|||
biquad_op.cc
|
||||
complex_norm_op.cc
|
||||
contrast_op.cc
|
||||
db_to_amplitude_op.cc
|
||||
dc_shift_op.cc
|
||||
deemph_biquad_op.cc
|
||||
detect_pitch_frequency_op.cc
|
||||
|
|
|
@ -152,6 +152,24 @@ Status Contrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Apply DBToAmplitude effect.
|
||||
/// \param input/output: Tensor of shape <...,time>
|
||||
/// \param ref: Reference which the output will be scaled by.
|
||||
/// \param power: If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.
|
||||
/// \return Status code
|
||||
template <typename T>
|
||||
Status DBToAmplitude(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, T ref, T power) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(input->shape(), input->type(), &out));
|
||||
auto itr_out = out->begin<T>();
|
||||
for (auto itr_in = input->begin<T>(); itr_in != input->end<T>(); itr_in++) {
|
||||
*itr_out = ref * pow(pow(10, (*itr_in) * 0.1), power);
|
||||
itr_out++;
|
||||
}
|
||||
*output = out;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Apply a DC shift to the audio.
|
||||
/// \param input/output: Tensor of shape <...,time>.
|
||||
/// \param shift: the amount to shift the audio.
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/kernels/db_to_amplitude_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
Status DBToAmplitudeOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
IO_CHECK(input, output);
|
||||
TensorShape input_shape = input->shape();
|
||||
// check input tensor dimension, it should be greater than 0.
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input_shape.Size() > 0, "DBToAmplitude: input tensor is not in shape of <..., time>.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input->type().IsNumeric(),
|
||||
"DBToAmplitude: input tensor type should be int, float or double, but got " + input->type().ToString());
|
||||
|
||||
std::shared_ptr<Tensor> input_tensor;
|
||||
if (input->type() != DataType(DataType::DE_FLOAT64)) {
|
||||
RETURN_IF_NOT_OK(TypeCast(input, &input_tensor, DataType(DataType::DE_FLOAT32)));
|
||||
return DBToAmplitude<float>(input_tensor, output, ref_, power_);
|
||||
} else {
|
||||
input_tensor = input;
|
||||
return DBToAmplitude<double>(input_tensor, output, ref_, power_);
|
||||
}
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DB_TO_AMPLITUDE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DB_TO_AMPLITUDE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
class DBToAmplitudeOp : public TensorOp {
|
||||
public:
|
||||
DBToAmplitudeOp(float ref, float power) : ref_(ref), power_(power) {}
|
||||
|
||||
~DBToAmplitudeOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
out << Name() << " ref: " << ref_ << ", power: " << power_ << std::endl;
|
||||
}
|
||||
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
std::string Name() const override { return kDBToAmplitudeOp; }
|
||||
|
||||
private:
|
||||
float ref_;
|
||||
float power_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_DB_TO_AMPLITUDE_OP_H_
|
|
@ -256,6 +256,27 @@ class Contrast final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Turn a waveform from the decibel scale to the power/amplitude scale.
|
||||
class DBToAmplitude final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
/// \param[in] ref Reference which the output will be scaled by.
|
||||
/// \param[in] power If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.
|
||||
explicit DBToAmplitude(float ref, float power);
|
||||
|
||||
/// \brief Destructor.
|
||||
~DBToAmplitude() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Apply a DC shift to the audio.
|
||||
class DCShift : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -154,6 +154,7 @@ constexpr char kBassBiquadOp[] = "BassBiquadOp";
|
|||
constexpr char kBiquadOp[] = "BiquadOp";
|
||||
constexpr char kComplexNormOp[] = "ComplexNormOp";
|
||||
constexpr char kContrastOp[] = "ContrastOp";
|
||||
constexpr char kDBToAmplitudeOp[] = " DBToAmplitudeOp";
|
||||
constexpr char kDCShiftOp[] = "DCShiftOp";
|
||||
constexpr char kDeemphBiquadOp[] = "DeemphBiquadOp";
|
||||
constexpr char kDetectPitchFrequencyOp[] = "DetectPitchFrequencyOp";
|
||||
|
|
|
@ -25,10 +25,11 @@ import mindspore._c_dataengine as cde
|
|||
from ..transforms.c_transforms import TensorOperation
|
||||
from .utils import FadeShape, GainType, Interpolation, Modulation, ScaleType
|
||||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, check_fade, check_flanger, \
|
||||
check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, check_masking, check_mu_law_coding, \
|
||||
check_overdrive, check_phaser, check_riaa_biquad, check_time_stretch, check_treble_biquad, check_vol
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, \
|
||||
check_db_to_amplitude, check_dc_shift, check_deemph_biquad, check_detect_pitch_frequency, check_equalizer_biquad, \
|
||||
check_fade, check_flanger, check_highpass_biquad, check_lfilter, check_lowpass_biquad, check_magphase, \
|
||||
check_masking, check_mu_law_coding, check_overdrive, check_phaser, check_riaa_biquad, check_time_stretch, \
|
||||
check_treble_biquad, check_vol
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -330,6 +331,32 @@ class Contrast(AudioTensorOperation):
|
|||
return cde.ContrastOperation(self.enhancement_amount)
|
||||
|
||||
|
||||
class DBToAmplitude(AudioTensorOperation):
|
||||
"""
|
||||
Turn a waveform from the decibel scale to the power/amplitude scale.
|
||||
|
||||
Args:
|
||||
ref (float): Reference which the output will be scaled by.
|
||||
power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.array([[2.716064453125e-03, 6.34765625e-03], [9.246826171875e-03, 1.0894775390625e-02]])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.DBToAmplitude(0.5, 0.5)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_db_to_amplitude
|
||||
def __init__(self, ref, power):
|
||||
self.ref = ref
|
||||
self.power = power
|
||||
|
||||
def parse(self):
|
||||
return cde.DBToAmplitudeOperation(self.ref, self.power)
|
||||
|
||||
|
||||
class DCShift(AudioTensorOperation):
|
||||
"""
|
||||
Apply a DC shift to the audio.
|
||||
|
|
|
@ -195,6 +195,21 @@ def check_contrast(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_db_to_amplitude(method):
|
||||
"""Wrapper method to check the parameters of db_to_amplitude."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[ref, power], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(ref, (float, int), "ref")
|
||||
check_float32(ref, "ref")
|
||||
type_check(power, (float, int), "power")
|
||||
check_float32(power, "power")
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_dc_shift(method):
|
||||
"""Wrapper method to check the parameters of DCShift."""
|
||||
|
||||
|
|
|
@ -1916,3 +1916,44 @@ TEST_F(MindDataTestPipeline, TestCreateDctWrongArg) {
|
|||
Status s04 = audio::CreateDct(&output, 200, -400, NormMode::kOrtho);
|
||||
EXPECT_FALSE(s04.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: DBToAmplitude
|
||||
/// Description: test DBToAmplitude in pipeline mode
|
||||
/// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestPipeline, TestDBToAmplitudePipeline) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestDBToAmplitudePipeline.";
|
||||
// Original waveform
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("waveform", mindspore::DataType::kNumberTypeFloat32, {2, 200}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(50, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
ds = ds->SetNumWorkers(4);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
auto db_to_amplitude_op = audio::DBToAmplitude(2, 2);
|
||||
|
||||
ds = ds->Map({db_to_amplitude_op});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
|
||||
std::vector<int64_t> expected = {2, 200};
|
||||
|
||||
int i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto col = row["waveform"];
|
||||
ASSERT_EQ(col.Shape(), expected);
|
||||
ASSERT_EQ(col.Shape().size(), 2);
|
||||
ASSERT_EQ(col.DataType(), mindspore::DataType::kNumberTypeFloat32);
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
i++;
|
||||
}
|
||||
EXPECT_EQ(i, 50);
|
||||
|
||||
iter->Stop();
|
||||
}
|
||||
|
|
|
@ -1480,3 +1480,25 @@ TEST_F(MindDataTestExecute, TestFlangerWithWrongArg) {
|
|||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_FALSE(s01.IsOk());
|
||||
}
|
||||
|
||||
// Feature: DBToAmplitude
|
||||
// Description: test DBToAmplitude in eager mode
|
||||
// Expectation: the data is processed successfully
|
||||
TEST_F(MindDataTestExecute, TestDBToAmplitudeWithEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestDBToAmplitudeWithEager.";
|
||||
// Original waveform
|
||||
std::vector<float> labels = {
|
||||
2.716064453125000000e-03, 6.347656250000000000e-03, 9.246826171875000000e-03, 1.089477539062500000e-02,
|
||||
1.138305664062500000e-02, 1.156616210937500000e-02, 1.394653320312500000e-02, 1.550292968750000000e-02,
|
||||
1.614379882812500000e-02, 1.840209960937500000e-02, 1.718139648437500000e-02, 1.599121093750000000e-02,
|
||||
1.647949218750000000e-02, 1.510620117187500000e-02, 1.385498046875000000e-02, 1.345825195312500000e-02,
|
||||
1.419067382812500000e-02, 1.284790039062500000e-02, 1.052856445312500000e-02, 9.368896484375000000e-03};
|
||||
std::shared_ptr<Tensor> input;
|
||||
ASSERT_OK(Tensor::CreateFromVector(labels, TensorShape({2, 10}), &input));
|
||||
auto input_02 = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(input));
|
||||
std::shared_ptr<TensorTransform> DBToAmplitude_01 = std::make_shared<audio::DBToAmplitude>(2, 2);
|
||||
mindspore::dataset::Execute Transform01({DBToAmplitude_01});
|
||||
// Filtered waveform by DBToAmplitude
|
||||
Status s01 = Transform01(input_02, &input_02);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_db_to_amplitude_eager():
|
||||
"""
|
||||
Feature: DBToAmplitude
|
||||
Description: test DBToAmplitude in eager mode
|
||||
Expectation: the data is processed successfully
|
||||
"""
|
||||
logger.info("mindspore eager mode normal testcase:DBToAmplitude op")
|
||||
|
||||
# Original waveform
|
||||
waveform = np.array([1, 2, 3, 4, 5, 6], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([3.1698, 5.0238, 7.9621, 12.6191, 20.0000, 31.6979], dtype=np.float64)
|
||||
DBToAmplitude_op = audio.DBToAmplitude(2, 2)
|
||||
# Filtered waveform by DBToAmplitude
|
||||
output = DBToAmplitude_op(waveform)
|
||||
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_db_to_amplitude_pipeline():
|
||||
"""
|
||||
Feature: DBToAmplitude
|
||||
Description: test DBToAmplitude in pipeline mode
|
||||
Expectation: the data is processed successfully
|
||||
"""
|
||||
logger.info("mindspore pipeline mode normal testcase:DBToAmplitude op")
|
||||
|
||||
# Original waveform
|
||||
waveform = np.array([[2, 2, 3], [0.1, 0.2, 0.3]], dtype=np.float64)
|
||||
# Expect waveform
|
||||
expect_waveform = np.array([[2.5119, 2.5119, 3.9811],
|
||||
[1.0471, 1.0965, 1.1482]], dtype=np.float64)
|
||||
|
||||
dataset = ds.NumpySlicesDataset(waveform, ["audio"], shuffle=False)
|
||||
DBToAmplitude_op = audio.DBToAmplitude(1, 2)
|
||||
# Filtered waveform by DBToAmplitude
|
||||
dataset = dataset.map(input_columns=["audio"], operations=DBToAmplitude_op, num_parallel_workers=8)
|
||||
i = 0
|
||||
for item in dataset.create_dict_iterator(output_numpy=True):
|
||||
count_unequal_element(expect_waveform[i, :], item['audio'], 0.0001, 0.0001)
|
||||
i += 1
|
||||
|
||||
|
||||
def test_db_to_amplitude_invalid_input():
|
||||
"""
|
||||
Feature: DBToAmplitude
|
||||
Description: test param check of DBToAmplitude
|
||||
Expectation: throw correct error and message
|
||||
"""
|
||||
logger.info("mindspore eager mode invalid input testcase:filter_wikipedia_xml op")
|
||||
|
||||
def test_invalid_input(test_name, ref, power, error, error_msg):
|
||||
logger.info("Test DBToAmplitude with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.DBToAmplitude(ref, power)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid ref parameter type as a String", "1.0", 1.0, TypeError,
|
||||
"Argument ref with value 1.0 is not of type [<class 'float'>, <class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid ref parameter value", 122323242445423534543, 1.0, ValueError,
|
||||
"Input ref is not within the required interval of [-16777216, 16777216].")
|
||||
test_invalid_input("invalid power parameter type as a String", 1.0, "1.0", TypeError,
|
||||
"Argument power with value 1.0 is not of type [<class 'float'>, <class 'int'>],"
|
||||
" but got <class 'str'>.")
|
||||
test_invalid_input("invalid power parameter value", 1.0, 1343454254325445, ValueError,
|
||||
"Input power is not within the required interval of [-16777216, 16777216].")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_db_to_amplitude_eager()
|
||||
test_db_to_amplitude_pipeline()
|
||||
test_db_to_amplitude_invalid_input()
|
Loading…
Reference in New Issue