diff --git a/mindspore/ccsrc/minddata/dataset/api/audio.cc b/mindspore/ccsrc/minddata/dataset/api/audio.cc index ec1af151891..61752919ff9 100755 --- a/mindspore/ccsrc/minddata/dataset/api/audio.cc +++ b/mindspore/ccsrc/minddata/dataset/api/audio.cc @@ -34,6 +34,7 @@ #include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" #include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" #include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" #include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" @@ -316,6 +317,16 @@ std::shared_ptr LowpassBiquad::Parse() { return std::make_shared(data_->sample_rate_, data_->cutoff_freq_, data_->Q_); } +// Magphase Transform Operation. +struct Magphase::Data { + explicit Data(float power) : power_(power) {} + float power_; +}; + +Magphase::Magphase(float power) : data_(std::make_shared(power)) {} + +std::shared_ptr Magphase::Parse() { return std::make_shared(data_->power_); } + // MuLawDecoding Transform Operation. struct MuLawDecoding::Data { explicit Data(int quantization_channels) : quantization_channels_(quantization_channels) {} diff --git a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc index 0aa8149eb14..2fd87e60869 100644 --- a/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc +++ b/mindspore/ccsrc/minddata/dataset/api/python/bindings/dataset/audio/kernels/ir/bindings.cc @@ -38,6 +38,7 @@ #include "minddata/dataset/audio/ir/kernels/highpass_biquad_ir.h" #include "minddata/dataset/audio/ir/kernels/lfilter_ir.h" #include "minddata/dataset/audio/ir/kernels/lowpass_biquad_ir.h" +#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" #include "minddata/dataset/audio/ir/kernels/mu_law_decoding_ir.h" #include "minddata/dataset/audio/ir/kernels/time_masking_ir.h" #include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h" @@ -259,6 +260,17 @@ PYBIND_REGISTER( })); })); +PYBIND_REGISTER(MagphaseOperation, 1, ([](const py::module *m) { + (void) + py::class_>( + *m, "MagphaseOperation") + .def(py::init([](float power) { + auto magphase = std::make_shared(power); + THROW_IF_ERROR(magphase->ValidateParams()); + return magphase; + })); + })); + PYBIND_REGISTER( MuLawDecodingOperation, 1, ([](const py::module *m) { (void)py::class_>( diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt index cd54df25de4..a2abb33523c 100755 --- a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(audio-ir-kernels OBJECT highpass_biquad_ir.cc lfilter_ir.cc lowpass_biquad_ir.cc + magphase_ir.cc mu_law_decoding_ir.cc time_masking_ir.cc time_stretch_ir.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.cc b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.cc new file mode 100644 index 00000000000..baca2c6fc9b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "minddata/dataset/audio/ir/kernels/magphase_ir.h" + +#include "minddata/dataset/audio/ir/validators.h" +#include "minddata/dataset/audio/kernels/magphase_op.h" + +namespace mindspore { +namespace dataset { + +namespace audio { + +MagphaseOperation::MagphaseOperation(float power) : power_(power) {} + +Status MagphaseOperation::ValidateParams() { + RETURN_IF_NOT_OK(ValidateFloatScalarNonNegative("Magphase", "power", power_)); + return Status::OK(); +} + +std::shared_ptr MagphaseOperation::Build() { return std::make_shared(power_); } + +Status MagphaseOperation::to_json(nlohmann::json *out_json) { + nlohmann::json args; + args["power"] = power_; + *out_json = args; + return Status::OK(); +} + +Status MagphaseOperation::from_json(nlohmann::json op_params, std::shared_ptr *operation) { + CHECK_FAIL_RETURN_UNEXPECTED(op_params.find("power") != op_params.end(), "Fail to find power"); + float power = op_params["power"]; + *operation = std::make_shared(power); + return Status::OK(); +} + +} // namespace audio +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.h b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.h new file mode 100644 index 00000000000..9a235ce6ae3 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/ir/kernels/magphase_ir.h @@ -0,0 +1,58 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ + +#include +#include + +#include "include/api/status.h" +#include "minddata/dataset/include/dataset/constants.h" +#include "minddata/dataset/include/dataset/transforms.h" +#include "minddata/dataset/kernels/ir/tensor_operation.h" + +namespace mindspore { +namespace dataset { + +namespace audio { + +constexpr char kMagphaseOperation[] = "Magphase"; + +class MagphaseOperation : public TensorOperation { + public: + explicit MagphaseOperation(float power); + + ~MagphaseOperation() = default; + + std::shared_ptr Build() override; + + std::string Name() const override { return kMagphaseOperation; } + + Status ValidateParams() override; + + Status to_json(nlohmann::json *out_json) override; + + static Status from_json(nlohmann::json op_params, std::shared_ptr *operation); + + private: + float power_; +}; + +} // namespace audio +} // namespace dataset +} // namespace mindspore +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_IR_KERNELS_MAGPHASE_IR_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt index 3cd09cec7e5..6ad8faa0616 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/CMakeLists.txt @@ -21,6 +21,7 @@ add_library(audio-kernels OBJECT highpass_biquad_op.cc lfilter_op.cc lowpass_biquad_op.cc + magphase_op.cc mu_law_decoding_op.cc time_masking_op.cc time_stretch_op.cc diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc index 45397fae18d..c8438da2216 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.cc @@ -630,5 +630,23 @@ Status Fade(const std::shared_ptr &input, std::shared_ptr *outpu } return Status::OK(); } + +Status Magphase(const TensorRow &input, TensorRow *output, float power) { + std::shared_ptr mag; + std::shared_ptr phase; + + RETURN_IF_NOT_OK(ComplexNorm(input[0], &mag, power)); + if (input[0]->type() == DataType(DataType::DE_FLOAT64)) { + RETURN_IF_NOT_OK(Angle(input[0], &phase)); + } else { + std::shared_ptr tmp; + RETURN_IF_NOT_OK(TypeCast(input[0], &tmp, DataType(DataType::DE_FLOAT32))); + RETURN_IF_NOT_OK(Angle(tmp, &phase)); + } + (*output).push_back(mag); + (*output).push_back(phase); + + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h index db0c8f8f2a0..c3edcc40f9d 100644 --- a/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/audio_utils.h @@ -365,6 +365,13 @@ Status Vol(const std::shared_ptr &input, std::shared_ptr *output return Status::OK(); } + +/// \brief Separate a complex-valued spectrogram with shape (…, 2) into its magnitude and phase. +/// \param input: Complex tensor. +/// \param output: The magnitude and phase of the complex tensor. +/// \param power: Power of the norm. +Status Magphase(const TensorRow &input, TensorRow *output, float power); + } // namespace dataset } // namespace mindspore #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_AUDIO_UTILS_H_ diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.cc b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.cc new file mode 100644 index 00000000000..68cee5f11db --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "minddata/dataset/audio/kernels/magphase_op.h" + +#include "minddata/dataset/audio/kernels/audio_utils.h" +#include "minddata/dataset/util/status.h" + +namespace mindspore { +namespace dataset { + +constexpr float MagphaseOp::kPower = 1.0; + +Status MagphaseOp::Compute(const TensorRow &input, TensorRow *output) { + IO_CHECK_VECTOR(input, output); + CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2 && input[0]->shape()[-1] == 2, + "Magphase: input tensor is not in shape of <..., 2>."); + CHECK_FAIL_RETURN_UNEXPECTED( + input[0]->type().IsNumeric(), + "Magphase: input tensor type should be int, float or double, but got: " + input[0]->type().ToString()); + RETURN_IF_NOT_OK(Magphase(input, output, power_)); + return Status::OK(); +} + +Status MagphaseOp::OutputShape(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + auto vec = inputs[0].AsVector(); + vec.pop_back(); + auto out = TensorShape(vec); + outputs = {out, out}; + if (!outputs.empty()) { + return Status::OK(); + } + return Status(StatusCode::kMDUnexpectedError, "Magphase: invalid input wrong shape."); +} + +Status MagphaseOp::OutputType(const std::vector &inputs, std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); + if (inputs[0] == DataType(DataType::DE_FLOAT64)) { + outputs[0] = DataType(DataType::DE_FLOAT64); + } else { + outputs[0] = DataType(DataType::DE_FLOAT32); + } + return Status::OK(); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.h b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.h new file mode 100644 index 00000000000..bf4e1a86c1c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/audio/kernels/magphase_op.h @@ -0,0 +1,52 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ +#define MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { + +class MagphaseOp : public TensorOp { + public: + static const float kPower; + + explicit MagphaseOp(float power = kPower) : power_(power) {} + + ~MagphaseOp() override = default; + + Status Compute(const TensorRow &input, TensorRow *output) override; + + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + Status OutputType(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kMagphaseOp; } + + private: + float power_; +}; + +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_AUDIO_KERNELS_MAGPHASE_OP_H_ diff --git a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h index 1aadb74923b..3d4b50177a5 100755 --- a/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h +++ b/mindspore/ccsrc/minddata/dataset/include/dataset/audio.h @@ -443,6 +443,26 @@ class LowpassBiquad final : public TensorTransform { std::shared_ptr data_; }; +/// \brief Separate a complex-valued spectrogram with shape (..., 2) into its magnitude and phase. +class Magphase final : public TensorTransform { + public: + /// \brief Constructor. + /// \param[in] power Power of the norm, which must be non-negative (Default: 1.0). + explicit Magphase(float power); + + /// \brief Destructor. + ~Magphase() = default; + + protected: + /// \brief Function to convert TensorTransform object into a TensorOperation object. + /// \return Shared pointer to TensorOperation object. + std::shared_ptr Parse() override; + + private: + struct Data; + std::shared_ptr data_; +}; + /// \brief MuLawDecoding TensorTransform. /// \note Decode mu-law encoded signal. class MuLawDecoding final : public TensorTransform { diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index e32a8cece33..a7711604ee5 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -157,6 +157,7 @@ constexpr char kFrequencyMaskingOp[] = "FrequencyMaskingOp"; constexpr char kHighpassBiquadOp[] = "HighpassBiquadOp"; constexpr char kLFilterOp[] = "LFilterOp"; constexpr char kLowpassBiquadOp[] = "LowpassBiquadOp"; +constexpr char kMagphaseOp[] = "MagphaseOp"; constexpr char kMuLawDecodingOp[] = "MuLawDecodingOp"; constexpr char kTimeMaskingOp[] = "TimeMaskingOp"; constexpr char kTimeStretchOp[] = "TimeStretchOp"; diff --git a/mindspore/dataset/audio/transforms.py b/mindspore/dataset/audio/transforms.py index a78dca34d64..a2a40fd6baa 100755 --- a/mindspore/dataset/audio/transforms.py +++ b/mindspore/dataset/audio/transforms.py @@ -27,7 +27,7 @@ from .utils import FadeShape, GainType, ScaleType from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_biquad, check_bandpass_biquad, \ check_bandreject_biquad, check_bass_biquad, check_biquad, check_complex_norm, check_contrast, check_dc_shift, \ check_deemph_biquad, check_equalizer_biquad, check_fade, check_highpass_biquad, check_lfilter, \ - check_lowpass_biquad, check_masking, check_mu_law_decoding, check_time_stretch, check_vol + check_lowpass_biquad, check_magphase, check_masking, check_mu_law_decoding, check_time_stretch, check_vol class AudioTensorOperation(TensorOperation): @@ -579,6 +579,30 @@ class LowpassBiquad(AudioTensorOperation): return cde.LowpassBiquadOperation(self.sample_rate, self.cutoff_freq, self.Q) +class Magphase(AudioTensorOperation): + """ + Separate a complex-valued spectrogram with shape (..., 2) into its magnitude and phase. + + Args: + power (float): Power of the norm, which must be non-negative (default=1.0). + + Examples: + >>> import numpy as np + >>> + >>> waveform = np.random.random([2, 4, 2]) + >>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"]) + >>> transforms = [audio.Magphase()] + >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) + """ + + @check_magphase + def __init__(self, power=1.0): + self.power = power + + def parse(self): + return cde.MagphaseOperation(self.power) + + class MuLawDecoding(AudioTensorOperation): """ Decode mu-law encoded signal. diff --git a/mindspore/dataset/audio/validators.py b/mindspore/dataset/audio/validators.py index 7cf70783252..81b138ed30e 100755 --- a/mindspore/dataset/audio/validators.py +++ b/mindspore/dataset/audio/validators.py @@ -333,13 +333,30 @@ def check_masking(method): return new_method +def check_power(power): + """Wrapper method to check the parameters of power.""" + type_check(power, (int, float), "power") + check_non_negative_float32(power, "power") + + def check_complex_norm(method): """Wrapper method to check the parameters of ComplexNorm.""" @wraps(method) def new_method(self, *args, **kwargs): [power], _ = parse_user_args(method, *args, **kwargs) - type_check(power, (int, float), "power") - check_non_negative_float32(power, "power") + check_power(power) + return method(self, *args, **kwargs) + + return new_method + + +def check_magphase(method): + """Wrapper method to check the parameters of Magphase.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + [power], _ = parse_user_args(method, *args, **kwargs) + check_power(power) return method(self, *args, **kwargs) return new_method diff --git a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc index e2799e23bd1..2cd31f57172 100644 --- a/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc +++ b/tests/ut/cpp/dataset/c_api_audio_a_to_q_test.cc @@ -1343,3 +1343,45 @@ TEST_F(MindDataTestPipeline, TestFadeWithInvalidArg) { std::shared_ptr iter_02 = ds_02->CreateIterator(); EXPECT_EQ(iter_02, nullptr); } + +TEST_F(MindDataTestPipeline, TestMagphase) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMagphase."; + + float power = 2.0; + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {1, 2})); + std::shared_ptr ds = RandomData(8, schema); + EXPECT_NE(ds, nullptr); + std::shared_ptr magphase(new audio::Magphase(power)); + auto ds1 = ds->Map({magphase}, {"col1"}, {"mag", "phase"}); + EXPECT_NE(ds1, nullptr); + std::shared_ptr iter = ds1->CreateIterator(); + EXPECT_NE(iter, nullptr); + std::unordered_map row; + ASSERT_OK(iter->GetNextRow(&row)); + uint64_t i = 0; + while (row.size() != 0) { + i++; + ASSERT_OK(iter->GetNextRow(&row)); + } + EXPECT_EQ(i, 8); + iter->Stop(); +} + +TEST_F(MindDataTestPipeline, TestMagphaseWrongArgs) { + MS_LOG(INFO) << "Doing MindDataTestPipeline-TestMagphaseWrongArgs."; + + float power_wrong = -1.0; + std::shared_ptr magphase(new audio::Magphase(power_wrong)); + std::unordered_map row; + + //Magphase: power must be greater than or equal to 0. + std::shared_ptr schema = Schema(); + ASSERT_OK(schema->add_column("col1", mindspore::DataType::kNumberTypeFloat32, {2, 2})); + std::shared_ptr ds = RandomData(8, schema); + EXPECT_NE(ds, nullptr); + ds = ds->Map({magphase}, {"col1"}, {"mag", "phase"}); + EXPECT_NE(ds, nullptr); + std::shared_ptr iter = ds->CreateIterator(); + EXPECT_EQ(iter, nullptr); +} diff --git a/tests/ut/cpp/dataset/execute_test.cc b/tests/ut/cpp/dataset/execute_test.cc index 61a826443d9..b4bf81e0d1b 100644 --- a/tests/ut/cpp/dataset/execute_test.cc +++ b/tests/ut/cpp/dataset/execute_test.cc @@ -1082,3 +1082,18 @@ TEST_F(MindDataTestExecute, TestVolGainTypePower) { Status status = transform(input_tensor, &input_tensor); EXPECT_TRUE(status.IsOk()); } + +TEST_F(MindDataTestExecute, TestMagphaseEager) { + MS_LOG(INFO) << "Doing MindDataTestExecute-TestMagphaseEager."; + float power = 1.0; + std::vector output_tensor; + std::shared_ptr test; + std::vector test_vector = {3, 4, -3, 4, 3, -4, -3, -4, + 5, 12, -5, 12, 5, -12, -5, -12}; + Tensor::CreateFromVector(test_vector, TensorShape({2, 4, 2}), &test); + auto input_tensor = mindspore::MSTensor(std::make_shared(test)); + std::shared_ptr magphase(new audio::Magphase({power})); + auto transform = Execute({magphase}); + Status rc = transform({input_tensor}, &output_tensor); + ASSERT_TRUE(rc.IsOk()); +} diff --git a/tests/ut/python/dataset/test_magphase.py b/tests/ut/python/dataset/test_magphase.py new file mode 100644 index 00000000000..bd6daf135aa --- /dev/null +++ b/tests/ut/python/dataset/test_magphase.py @@ -0,0 +1,104 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Testing Magphase Python API +""" +import numpy as np + +import mindspore.dataset as ds +import mindspore.dataset.audio.transforms as audio +from mindspore import log as logger + + +def test_magphase_pipeline(): + """ + Test magphase (pipeline). + """ + logger.info("Test Magphase pipeline.") + + data1 = [[[3.0, -4.0], [-5.0, 12.0]]] + expected = [5, 13, -0.927295, 1.965587] + dataset = ds.NumpySlicesDataset(data1, column_names=["col1"], shuffle=False) + magphase_window = audio.Magphase(power=1.0) + dataset = dataset.map(operations=magphase_window, input_columns=["col1"], + output_columns=["mag", "phase"], column_order=["mag", "phase"]) + for data1, data2 in dataset.create_tuple_iterator(num_epochs=1, output_numpy=True): + assert abs(data1[0] - expected[0]) < 0.00001 + assert abs(data1[1] - expected[1]) < 0.00001 + assert abs(data2[0] - expected[2]) < 0.00001 + assert abs(data2[1] - expected[3]) < 0.00001 + + logger.info("Finish testing Magphase.") + + +def test_magphase_eager(): + """ + Test magphase (eager). + """ + logger.info("Test Magphase eager.") + + input_number = np.array([41, 67, 34, 0, 69, 24, 78, 58]).reshape((2, 2, 2)).astype("double") + mag = np.array([78.54934755, 34., 73.05477397, 97.20082304]).reshape((2, 2)).astype("double") + phase = np.array([1.02164342, 0, 0.33473684, 0.63938591]).reshape((2, 2)).astype("double") + magphase_window = audio.Magphase() + data1, data2 = magphase_window(input_number) + assert (abs(data1 - mag) < 0.00001).all() + assert (abs(data2 - phase) < 0.00001).all() + + logger.info("Finish testing Magphase.") + + +def test_magphase_exception(): + """ + Test magphase not callable. + """ + logger.info("Test Magphase not callable.") + + try: + input_number = np.array([1, 2, 3, 4]).reshape(4,).astype("double") + magphase_window = audio.Magphase(power=2.0) + _ = magphase_window(input_number) + except RuntimeError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "Magphase: input tensor is not in shape of <..., 2>." in str(error) + try: + input_number = np.array([1, 2, 3, 4]).reshape(1, 4).astype("double") + magphase_window = audio.Magphase(power=2.0) + _ = magphase_window(input_number) + except RuntimeError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "Magphase: input tensor is not in shape of <..., 2>." in str(error) + try: + input_number = np.array(['test', 'test']).reshape(1, 2) + magphase_window = audio.Magphase(power=2.0) + _ = magphase_window(input_number) + except RuntimeError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "Magphase: input tensor type should be int, float or double" in str(error) + try: + input_number = np.array([1, 2, 3, 4]).reshape(2, 2).astype("double") + magphase_window = audio.Magphase(power=-1.0) + _ = magphase_window(input_number) + except ValueError as error: + logger.info("Got an exception in Magphase: {}".format(str(error))) + assert "Input power is not within the required interval of [0, 16777216]." in str(error) + + logger.info("Finish testing Magphase.") + + +if __name__ == "__main__": + test_magphase_pipeline() + test_magphase_eager() + test_magphase_exception()