diff --git a/mindspore/dataset/audio/__init__.py b/mindspore/dataset/audio/__init__.py index ceb55c8f64f..e0d76830106 100644 --- a/mindspore/dataset/audio/__init__.py +++ b/mindspore/dataset/audio/__init__.py @@ -23,7 +23,7 @@ Common imported modules in corresponding API examples are as follows: .. code-block:: import mindspore.dataset as ds - from mindspore.dataset import audio + import mindspore.dataset.audio.transforms as audio """ from . import transforms from . import utils diff --git a/mindspore/dataset/audio/transforms.py b/mindspore/dataset/audio/transforms.py index ddd74b26ebf..0d9476b8982 100644 --- a/mindspore/dataset/audio/transforms.py +++ b/mindspore/dataset/audio/transforms.py @@ -33,7 +33,7 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_ class AudioTensorOperation(TensorOperation): """ - Base class of Audio Tensor Ops + Base class of Audio Tensor Ops. """ def __call__(self, *input_tensor_list): @@ -266,6 +266,7 @@ class Biquad(TensorOperation): >>> biquad_op = audio.Biquad(0.01, 0.02, 0.13, 1, 0.12, 0.3) >>> waveform_filtered = biquad_op(waveform) """ + @check_biquad def __init__(self, b0, b1, b2, a0, a1, a2): self.b0 = b0 @@ -294,6 +295,7 @@ class ComplexNorm(AudioTensorOperation): >>> transforms = [audio.ComplexNorm()] >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) """ + @check_complex_norm def __init__(self, power=1.0): self.power = power @@ -360,7 +362,7 @@ class DeemphBiquad(AudioTensorOperation): Design two-pole deemph filter for audio waveform of dimension of (..., time). Args: - Sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz), + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz), the value must be 44100 or 48000. Examples: @@ -371,6 +373,7 @@ class DeemphBiquad(AudioTensorOperation): >>> transforms = [audio.DeemphBiquad(44100)] >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) """ + @check_deemph_biquad def __init__(self, sample_rate): self.sample_rate = sample_rate @@ -501,7 +504,6 @@ class Fade(AudioTensorOperation): DE_C_MODULATION_TYPE = {Modulation.SINUSOIDAL: cde.Modulation.DE_MODULATION_SINUSOIDAL, Modulation.TRIANGULAR: cde.Modulation.DE_MODULATION_TRIANGULAR} - DE_C_INTERPOLATION_TYPE = {Interpolation.LINEAR: cde.Interpolation.DE_INTERPOLATION_LINEAR, Interpolation.QUADRATIC: cde.Interpolation.DE_INTERPOLATION_QUADRATIC} @@ -571,6 +573,7 @@ class FrequencyMasking(AudioTensorOperation): >>> transforms = [audio.FrequencyMasking(frequency_mask_param=1)] >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) """ + @check_masking def __init__(self, iid_masks=False, frequency_mask_param=0, mask_start=0, mask_value=0.0): self.iid_masks = iid_masks @@ -634,6 +637,7 @@ class LFilter(AudioTensorOperation): >>> transforms = [audio.LFilter(a_coeffs, b_coeffs)] >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) """ + @check_lfilter def __init__(self, a_coeffs, b_coeffs, clamp=True): self.a_coeffs = a_coeffs @@ -662,6 +666,7 @@ class LowpassBiquad(AudioTensorOperation): >>> transforms = [audio.LowpassBiquad(4000, 1500, 0.7)] >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) """ + @check_lowpass_biquad def __init__(self, sample_rate, cutoff_freq, Q=0.707): self.sample_rate = sample_rate @@ -846,6 +851,7 @@ class TimeStretch(AudioTensorOperation): >>> transforms = [audio.TimeStretch()] >>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"]) """ + @check_time_stretch def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): self.n_freq = n_freq diff --git a/mindspore/dataset/audio/utils.py b/mindspore/dataset/audio/utils.py index 09c344bcc95..c92f5480d3d 100644 --- a/mindspore/dataset/audio/utils.py +++ b/mindspore/dataset/audio/utils.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== """ -enum for audio ops +Enum for audio ops. """ from enum import Enum diff --git a/mindspore/dataset/audio/validators.py b/mindspore/dataset/audio/validators.py index 6d463fe171b..8fe86d0a789 100644 --- a/mindspore/dataset/audio/validators.py +++ b/mindspore/dataset/audio/validators.py @@ -31,24 +31,17 @@ def check_amplitude_to_db(method): def new_method(self, *args, **kwargs): [stype, ref_value, amin, top_db], _ = parse_user_args(method, *args, **kwargs) - # type check stype type_check(stype, (ScaleType,), "stype") - # type check ref_value type_check(ref_value, (int, float), "ref_value") - # value check ref_value if ref_value is not None: check_pos_float32(ref_value, "ref_value") - # type check amin type_check(amin, (int, float), "amin") - # value check amin if amin is not None: check_pos_float32(amin, "amin") - # type check top_db type_check(top_db, (int, float), "top_db") - # value check top_db if top_db is not None: check_pos_float32(top_db, "top_db") @@ -246,7 +239,7 @@ def check_equalizer_biquad(method): def check_lfilter(method): - """Wrapper method to check the parameters of lfilter.""" + """Wrapper method to check the parameters of LFilter.""" @wraps(method) def new_method(self, *args, **kwargs): @@ -361,7 +354,7 @@ def check_treble_biquad(method): def check_masking(method): - """Wrapper method to check the parameters of time_masking and FrequencyMasking""" + """Wrapper method to check the parameters of TimeMasking and FrequencyMasking""" @wraps(method) def new_method(self, *args, **kwargs): @@ -455,9 +448,7 @@ def check_vol(method): @wraps(method) def new_method(self, *args, **kwargs): [gain, gain_type], _ = parse_user_args(method, *args, **kwargs) - # type check gain type_check(gain, (int, float), "gain") - # type check gain_type and value check gain type_check(gain_type, (GainType,), "gain_type") if gain_type == GainType.AMPLITUDE: check_non_negative_float32(gain, "gain") @@ -520,8 +511,8 @@ def check_flanger(method): type_check(phase, (float, int), "phase") check_value(phase, [0, 100], "phase") - type_check(modulation, (Modulation), "modulation") - type_check(interpolation, (Interpolation), "interpolation") + type_check(modulation, (Modulation,), "modulation") + type_check(interpolation, (Interpolation,), "interpolation") return method(self, *args, **kwargs) return new_method diff --git a/tests/ut/python/dataset/test_flanger.py b/tests/ut/python/dataset/test_flanger.py index a515582ae4a..243b5fee8d2 100644 --- a/tests/ut/python/dataset/test_flanger.py +++ b/tests/ut/python/dataset/test_flanger.py @@ -70,7 +70,6 @@ def test_flanger_eager_triangular_linear_int(): count_unequal_element(expect_waveform, output, 0.0001, 0.0001) - def test_flanger_shape_221(): """ mindspore eager mode normal testcase:flanger op""" # Original waveform @@ -191,13 +190,13 @@ def test_invalid_flanger_input(): test_invalid_input("invalid modulation parameter value", 44100, 0.0, 2.0, 0.0, 71.0, 0.5, 25.0, "test", Interpolation.LINEAR, TypeError, - "Argument modulation with value test is not of type [," - " ], but got .") + "Argument modulation with value test is not of type [], " + "but got .") test_invalid_input("invalid modulation parameter value", 44100, 0.0, 2.0, 0.0, 71.0, 0.5, 25.0, Modulation.SINUSOIDAL, "test", TypeError, - "Argument interpolation with value test is not of type [," - " ], but got .") + "Argument interpolation with value test is not of type [], " + "but got .") if __name__ == '__main__':