fix api docs of audio

This commit is contained in:
Xiao Tianci 2021-10-14 20:11:13 +08:00
parent 9c069f04c0
commit 930a5b89c0
5 changed files with 19 additions and 23 deletions

View File

@ -23,7 +23,7 @@ Common imported modules in corresponding API examples are as follows:
.. code-block::
import mindspore.dataset as ds
from mindspore.dataset import audio
import mindspore.dataset.audio.transforms as audio
"""
from . import transforms
from . import utils

View File

@ -33,7 +33,7 @@ from .validators import check_allpass_biquad, check_amplitude_to_db, check_band_
class AudioTensorOperation(TensorOperation):
"""
Base class of Audio Tensor Ops
Base class of Audio Tensor Ops.
"""
def __call__(self, *input_tensor_list):
@ -266,6 +266,7 @@ class Biquad(TensorOperation):
>>> biquad_op = audio.Biquad(0.01, 0.02, 0.13, 1, 0.12, 0.3)
>>> waveform_filtered = biquad_op(waveform)
"""
@check_biquad
def __init__(self, b0, b1, b2, a0, a1, a2):
self.b0 = b0
@ -294,6 +295,7 @@ class ComplexNorm(AudioTensorOperation):
>>> transforms = [audio.ComplexNorm()]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_complex_norm
def __init__(self, power=1.0):
self.power = power
@ -360,7 +362,7 @@ class DeemphBiquad(AudioTensorOperation):
Design two-pole deemph filter for audio waveform of dimension of (..., time).
Args:
Sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz),
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz),
the value must be 44100 or 48000.
Examples:
@ -371,6 +373,7 @@ class DeemphBiquad(AudioTensorOperation):
>>> transforms = [audio.DeemphBiquad(44100)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_deemph_biquad
def __init__(self, sample_rate):
self.sample_rate = sample_rate
@ -501,7 +504,6 @@ class Fade(AudioTensorOperation):
DE_C_MODULATION_TYPE = {Modulation.SINUSOIDAL: cde.Modulation.DE_MODULATION_SINUSOIDAL,
Modulation.TRIANGULAR: cde.Modulation.DE_MODULATION_TRIANGULAR}
DE_C_INTERPOLATION_TYPE = {Interpolation.LINEAR: cde.Interpolation.DE_INTERPOLATION_LINEAR,
Interpolation.QUADRATIC: cde.Interpolation.DE_INTERPOLATION_QUADRATIC}
@ -571,6 +573,7 @@ class FrequencyMasking(AudioTensorOperation):
>>> transforms = [audio.FrequencyMasking(frequency_mask_param=1)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_masking
def __init__(self, iid_masks=False, frequency_mask_param=0, mask_start=0, mask_value=0.0):
self.iid_masks = iid_masks
@ -634,6 +637,7 @@ class LFilter(AudioTensorOperation):
>>> transforms = [audio.LFilter(a_coeffs, b_coeffs)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_lfilter
def __init__(self, a_coeffs, b_coeffs, clamp=True):
self.a_coeffs = a_coeffs
@ -662,6 +666,7 @@ class LowpassBiquad(AudioTensorOperation):
>>> transforms = [audio.LowpassBiquad(4000, 1500, 0.7)]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_lowpass_biquad
def __init__(self, sample_rate, cutoff_freq, Q=0.707):
self.sample_rate = sample_rate
@ -846,6 +851,7 @@ class TimeStretch(AudioTensorOperation):
>>> transforms = [audio.TimeStretch()]
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
"""
@check_time_stretch
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None):
self.n_freq = n_freq

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================
"""
enum for audio ops
Enum for audio ops.
"""
from enum import Enum

View File

@ -31,24 +31,17 @@ def check_amplitude_to_db(method):
def new_method(self, *args, **kwargs):
[stype, ref_value, amin, top_db], _ = parse_user_args(method, *args, **kwargs)
# type check stype
type_check(stype, (ScaleType,), "stype")
# type check ref_value
type_check(ref_value, (int, float), "ref_value")
# value check ref_value
if ref_value is not None:
check_pos_float32(ref_value, "ref_value")
# type check amin
type_check(amin, (int, float), "amin")
# value check amin
if amin is not None:
check_pos_float32(amin, "amin")
# type check top_db
type_check(top_db, (int, float), "top_db")
# value check top_db
if top_db is not None:
check_pos_float32(top_db, "top_db")
@ -246,7 +239,7 @@ def check_equalizer_biquad(method):
def check_lfilter(method):
"""Wrapper method to check the parameters of lfilter."""
"""Wrapper method to check the parameters of LFilter."""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -361,7 +354,7 @@ def check_treble_biquad(method):
def check_masking(method):
"""Wrapper method to check the parameters of time_masking and FrequencyMasking"""
"""Wrapper method to check the parameters of TimeMasking and FrequencyMasking"""
@wraps(method)
def new_method(self, *args, **kwargs):
@ -455,9 +448,7 @@ def check_vol(method):
@wraps(method)
def new_method(self, *args, **kwargs):
[gain, gain_type], _ = parse_user_args(method, *args, **kwargs)
# type check gain
type_check(gain, (int, float), "gain")
# type check gain_type and value check gain
type_check(gain_type, (GainType,), "gain_type")
if gain_type == GainType.AMPLITUDE:
check_non_negative_float32(gain, "gain")
@ -520,8 +511,8 @@ def check_flanger(method):
type_check(phase, (float, int), "phase")
check_value(phase, [0, 100], "phase")
type_check(modulation, (Modulation), "modulation")
type_check(interpolation, (Interpolation), "interpolation")
type_check(modulation, (Modulation,), "modulation")
type_check(interpolation, (Interpolation,), "interpolation")
return method(self, *args, **kwargs)
return new_method

View File

@ -70,7 +70,6 @@ def test_flanger_eager_triangular_linear_int():
count_unequal_element(expect_waveform, output, 0.0001, 0.0001)
def test_flanger_shape_221():
""" mindspore eager mode normal testcase:flanger op"""
# Original waveform
@ -191,13 +190,13 @@ def test_invalid_flanger_input():
test_invalid_input("invalid modulation parameter value", 44100, 0.0, 2.0, 0.0, 71.0, 0.5, 25.0, "test",
Interpolation.LINEAR, TypeError,
"Argument modulation with value test is not of type [<Modulation.SINUSOIDAL: 'sinusoidal'>,"
" <Modulation.TRIANGULAR: 'triangular'>], but got <class 'str'>.")
"Argument modulation with value test is not of type [<enum 'Modulation'>], "
"but got <class 'str'>.")
test_invalid_input("invalid modulation parameter value", 44100, 0.0, 2.0, 0.0, 71.0, 0.5, 25.0,
Modulation.SINUSOIDAL, "test", TypeError,
"Argument interpolation with value test is not of type [<Interpolation.LINEAR: 'linear'>,"
" <Interpolation.QUADRATIC: 'quadratic'>], but got <class 'str'>.")
"Argument interpolation with value test is not of type [<enum 'Interpolation'>], "
"but got <class 'str'>.")
if __name__ == '__main__':