forked from mindspore-Ecosystem/mindspore
!18114 [assistant][ops]Add new operator Magphase
Merge pull request !18114 from YJfuel123/Magphase
This commit is contained in:
commit
6489dacdab
|
@ -34,6 +34,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
|
@ -316,6 +317,16 @@ std::shared_ptr<TensorOperation> LowpassBiquad::Parse() {
|
|||
return std::make_shared<LowpassBiquadOperation>(data_->sample_rate_, data_->cutoff_freq_, data_->Q_);
|
||||
}
|
||||
|
||||
// Magphase Transform Operation.
|
||||
struct Magphase::Data {
|
||||
explicit Data(float power) : power_(power) {}
|
||||
float power_;
|
||||
};
|
||||
|
||||
Magphase::Magphase(float power) : data_(std::make_shared<Data>(power)) {}
|
||||
|
||||
std::shared_ptr<TensorOperation> Magphase::Parse() { return std::make_shared<MagphaseOperation>(data_->power_); }
|
||||
|
||||
// MuLawDecoding Transform Operation.
|
||||
struct MuLawDecoding::Data {
|
||||
explicit Data(int quantization_channels) : quantization_channels_(quantization_channels) {}
|
||||
|
|
|
@ -38,6 +38,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lfilter_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_masking_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
|
@ -259,6 +260,17 @@ PYBIND_REGISTER(
|
|||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(MagphaseOperation, 1, ([](const py::module *m) {
|
||||
(void)
|
||||
py::class_<audio::MagphaseOperation, TensorOperation, std::shared_ptr<audio::MagphaseOperation>>(
|
||||
*m, "MagphaseOperation")
|
||||
.def(py::init([](float power) {
|
||||
auto magphase = std::make_shared<audio::MagphaseOperation>(power);
|
||||
THROW_IF_ERROR(magphase->ValidateParams());
|
||||
return magphase;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(
|
||||
MuLawDecodingOperation, 1, ([](const py::module *m) {
|
||||
(void)py::class_<audio::MuLawDecodingOperation, TensorOperation, std::shared_ptr<audio::MuLawDecodingOperation>>(
|
||||
|
|
|
@ -20,6 +20,7 @@ add_library(audio-ir-kernels OBJECT
|
|||
highpass_biquad_ir.cc
|
||||
lfilter_ir.cc
|
||||
lowpass_biquad_ir.cc
|
||||
magphase_ir.cc
|
||||
mu_law_decoding_ir.cc
|
||||
time_masking_ir.cc
|
||||
time_stretch_ir.cc
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/audio/ir/kernels/magphase_ir.h"
|
||||
|
||||
#include "minddata/dataset/audio/ir/validators.h"
|
||||
#include "minddata/dataset/audio/kernels/magphase_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace audio {
|
||||
|
||||
MagphaseOperation::MagphaseOperation(float power) : power_(power) {}
|
||||
|
||||
Status MagphaseOperation::ValidateParams() {
|
||||
RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Magphase", "power", power_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<TensorOp> MagphaseOperation::Build() { return std::make_shared<MagphaseOp>(power_); }
|
||||
|
||||
Status MagphaseOperation::to_json(nlohmann::json *out_json) {
|
||||
nlohmann::json args;
|
||||
args["power"] = power_;
|
||||
*out_json = args;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MagphaseOperation::from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("power") != op_params.end(), "Fail to find power");
|
||||
float power = op_params["power"];
|
||||
*operation = std::make_shared<audio::MagphaseOperation>(power);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,58 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "minddata/dataset/kernels/ir/tensor_operation.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
namespace audio {
|
||||
|
||||
constexpr char kMagphaseOperation[] = "Magphase";
|
||||
|
||||
class MagphaseOperation : public TensorOperation {
|
||||
public:
|
||||
explicit MagphaseOperation(float power);
|
||||
|
||||
~MagphaseOperation() = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
std::string Name() const override { return kMagphaseOperation; }
|
||||
|
||||
Status ValidateParams() override;
|
||||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
static Status from_json(nlohmann::json op_params, std::shared_ptr<TensorOperation> *operation);
|
||||
|
||||
private:
|
||||
float power_;
|
||||
};
|
||||
|
||||
} // namespace audio
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_
|
|
@ -21,6 +21,7 @@ add_library(audio-kernels OBJECT
|
|||
highpass_biquad_op.cc
|
||||
lfilter_op.cc
|
||||
lowpass_biquad_op.cc
|
||||
magphase_op.cc
|
||||
mu_law_decoding_op.cc
|
||||
time_masking_op.cc
|
||||
time_stretch_op.cc
|
||||
|
|
|
@ -630,5 +630,23 @@ Status Fade(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Magphase(const TensorRow &input, TensorRow *output, float power) {
|
||||
std::shared_ptr<Tensor> mag;
|
||||
std::shared_ptr<Tensor> phase;
|
||||
|
||||
RETURN_IF_NOT_OK(ComplexNorm(input[0], &mag, power));
|
||||
if (input[0]->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
RETURN_IF_NOT_OK(Angle<double>(input[0], &phase));
|
||||
} else {
|
||||
std::shared_ptr<Tensor> tmp;
|
||||
RETURN_IF_NOT_OK(TypeCast(input[0], &tmp, DataType(DataType::DE_FLOAT32)));
|
||||
RETURN_IF_NOT_OK(Angle<float>(tmp, &phase));
|
||||
}
|
||||
(*output).push_back(mag);
|
||||
(*output).push_back(phase);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -365,6 +365,13 @@ Status Vol(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
|
|||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Separate a complex-valued spectrogram with shape (…, 2) into its magnitude and phase.
|
||||
/// \param input: Complex tensor.
|
||||
/// \param output: The magnitude and phase of the complex tensor.
|
||||
/// \param power: Power of the norm.
|
||||
Status Magphase(const TensorRow &input, TensorRow *output, float power);
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "minddata/dataset/audio/kernels/magphase_op.h"
|
||||
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
constexpr float MagphaseOp::kPower = 1.0;
|
||||
|
||||
Status MagphaseOp::Compute(const TensorRow &input, TensorRow *output) {
|
||||
IO_CHECK_VECTOR(input, output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2 && input[0]->shape()[-1] == 2,
|
||||
"Magphase: input tensor is not in shape of <..., 2>.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(
|
||||
input[0]->type().IsNumeric(),
|
||||
"Magphase: input tensor type should be int, float or double, but got: " + input[0]->type().ToString());
|
||||
RETURN_IF_NOT_OK(Magphase(input, output, power_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MagphaseOp::OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs));
|
||||
outputs.clear();
|
||||
auto vec = inputs[0].AsVector();
|
||||
vec.pop_back();
|
||||
auto out = TensorShape(vec);
|
||||
outputs = {out, out};
|
||||
if (!outputs.empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
return Status(StatusCode::kMDUnexpectedError, "Magphase: invalid input wrong shape.");
|
||||
}
|
||||
|
||||
Status MagphaseOp::OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) {
|
||||
RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs));
|
||||
if (inputs[0] == DataType(DataType::DE_FLOAT64)) {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT64);
|
||||
} else {
|
||||
outputs[0] = DataType(DataType::DE_FLOAT32);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/core/tensor.h"
|
||||
#include "minddata/dataset/kernels/tensor_op.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class MagphaseOp : public TensorOp {
|
||||
public:
|
||||
static const float kPower;
|
||||
|
||||
explicit MagphaseOp(float power = kPower) : power_(power) {}
|
||||
|
||||
~MagphaseOp() override = default;
|
||||
|
||||
Status Compute(const TensorRow &input, TensorRow *output) override;
|
||||
|
||||
Status OutputShape(const std::vector<TensorShape> &inputs, std::vector<TensorShape> &outputs) override;
|
||||
|
||||
Status OutputType(const std::vector<DataType> &inputs, std::vector<DataType> &outputs) override;
|
||||
|
||||
std::string Name() const override { return kMagphaseOp; }
|
||||
|
||||
private:
|
||||
float power_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_
|
|
@ -443,6 +443,26 @@ class LowpassBiquad final : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief Separate a complex-valued spectrogram with shape (..., 2) into its magnitude and phase.
|
||||
class Magphase final : public TensorTransform {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
/// \param[in] power Power of the norm, which must be non-negative (Default: 1.0).
|
||||
explicit Magphase(float power);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Magphase() = default;
|
||||
|
||||
protected:
|
||||
/// \brief Function to convert TensorTransform object into a TensorOperation object.
|
||||
/// \return Shared pointer to TensorOperation object.
|
||||
std::shared_ptr<TensorOperation> Parse() override;
|
||||
|
||||
private:
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \brief MuLawDecoding TensorTransform.
|
||||
/// \note Decode mu-law encoded signal.
|
||||
class MuLawDecoding final : public TensorTransform {
|
||||
|
|
|
@ -157,6 +157,7 @@ constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp";
|
|||
constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp";
|
||||
constexpr char kLFilterOp[] = "LFilterOp";
|
||||
constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp";
|
||||
constexpr char kMagphaseOp[] = "MagphaseOp";
|
||||
constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp";
|
||||
constexpr char kTimeMaskingOp[] = "TimeMaskingOp";
|
||||
constexpr char kTimeStretchOp[] = "TimeStretchOp";
|
||||
|
|
|
@ -27,7 +27,7 @@ from .utils import FadeShape, GainType, ScaleType
|
|||
from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \
|
||||
check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \
|
||||
check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \
|
||||
check_lowpass_biquad, check_masking, check_mu_law_decoding, check_time_stretch, check_vol
|
||||
check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_time_stretch, check_vol
|
||||
|
||||
|
||||
class AudioTensorOperation(TensorOperation):
|
||||
|
@ -579,6 +579,30 @@ class LowpassBiquad(AudioTensorOperation):
|
|||
return cde.LowpassBiquadOperation(self.sample_rate, self.cutoff_freq, self.Q)
|
||||
|
||||
|
||||
class Magphase(AudioTensorOperation):
|
||||
"""
|
||||
Separate a complex-valued spectrogram with shape (..., 2) into its magnitude and phase.
|
||||
|
||||
Args:
|
||||
power (float): Power of the norm, which must be non-negative (default=1.0).
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>>
|
||||
>>> waveform = np.random.random([2, 4, 2])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.Magphase()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_magphase
|
||||
def __init__(self, power=1.0):
|
||||
self.power = power
|
||||
|
||||
def parse(self):
|
||||
return cde.MagphaseOperation(self.power)
|
||||
|
||||
|
||||
class MuLawDecoding(AudioTensorOperation):
|
||||
"""
|
||||
Decode mu-law encoded signal.
|
||||
|
|
|
@ -333,13 +333,30 @@ def check_masking(method):
|
|||
return new_method
|
||||
|
||||
|
||||
def check_power(power):
|
||||
"""Wrapper method to check the parameters of power."""
|
||||
type_check(power, (int, float), "power")
|
||||
check_non_negative_float32(power, "power")
|
||||
|
||||
|
||||
def check_complex_norm(method):
|
||||
"""Wrapper method to check the parameters of ComplexNorm."""
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[power], _ = parse_user_args(method, *args, **kwargs)
|
||||
type_check(power, (int, float), "power")
|
||||
check_non_negative_float32(power, "power")
|
||||
check_power(power)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
||||
|
||||
def check_magphase(method):
|
||||
"""Wrapper method to check the parameters of Magphase."""
|
||||
|
||||
@wraps(method)
|
||||
def new_method(self, *args, **kwargs):
|
||||
[power], _ = parse_user_args(method, *args, **kwargs)
|
||||
check_power(power)
|
||||
return method(self, *args, **kwargs)
|
||||
|
||||
return new_method
|
||||
|
|
|
@ -1343,3 +1343,45 @@ TEST_F(MindDataTestPipeline, TestFadeWithInvalidArg) {
|
|||
std::shared_ptr<Iterator> iter_02 = ds_02->CreateIterator();
|
||||
EXPECT_EQ(iter_02, nullptr);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMagphase) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMagphase.";
|
||||
|
||||
float power = 2.0;
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {1, 2}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::shared_ptr<TensorTransform> magphase(new audio::Magphase(power));
|
||||
auto ds1 = ds->Map({magphase}, {"col1"}, {"mag", "phase"});
|
||||
EXPECT_NE(ds1, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
i++;
|
||||
ASSERT_OK(iter->GetNextRow(&row));
|
||||
}
|
||||
EXPECT_EQ(i, 8);
|
||||
iter->Stop();
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestMagphaseWrongArgs) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMagphaseWrongArgs.";
|
||||
|
||||
float power_wrong = -1.0;
|
||||
std::shared_ptr<TensorTransform> magphase(new audio::Magphase(power_wrong));
|
||||
std::unordered_map<std::string, mindspore::MSTensor> row;
|
||||
|
||||
//Magphase: power must be greater than or equal to 0.
|
||||
std::shared_ptr<SchemaObj> schema = Schema();
|
||||
ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {2, 2}));
|
||||
std::shared_ptr<Dataset> ds = RandomData(8, schema);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
ds = ds->Map({magphase}, {"col1"}, {"mag", "phase"});
|
||||
EXPECT_NE(ds, nullptr);
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_EQ(iter, nullptr);
|
||||
}
|
||||
|
|
|
@ -1082,3 +1082,18 @@ TEST_F(MindDataTestExecute, TestVolGainTypePower) {
|
|||
Status status = transform(input_tensor, &input_tensor);
|
||||
EXPECT_TRUE(status.IsOk());
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestExecute, TestMagphaseEager) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestExecute-TestMagphaseEager.";
|
||||
float power = 1.0;
|
||||
std::vector<mindspore::MSTensor> output_tensor;
|
||||
std::shared_ptr<Tensor> test;
|
||||
std::vector<float> test_vector = {3, 4, -3, 4, 3, -4, -3, -4,
|
||||
5, 12, -5, 12, 5, -12, -5, -12};
|
||||
Tensor::CreateFromVector(test_vector, TensorShape({2, 4, 2}), &test);
|
||||
auto input_tensor = mindspore::MSTensor(std::make_shared<mindspore::dataset::DETensor>(test));
|
||||
std::shared_ptr<TensorTransform> magphase(new audio::Magphase({power}));
|
||||
auto transform = Execute({magphase});
|
||||
Status rc = transform({input_tensor}, &output_tensor);
|
||||
ASSERT_TRUE(rc.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Testing Magphase Python API
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.audio.transforms as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def test_magphase_pipeline():
|
||||
"""
|
||||
Test magphase (pipeline).
|
||||
"""
|
||||
logger.info("Test Magphase pipeline.")
|
||||
|
||||
data1 = [[[3.0, -4.0], [-5.0, 12.0]]]
|
||||
expected = [5, 13, -0.927295, 1.965587]
|
||||
dataset = ds.NumpySlicesDataset(data1, column_names=["col1"], shuffle=False)
|
||||
magphase_window = audio.Magphase(power=1.0)
|
||||
dataset = dataset.map(operations=magphase_window, input_columns=["col1"],
|
||||
output_columns=["mag", "phase"], column_order=["mag", "phase"])
|
||||
for data1, data2 in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True):
|
||||
assert abs(data1[0] - expected[0]) < 0.00001
|
||||
assert abs(data1[1] - expected[1]) < 0.00001
|
||||
assert abs(data2[0] - expected[2]) < 0.00001
|
||||
assert abs(data2[1] - expected[3]) < 0.00001
|
||||
|
||||
logger.info("Finish testing Magphase.")
|
||||
|
||||
|
||||
def test_magphase_eager():
|
||||
"""
|
||||
Test magphase (eager).
|
||||
"""
|
||||
logger.info("Test Magphase eager.")
|
||||
|
||||
input_number = np.array([41, 67, 34, 0, 69, 24, 78, 58]).reshape((2, 2, 2)).astype("double")
|
||||
mag = np.array([78.54934755, 34., 73.05477397, 97.20082304]).reshape((2, 2)).astype("double")
|
||||
phase = np.array([1.02164342, 0, 0.33473684, 0.63938591]).reshape((2, 2)).astype("double")
|
||||
magphase_window = audio.Magphase()
|
||||
data1, data2 = magphase_window(input_number)
|
||||
assert (abs(data1 - mag) < 0.00001).all()
|
||||
assert (abs(data2 - phase) < 0.00001).all()
|
||||
|
||||
logger.info("Finish testing Magphase.")
|
||||
|
||||
|
||||
def test_magphase_exception():
|
||||
"""
|
||||
Test magphase not callable.
|
||||
"""
|
||||
logger.info("Test Magphase not callable.")
|
||||
|
||||
try:
|
||||
input_number = np.array([1, 2, 3, 4]).reshape(4,).astype("double")
|
||||
magphase_window = audio.Magphase(power=2.0)
|
||||
_ = magphase_window(input_number)
|
||||
except RuntimeError as error:
|
||||
logger.info("Got an exception in Magphase: {}".format(str(error)))
|
||||
assert "Magphase: input tensor is not in shape of <..., 2>." in str(error)
|
||||
try:
|
||||
input_number = np.array([1, 2, 3, 4]).reshape(1, 4).astype("double")
|
||||
magphase_window = audio.Magphase(power=2.0)
|
||||
_ = magphase_window(input_number)
|
||||
except RuntimeError as error:
|
||||
logger.info("Got an exception in Magphase: {}".format(str(error)))
|
||||
assert "Magphase: input tensor is not in shape of <..., 2>." in str(error)
|
||||
try:
|
||||
input_number = np.array(['test', 'test']).reshape(1, 2)
|
||||
magphase_window = audio.Magphase(power=2.0)
|
||||
_ = magphase_window(input_number)
|
||||
except RuntimeError as error:
|
||||
logger.info("Got an exception in Magphase: {}".format(str(error)))
|
||||
assert "Magphase: input tensor type should be int, float or double" in str(error)
|
||||
try:
|
||||
input_number = np.array([1, 2, 3, 4]).reshape(2, 2).astype("double")
|
||||
magphase_window = audio.Magphase(power=-1.0)
|
||||
_ = magphase_window(input_number)
|
||||
except ValueError as error:
|
||||
logger.info("Got an exception in Magphase: {}".format(str(error)))
|
||||
assert "Input power is not within the required interval of [0, 16777216]." in str(error)
|
||||
|
||||
logger.info("Finish testing Magphase.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_magphase_pipeline()
|
||||
test_magphase_eager()
|
||||
test_magphase_exception()
|
Loading…
Reference in New Issue