forked from mindspore-Ecosystem/mindspore
[feat][assistant][I3J6U2] Add new audio operator CreateDct
This commit is contained in:
parent
e37c4fc246
commit
4f4777f0da
|
@ -2,6 +2,7 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc"
|
|||
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
|
||||
if(ENABLE_PYTHON)
|
||||
add_library(APItoPython OBJECT
|
||||
python/bindings/dataset/audio/bindings.cc
|
||||
python/bindings/dataset/audio/kernels/ir/bindings.cc
|
||||
python/bindings/dataset/callback/bindings.cc
|
||||
python/bindings/dataset/core/bindings.cc
|
||||
|
|
|
@ -44,6 +44,7 @@
|
|||
#include "minddata/dataset/audio/ir/kernels/time_stretch_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/treble_biquad_ir.h"
|
||||
#include "minddata/dataset/audio/ir/kernels/vol_ir.h"
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -209,6 +210,18 @@ std::shared_ptr<TensorOperation> DCShift::Parse() {
|
|||
return std::make_shared<DCShiftOperation>(data_->shift_, data_->limiter_gain_);
|
||||
}
|
||||
|
||||
Status CreateDct(mindspore::MSTensor *output, int32_t n_mfcc, int32_t n_mels, NormMode norm) {
|
||||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(n_mfcc > 0, "CreateDct: n_mfcc must be greater than 0, got: " + std::to_string(n_mfcc));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(n_mels > 0, "CreateDct: n_mels must be greater than 0, got: " + std::to_string(n_mels));
|
||||
|
||||
std::shared_ptr<dataset::Tensor> dct;
|
||||
RETURN_IF_NOT_OK(Dct(&dct, n_mfcc, n_mels, norm));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(dct->HasData(), "CreateDct: get an empty tensor with shape " + dct->shape().ToString());
|
||||
*output = mindspore::MSTensor(std::make_shared<DETensor>(dct));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// DeemphBiquad Transform Operation.
|
||||
struct DeemphBiquad::Data {
|
||||
explicit Data(int32_t sample_rate) : sample_rate_(sample_rate) {}
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "pybind11/pybind11.h"
|
||||
|
||||
#include "minddata/dataset/api/python/pybind_conversion.h"
|
||||
#include "minddata/dataset/api/python/pybind_register.h"
|
||||
#include "minddata/dataset/audio/kernels/audio_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
PYBIND_REGISTER(CreateDct, 1, ([](py::module *m) {
|
||||
(void)m->def("CreateDct", ([](int32_t n_mfcc, int32_t n_mels, NormMode norm) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(Dct(&out, n_mfcc, n_mels, norm));
|
||||
return out;
|
||||
}));
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NormMode, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<NormMode>(*m, "NormMode", py::arithmetic())
|
||||
.value("DE_NORMMODE_NONE", NormMode::kNone)
|
||||
.value("DE_NORMMODE_ORTHO", NormMode::kOrtho)
|
||||
.export_values();
|
||||
}));
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
|
@ -348,6 +348,34 @@ Status TimeStretch(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Dct(std::shared_ptr<Tensor> *output, int n_mfcc, int n_mels, NormMode norm) {
|
||||
TensorShape dct_shape({n_mels, n_mfcc});
|
||||
Tensor::CreateEmpty(dct_shape, DataType(DataType::DE_FLOAT32), output);
|
||||
auto iter = (*output)->begin<float>();
|
||||
float sqrt_2 = 1 / sqrt(2);
|
||||
float sqrt_2_n_mels = sqrt(2.0 / n_mels);
|
||||
for (int i = 0; i < n_mels; i++) {
|
||||
for (int j = 0; j < n_mfcc; j++) {
|
||||
// calculate temp:
|
||||
// 1. while norm = None, use 2*cos(PI*(i+0.5)*j/n_mels)
|
||||
// 2. while norm = Ortho, divide the first row by sqrt(2),
|
||||
// then using sqrt(2.0 / n_mels)*cos(PI*(i+0.5)*j/n_mels)
|
||||
float temp = PI / n_mels * (i + 0.5) * j;
|
||||
temp = cos(temp);
|
||||
if (norm == NormMode::kOrtho) {
|
||||
if (j == 0) {
|
||||
temp *= sqrt_2;
|
||||
}
|
||||
temp *= sqrt_2_n_mels;
|
||||
} else {
|
||||
temp *= 2;
|
||||
}
|
||||
(*iter++) = temp;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RandomMaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t mask_param,
|
||||
float mask_value, int axis, std::mt19937 rnd) {
|
||||
std::uniform_int_distribution<int32_t> mask_width_value(0, mask_param);
|
||||
|
|
|
@ -304,6 +304,13 @@ Status RandomMaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr
|
|||
Status MaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int32_t mask_width,
|
||||
int32_t mask_start, float mask_value, int32_t axis);
|
||||
|
||||
/// \brief Create a DCT transformation matrix with shape (n_mels, n_mfcc), normalized depending on norm.
|
||||
/// \param n_mfcc: Number of mfc coefficients to retain, the value must be greater than 0.
|
||||
/// \param n_mels: Number of mel filterbanks, the value must be greater than 0.
|
||||
/// \param norm: Norm to use, can be NormMode::kNone or NormMode::kOrtho.
|
||||
/// \return Status code.
|
||||
Status Dct(std::shared_ptr<Tensor> *output, int32_t n_mfcc, int32_t n_mels, NormMode norm);
|
||||
|
||||
/// \brief Compute the norm of complex tensor input.
|
||||
/// \param power Power of the norm description (optional).
|
||||
/// \param input Tensor shape of <..., complex=2>.
|
||||
|
|
|
@ -25,8 +25,9 @@
|
|||
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
#include "include/api/status.h"
|
||||
#include "minddata/dataset/include/dataset/constants.h"
|
||||
#include "minddata/dataset/include/dataset/transforms.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/dataset/constants.h"
|
||||
#include "include/dataset/transforms.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
@ -281,6 +282,12 @@ class DCShift : public TensorTransform {
|
|||
std::shared_ptr<Data> data_;
|
||||
};
|
||||
|
||||
/// \param[in] n_mfcc Number of mfc coefficients to retain, the value must be greater than 0.
|
||||
/// \param[in] n_mels Number of mel filterbanks, the value must be greater than 0.
|
||||
/// \param[in] norm Norm to use, can be NormMode::kNone or NormMode::kOrtho.
|
||||
/// \return Status error code, returns OK if no error encountered.
|
||||
Status CreateDct(mindspore::MSTensor *output, int32_t n_mfcc, int32_t n_mels, NormMode norm = NormMode::kNone);
|
||||
|
||||
/// \brief Design two-pole deemph filter. Similar to SoX implementation.
|
||||
class DeemphBiquad final : public TensorTransform {
|
||||
public:
|
||||
|
|
|
@ -62,6 +62,12 @@ enum class ConvertMode {
|
|||
COLOR_RGBA2GRAY = 11 ///< Convert RGBA image to GRAY image.
|
||||
};
|
||||
|
||||
/// \brief Values of norm in CreateDct.
|
||||
enum class NormMode {
|
||||
kNone = 0, ///< None type norm.
|
||||
kOrtho = 1 ///< Ortho type norm.
|
||||
};
|
||||
|
||||
/// \brief Target devices to perform map operation.
|
||||
enum class MapTargetDevice {
|
||||
kCpu, ///< CPU Device.
|
||||
|
|
|
@ -17,6 +17,7 @@ enum for audio ops
|
|||
"""
|
||||
|
||||
from enum import Enum
|
||||
import mindspore._c_dataengine as cde
|
||||
|
||||
|
||||
class FadeShape(str, Enum):
|
||||
|
@ -91,3 +92,52 @@ class ScaleType(str, Enum):
|
|||
"""
|
||||
POWER: str = "power"
|
||||
MAGNITUDE: str = "magnitude"
|
||||
|
||||
|
||||
class NormMode(str, Enum):
|
||||
"""
|
||||
Norm Types.
|
||||
|
||||
Possible enumeration values are: NormMode.NONE, NormMode.ORTHO.
|
||||
|
||||
- NormMode.NONE: means the mode of input audio is none.
|
||||
- NormMode.ORTHO: means the mode of input audio is ortho.
|
||||
"""
|
||||
NONE: str = "none"
|
||||
ORTHO: str = "ortho"
|
||||
|
||||
|
||||
DE_C_NORMMODE_TYPE = {NormMode.NONE: cde.NormMode.DE_NORMMODE_NONE,
|
||||
NormMode.ORTHO: cde.NormMode.DE_NORMMODE_ORTHO}
|
||||
|
||||
|
||||
def CreateDct(n_mfcc, n_mels, norm=NormMode.NONE):
|
||||
"""
|
||||
Create a DCT transformation matrix with shape (n_mels, n_mfcc), normalized depending on norm.
|
||||
|
||||
Args:
|
||||
n_mfcc (int): Number of mfc coefficients to retain, the value must be greater than 0.
|
||||
n_mels (int): Number of mel filterbanks, the value must be greater than 0.
|
||||
norm (NormMode): Normalization mode, can be NormMode.NONE or NormMode.ORTHO (default=NormMode.NONE).
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, the transformation matrix, to be right-multiplied to row-wise data of size (n_mels, n_mfcc).
|
||||
|
||||
Examples:
|
||||
>>> dct = audio.CreateDct(100, 200, audio.NormMode.NONE)
|
||||
"""
|
||||
|
||||
if not isinstance(n_mfcc, int):
|
||||
raise TypeError("n_mfcc with value {0} is not of type {1}, but got {2}.".format(
|
||||
n_mfcc, int, type(n_mfcc)))
|
||||
if not isinstance(n_mels, int):
|
||||
raise TypeError("n_mels with value {0} is not of type {1}, but got {2}.".format(
|
||||
n_mels, int, type(n_mels)))
|
||||
if not isinstance(norm, NormMode):
|
||||
raise TypeError("norm with value {0} is not of type {1}, but got {2}.".format(
|
||||
norm, NormMode, type(norm)))
|
||||
if n_mfcc <= 0:
|
||||
raise ValueError("n_mfcc must be greater than 0, but got {0}.".format(n_mfcc))
|
||||
if n_mels <= 0:
|
||||
raise ValueError("n_mels must be greater than 0, but got {0}.".format(n_mels))
|
||||
return cde.CreateDct(n_mfcc, n_mels, DE_C_NORMMODE_TYPE[norm]).as_array()
|
||||
|
|
|
@ -1662,3 +1662,39 @@ TEST_F(MindDataTestPipeline, TestFlangerParamCheck) {
|
|||
std::shared_ptr<Iterator> iterPhase = dsPhase->CreateIterator();
|
||||
EXPECT_EQ(iterPhase, nullptr);
|
||||
}
|
||||
|
||||
/// Feature: CreateDct
|
||||
/// Description: test CreateDct in eager mode
|
||||
/// Expectation: the returned result is as expected
|
||||
TEST_F(MindDataTestPipeline, TestCreateDctNone) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCreateDctNone.";
|
||||
mindspore::MSTensor output;
|
||||
Status s01 = audio::CreateDct(&output, 200, 400, NormMode::kNone);
|
||||
EXPECT_TRUE(s01.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: CreateDct
|
||||
/// Description: test CreateDct in eager mode
|
||||
/// Expectation: the returned result is as expected
|
||||
TEST_F(MindDataTestPipeline, TestCreateDctOrtho) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCreateDctOrtho.";
|
||||
mindspore::MSTensor output;
|
||||
Status s02 = audio::CreateDct(&output, 200, 400, NormMode::kOrtho);
|
||||
EXPECT_TRUE(s02.IsOk());
|
||||
}
|
||||
|
||||
/// Feature: CreateDct
|
||||
/// Description: test WrongArg of CreateDct
|
||||
/// Expectation: return error
|
||||
TEST_F(MindDataTestPipeline, TestCreateDctWrongArg) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCreateDctWrongArg.";
|
||||
mindspore::MSTensor output;
|
||||
// Check n_mfcc
|
||||
MS_LOG(INFO) << "n_mfcc is negative.";
|
||||
Status s03 = audio::CreateDct(&output, -200, 400, NormMode::kNone);
|
||||
EXPECT_FALSE(s03.IsOk());
|
||||
// Check n_mels
|
||||
MS_LOG(INFO) << "n_mels is negative.";
|
||||
Status s04 = audio::CreateDct(&output, 200, -400, NormMode::kOrtho);
|
||||
EXPECT_FALSE(s04.IsOk());
|
||||
}
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset.audio.utils as audio
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
def count_unequal_element(data_expected, data_me, rtol, atol):
|
||||
assert data_expected.shape == data_me.shape
|
||||
total_count = len(data_expected.flatten())
|
||||
error = np.abs(data_expected - data_me)
|
||||
greater = np.greater(error, atol + np.abs(data_expected) * rtol)
|
||||
loss_count = np.count_nonzero(greater)
|
||||
assert (loss_count / total_count) < rtol, \
|
||||
"\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}". \
|
||||
format(data_expected[greater], data_me[greater], error[greater])
|
||||
|
||||
|
||||
def test_create_dct_none():
|
||||
"""
|
||||
Feature: CreateDct
|
||||
Description: test CreateDct in eager mode
|
||||
Expectation: the returned result is as expected
|
||||
"""
|
||||
expect = np.array([[2.00000000, 1.84775901],
|
||||
[2.00000000, 0.76536685],
|
||||
[2.00000000, -0.76536703],
|
||||
[2.00000000, -1.84775925]], dtype=np.float64)
|
||||
output = audio.CreateDct(2, 4, audio.NormMode.NONE)
|
||||
count_unequal_element(expect, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_create_dct_ortho():
|
||||
"""
|
||||
Feature: CreateDct
|
||||
Description: test CreateDct in eager mode
|
||||
Expectation: the returned result is as expected
|
||||
"""
|
||||
output = audio.CreateDct(1, 3, audio.NormMode.ORTHO)
|
||||
expect = np.array([[0.57735026],
|
||||
[0.57735026],
|
||||
[0.57735026]], dtype=np.float64)
|
||||
count_unequal_element(expect, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
def test_createdct_invalid_input():
|
||||
"""
|
||||
Feature: CreateDct
|
||||
Description: Error detection
|
||||
Expectation: return error
|
||||
"""
|
||||
def test_invalid_input(test_name, n_mfcc, n_mels, norm, error, error_msg):
|
||||
logger.info("Test CreateDct with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.CreateDct(n_mfcc, n_mels, norm)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid n_mfcc parameter type as a float", 100.5, 200, audio.NormMode.NONE, TypeError,
|
||||
"n_mfcc with value 100.5 is not of type <class 'int'>, but got <class 'float'>.")
|
||||
test_invalid_input("invalid n_mfcc parameter type as a String", "100", 200, audio.NormMode.NONE, TypeError,
|
||||
"n_mfcc with value 100 is not of type <class 'int'>, but got <class 'str'>.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 100, "200", audio.NormMode.NONE, TypeError,
|
||||
"n_mels with value 200 is not of type <class 'int'>, but got <class 'str'>.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 0, 200, audio.NormMode.NONE, ValueError,
|
||||
"n_mfcc must be greater than 0, but got 0.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 100, 0, audio.NormMode.NONE, ValueError,
|
||||
"n_mels must be greater than 0, but got 0.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", -100, 200, audio.NormMode.NONE, ValueError,
|
||||
"n_mfcc must be greater than 0, but got -100.")
|
||||
test_invalid_input("invalid n_mfcc parameter value", None, 100, audio.NormMode.NONE, TypeError,
|
||||
"n_mfcc with value None is not of type <class 'int'>, but got <class 'NoneType'>.")
|
||||
test_invalid_input("invalid n_mels parameter value", 100, None, audio.NormMode.NONE, TypeError,
|
||||
"n_mels with value None is not of type <class 'int'>, but got <class 'NoneType'>.")
|
||||
test_invalid_input("invalid n_mels parameter value", 100, 200, "None", TypeError,
|
||||
"norm with value None is not of type <enum 'NormMode'>, but got <class 'str'>.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_dct_none()
|
||||
test_create_dct_ortho()
|
||||
test_createdct_invalid_input()
|
Loading…
Reference in New Issue