!29195 Update API docs for Audio and Callback to the previous modifications
Merge pull request !29195 from xiaotianci/fix_chinese_api
This commit is contained in:
commit
f4bd87380b
|
@ -3,13 +3,13 @@ mindspore.dataset.WaitedDSCallback
|
|||
|
||||
.. py:class:: mindspore.dataset.WaitedDSCallback(step_size=1)
|
||||
|
||||
阻塞式数据处理回调类的抽象基类,用于与训练回调类(`mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_)的同步。
|
||||
阻塞式数据处理回调类的抽象基类,用于与训练回调类 `mindspore.train.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_ 的同步。
|
||||
|
||||
可用于在step或epoch开始前执行自定义的回调方法,例如在自动数据增强中根据上一个epoch的loss值来更新增强算子参数配置。
|
||||
|
||||
注意,第2个step或epoch开始时才会触发该调用。
|
||||
|
||||
用户可通过 `train_run_context` 获取模型相关信息,如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等,详见 `mindspore.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_ 。
|
||||
用户可通过 `train_run_context` 获取网络训练相关信息,如 `network` 、 `train_network` 、 `epoch_num` 、 `batch_num` 、 `loss_fn` 、 `optimizer` 、 `parallel_mode` 、 `device_number` 、 `list_callback` 、 `cur_epoch_num` 、 `cur_step_num` 、 `dataset_sink_mode` 、 `net_outputs` 等,详见 `mindspore.train.callback <https://mindspore.cn/docs/api/zh-CN/master/api_python/mindspore.train.html#mindspore.train.callback.Callback>`_ 。
|
||||
|
||||
用户可通过 `ds_run_context` 获取数据处理管道相关信息,包括 `cur_epoch_num` (当前epoch数)、 `cur_step_num_in_epoch` (当前epoch的step数)、 `cur_step_num` (当前step数)。
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ mindspore.dataset.audio.transforms.FrequencyMasking
|
|||
- **mask_start** (int, 可选) - 添加掩码的起始位置,只有当 `iid_masks` 为True时,该值才会生效。取值范围为[0, freq_length - freq_mask_param],其中 `freq_length` 为音频波形在频域的长度,默认值:0。
|
||||
- **mask_value** (float, 可选) - 掩码填充值,默认值:0.0。
|
||||
|
||||
.. image:: api_img/dataset/frequency_masking_original.png
|
||||
.. image:: api_img/frequency_masking_original.png
|
||||
|
||||
.. image:: api_img/dataset/frequency_masking.png
|
||||
.. image:: api_img/frequency_masking.png
|
||||
|
||||
|
|
|
@ -14,6 +14,6 @@ mindspore.dataset.audio.transforms.TimeMasking
|
|||
- **mask_start** (int, 可选) - 添加掩码的起始位置,只有当 `iid_masks` 为True时,该值才会生效。取值范围为[0, time_length - time_mask_param],其中 `time_length` 为音频波形在时域的长度,默认值:0。
|
||||
- **mask_value** (float, 可选) - 掩码填充值,默认值:0.0。
|
||||
|
||||
.. image:: api_img/dataset/time_masking_original.png
|
||||
.. image:: api_img/time_masking_original.png
|
||||
|
||||
.. image:: api_img/dataset/time_masking.png
|
||||
.. image:: api_img/time_masking.png
|
||||
|
|
|
@ -13,8 +13,8 @@ mindspore.dataset.audio.transforms.TimeStretch
|
|||
- **n_freq** (int, 可选) - STFT中的滤波器组数,默认值:201。
|
||||
- **fixed_rate** (float, 可选) - 频谱在时域加快或减缓的比例,默认值:None,表示保持原始速率。
|
||||
|
||||
.. image:: api_img/dataset/time_stretch_rate1.5.png
|
||||
.. image:: api_img/time_stretch_rate1.5.png
|
||||
|
||||
.. image:: api_img/dataset/time_stretch_original.png
|
||||
.. image:: api_img/time_stretch_original.png
|
||||
|
||||
.. image:: api_img/dataset/time_stretch_rate0.8.png
|
||||
.. image:: api_img/time_stretch_rate0.8.png
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
namespace mindspore {
|
||||
namespace dataset {
|
||||
PYBIND_REGISTER(CreateDct, 1, ([](py::module *m) {
|
||||
(void)m->def("CreateDct", ([](int32_t n_mfcc, int32_t n_mels, NormMode norm) {
|
||||
(void)m->def("create_dct", ([](int32_t n_mfcc, int32_t n_mels, NormMode norm) {
|
||||
std::shared_ptr<Tensor> out;
|
||||
THROW_IF_ERROR(Dct(&out, n_mfcc, n_mels, norm));
|
||||
return out;
|
||||
|
@ -32,8 +32,8 @@ PYBIND_REGISTER(CreateDct, 1, ([](py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(MelscaleFbanks, 1, ([](py::module *m) {
|
||||
(void)m->def(
|
||||
"MelscaleFbanks", ([](int32_t n_freqs, float f_min, float f_max, int32_t n_mels,
|
||||
int32_t sample_rate, NormType norm, MelType mel_type) {
|
||||
"melscale_fbanks", ([](int32_t n_freqs, float f_min, float f_max, int32_t n_mels,
|
||||
int32_t sample_rate, NormType norm, MelType mel_type) {
|
||||
std::shared_ptr<Tensor> fb;
|
||||
THROW_IF_ERROR(CreateFbanks(&fb, n_freqs, f_min, f_max, n_mels, sample_rate, norm, mel_type));
|
||||
return fb;
|
||||
|
@ -42,22 +42,22 @@ PYBIND_REGISTER(MelscaleFbanks, 1, ([](py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(MelType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<MelType>(*m, "MelType", py::arithmetic())
|
||||
.value("DE_MELTYPE_HTK", MelType::kHtk)
|
||||
.value("DE_MELTYPE_SLANEY", MelType::kSlaney)
|
||||
.value("DE_MEL_TYPE_HTK", MelType::kHtk)
|
||||
.value("DE_MEL_TYPE_SLANEY", MelType::kSlaney)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NormType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<NormType>(*m, "NormType", py::arithmetic())
|
||||
.value("DE_NORMTYPE_NONE", NormType::kNone)
|
||||
.value("DE_NORMTYPE_SLANEY", NormType::kSlaney)
|
||||
.value("DE_NORM_TYPE_NONE", NormType::kNone)
|
||||
.value("DE_NORM_TYPE_SLANEY", NormType::kSlaney)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
PYBIND_REGISTER(NormMode, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<NormMode>(*m, "NormMode", py::arithmetic())
|
||||
.value("DE_NORMMODE_NONE", NormMode::kNone)
|
||||
.value("DE_NORMMODE_ORTHO", NormMode::kOrtho)
|
||||
.value("DE_NORM_MODE_NONE", NormMode::kNone)
|
||||
.value("DE_NORM_MODE_ORTHO", NormMode::kOrtho)
|
||||
.export_values();
|
||||
}));
|
||||
} // namespace dataset
|
||||
|
|
|
@ -84,8 +84,8 @@ PYBIND_REGISTER(
|
|||
|
||||
PYBIND_REGISTER(ScaleType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<ScaleType>(*m, "ScaleType", py::arithmetic())
|
||||
.value("DE_SCALETYPE_MAGNITUDE", ScaleType::kMagnitude)
|
||||
.value("DE_SCALETYPE_POWER", ScaleType::kPower)
|
||||
.value("DE_SCALE_TYPE_MAGNITUDE", ScaleType::kMagnitude)
|
||||
.value("DE_SCALE_TYPE_POWER", ScaleType::kPower)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
|
@ -234,9 +234,9 @@ PYBIND_REGISTER(DetectPitchFrequencyOperation, 1, ([](const py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(DensityFunction, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<DensityFunction>(*m, "DensityFunction", py::arithmetic())
|
||||
.value("DE_DENSITYFUNCTION_TPDF", DensityFunction::kTPDF)
|
||||
.value("DE_DENSITYFUNCTION_RPDF", DensityFunction::kRPDF)
|
||||
.value("DE_DENSITYFUNCTION_GPDF", DensityFunction::kGPDF)
|
||||
.value("DE_DENSITY_FUNCTION_TPDF", DensityFunction::kTPDF)
|
||||
.value("DE_DENSITY_FUNCTION_RPDF", DensityFunction::kRPDF)
|
||||
.value("DE_DENSITY_FUNCTION_GPDF", DensityFunction::kGPDF)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
|
@ -263,11 +263,11 @@ PYBIND_REGISTER(EqualizerBiquadOperation, 1, ([](const py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(FadeShape, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<FadeShape>(*m, "FadeShape", py::arithmetic())
|
||||
.value("DE_FADESHAPE_LINEAR", FadeShape::kLinear)
|
||||
.value("DE_FADESHAPE_EXPONENTIAL", FadeShape::kExponential)
|
||||
.value("DE_FADESHAPE_LOGARITHMIC", FadeShape::kLogarithmic)
|
||||
.value("DE_FADESHAPE_QUARTERSINE", FadeShape::kQuarterSine)
|
||||
.value("DE_FADESHAPE_HALFSINE", FadeShape::kHalfSine)
|
||||
.value("DE_FADE_SHAPE_LINEAR", FadeShape::kLinear)
|
||||
.value("DE_FADE_SHAPE_EXPONENTIAL", FadeShape::kExponential)
|
||||
.value("DE_FADE_SHAPE_LOGARITHMIC", FadeShape::kLogarithmic)
|
||||
.value("DE_FADE_SHAPE_QUARTER_SINE", FadeShape::kQuarterSine)
|
||||
.value("DE_FADE_SHAPE_HALF_SINE", FadeShape::kHalfSine)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
|
@ -442,11 +442,11 @@ PYBIND_REGISTER(SlidingWindowCmnOperation, 1, ([](const py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(WindowType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<WindowType>(*m, "WindowType", py::arithmetic())
|
||||
.value("DE_BARTLETT", WindowType::kBartlett)
|
||||
.value("DE_BLACKMAN", WindowType::kBlackman)
|
||||
.value("DE_HAMMING", WindowType::kHamming)
|
||||
.value("DE_HANN", WindowType::kHann)
|
||||
.value("DE_KAISER", WindowType::kKaiser)
|
||||
.value("DE_WINDOW_TYPE_BARTLETT", WindowType::kBartlett)
|
||||
.value("DE_WINDOW_TYPE_BLACKMAN", WindowType::kBlackman)
|
||||
.value("DE_WINDOW_TYPE_HAMMING", WindowType::kHamming)
|
||||
.value("DE_WINDOW_TYPE_HANN", WindowType::kHann)
|
||||
.value("DE_WINDOW_TYPE_KAISER", WindowType::kKaiser)
|
||||
.export_values();
|
||||
}));
|
||||
|
||||
|
@ -522,9 +522,9 @@ PYBIND_REGISTER(VolOperation, 1, ([](const py::module *m) {
|
|||
|
||||
PYBIND_REGISTER(GainType, 0, ([](const py::module *m) {
|
||||
(void)py::enum_<GainType>(*m, "GainType", py::arithmetic())
|
||||
.value("DE_GAINTYPE_AMPLITUDE", GainType::kAmplitude)
|
||||
.value("DE_GAINTYPE_POWER", GainType::kPower)
|
||||
.value("DE_GAINTYPE_DB", GainType::kDb)
|
||||
.value("DE_GAIN_TYPE_AMPLITUDE", GainType::kAmplitude)
|
||||
.value("DE_GAIN_TYPE_POWER", GainType::kPower)
|
||||
.value("DE_GAIN_TYPE_DB", GainType::kDb)
|
||||
.export_values();
|
||||
}));
|
||||
} // namespace dataset
|
||||
|
|
|
@ -50,12 +50,24 @@ class AudioTensorOperation(TensorOperation):
|
|||
|
||||
class AllpassBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design two-pole all-pass filter for audio waveform of dimension of (..., time).
|
||||
Design two-pole all-pass filter with central frequency and bandwidth for audio waveform.
|
||||
|
||||
An all-pass filter changes the audio's frequency to phase relationship without changing
|
||||
its frequency to amplitude relationship. The system function is:
|
||||
|
||||
.. math::
|
||||
H(s) = \frac{s^2 - \frac{s}{Q} + 1}{s^2 + \frac{s}{Q} + 1}
|
||||
|
||||
Similar to `SoX <http://sox.sourceforge.net/sox.html>`_ implementation.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., time).
|
||||
|
||||
Args:
|
||||
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
central_freq (float): central frequency (in Hz).
|
||||
Q(float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
|
||||
sample_rate (int): Sampling rate (in Hz), which can't be zero.
|
||||
central_freq (float): Central frequency (in Hz).
|
||||
Q (float, optional): `Quality factor <https://en.wikipedia.org/wiki/Q_factor>`_ ,
|
||||
in range of (0, 1]. Default: 0.707.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -76,26 +88,34 @@ class AllpassBiquad(AudioTensorOperation):
|
|||
return cde.AllpassBiquadOperation(self.sample_rate, self.central_freq, self.Q)
|
||||
|
||||
|
||||
DE_C_SCALETYPE_TYPE = {ScaleType.MAGNITUDE: cde.ScaleType.DE_SCALETYPE_MAGNITUDE,
|
||||
ScaleType.POWER: cde.ScaleType.DE_SCALETYPE_POWER}
|
||||
DE_C_SCALE_TYPE = {ScaleType.POWER: cde.ScaleType.DE_SCALE_TYPE_POWER,
|
||||
ScaleType.MAGNITUDE: cde.ScaleType.DE_SCALE_TYPE_MAGNITUDE}
|
||||
|
||||
|
||||
class AmplitudeToDB(AudioTensorOperation):
|
||||
"""
|
||||
Converts the input tensor from amplitude/power scale to decibel scale.
|
||||
Turn the input audio waveform from the amplitude/power scale to decibel scale.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., freq, time).
|
||||
|
||||
Args:
|
||||
stype (ScaleType, optional): Scale of the input tensor (default=ScaleType.POWER).
|
||||
It can be one of ScaleType.MAGNITUDE or ScaleType.POWER.
|
||||
ref_value (float, optional): Param for generate db_multiplier (default=1.0).
|
||||
amin (float, optional): Lower bound to clamp the input waveform. It must be greater than zero (default=1e-10).
|
||||
top_db (float, optional): Minimum cut-off decibels. The range of values is non-negative.
|
||||
Commonly set at 80 (default=80.0).
|
||||
stype (ScaleType, optional): Scale of the input waveform, which can be
|
||||
ScaleType.POWER or ScaleType.MAGNITUDE. Default: ScaleType.POWER.
|
||||
ref_value (float, optional): Multiplier reference value for generating
|
||||
`db_multiplier`. Default: 1.0. The formula is
|
||||
|
||||
:math:`\text{db_multiplier} = Log10(max(\text{ref_value}, amin))`.
|
||||
|
||||
amin (float, optional): Lower bound to clamp the input waveform, which must
|
||||
be greater than zero. Default: 1e-10.
|
||||
top_db (float, optional): Minimum cut-off decibels, which must be non-negative. Default: 80.0.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore.dataset.audio import ScaleType
|
||||
>>>
|
||||
>>> waveform = np.random.random([1, 400//2+1, 30])
|
||||
>>> waveform = np.random.random([1, 400 // 2 + 1, 30])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.AmplitudeToDB(stype=ScaleType.POWER)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
@ -109,13 +129,16 @@ class AmplitudeToDB(AudioTensorOperation):
|
|||
self.top_db = top_db
|
||||
|
||||
def parse(self):
|
||||
return cde.AmplitudeToDBOperation(DE_C_SCALETYPE_TYPE[self.stype], self.ref_value, self.amin, self.top_db)
|
||||
return cde.AmplitudeToDBOperation(DE_C_SCALE_TYPE[self.stype], self.ref_value, self.amin, self.top_db)
|
||||
|
||||
|
||||
class Angle(AudioTensorOperation):
|
||||
"""
|
||||
Calculate the angle of the complex number sequence of shape (..., 2).
|
||||
The first dimension represents the real part while the second represents the imaginary.
|
||||
Calculate the angle of complex number sequence.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., complex=2).
|
||||
The first dimension represents the real part while the second represents the imaginary.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -132,14 +155,24 @@ class Angle(AudioTensorOperation):
|
|||
|
||||
class BandBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design two-pole band filter for audio waveform of dimension of (..., time).
|
||||
Design two-pole band-pass filter for audio waveform.
|
||||
|
||||
The frequency response drops logarithmically around the center frequency. The
|
||||
bandwidth gives the slope of the drop. The frequencies at band edge will be
|
||||
half of their original amplitudes.
|
||||
|
||||
Similar to `SoX <http://sox.sourceforge.net/sox.html>`_ implementation.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., time).
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
sample_rate (int): Sampling rate (in Hz), which can't be zero.
|
||||
central_freq (float): Central frequency (in Hz).
|
||||
Q(float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
|
||||
Q (float, optional): `Quality factor <https://en.wikipedia.org/wiki/Q_factor>`_ ,
|
||||
in range of (0, 1]. Default: 0.707.
|
||||
noise (bool, optional) : If True, uses the alternate mode for un-pitched audio (e.g. percussion).
|
||||
If False, uses mode oriented to pitched audio, i.e. voice, singing, or instrumental music (default=False).
|
||||
If False, uses mode oriented to pitched audio, i.e. voice, singing, or instrumental music. Default: False.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -162,15 +195,32 @@ class BandBiquad(AudioTensorOperation):
|
|||
|
||||
|
||||
class BandpassBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design two-pole band-pass filter. Similar to SoX implementation.
|
||||
r"""
|
||||
Design two-pole Butterworth band-pass filter for audio waveform.
|
||||
|
||||
The frequency response of the Butterworth filter is maximally flat (i.e. has no ripples)
|
||||
in the passband and rolls off towards zero in the stopband.
|
||||
|
||||
The system function of Butterworth band-pass filter is:
|
||||
|
||||
.. math::
|
||||
H(s) = \begin{cases}
|
||||
\frac{s}{s^2 + \frac{s}{Q} + 1}, &\text{if const_skirt_gain=True}; \cr
|
||||
\frac{\frac{s}{Q}}{s^2 + \frac{s}{Q} + 1}, &\text{if const_skirt_gain=False}.
|
||||
\end{cases}
|
||||
|
||||
Similar to `SoX <http://sox.sourceforge.net/sox.html>`_ implementation.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., time).
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
sample_rate (int): Sampling rate (in Hz), which can't be zero.
|
||||
central_freq (float): Central frequency (in Hz).
|
||||
Q (float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0,1] (default=0.707).
|
||||
const_skirt_gain (bool, optional) : If True, uses a constant skirt gain (peak gain = Q).
|
||||
If False, uses a constant 0dB peak gain (default=False).
|
||||
Q (float, optional): `Quality factor <https://en.wikipedia.org/wiki/Q_factor>`_ ,
|
||||
in range of (0, 1]. Default: 0.707.
|
||||
const_skirt_gain (bool, optional) : If True, uses a constant skirt gain (peak gain = Q);
|
||||
If False, uses a constant 0dB peak gain. Default: False.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -194,12 +244,26 @@ class BandpassBiquad(AudioTensorOperation):
|
|||
|
||||
class BandrejectBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design two-pole band-reject filter for audio waveform of dimension of (..., time).
|
||||
Design two-pole Butterworth band-reject filter for audio waveform.
|
||||
|
||||
The frequency response of the Butterworth filter is maximally flat (i.e. has no ripples)
|
||||
in the passband and rolls off towards zero in the stopband.
|
||||
|
||||
The system function of Butterworth band-reject filter is:
|
||||
|
||||
.. math::
|
||||
H(s) = \frac{s^2 + 1}{s^2 + \frac{s}{Q} + 1}
|
||||
|
||||
Similar to `SoX <http://sox.sourceforge.net/sox.html>`_ implementation.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., time).
|
||||
|
||||
Args:
|
||||
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
central_freq (float): central frequency (in Hz).
|
||||
Q(float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
|
||||
sample_rate (int): Sampling rate (in Hz), which can't be zero.
|
||||
central_freq (float): Central frequency (in Hz).
|
||||
Q (float, optional): `Quality factor <https://en.wikipedia.org/wiki/Q_factor>`_ ,
|
||||
in range of (0, 1]. Default: 0.707.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -221,14 +285,26 @@ class BandrejectBiquad(AudioTensorOperation):
|
|||
|
||||
|
||||
class BassBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design a bass tone-control effect for audio waveform of dimension of (..., time).
|
||||
r"""
|
||||
Design a bass tone-control effect, also known as two-pole low-shelf filter for audio waveform.
|
||||
|
||||
A low-shelf filter passes all frequencies, but increase or reduces frequencies below the shelf
|
||||
frequency by specified amount. The system function is:
|
||||
|
||||
.. math::
|
||||
H(s) = A\frac{s^2 + \frac{\sqrt{A}}{Q}s + A}{As^2 + \frac{\sqrt{A}}{Q}s + 1}
|
||||
|
||||
Similar to `SoX <http://sox.sourceforge.net/sox.html>`_ implementation.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., time).
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
sample_rate (int): Sampling rate (in Hz), which can't be zero.
|
||||
gain (float): Desired gain at the boost (or attenuation) in dB.
|
||||
central_freq (float): Central frequency (in Hz) (default=100.0).
|
||||
Q(float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
|
||||
central_freq (float, optional): Central frequency (in Hz). Default: 100.0.
|
||||
Q (float, optional): `Quality factor <https://en.wikipedia.org/wiki/Q_factor>`_ ,
|
||||
in range of (0, 1]. Default: 0.707.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -252,7 +328,7 @@ class BassBiquad(AudioTensorOperation):
|
|||
|
||||
class Biquad(TensorOperation):
|
||||
"""
|
||||
Perform a biquad filter of input tensor.
|
||||
Perform a biquad filter of input audio.
|
||||
|
||||
Args:
|
||||
b0 (float): Numerator coefficient of current input, x[n].
|
||||
|
@ -285,10 +361,14 @@ class Biquad(TensorOperation):
|
|||
|
||||
class ComplexNorm(AudioTensorOperation):
|
||||
"""
|
||||
Compute the norm of complex tensor input.
|
||||
Compute the norm of complex number sequence.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., complex=2).
|
||||
The first dimension represents the real part while the second represents the imaginary.
|
||||
|
||||
Args:
|
||||
power (float, optional): Power of the norm, which must be non-negative (default=1.0).
|
||||
power (float, optional): Power of the norm, which must be non-negative. Default: 1.0.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -355,12 +435,19 @@ class ComputeDeltas(AudioTensorOperation):
|
|||
|
||||
class Contrast(AudioTensorOperation):
|
||||
"""
|
||||
Apply contrast effect. Similar to SoX implementation.
|
||||
Apply contrast effect for audio waveform.
|
||||
|
||||
Comparable with compression, this effect modifies an audio signal to make it sound louder.
|
||||
|
||||
Similar to `SoX <http://sox.sourceforge.net/sox.html>`_ implementation.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., time).
|
||||
|
||||
Args:
|
||||
enhancement_amount (float): Controls the amount of the enhancement. Allowed range is [0, 100] (default=75.0).
|
||||
Note that enhancement_amount equal to 0 still gives a significant contrast enhancement.
|
||||
enhancement_amount (float, optional): Controls the amount of the enhancement,
|
||||
in range of [0, 100]. Default: 75.0. Note that `enhancement_amount` equal
|
||||
to 0 still gives a significant contrast enhancement.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -420,7 +507,7 @@ class DCShift(AudioTensorOperation):
|
|||
>>> waveform = np.array([0.60, 0.97, -1.04, -1.26, 0.97, 0.91, 0.48, 0.93])
|
||||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.DCShift(0.5, 0.02)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operation=transforms, input_columns=["audio"])
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
"""
|
||||
|
||||
@check_dc_shift
|
||||
|
@ -496,9 +583,9 @@ class DetectPitchFrequency(AudioTensorOperation):
|
|||
self.win_length, self.freq_low, self.freq_high)
|
||||
|
||||
|
||||
DE_C_DENSITYFUNCTION_TYPE = {DensityFunction.TPDF: cde.DensityFunction.DE_DENSITYFUNCTION_TPDF,
|
||||
DensityFunction.RPDF: cde.DensityFunction.DE_DENSITYFUNCTION_RPDF,
|
||||
DensityFunction.GPDF: cde.DensityFunction.DE_DENSITYFUNCTION_GPDF}
|
||||
DE_C_DENSITY_FUNCTION = {DensityFunction.TPDF: cde.DensityFunction.DE_DENSITY_FUNCTION_TPDF,
|
||||
DensityFunction.RPDF: cde.DensityFunction.DE_DENSITY_FUNCTION_RPDF,
|
||||
DensityFunction.GPDF: cde.DensityFunction.DE_DENSITY_FUNCTION_GPDF}
|
||||
|
||||
|
||||
class Dither(AudioTensorOperation):
|
||||
|
@ -530,7 +617,7 @@ class Dither(AudioTensorOperation):
|
|||
self.noise_shaping = noise_shaping
|
||||
|
||||
def parse(self):
|
||||
return cde.DitherOperation(DE_C_DENSITYFUNCTION_TYPE[self.density_function], self.noise_shaping)
|
||||
return cde.DitherOperation(DE_C_DENSITY_FUNCTION[self.density_function], self.noise_shaping)
|
||||
|
||||
|
||||
class EqualizerBiquad(AudioTensorOperation):
|
||||
|
@ -563,11 +650,11 @@ class EqualizerBiquad(AudioTensorOperation):
|
|||
return cde.EqualizerBiquadOperation(self.sample_rate, self.center_freq, self.gain, self.Q)
|
||||
|
||||
|
||||
DE_C_FADESHAPE_TYPE = {FadeShape.LINEAR: cde.FadeShape.DE_FADESHAPE_LINEAR,
|
||||
FadeShape.EXPONENTIAL: cde.FadeShape.DE_FADESHAPE_EXPONENTIAL,
|
||||
FadeShape.LOGARITHMIC: cde.FadeShape.DE_FADESHAPE_LOGARITHMIC,
|
||||
FadeShape.QUARTERSINE: cde.FadeShape.DE_FADESHAPE_QUARTERSINE,
|
||||
FadeShape.HALFSINE: cde.FadeShape.DE_FADESHAPE_HALFSINE}
|
||||
DE_C_FADE_SHAPE = {FadeShape.QUARTER_SINE: cde.FadeShape.DE_FADE_SHAPE_QUARTER_SINE,
|
||||
FadeShape.HALF_SINE: cde.FadeShape.DE_FADE_SHAPE_HALF_SINE,
|
||||
FadeShape.LINEAR: cde.FadeShape.DE_FADE_SHAPE_LINEAR,
|
||||
FadeShape.LOGARITHMIC: cde.FadeShape.DE_FADE_SHAPE_LOGARITHMIC,
|
||||
FadeShape.EXPONENTIAL: cde.FadeShape.DE_FADE_SHAPE_EXPONENTIAL}
|
||||
|
||||
|
||||
class Fade(AudioTensorOperation):
|
||||
|
@ -578,17 +665,18 @@ class Fade(AudioTensorOperation):
|
|||
fade_in_len (int, optional): Length of fade-in (time frames), which must be non-negative (default=0).
|
||||
fade_out_len (int, optional): Length of fade-out (time frames), which must be non-negative (default=0).
|
||||
fade_shape (FadeShape, optional): Shape of fade (default=FadeShape.LINEAR). Can be one of
|
||||
[FadeShape.LINEAR, FadeShape.EXPONENTIAL, FadeShape.LOGARITHMIC, FadeShape.QUARTERSINC, FadeShape.HALFSINC].
|
||||
FadeShape.QUARTER_SINE, FadeShape.HALF_SINE, FadeShape.LINEAR, FadeShape.LOGARITHMIC or
|
||||
FadeShape.EXPONENTIAL.
|
||||
|
||||
-FadeShape.QUARTER_SINE, means it tend to 0 in an quarter sin function.
|
||||
|
||||
-FadeShape.HALF_SINE, means it tend to 0 in an half sin function.
|
||||
|
||||
-FadeShape.LINEAR, means it linear to 0.
|
||||
|
||||
-FadeShape.EXPONENTIAL, means it tend to 0 in an exponential function.
|
||||
|
||||
-FadeShape.LOGARITHMIC, means it tend to 0 in an logrithmic function.
|
||||
|
||||
-FadeShape.QUARTERSINE, means it tend to 0 in an quarter sin function.
|
||||
|
||||
-FadeShape.HALFSINE, means it tend to 0 in an half sin function.
|
||||
-FadeShape.EXPONENTIAL, means it tend to 0 in an exponential function.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If fade_in_len exceeds waveform length.
|
||||
|
@ -611,14 +699,14 @@ class Fade(AudioTensorOperation):
|
|||
self.fade_shape = fade_shape
|
||||
|
||||
def parse(self):
|
||||
return cde.FadeOperation(self.fade_in_len, self.fade_out_len, DE_C_FADESHAPE_TYPE[self.fade_shape])
|
||||
return cde.FadeOperation(self.fade_in_len, self.fade_out_len, DE_C_FADE_SHAPE[self.fade_shape])
|
||||
|
||||
|
||||
DE_C_MODULATION_TYPE = {Modulation.SINUSOIDAL: cde.Modulation.DE_MODULATION_SINUSOIDAL,
|
||||
Modulation.TRIANGULAR: cde.Modulation.DE_MODULATION_TRIANGULAR}
|
||||
DE_C_MODULATION = {Modulation.SINUSOIDAL: cde.Modulation.DE_MODULATION_SINUSOIDAL,
|
||||
Modulation.TRIANGULAR: cde.Modulation.DE_MODULATION_TRIANGULAR}
|
||||
|
||||
DE_C_INTERPOLATION_TYPE = {Interpolation.LINEAR: cde.Interpolation.DE_INTERPOLATION_LINEAR,
|
||||
Interpolation.QUADRATIC: cde.Interpolation.DE_INTERPOLATION_QUADRATIC}
|
||||
DE_C_INTERPOLATION = {Interpolation.LINEAR: cde.Interpolation.DE_INTERPOLATION_LINEAR,
|
||||
Interpolation.QUADRATIC: cde.Interpolation.DE_INTERPOLATION_QUADRATIC}
|
||||
|
||||
|
||||
class Flanger(AudioTensorOperation):
|
||||
|
@ -662,21 +750,27 @@ class Flanger(AudioTensorOperation):
|
|||
|
||||
def parse(self):
|
||||
return cde.FlangerOperation(self.sample_rate, self.delay, self.depth, self.regen, self.width, self.speed,
|
||||
self.phase, DE_C_MODULATION_TYPE[self.modulation],
|
||||
DE_C_INTERPOLATION_TYPE[self.interpolation])
|
||||
self.phase, DE_C_MODULATION[self.modulation],
|
||||
DE_C_INTERPOLATION[self.interpolation])
|
||||
|
||||
|
||||
class FrequencyMasking(AudioTensorOperation):
|
||||
"""
|
||||
Apply masking to a spectrogram in the frequency domain.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., freq, time).
|
||||
|
||||
Args:
|
||||
iid_masks (bool, optional): Whether to apply different masks to each example (default=false).
|
||||
frequency_mask_param (int): Maximum possible length of the mask, range: [0, freq_length] (default=0).
|
||||
Indices uniformly sampled from [0, frequency_mask_param].
|
||||
mask_start (int): Mask start takes effect when iid_masks=true,
|
||||
range: [0, freq_length-frequency_mask_param] (default=0).
|
||||
mask_value (double): Mask value (default=0.0).
|
||||
iid_masks (bool, optional): Whether to apply different masks to each example/channel. Default: False.
|
||||
freq_mask_param (int, optional): When `iid_masks` is True, length of the mask will be uniformly sampled
|
||||
from [0, freq_mask_param]; When `iid_masks` is False, directly use it as length of the mask.
|
||||
The value should be in range of [0, freq_length], where `freq_length` is the length of audio waveform
|
||||
in frequency domain. Default: 0.
|
||||
mask_start (int): Starting point to apply mask, only works when `iid_masks` is True. The value should
|
||||
be in range of [0, freq_length - freq_mask_param], where `freq_length` is the length of audio waveform
|
||||
in frequency domain. Default: 0.
|
||||
mask_value (float, optional): Value to assign to the masked columns. Default: 0.0.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -685,12 +779,16 @@ class FrequencyMasking(AudioTensorOperation):
|
|||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.FrequencyMasking(frequency_mask_param=1)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
.. image:: api_img/frequency_masking_original.png
|
||||
|
||||
.. image:: api_img/frequency_masking.png
|
||||
"""
|
||||
|
||||
@check_masking
|
||||
def __init__(self, iid_masks=False, frequency_mask_param=0, mask_start=0, mask_value=0.0):
|
||||
def __init__(self, iid_masks=False, freq_mask_param=0, mask_start=0, mask_value=0.0):
|
||||
self.iid_masks = iid_masks
|
||||
self.frequency_mask_param = frequency_mask_param
|
||||
self.frequency_mask_param = freq_mask_param
|
||||
self.mask_start = mask_start
|
||||
self.mask_value = mask_value
|
||||
|
||||
|
@ -787,12 +885,24 @@ class LFilter(AudioTensorOperation):
|
|||
|
||||
class LowpassBiquad(AudioTensorOperation):
|
||||
"""
|
||||
Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
|
||||
Design two-pole low-pass filter for audio waveform.
|
||||
|
||||
A low-pass filter passes frequencies lower than a selected cutoff frequency
|
||||
but attenuates frequencies higher than it. The system function is:
|
||||
|
||||
.. math::
|
||||
H(s) = \frac{1}{s^2 + \frac{s}{Q} + 1}
|
||||
|
||||
Similar to `SoX <http://sox.sourceforge.net/sox.html>`_ implementation.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., time).
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sampling rate of the waveform, e.g. 44100 (Hz), the value can't be zero.
|
||||
cutoff_freq (float): Filter cutoff frequency.
|
||||
Q(float, optional): Quality factor, https://en.wikipedia.org/wiki/Q_factor, range: (0, 1] (default=0.707).
|
||||
sample_rate (int): Sampling rate (in Hz), which can't be zero.
|
||||
cutoff_freq (float): Filter cutoff frequency (in Hz).
|
||||
Q (float, optional): `Quality factor <https://en.wikipedia.org/wiki/Q_factor>`_ ,
|
||||
in range of (0, 1]. Default: 0.707.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -1012,11 +1122,11 @@ class SlidingWindowCmn(AudioTensorOperation):
|
|||
return cde.SlidingWindowCmnOperation(self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
|
||||
|
||||
|
||||
DE_C_WINDOW_TYPE = {WindowType.BARTLETT: cde.WindowType.DE_BARTLETT,
|
||||
WindowType.BLACKMAN: cde.WindowType.DE_BLACKMAN,
|
||||
WindowType.HAMMING: cde.WindowType.DE_HAMMING,
|
||||
WindowType.HANN: cde.WindowType.DE_HANN,
|
||||
WindowType.KAISER: cde.WindowType.DE_KAISER}
|
||||
DE_C_WINDOW_TYPE = {WindowType.BARTLETT: cde.WindowType.DE_WINDOW_TYPE_BARTLETT,
|
||||
WindowType.BLACKMAN: cde.WindowType.DE_WINDOW_TYPE_BLACKMAN,
|
||||
WindowType.HAMMING: cde.WindowType.DE_WINDOW_TYPE_HAMMING,
|
||||
WindowType.HANN: cde.WindowType.DE_WINDOW_TYPE_HANN,
|
||||
WindowType.KAISER: cde.WindowType.DE_WINDOW_TYPE_KAISER}
|
||||
|
||||
|
||||
class SpectralCentroid(TensorOperation):
|
||||
|
@ -1110,13 +1220,19 @@ class TimeMasking(AudioTensorOperation):
|
|||
"""
|
||||
Apply masking to a spectrogram in the time domain.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., freq, time).
|
||||
|
||||
Args:
|
||||
iid_masks (bool, optional): Whether to apply different masks to each example (default=false).
|
||||
time_mask_param (int): Maximum possible length of the mask, range: [0, time_length] (default=0).
|
||||
Indices uniformly sampled from [0, time_mask_param].
|
||||
mask_start (int): Mask start takes effect when iid_masks=true,
|
||||
range: [0, time_length-time_mask_param] (default=0).
|
||||
mask_value (double): Mask value (default=0.0).
|
||||
iid_masks (bool, optional): Whether to apply different masks to each example/channel. Default: False.
|
||||
time_mask_param (int): When `iid_masks` is True, length of the mask will be uniformly sampled
|
||||
from [0, time_mask_param]; When `iid_masks` is False, directly use it as length of the mask.
|
||||
The value should be in range of [0, time_length], where `time_length` is the length of audio waveform
|
||||
in time domain. Default: 0.
|
||||
mask_start (int): Starting point to apply mask, only works when `iid_masks` is True. The value should
|
||||
be in range of [0, time_length - time_mask_param], where `time_length` is the length of audio waveform
|
||||
in time domain. Default: 0.
|
||||
mask_value (float, optional): Value to assign to the masked columns. Default: 0.0.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -1125,6 +1241,10 @@ class TimeMasking(AudioTensorOperation):
|
|||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.TimeMasking(time_mask_param=1)]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
.. image:: api_img/time_masking_original.png
|
||||
|
||||
.. image:: api_img/time_masking.png
|
||||
"""
|
||||
|
||||
@check_masking
|
||||
|
@ -1140,13 +1260,18 @@ class TimeMasking(AudioTensorOperation):
|
|||
|
||||
class TimeStretch(AudioTensorOperation):
|
||||
"""
|
||||
Stretch STFT in time at a given rate, without changing the pitch.
|
||||
Stretch Short Time Fourier Transform (STFT) in time without modifying pitch for a given rate.
|
||||
|
||||
Note:
|
||||
The dimension of the audio waveform to be processed needs to be (..., freq, time, complex=2).
|
||||
The first dimension represents the real part while the second represents the imaginary.
|
||||
|
||||
Args:
|
||||
hop_length (int, optional): Length of hop between STFT windows (default=None, will use ((n_freq - 1) * 2) // 2).
|
||||
n_freq (int, optional): Number of filter banks form STFT (default=201).
|
||||
fixed_rate (float, optional): Rate to speed up or slow down the input in time
|
||||
(default=None, will keep the original rate).
|
||||
hop_length (int, optional): Length of hop between STFT windows, i.e. the number of samples
|
||||
between consecutive frames. Default: None, will use `n_freq - 1`.
|
||||
n_freq (int, optional): Number of filter banks from STFT. Default: 201.
|
||||
fixed_rate (float, optional): Rate to speed up or slow down by. Default: None, will keep
|
||||
the original rate.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
|
@ -1155,6 +1280,12 @@ class TimeStretch(AudioTensorOperation):
|
|||
>>> numpy_slices_dataset = ds.NumpySlicesDataset(data=waveform, column_names=["audio"])
|
||||
>>> transforms = [audio.TimeStretch()]
|
||||
>>> numpy_slices_dataset = numpy_slices_dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
.. image:: api_img/time_stretch_rate1.5.png
|
||||
|
||||
.. image:: api_img/time_stretch_original.png
|
||||
|
||||
.. image:: api_img/time_stretch_rate0.8.png
|
||||
"""
|
||||
|
||||
@check_time_stretch
|
||||
|
@ -1200,9 +1331,9 @@ class TrebleBiquad(AudioTensorOperation):
|
|||
return cde.TrebleBiquadOperation(self.sample_rate, self.gain, self.central_freq, self.Q)
|
||||
|
||||
|
||||
DE_C_GAINTYPE_TYPE = {GainType.AMPLITUDE: cde.GainType.DE_GAINTYPE_AMPLITUDE,
|
||||
GainType.POWER: cde.GainType.DE_GAINTYPE_POWER,
|
||||
GainType.DB: cde.GainType.DE_GAINTYPE_DB}
|
||||
DE_C_GAIN_TYPE = {GainType.AMPLITUDE: cde.GainType.DE_GAIN_TYPE_AMPLITUDE,
|
||||
GainType.POWER: cde.GainType.DE_GAIN_TYPE_POWER,
|
||||
GainType.DB: cde.GainType.DE_GAIN_TYPE_DB}
|
||||
|
||||
|
||||
class Vol(AudioTensorOperation):
|
||||
|
@ -1233,4 +1364,4 @@ class Vol(AudioTensorOperation):
|
|||
self.gain_type = gain_type
|
||||
|
||||
def parse(self):
|
||||
return cde.VolOperation(self.gain, DE_C_GAINTYPE_TYPE[self.gain_type])
|
||||
return cde.VolOperation(self.gain, DE_C_GAIN_TYPE[self.gain_type])
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
# Copyright 2021-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -18,19 +18,39 @@ Enum for audio ops.
|
|||
from enum import Enum
|
||||
|
||||
import mindspore._c_dataengine as cde
|
||||
from mindspore.dataset.core.validator_helpers import check_non_negative_float32, check_non_negative_int32, check_pos_float32, check_pos_int32, \
|
||||
type_check
|
||||
from mindspore.dataset.core.validator_helpers import check_non_negative_float32, check_non_negative_int32, \
|
||||
check_pos_float32, check_pos_int32, type_check
|
||||
|
||||
|
||||
class BorderType(str, Enum):
|
||||
"""
|
||||
Padding Mode, BorderType Type.
|
||||
|
||||
Possible enumeration values are: BorderType.CONSTANT, BorderType.EDGE, BorderType.REFLECT, BorderType.SYMMETRIC.
|
||||
|
||||
- BorderType.CONSTANT: means it fills the border with constant values.
|
||||
- BorderType.EDGE: means it pads with the last value on the edge.
|
||||
- BorderType.REFLECT: means it reflects the values on the edge omitting the last value of edge.
|
||||
- BorderType.SYMMETRIC: means it reflects the values on the edge repeating the last value of edge.
|
||||
|
||||
Note: This class derived from class str to support json serializable.
|
||||
"""
|
||||
CONSTANT: str = "constant"
|
||||
EDGE: str = "edge"
|
||||
REFLECT: str = "reflect"
|
||||
SYMMETRIC: str = "symmetric"
|
||||
|
||||
|
||||
class DensityFunction(str, Enum):
|
||||
"""
|
||||
Density Functions.
|
||||
|
||||
Possible enumeration values are: DensityFunction.TPDF, DensityFunction.GPDF,
|
||||
DensityFunction.RPDF.
|
||||
Possible enumeration values are: DensityFunction.TPDF, DensityFunction.RPDF,
|
||||
DensityFunction.GPDF.
|
||||
|
||||
- DensityFunction.TPDF: means triangular probability density function.
|
||||
- DensityFunction.GPDF: means gaussian probability density function.
|
||||
- DensityFunction.RPDF: means rectangular probability density function.
|
||||
- DensityFunction.GPDF: means gaussian probability density function.
|
||||
"""
|
||||
TPDF: str = "TPDF"
|
||||
RPDF: str = "RPDF"
|
||||
|
@ -41,34 +61,34 @@ class FadeShape(str, Enum):
|
|||
"""
|
||||
Fade Shapes.
|
||||
|
||||
Possible enumeration values are: FadeShape.EXPONENTIAL, FadeShape.HALFSINE, FadeShape.LINEAR,
|
||||
FadeShape.LOGARITHMIC, FadeShape.QUARTERSINE.
|
||||
Possible enumeration values are: FadeShape.QUARTER_SINE, FadeShape.HALF_SINE, FadeShape.LINEAR,
|
||||
FadeShape.LOGARITHMIC, FadeShape.EXPONENTIAL.
|
||||
|
||||
- FadeShape.EXPONENTIAL: means the fade shape is exponential mode.
|
||||
- FadeShape.HALFSINE: means the fade shape is half_sine mode.
|
||||
- FadeShape.QUARTER_SINE: means the fade shape is quarter_sine mode.
|
||||
- FadeShape.HALF_SINE: means the fade shape is half_sine mode.
|
||||
- FadeShape.LINEAR: means the fade shape is linear mode.
|
||||
- FadeShape.LOGARITHMIC: means the fade shape is logarithmic mode.
|
||||
- FadeShape.QUARTERSINE: means the fade shape is quarter_sine mode.
|
||||
- FadeShape.EXPONENTIAL: means the fade shape is exponential mode.
|
||||
"""
|
||||
QUARTER_SINE: str = "quarter_sine"
|
||||
HALF_SINE: str = "half_sine"
|
||||
LINEAR: str = "linear"
|
||||
EXPONENTIAL: str = "exponential"
|
||||
LOGARITHMIC: str = "logarithmic"
|
||||
QUARTERSINE: str = "quarter_sine"
|
||||
HALFSINE: str = "half_sine"
|
||||
EXPONENTIAL: str = "exponential"
|
||||
|
||||
|
||||
class GainType(str, Enum):
|
||||
""""
|
||||
Gain Types.
|
||||
|
||||
Possible enumeration values are: GainType.AMPLITUDE, GainType.DB, GainType.POWER.
|
||||
Possible enumeration values are: GainType.AMPLITUDE, GainType.POWER, GainType.DB.
|
||||
|
||||
- GainType.AMPLITUDE: means input gain type is amplitude.
|
||||
- GainType.DB: means input gain type is decibel.
|
||||
- GainType.POWER: means input gain type is power.
|
||||
- GainType.DB: means input gain type is decibel.
|
||||
"""
|
||||
POWER: str = "power"
|
||||
AMPLITUDE: str = "amplitude"
|
||||
POWER: str = "power"
|
||||
DB: str = "db"
|
||||
|
||||
|
||||
|
@ -85,49 +105,6 @@ class Interpolation(str, Enum):
|
|||
QUADRATIC: str = "quadratic"
|
||||
|
||||
|
||||
class Modulation(str, Enum):
|
||||
"""
|
||||
Modulation Type.
|
||||
|
||||
Possible enumeration values are: Modulation.SINUSOIDAL, Modulation.TRIANGULAR.
|
||||
|
||||
- Modulation.SINUSOIDAL: means input modulation type is sinusoidal.
|
||||
- Modulation.TRIANGULAR: means input modulation type is triangular.
|
||||
"""
|
||||
SINUSOIDAL: str = "sinusoidal"
|
||||
TRIANGULAR: str = "triangular"
|
||||
|
||||
|
||||
class ScaleType(str, Enum):
|
||||
"""
|
||||
Scale Types.
|
||||
|
||||
Possible enumeration values are: ScaleType.MAGNITUDE, ScaleType.POWER.
|
||||
|
||||
- ScaleType.MAGNITUDE: means the scale of input audio is magnitude.
|
||||
- ScaleType.POWER: means the scale of input audio is power.
|
||||
"""
|
||||
POWER: str = "power"
|
||||
MAGNITUDE: str = "magnitude"
|
||||
|
||||
|
||||
class NormType(str, Enum):
|
||||
"""
|
||||
Norm Types.
|
||||
|
||||
Possible enumeration values are: NormType.NONE, NormType.SLANEY.
|
||||
|
||||
- NormType.NONE: norm the input data with none.
|
||||
- NormType.SLANEY: norm the input data with slaney.
|
||||
"""
|
||||
NONE: str = "none"
|
||||
SLANEY: str = "slaney"
|
||||
|
||||
|
||||
DE_C_NORMTYPE_TYPE = {NormType.NONE: cde.NormType.DE_NORMTYPE_NONE,
|
||||
NormType.SLANEY: cde.NormType.DE_NORMTYPE_SLANEY}
|
||||
|
||||
|
||||
class MelType(str, Enum):
|
||||
"""
|
||||
Mel Types.
|
||||
|
@ -141,8 +118,121 @@ class MelType(str, Enum):
|
|||
SLANEY: str = "slaney"
|
||||
|
||||
|
||||
DE_C_MELTYPE_TYPE = {MelType.HTK: cde.MelType.DE_MELTYPE_HTK,
|
||||
MelType.SLANEY: cde.MelType.DE_MELTYPE_SLANEY}
|
||||
class Modulation(str, Enum):
|
||||
"""
|
||||
Modulation Type.
|
||||
|
||||
Possible enumeration values are: Modulation.SINUSOIDAL, Modulation.TRIANGULAR.
|
||||
|
||||
- Modulation.SINUSOIDAL: means input modulation type is sinusoidal.
|
||||
- Modulation.TRIANGULAR: means input modulation type is triangular.
|
||||
"""
|
||||
SINUSOIDAL: str = "sinusoidal"
|
||||
TRIANGULAR: str = "triangular"
|
||||
|
||||
|
||||
class NormMode(str, Enum):
|
||||
"""
|
||||
Norm Types.
|
||||
|
||||
Possible enumeration values are: NormMode.ORTHO, NormMode.NONE.
|
||||
|
||||
- NormMode.ORTHO: means the mode of input audio is ortho.
|
||||
- NormMode.NONE: means the mode of input audio is none.
|
||||
"""
|
||||
ORTHO: str = "ortho"
|
||||
NONE: str = "none"
|
||||
|
||||
|
||||
class NormType(str, Enum):
|
||||
"""
|
||||
Norm Types.
|
||||
|
||||
Possible enumeration values are: NormType.SLANEY, NormType.NONE.
|
||||
|
||||
- NormType.SLANEY: norm the input data with slaney.
|
||||
- NormType.NONE: norm the input data with none.
|
||||
"""
|
||||
SLANEY: str = "slaney"
|
||||
NONE: str = "none"
|
||||
|
||||
|
||||
class ScaleType(str, Enum):
|
||||
"""
|
||||
Scale Types.
|
||||
|
||||
Possible enumeration values are: ScaleType.POWER, ScaleType.MAGNITUDE.
|
||||
|
||||
- ScaleType.POWER: means the scale of input audio is power.
|
||||
- ScaleType.MAGNITUDE: means the scale of input audio is magnitude.
|
||||
"""
|
||||
POWER: str = "power"
|
||||
MAGNITUDE: str = "magnitude"
|
||||
|
||||
|
||||
class WindowType(str, Enum):
|
||||
"""
|
||||
Window Function types,
|
||||
|
||||
Possible enumeration values are: WindowType.BARTLETT, WindowType.BLACKMAN, WindowType.HAMMING, WindowType.HANN,
|
||||
WindowType.KAISER.
|
||||
|
||||
- WindowType.BARTLETT: means the type of window function is Bartlett.
|
||||
- WindowType.BLACKMAN: means the type of window function is Blackman.
|
||||
- WindowType.HAMMING: means the type of window function is Hamming.
|
||||
- WindowType.HANN: means the type of window function is Hann.
|
||||
- WindowType.KAISER: means the type of window function is Kaiser, currently not supported on macOS.
|
||||
"""
|
||||
BARTLETT: str = "bartlett"
|
||||
BLACKMAN: str = "blackman"
|
||||
HAMMING: str = "hamming"
|
||||
HANN: str = "hann"
|
||||
KAISER: str = "kaiser"
|
||||
|
||||
|
||||
DE_C_NORM_MODE = {NormMode.ORTHO: cde.NormMode.DE_NORM_MODE_ORTHO,
|
||||
NormMode.NONE: cde.NormMode.DE_NORM_MODE_NONE}
|
||||
|
||||
|
||||
def create_dct(n_mfcc, n_mels, norm=NormMode.NONE):
|
||||
"""
|
||||
Create a DCT transformation matrix with shape (n_mels, n_mfcc), normalized depending on norm.
|
||||
|
||||
Args:
|
||||
n_mfcc (int): Number of mfc coefficients to retain, the value must be greater than 0.
|
||||
n_mels (int): Number of mel filterbanks, the value must be greater than 0.
|
||||
norm (NormMode): Normalization mode, can be NormMode.NONE or NormMode.ORTHO (default=NormMode.NONE).
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, the transformation matrix, to be right-multiplied to row-wise data of size (n_mels, n_mfcc).
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset.audio import create_dct
|
||||
>>>
|
||||
>>> dct = create_dct(100, 200, audio.NormMode.NONE)
|
||||
"""
|
||||
|
||||
if not isinstance(n_mfcc, int):
|
||||
raise TypeError("n_mfcc with value {0} is not of type {1}, but got {2}.".format(
|
||||
n_mfcc, int, type(n_mfcc)))
|
||||
if not isinstance(n_mels, int):
|
||||
raise TypeError("n_mels with value {0} is not of type {1}, but got {2}.".format(
|
||||
n_mels, int, type(n_mels)))
|
||||
if not isinstance(norm, NormMode):
|
||||
raise TypeError("norm with value {0} is not of type {1}, but got {2}.".format(
|
||||
norm, NormMode, type(norm)))
|
||||
if n_mfcc <= 0:
|
||||
raise ValueError("n_mfcc must be greater than 0, but got {0}.".format(n_mfcc))
|
||||
if n_mels <= 0:
|
||||
raise ValueError("n_mels must be greater than 0, but got {0}.".format(n_mels))
|
||||
return cde.create_dct(n_mfcc, n_mels, DE_C_NORM_MODE[norm]).as_array()
|
||||
|
||||
|
||||
DE_C_MEL_TYPE = {MelType.HTK: cde.MelType.DE_MEL_TYPE_HTK,
|
||||
MelType.SLANEY: cde.MelType.DE_MEL_TYPE_SLANEY}
|
||||
|
||||
DE_C_NORM_TYPE = {NormType.SLANEY: cde.NormType.DE_NORM_TYPE_SLANEY,
|
||||
NormType.NONE: cde.NormType.DE_NORM_TYPE_NONE}
|
||||
|
||||
|
||||
def melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm=NormType.NONE, mel_type=MelType.HTK):
|
||||
|
@ -162,7 +252,9 @@ def melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm=NormType.NO
|
|||
numpy.ndarray, the frequency transformation matrix.
|
||||
|
||||
Examples:
|
||||
>>> melscale_fbanks = audio.melscale_fbanks(n_freqs=4096, f_min=0, f_max=8000, n_mels=40, sample_rate=16000)
|
||||
>>> from mindspore.dataset.audio import melscale_fbanks
|
||||
>>>
|
||||
>>> fbanks = melscale_fbanks(n_freqs=4096, f_min=0, f_max=8000, n_mels=40, sample_rate=16000)
|
||||
"""
|
||||
|
||||
type_check(n_freqs, (int,), "n_freqs")
|
||||
|
@ -185,94 +277,5 @@ def melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, norm=NormType.NO
|
|||
|
||||
type_check(norm, (NormType,), "norm")
|
||||
type_check(mel_type, (MelType,), "mel_type")
|
||||
return cde.MelscaleFbanks(n_freqs, f_min, f_max, n_mels, sample_rate, DE_C_NORMTYPE_TYPE[norm],
|
||||
DE_C_MELTYPE_TYPE[mel_type]).as_array()
|
||||
|
||||
|
||||
class NormMode(str, Enum):
|
||||
"""
|
||||
Norm Types.
|
||||
|
||||
Possible enumeration values are: NormMode.NONE, NormMode.ORTHO.
|
||||
|
||||
- NormMode.NONE: means the mode of input audio is none.
|
||||
- NormMode.ORTHO: means the mode of input audio is ortho.
|
||||
"""
|
||||
NONE: str = "none"
|
||||
ORTHO: str = "ortho"
|
||||
|
||||
|
||||
DE_C_NORMMODE_TYPE = {NormMode.NONE: cde.NormMode.DE_NORMMODE_NONE,
|
||||
NormMode.ORTHO: cde.NormMode.DE_NORMMODE_ORTHO}
|
||||
|
||||
|
||||
def CreateDct(n_mfcc, n_mels, norm=NormMode.NONE):
|
||||
"""
|
||||
Create a DCT transformation matrix with shape (n_mels, n_mfcc), normalized depending on norm.
|
||||
|
||||
Args:
|
||||
n_mfcc (int): Number of mfc coefficients to retain, the value must be greater than 0.
|
||||
n_mels (int): Number of mel filterbanks, the value must be greater than 0.
|
||||
norm (NormMode): Normalization mode, can be NormMode.NONE or NormMode.ORTHO (default=NormMode.NONE).
|
||||
|
||||
Returns:
|
||||
numpy.ndarray, the transformation matrix, to be right-multiplied to row-wise data of size (n_mels, n_mfcc).
|
||||
|
||||
Examples:
|
||||
>>> dct = audio.CreateDct(100, 200, audio.NormMode.NONE)
|
||||
"""
|
||||
|
||||
if not isinstance(n_mfcc, int):
|
||||
raise TypeError("n_mfcc with value {0} is not of type {1}, but got {2}.".format(
|
||||
n_mfcc, int, type(n_mfcc)))
|
||||
if not isinstance(n_mels, int):
|
||||
raise TypeError("n_mels with value {0} is not of type {1}, but got {2}.".format(
|
||||
n_mels, int, type(n_mels)))
|
||||
if not isinstance(norm, NormMode):
|
||||
raise TypeError("norm with value {0} is not of type {1}, but got {2}.".format(
|
||||
norm, NormMode, type(norm)))
|
||||
if n_mfcc <= 0:
|
||||
raise ValueError("n_mfcc must be greater than 0, but got {0}.".format(n_mfcc))
|
||||
if n_mels <= 0:
|
||||
raise ValueError("n_mels must be greater than 0, but got {0}.".format(n_mels))
|
||||
return cde.CreateDct(n_mfcc, n_mels, DE_C_NORMMODE_TYPE[norm]).as_array()
|
||||
|
||||
|
||||
class BorderType(str, Enum):
|
||||
"""
|
||||
Padding Mode, BorderType Type.
|
||||
|
||||
Possible enumeration values are: BorderType.CONSTANT, BorderType.EDGE, BorderType.REFLECT, BorderType.SYMMETRIC.
|
||||
|
||||
- BorderType.CONSTANT: means it fills the border with constant values.
|
||||
- BorderType.EDGE: means it pads with the last value on the edge.
|
||||
- BorderType.REFLECT: means it reflects the values on the edge omitting the last value of edge.
|
||||
- BorderType.SYMMETRIC: means it reflects the values on the edge repeating the last value of edge.
|
||||
|
||||
Note: This class derived from class str to support json serializable.
|
||||
"""
|
||||
CONSTANT: str = "constant"
|
||||
EDGE: str = "edge"
|
||||
REFLECT: str = "reflect"
|
||||
SYMMETRIC: str = "symmetric"
|
||||
|
||||
|
||||
class WindowType(str, Enum):
|
||||
"""
|
||||
Window Function types,
|
||||
|
||||
Possible enumeration values are: WindowType.BARTLETT, WindowType.BLACKMAN, WindowType.HAMMING, WindowType.HANN,
|
||||
WindowType.KAISER.
|
||||
|
||||
- WindowType.BARTLETT: means the type of window function is bartlett.
|
||||
- WindowType.BLACKMAN: means the type of window function is blackman.
|
||||
- WindowType.HAMMING: means the type of window function is hamming.
|
||||
- WindowType.HANN: means the type of window function is hann.
|
||||
- WindowType.KAISER: means the type of window function is kaiser.
|
||||
Currently kaiser window is not supported on macOS.
|
||||
"""
|
||||
BARTLETT: str = "bartlett"
|
||||
BLACKMAN: str = "blackman"
|
||||
HAMMING: str = "hamming"
|
||||
HANN: str = "hann"
|
||||
KAISER: str = "kaiser"
|
||||
return cde.melscale_fbanks(n_freqs, f_min, f_max, n_mels, sample_rate, DE_C_NORM_TYPE[norm],
|
||||
DE_C_MEL_TYPE[mel_type]).as_array()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -24,10 +24,14 @@ from .validators import check_callback
|
|||
|
||||
class DSCallback:
|
||||
"""
|
||||
Abstract base class used to build a dataset callback class.
|
||||
Abstract base class used to build dataset callback classes.
|
||||
|
||||
Users can obtain the dataset pipeline context through `ds_run_context`, including
|
||||
`cur_epoch_num`, `cur_step_num_in_epoch` and `cur_step_num`.
|
||||
|
||||
Args:
|
||||
step_size (int, optional): The number of steps between the step_begin and step_end are called (Default=1).
|
||||
step_size (int, optional): The number of steps between adjacent `ds_step_begin`/`ds_step_end`
|
||||
calls. Default: 1, will be called at each step.
|
||||
|
||||
Examples:
|
||||
>>> from mindspore.dataset import DSCallback
|
||||
|
@ -37,7 +41,7 @@ class DSCallback:
|
|||
... print(cb_params.cur_epoch_num)
|
||||
... print(cb_params.cur_step_num)
|
||||
>>>
|
||||
>>> # dataset is an instance of Dataset object
|
||||
>>> # dataset is a MindSpore dataset object and op is a data processing operator
|
||||
>>> dataset = dataset.map(operations=op, callbacks=PrintInfo())
|
||||
"""
|
||||
|
||||
|
@ -50,7 +54,7 @@ class DSCallback:
|
|||
Called before the data pipeline is started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
ds_run_context (RunContext): Include some information of the data pipeline.
|
||||
"""
|
||||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
|
@ -58,7 +62,7 @@ class DSCallback:
|
|||
Called before a new epoch is started.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
ds_run_context (RunContext): Include some information of the data pipeline.
|
||||
"""
|
||||
|
||||
def ds_epoch_end(self, ds_run_context):
|
||||
|
@ -66,28 +70,28 @@ class DSCallback:
|
|||
Called after an epoch is finished.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
ds_run_context (RunContext): Include some information of the data pipeline.
|
||||
"""
|
||||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
"""
|
||||
Called before each step start.
|
||||
Called before a step start.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
ds_run_context (RunContext): Include some information of the data pipeline.
|
||||
"""
|
||||
|
||||
def ds_step_end(self, ds_run_context):
|
||||
"""
|
||||
Called after each step finished.
|
||||
Called after a step finished.
|
||||
|
||||
Args:
|
||||
ds_run_context (RunContext): Include some information of the pipeline.
|
||||
ds_run_context (RunContext): Include some information of the data pipeline.
|
||||
"""
|
||||
|
||||
def create_runtime_obj(self):
|
||||
"""
|
||||
Creates a runtime (C++) object from the callback methods defined by the user.
|
||||
Internal method, creates a runtime (C++) object from the callback methods defined by the user.
|
||||
|
||||
Returns:
|
||||
_c_dataengine.PyDSCallback.
|
||||
|
@ -122,24 +126,93 @@ class DSCallback:
|
|||
|
||||
class WaitedDSCallback(Callback, DSCallback):
|
||||
"""
|
||||
Abstract base class used to build a dataset callback class that is synchronized with the training callback.
|
||||
Abstract base class used to build dataset callback classes that are synchronized with the training callback class
|
||||
`mindspore.train.callback <https://mindspore.cn/docs/api/en/master/api_python/
|
||||
mindspore.train.html#mindspore.train.callback.Callback>`_.
|
||||
|
||||
This class can be used to execute a user defined logic right after the previous step or epoch.
|
||||
For example, one augmentation needs the loss from the previous trained epoch to update some of its parameters.
|
||||
It can be used to execute a custom callback method before a step or an epoch, such as
|
||||
updating the parameters of operators according to the loss of the previous training epoch in auto augmentation.
|
||||
|
||||
Note that the call is triggered only at the beginning of the second step or epoch.
|
||||
|
||||
Users can obtain the network training context through `train_run_context`, such as
|
||||
`network`, `train_network`, `epoch_num`, `batch_num`, `loss_fn`, `optimizer`, `parallel_mode`,
|
||||
`device_number`, `list_callback`, `cur_epoch_num`, `cur_step_num`, `dataset_sink_mode`,
|
||||
`net_outputs`, etc., see
|
||||
`mindspore.train.callback <https://mindspore.cn/docs/api/en/master/api_python/
|
||||
mindspore.train.html#mindspore.train.callback.Callback>`_.
|
||||
|
||||
Users can obtain the dataset pipeline context through `ds_run_context`, including
|
||||
`cur_epoch_num`, `cur_step_num_in_epoch` and `cur_step_num`.
|
||||
|
||||
Args:
|
||||
step_size (int, optional): The number of rows in each step. Usually the step size
|
||||
will be equal to the batch size (Default=1).
|
||||
step_size (int, optional): The number of rows in each step, usually set equal to the batch size. Default: 1.
|
||||
|
||||
Examples:
|
||||
>>> import mindspore.nn as nn
|
||||
>>> from mindspore.dataset import WaitedDSCallback
|
||||
>>> from mindspore import context
|
||||
>>> from mindspore.train import Model
|
||||
>>> from mindspore.train.callback import Callback
|
||||
>>>
|
||||
>>> my_cb = WaitedDSCallback(32)
|
||||
>>> # dataset is an instance of Dataset object
|
||||
>>> dataset = dataset.map(operations=AugOp(), callbacks=my_cb)
|
||||
>>> dataset = dataset.batch(32)
|
||||
>>> # define the model
|
||||
>>> model.train(epochs, data, callbacks=[my_cb])
|
||||
>>> context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
>>>
|
||||
>>> # custom callback class for data synchronization in data pipeline
|
||||
>>> class MyWaitedCallback(WaitedDSCallback):
|
||||
... def __init__(self, events, step_size=1):
|
||||
... super().__init__(step_size)
|
||||
... self.events = events
|
||||
...
|
||||
... # callback method to be executed by data pipeline before the epoch starts
|
||||
... def sync_epoch_begin(self, train_run_context, ds_run_context):
|
||||
... event = f"ds_epoch_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||
... self.events.append(event)
|
||||
...
|
||||
... # callback method to be executed by data pipeline before the step starts
|
||||
... def sync_step_begin(self, train_run_context, ds_run_context):
|
||||
... event = f"ds_step_begin_{ds_run_context.cur_epoch_num}_{ds_run_context.cur_step_num}"
|
||||
... self.events.append(event)
|
||||
>>>
|
||||
>>> # custom callback class for data synchronization in network training
|
||||
>>> class MyMSCallback(Callback):
|
||||
... def __init__(self, events):
|
||||
... self.events = events
|
||||
...
|
||||
... # callback method to be executed by network training after the epoch ends
|
||||
... def epoch_end(self, run_context):
|
||||
... cb_params = run_context.original_args()
|
||||
... event = f"ms_epoch_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||
... self.events.append(event)
|
||||
...
|
||||
... # callback method to be executed by network training after the step ends
|
||||
... def step_end(self, run_context):
|
||||
... cb_params = run_context.original_args()
|
||||
... event = f"ms_step_end_{cb_params.cur_epoch_num}_{cb_params.cur_step_num}"
|
||||
... self.events.append(event)
|
||||
>>>
|
||||
>>> # custom network
|
||||
>>> class Net(nn.Cell):
|
||||
... def construct(self, x, y):
|
||||
... return x
|
||||
>>>
|
||||
>>> # define a parameter that needs to be synchronized between data pipeline and network training
|
||||
>>> events = []
|
||||
>>>
|
||||
>>> # define callback classes of data pipeline and netwok training
|
||||
>>> my_cb1 = MyWaitedCallback(events, 1)
|
||||
>>> my_cb2 = MyMSCallback(events)
|
||||
>>> arr = [1, 2, 3, 4]
|
||||
>>>
|
||||
>>> # construct data pipeline
|
||||
>>> data = ds.NumpySlicesDataset((arr, arr), column_names=["c1", "c2"], shuffle=False)
|
||||
>>> # map the data callback object into the pipeline
|
||||
>>> data = data.map(operations=(lambda x: x), callbacks=my_cb1)
|
||||
>>>
|
||||
>>> net = Net()
|
||||
>>> model = Model(net)
|
||||
>>>
|
||||
>>> # add the data and network callback objects to the model training callback list
|
||||
>>> model.train(2, data, dataset_sink_mode=False, callbacks=[my_cb2, my_cb1])
|
||||
"""
|
||||
|
||||
def __init__(self, step_size=1):
|
||||
|
@ -159,7 +232,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
|
||||
Args:
|
||||
train_run_context: Include some information of the model with feedback from the previous epoch.
|
||||
ds_run_context: Include some information of the dataset pipeline.
|
||||
ds_run_context: Include some information of the data pipeline.
|
||||
"""
|
||||
|
||||
def sync_step_begin(self, train_run_context, ds_run_context):
|
||||
|
@ -168,7 +241,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
|
||||
Args:
|
||||
train_run_context: Include some information of the model with feedback from the previous step.
|
||||
ds_run_context: Include some information of the dataset pipeline.
|
||||
ds_run_context: Include some information of the data pipeline.
|
||||
"""
|
||||
|
||||
def epoch_end(self, run_context):
|
||||
|
@ -183,10 +256,11 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
|
||||
def ds_epoch_begin(self, ds_run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines ds_epoch_begin of DSCallback to wait for MS epoch_end callback.
|
||||
Internal method, do not call/override. Define mindspore.dataset.DSCallback.ds_epoch_begin
|
||||
to wait for mindspore.train.callback.Callback.epoch_end.
|
||||
|
||||
Args:
|
||||
ds_run_context: Include some information of the pipeline.
|
||||
ds_run_context: Include some information of the data pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_epoch_num > 1:
|
||||
if not self.training_ended:
|
||||
|
@ -209,10 +283,11 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
|
||||
def ds_step_begin(self, ds_run_context):
|
||||
"""
|
||||
Internal method, do not call/override. Defines ds_step_begin of DSCallback to wait for MS step_end callback.
|
||||
Internal method, do not call/override. Define mindspore.dataset.DSCallback.ds_step_begin
|
||||
to wait for mindspore.train.callback.Callback.step_end.
|
||||
|
||||
Args:
|
||||
ds_run_context: Include some information of the pipeline.
|
||||
ds_run_context: Include some information of the data pipeline.
|
||||
"""
|
||||
if ds_run_context.cur_step_num > self.step_size:
|
||||
if not self.training_ended:
|
||||
|
@ -225,7 +300,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
|
||||
def create_runtime_obj(self):
|
||||
"""
|
||||
Creates a runtime (C++) object from the callback methods defined by the user. This method is internal.
|
||||
Internal method, creates a runtime (C++) object from the callback methods defined by the user.
|
||||
|
||||
Returns:
|
||||
_c_dataengine.PyDSCallback.
|
||||
|
@ -249,7 +324,7 @@ class WaitedDSCallback(Callback, DSCallback):
|
|||
|
||||
def end(self, run_context):
|
||||
"""
|
||||
Internal method, release the wait if training is ended.
|
||||
Internal method, release wait when the network training ends.
|
||||
|
||||
Args:
|
||||
run_context: Include some information of the model.
|
||||
|
|
|
@ -410,7 +410,7 @@ class CoNLL2000Dataset(SourceDataset, TextBaseDataset):
|
|||
|
||||
Examples:
|
||||
>>> conll2000_dataset_dir = "/path/to/conll2000_dataset_dir"
|
||||
>>> dataset = ds.CoNLL2000Dataset(dataset_files=conll2000_dataset_dir, usage='all')
|
||||
>>> dataset = ds.CoNLL2000Dataset(dataset_dir=conll2000_dataset_dir, usage='all')
|
||||
"""
|
||||
|
||||
@check_conll2000_dataset
|
||||
|
@ -786,7 +786,7 @@ class IWSLT2016Dataset(SourceDataset, TextBaseDataset):
|
|||
|
||||
Examples:
|
||||
>>> iwslt2016_dataset_dir = "/path/to/iwslt2016_dataset_dir"
|
||||
>>> dataset = ds.IWSLT2016Dataset(dataset_files=iwslt2016_dataset_dir, usage='all',
|
||||
>>> dataset = ds.IWSLT2016Dataset(dataset_dir=iwslt2016_dataset_dir, usage='all',
|
||||
... language_pair=('de', 'en'), valid_set='tst2013', test_set='tst2014')
|
||||
|
||||
About IWSLT2016 dataset:
|
||||
|
@ -907,7 +907,7 @@ class IWSLT2017Dataset(SourceDataset, TextBaseDataset):
|
|||
|
||||
Examples:
|
||||
>>> iwslt2017_dataset_dir = "/path/to/iwslt207_dataset_dir"
|
||||
>>> dataset = ds.IWSLT2017Dataset(dataset_files=iwslt2017_dataset_dir, usage='all', language_pair=('de', 'en'))
|
||||
>>> dataset = ds.IWSLT2017Dataset(dataset_dir=iwslt2017_dataset_dir, usage='all', language_pair=('de', 'en'))
|
||||
|
||||
About IWSLT2017 dataset:
|
||||
|
||||
|
@ -1092,7 +1092,7 @@ class SogouNewsDataset(SourceDataset, TextBaseDataset):
|
|||
|
||||
Examples:
|
||||
>>> sogou_news_dataset_dir = "/path/to/sogou_news_dataset_dir"
|
||||
>>> dataset = ds.SogouNewsDataset(dataset_files=sogou_news_dataset_dir, usage='all')
|
||||
>>> dataset = ds.SogouNewsDataset(dataset_dir=sogou_news_dataset_dir, usage='all')
|
||||
|
||||
About SogouNews Dataset:
|
||||
|
||||
|
@ -1234,7 +1234,7 @@ class UDPOSDataset(SourceDataset, TextBaseDataset):
|
|||
|
||||
Examples:
|
||||
>>> udpos_dataset_dir = "/path/to/udpos_dataset_dir"
|
||||
>>> dataset = ds.UDPOSDataset(dataset_files=udpos_dataset_dir, usage='all')
|
||||
>>> dataset = ds.UDPOSDataset(dataset_dir=udpos_dataset_dir, usage='all')
|
||||
"""
|
||||
|
||||
@check_udpos_dataset
|
||||
|
|
|
@ -97,27 +97,27 @@ DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBO
|
|||
DE_C_SLICE_MODE = {SliceMode.PAD: cde.SliceMode.DE_SLICE_PAD,
|
||||
SliceMode.DROP: cde.SliceMode.DE_SLICE_DROP}
|
||||
|
||||
DE_C_CONVERTCOLOR_MODE = {ConvertMode.COLOR_BGR2BGRA: cde.ConvertMode.DE_COLOR_BGR2BGRA,
|
||||
ConvertMode.COLOR_RGB2RGBA: cde.ConvertMode.DE_COLOR_RGB2RGBA,
|
||||
ConvertMode.COLOR_BGRA2BGR: cde.ConvertMode.DE_COLOR_BGRA2BGR,
|
||||
ConvertMode.COLOR_RGBA2RGB: cde.ConvertMode.DE_COLOR_RGBA2RGB,
|
||||
ConvertMode.COLOR_BGR2RGBA: cde.ConvertMode.DE_COLOR_BGR2RGBA,
|
||||
ConvertMode.COLOR_RGB2BGRA: cde.ConvertMode.DE_COLOR_RGB2BGRA,
|
||||
ConvertMode.COLOR_RGBA2BGR: cde.ConvertMode.DE_COLOR_RGBA2BGR,
|
||||
ConvertMode.COLOR_BGRA2RGB: cde.ConvertMode.DE_COLOR_BGRA2RGB,
|
||||
ConvertMode.COLOR_BGR2RGB: cde.ConvertMode.DE_COLOR_BGR2RGB,
|
||||
ConvertMode.COLOR_RGB2BGR: cde.ConvertMode.DE_COLOR_RGB2BGR,
|
||||
ConvertMode.COLOR_BGRA2RGBA: cde.ConvertMode.DE_COLOR_BGRA2RGBA,
|
||||
ConvertMode.COLOR_RGBA2BGRA: cde.ConvertMode.DE_COLOR_RGBA2BGRA,
|
||||
ConvertMode.COLOR_BGR2GRAY: cde.ConvertMode.DE_COLOR_BGR2GRAY,
|
||||
ConvertMode.COLOR_RGB2GRAY: cde.ConvertMode.DE_COLOR_RGB2GRAY,
|
||||
ConvertMode.COLOR_GRAY2BGR: cde.ConvertMode.DE_COLOR_GRAY2BGR,
|
||||
ConvertMode.COLOR_GRAY2RGB: cde.ConvertMode.DE_COLOR_GRAY2RGB,
|
||||
ConvertMode.COLOR_GRAY2BGRA: cde.ConvertMode.DE_COLOR_GRAY2BGRA,
|
||||
ConvertMode.COLOR_GRAY2RGBA: cde.ConvertMode.DE_COLOR_GRAY2RGBA,
|
||||
ConvertMode.COLOR_BGRA2GRAY: cde.ConvertMode.DE_COLOR_BGRA2GRAY,
|
||||
ConvertMode.COLOR_RGBA2GRAY: cde.ConvertMode.DE_COLOR_RGBA2GRAY,
|
||||
}
|
||||
DE_C_CONVERT_COLOR_MODE = {ConvertMode.COLOR_BGR2BGRA: cde.ConvertMode.DE_COLOR_BGR2BGRA,
|
||||
ConvertMode.COLOR_RGB2RGBA: cde.ConvertMode.DE_COLOR_RGB2RGBA,
|
||||
ConvertMode.COLOR_BGRA2BGR: cde.ConvertMode.DE_COLOR_BGRA2BGR,
|
||||
ConvertMode.COLOR_RGBA2RGB: cde.ConvertMode.DE_COLOR_RGBA2RGB,
|
||||
ConvertMode.COLOR_BGR2RGBA: cde.ConvertMode.DE_COLOR_BGR2RGBA,
|
||||
ConvertMode.COLOR_RGB2BGRA: cde.ConvertMode.DE_COLOR_RGB2BGRA,
|
||||
ConvertMode.COLOR_RGBA2BGR: cde.ConvertMode.DE_COLOR_RGBA2BGR,
|
||||
ConvertMode.COLOR_BGRA2RGB: cde.ConvertMode.DE_COLOR_BGRA2RGB,
|
||||
ConvertMode.COLOR_BGR2RGB: cde.ConvertMode.DE_COLOR_BGR2RGB,
|
||||
ConvertMode.COLOR_RGB2BGR: cde.ConvertMode.DE_COLOR_RGB2BGR,
|
||||
ConvertMode.COLOR_BGRA2RGBA: cde.ConvertMode.DE_COLOR_BGRA2RGBA,
|
||||
ConvertMode.COLOR_RGBA2BGRA: cde.ConvertMode.DE_COLOR_RGBA2BGRA,
|
||||
ConvertMode.COLOR_BGR2GRAY: cde.ConvertMode.DE_COLOR_BGR2GRAY,
|
||||
ConvertMode.COLOR_RGB2GRAY: cde.ConvertMode.DE_COLOR_RGB2GRAY,
|
||||
ConvertMode.COLOR_GRAY2BGR: cde.ConvertMode.DE_COLOR_GRAY2BGR,
|
||||
ConvertMode.COLOR_GRAY2RGB: cde.ConvertMode.DE_COLOR_GRAY2RGB,
|
||||
ConvertMode.COLOR_GRAY2BGRA: cde.ConvertMode.DE_COLOR_GRAY2BGRA,
|
||||
ConvertMode.COLOR_GRAY2RGBA: cde.ConvertMode.DE_COLOR_GRAY2RGBA,
|
||||
ConvertMode.COLOR_BGRA2GRAY: cde.ConvertMode.DE_COLOR_BGRA2GRAY,
|
||||
ConvertMode.COLOR_RGBA2GRAY: cde.ConvertMode.DE_COLOR_RGBA2GRAY,
|
||||
}
|
||||
|
||||
|
||||
def parse_padding(padding):
|
||||
|
@ -165,6 +165,7 @@ class AdjustGamma(ImageTensorOperation):
|
|||
>>> image_folder_dataset = image_folder_dataset.map(operations=transforms_list,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_adjust_gamma
|
||||
def __init__(self, gamma, gain=1):
|
||||
self.gamma = gamma
|
||||
|
@ -426,12 +427,13 @@ class ConvertColor(ImageTensorOperation):
|
|||
>>> image_folder_dataset_1 = image_folder_dataset_1.map(operations=convert_op,
|
||||
... input_columns=["image"])
|
||||
"""
|
||||
|
||||
@check_convert_color
|
||||
def __init__(self, convert_mode):
|
||||
self.convert_mode = convert_mode
|
||||
|
||||
def parse(self):
|
||||
return cde.ConvertColorOperation(DE_C_CONVERTCOLOR_MODE[self.convert_mode])
|
||||
return cde.ConvertColorOperation(DE_C_CONVERT_COLOR_MODE[self.convert_mode])
|
||||
|
||||
|
||||
class Crop(ImageTensorOperation):
|
||||
|
|
|
@ -136,7 +136,7 @@ class SliceMode(IntEnum):
|
|||
DROP = 1
|
||||
|
||||
|
||||
class AutoAugmentPolicy(IntEnum):
|
||||
class AutoAugmentPolicy(str, Enum):
|
||||
"""
|
||||
AutoAugment policy for different datasets.
|
||||
|
||||
|
@ -195,6 +195,6 @@ class AutoAugmentPolicy(IntEnum):
|
|||
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
|
||||
(("ShearX", 0.7, 2), ("Invert", 0.1, None))]
|
||||
"""
|
||||
IMAGENET = 0
|
||||
CIFAR10 = 1
|
||||
SVHN = 2
|
||||
IMAGENET: str = "imagenet"
|
||||
CIFAR10: str = "cifar10"
|
||||
SVHN: str = "svhn"
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.dataset.audio.utils as audio
|
||||
from mindspore.dataset.audio import create_dct, NormMode
|
||||
from mindspore import log as logger
|
||||
|
||||
|
||||
|
@ -40,7 +40,7 @@ def test_create_dct_none():
|
|||
[2.00000000, 0.76536685],
|
||||
[2.00000000, -0.76536703],
|
||||
[2.00000000, -1.84775925]], dtype=np.float64)
|
||||
output = audio.CreateDct(2, 4, audio.NormMode.NONE)
|
||||
output = create_dct(2, 4, NormMode.NONE)
|
||||
count_unequal_element(expect, output, 0.0001, 0.0001)
|
||||
|
||||
|
||||
|
@ -50,7 +50,7 @@ def test_create_dct_ortho():
|
|||
Description: test CreateDct in eager mode
|
||||
Expectation: the returned result is as expected
|
||||
"""
|
||||
output = audio.CreateDct(1, 3, audio.NormMode.ORTHO)
|
||||
output = create_dct(1, 3, NormMode.ORTHO)
|
||||
expect = np.array([[0.57735026],
|
||||
[0.57735026],
|
||||
[0.57735026]], dtype=np.float64)
|
||||
|
@ -66,24 +66,24 @@ def test_createdct_invalid_input():
|
|||
def test_invalid_input(test_name, n_mfcc, n_mels, norm, error, error_msg):
|
||||
logger.info("Test CreateDct with bad input: {0}".format(test_name))
|
||||
with pytest.raises(error) as error_info:
|
||||
audio.CreateDct(n_mfcc, n_mels, norm)
|
||||
create_dct(n_mfcc, n_mels, norm)
|
||||
assert error_msg in str(error_info.value)
|
||||
|
||||
test_invalid_input("invalid n_mfcc parameter type as a float", 100.5, 200, audio.NormMode.NONE, TypeError,
|
||||
test_invalid_input("invalid n_mfcc parameter type as a float", 100.5, 200, NormMode.NONE, TypeError,
|
||||
"n_mfcc with value 100.5 is not of type <class 'int'>, but got <class 'float'>.")
|
||||
test_invalid_input("invalid n_mfcc parameter type as a String", "100", 200, audio.NormMode.NONE, TypeError,
|
||||
test_invalid_input("invalid n_mfcc parameter type as a String", "100", 200, NormMode.NONE, TypeError,
|
||||
"n_mfcc with value 100 is not of type <class 'int'>, but got <class 'str'>.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 100, "200", audio.NormMode.NONE, TypeError,
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 100, "200", NormMode.NONE, TypeError,
|
||||
"n_mels with value 200 is not of type <class 'int'>, but got <class 'str'>.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 0, 200, audio.NormMode.NONE, ValueError,
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 0, 200, NormMode.NONE, ValueError,
|
||||
"n_mfcc must be greater than 0, but got 0.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 100, 0, audio.NormMode.NONE, ValueError,
|
||||
test_invalid_input("invalid n_mels parameter type as a String", 100, 0, NormMode.NONE, ValueError,
|
||||
"n_mels must be greater than 0, but got 0.")
|
||||
test_invalid_input("invalid n_mels parameter type as a String", -100, 200, audio.NormMode.NONE, ValueError,
|
||||
test_invalid_input("invalid n_mels parameter type as a String", -100, 200, NormMode.NONE, ValueError,
|
||||
"n_mfcc must be greater than 0, but got -100.")
|
||||
test_invalid_input("invalid n_mfcc parameter value", None, 100, audio.NormMode.NONE, TypeError,
|
||||
test_invalid_input("invalid n_mfcc parameter value", None, 100, NormMode.NONE, TypeError,
|
||||
"n_mfcc with value None is not of type <class 'int'>, but got <class 'NoneType'>.")
|
||||
test_invalid_input("invalid n_mels parameter value", 100, None, audio.NormMode.NONE, TypeError,
|
||||
test_invalid_input("invalid n_mels parameter value", 100, None, NormMode.NONE, TypeError,
|
||||
"n_mels with value None is not of type <class 'int'>, but got <class 'NoneType'>.")
|
||||
test_invalid_input("invalid n_mels parameter value", 100, 200, "None", TypeError,
|
||||
"norm with value None is not of type <enum 'NormMode'>, but got <class 'str'>.")
|
||||
|
|
|
@ -101,7 +101,7 @@ def test_fade_quarter_sine():
|
|||
[5, 7, 3, 78, 8, 4],
|
||||
[1, 2, 3, 4, 5, 6]]], dtype=np.float64)
|
||||
dataset = ds.NumpySlicesDataset(data=waveform, column_names='audio', shuffle=False)
|
||||
transforms = [audio.Fade(fade_in_len=6, fade_out_len=6, fade_shape=FadeShape.QUARTERSINE)]
|
||||
transforms = [audio.Fade(fade_in_len=6, fade_out_len=6, fade_shape=FadeShape.QUARTER_SINE)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
|
@ -124,7 +124,7 @@ def test_fade_half_sine():
|
|||
[0.04125976562500, 0.060577392578125, 0.0499572753906250,
|
||||
0.01306152343750, -0.019683837890625, -0.018829345703125]]]
|
||||
dataset = ds.NumpySlicesDataset(data=waveform, column_names='audio', shuffle=False)
|
||||
transforms = [audio.Fade(fade_in_len=3, fade_out_len=3, fade_shape=FadeShape.HALFSINE)]
|
||||
transforms = [audio.Fade(fade_in_len=3, fade_out_len=3, fade_shape=FadeShape.HALF_SINE)]
|
||||
dataset = dataset.map(operations=transforms, input_columns=["audio"])
|
||||
|
||||
for item in dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
|
|
Loading…
Reference in New Issue