fix codex on 1.6
This commit is contained in:
parent
0e1dcb50db
commit
96cec6da4f
|
@ -673,7 +673,8 @@ Status Execute::Run(const std::vector<std::shared_ptr<dataset::Execute>> &data_g
|
|||
extern "C" {
|
||||
// ExecuteRun_C has C-linkage specified, but returns user-defined type 'mindspore::Status' which is incompatible with C
|
||||
void ExecuteRun_C(const std::vector<std::shared_ptr<dataset::Execute>> &data_graph,
|
||||
std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs, Status *s) {
|
||||
const std::vector<mindspore::MSTensor> &inputs, std::vector<mindspore::MSTensor> *outputs,
|
||||
Status *s) {
|
||||
Status ret = Execute::Run(data_graph, inputs, outputs);
|
||||
if (s == nullptr) {
|
||||
return;
|
||||
|
|
|
@ -36,7 +36,7 @@ class AllpassBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
AllpassBiquadOperation(int32_t sample_rate, float central_freq, float Q);
|
||||
|
||||
~AllpassBiquadOperation() = default;
|
||||
~AllpassBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class AmplitudeToDBOperation : public TensorOperation {
|
|||
public:
|
||||
AmplitudeToDBOperation(ScaleType stype, float ref_value, float amin, float top_db);
|
||||
|
||||
~AmplitudeToDBOperation();
|
||||
~AmplitudeToDBOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class AngleOperation : public TensorOperation {
|
|||
public:
|
||||
AngleOperation();
|
||||
|
||||
~AngleOperation() = default;
|
||||
~AngleOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class BandBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
BandBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool noise);
|
||||
|
||||
~BandBiquadOperation() = default;
|
||||
~BandBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class BandpassBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
BandpassBiquadOperation(int32_t sample_rate, float central_freq, float Q, bool const_skirt_gain);
|
||||
|
||||
~BandpassBiquadOperation() = default;
|
||||
~BandpassBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class BandrejectBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
BandrejectBiquadOperation(int32_t sample_rate, float central_freq, float Q);
|
||||
|
||||
~BandrejectBiquadOperation() = default;
|
||||
~BandrejectBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class BassBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
BassBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q);
|
||||
|
||||
~BassBiquadOperation() = default;
|
||||
~BassBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class BiquadOperation : public TensorOperation {
|
|||
public:
|
||||
BiquadOperation(float b0, float b1, float b2, float a0, float a1, float a2);
|
||||
|
||||
~BiquadOperation() = default;
|
||||
~BiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class ComplexNormOperation : public TensorOperation {
|
|||
public:
|
||||
explicit ComplexNormOperation(float power);
|
||||
|
||||
~ComplexNormOperation();
|
||||
~ComplexNormOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class ComputeDeltasOperation : public TensorOperation {
|
|||
public:
|
||||
ComputeDeltasOperation(int32_t win_length, BorderType pad_mode);
|
||||
|
||||
~ComputeDeltasOperation() = default;
|
||||
~ComputeDeltasOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class ContrastOperation : public TensorOperation {
|
|||
public:
|
||||
explicit ContrastOperation(float enhancement_amount);
|
||||
|
||||
~ContrastOperation() = default;
|
||||
~ContrastOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class DBToAmplitudeOperation : public TensorOperation {
|
|||
public:
|
||||
DBToAmplitudeOperation(float ref, float power);
|
||||
|
||||
~DBToAmplitudeOperation() = default;
|
||||
~DBToAmplitudeOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class DCShiftOperation : public TensorOperation {
|
|||
public:
|
||||
DCShiftOperation(float shift, float limiter_gain);
|
||||
|
||||
~DCShiftOperation() = default;
|
||||
~DCShiftOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class DeemphBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
explicit DeemphBiquadOperation(int32_t sample_rate);
|
||||
|
||||
~DeemphBiquadOperation() = default;
|
||||
~DeemphBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class DetectPitchFrequencyOperation : public TensorOperation {
|
|||
DetectPitchFrequencyOperation(int32_t sample_rate, float frame_time, int32_t win_length, int32_t freq_low,
|
||||
int32_t freq_high);
|
||||
|
||||
~DetectPitchFrequencyOperation() = default;
|
||||
~DetectPitchFrequencyOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class DitherOperation : public TensorOperation {
|
|||
public:
|
||||
DitherOperation(DensityFunction density_function, bool noise_shaping);
|
||||
|
||||
~DitherOperation() = default;
|
||||
~DitherOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class EqualizerBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
EqualizerBiquadOperation(int32_t sample_rate, float center_freq, float gain, float Q);
|
||||
|
||||
~EqualizerBiquadOperation() = default;
|
||||
~EqualizerBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class FadeOperation : public TensorOperation {
|
|||
public:
|
||||
FadeOperation(int32_t fade_in_len, int32_t fade_out_len, FadeShape fade_shape);
|
||||
|
||||
~FadeOperation() = default;
|
||||
~FadeOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class FlangerOperation : public TensorOperation {
|
|||
explicit FlangerOperation(int32_t sample_rate, float delay, float depth, float regen, float width, float speed,
|
||||
float phase, Modulation modulation, Interpolation interpolation);
|
||||
|
||||
~FlangerOperation() = default;
|
||||
~FlangerOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class FrequencyMaskingOperation : public TensorOperation {
|
|||
public:
|
||||
FrequencyMaskingOperation(bool iid_masks, int32_t frequency_mask_param, int32_t mask_start, float mask_value);
|
||||
|
||||
~FrequencyMaskingOperation();
|
||||
~FrequencyMaskingOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
@ -43,9 +43,9 @@ class FrequencyMaskingOperation : public TensorOperation {
|
|||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
private:
|
||||
bool iid_masks_;
|
||||
int32_t frequency_mask_param_;
|
||||
int32_t mask_start_;
|
||||
bool iid_masks_;
|
||||
float mask_value_;
|
||||
}; // class FrequencyMaskingOperation
|
||||
} // namespace audio
|
||||
|
|
|
@ -35,7 +35,7 @@ class GainOperation : public TensorOperation {
|
|||
public:
|
||||
explicit GainOperation(float gain_db);
|
||||
|
||||
~GainOperation() = default;
|
||||
~GainOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class HighpassBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
HighpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q);
|
||||
|
||||
~HighpassBiquadOperation() = default;
|
||||
~HighpassBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class LFilterOperation : public TensorOperation {
|
|||
public:
|
||||
LFilterOperation(const std::vector<float> &a_coeffs, const std::vector<float> &b_coeffs, bool clamp);
|
||||
|
||||
~LFilterOperation() = default;
|
||||
~LFilterOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class LowpassBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
LowpassBiquadOperation(int32_t sample_rate, float cutoff_freq, float Q);
|
||||
|
||||
~LowpassBiquadOperation() = default;
|
||||
~LowpassBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class MagphaseOperation : public TensorOperation {
|
|||
public:
|
||||
explicit MagphaseOperation(float power);
|
||||
|
||||
~MagphaseOperation() = default;
|
||||
~MagphaseOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class MuLawDecodingOperation : public TensorOperation {
|
|||
public:
|
||||
explicit MuLawDecodingOperation(int32_t quantization_channels);
|
||||
|
||||
~MuLawDecodingOperation();
|
||||
~MuLawDecodingOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class MuLawEncodingOperation : public TensorOperation {
|
|||
public:
|
||||
explicit MuLawEncodingOperation(int32_t quantization_channels);
|
||||
|
||||
~MuLawEncodingOperation();
|
||||
~MuLawEncodingOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class OverdriveOperation : public TensorOperation {
|
|||
public:
|
||||
explicit OverdriveOperation(float gain, float color);
|
||||
|
||||
~OverdriveOperation() = default;
|
||||
~OverdriveOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -38,7 +38,7 @@ class PhaserOperation : public TensorOperation {
|
|||
PhaserOperation(int32_t sample_rate, float gain_in, float gain_out, float delay_ms, float decay, float mod_speed,
|
||||
bool sinusoidal);
|
||||
|
||||
~PhaserOperation() = default;
|
||||
~PhaserOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class RiaaBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
explicit RiaaBiquadOperation(int32_t sample_rate);
|
||||
|
||||
~RiaaBiquadOperation() = default;
|
||||
~RiaaBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class SlidingWindowCmnOperation : public TensorOperation {
|
|||
public:
|
||||
SlidingWindowCmnOperation(int32_t cmn_window, int32_t min_cmn_window, bool center, bool norm_vars);
|
||||
|
||||
~SlidingWindowCmnOperation();
|
||||
~SlidingWindowCmnOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class SpectralCentroidOperation : public TensorOperation {
|
|||
SpectralCentroidOperation(int32_t sample_rate, int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad,
|
||||
WindowType window);
|
||||
|
||||
~SpectralCentroidOperation() = default;
|
||||
~SpectralCentroidOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class SpectrogramOperation : public TensorOperation {
|
|||
SpectrogramOperation(int32_t n_fft, int32_t win_length, int32_t hop_length, int32_t pad, WindowType window,
|
||||
float power, bool normalized, bool center, BorderType pad_mode, bool onesided);
|
||||
|
||||
~SpectrogramOperation() = default;
|
||||
~SpectrogramOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class TimeMaskingOperation : public TensorOperation {
|
|||
public:
|
||||
TimeMaskingOperation(bool iid_masks, int32_t time_mask_param, int32_t mask_start, float mask_value);
|
||||
|
||||
~TimeMaskingOperation();
|
||||
~TimeMaskingOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ class TimeStretchOperation : public TensorOperation {
|
|||
public:
|
||||
TimeStretchOperation(float hop_length, int n_freq, float fixed_rate);
|
||||
|
||||
~TimeStretchOperation();
|
||||
~TimeStretchOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class TrebleBiquadOperation : public TensorOperation {
|
|||
public:
|
||||
TrebleBiquadOperation(int32_t sample_rate, float gain, float central_freq, float Q);
|
||||
|
||||
~TrebleBiquadOperation() = default;
|
||||
~TrebleBiquadOperation() override = default;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class VolOperation : public TensorOperation {
|
|||
public:
|
||||
VolOperation(float gain, GainType gain_type);
|
||||
|
||||
~VolOperation();
|
||||
~VolOperation() override;
|
||||
|
||||
std::shared_ptr<TensorOp> Build() override;
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ namespace dataset {
|
|||
/// \param[out] output - Tensor has n points with linearly space. The spacing between the points is (end-start)/(n-1).
|
||||
/// \return Status return code.
|
||||
template <typename T>
|
||||
Status Linspace(std::shared_ptr<Tensor> *output, T start, T end, int n) {
|
||||
Status Linspace(std::shared_ptr<Tensor> *output, T start, T end, int32_t n) {
|
||||
RETURN_IF_NOT_OK(ValidateNoGreaterThan("Linspace", "start", start, "end", end));
|
||||
n = std::isnan(n) ? 100 : n;
|
||||
TensorShape out_shape({n});
|
||||
|
@ -61,7 +61,7 @@ Status ComplexAngle(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor
|
|||
TensorShape input_shape = input->shape();
|
||||
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2]});
|
||||
std::vector<T> phase(input_shape[0] * input_shape[1] * input_shape[2]);
|
||||
int ind = 0;
|
||||
size_t ind = 0;
|
||||
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
|
||||
auto x = (*itr);
|
||||
|
@ -89,7 +89,7 @@ Status ComplexAbs(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
|
|||
TensorShape input_shape = input->shape();
|
||||
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2]});
|
||||
std::vector<T> abs(input_shape[0] * input_shape[1] * input_shape[2]);
|
||||
int ind = 0;
|
||||
size_t ind = 0;
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
|
||||
T x = (*itr);
|
||||
itr++;
|
||||
|
@ -121,7 +121,7 @@ Status Polar(const std::shared_ptr<Tensor> &abs, const std::shared_ptr<Tensor> &
|
|||
TensorShape input_shape = abs->shape();
|
||||
TensorShape out_shape({input_shape[0], input_shape[1], input_shape[2], 2});
|
||||
std::vector<T> complex_vec(input_shape[0] * input_shape[1] * input_shape[2] * 2);
|
||||
int ind = 0;
|
||||
size_t ind = 0;
|
||||
auto itr_abs = abs->begin<T>();
|
||||
auto itr_angle = angle->begin<T>();
|
||||
|
||||
|
@ -143,7 +143,7 @@ Status Polar(const std::shared_ptr<Tensor> &abs, const std::shared_ptr<Tensor> &
|
|||
/// \param[out] output - Complex tensor, <channel, freq, time, complex=2>.
|
||||
/// \return Status return code.
|
||||
template <typename T>
|
||||
Status PadComplexTensor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int length, int dim) {
|
||||
Status PadComplexTensor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, int length, size_t dim) {
|
||||
TensorShape input_shape = input->shape();
|
||||
std::vector<int64_t> pad_shape_vec = {input_shape[0], input_shape[1], input_shape[2], input_shape[3]};
|
||||
pad_shape_vec[dim] += static_cast<int64_t>(length);
|
||||
|
@ -153,7 +153,7 @@ Status PadComplexTensor(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
|||
auto itr_input = input->begin<T>();
|
||||
int64_t input_cnt = 0;
|
||||
/*lint -e{446} ind is modified in the body of the for loop */
|
||||
for (int ind = 0; ind < static_cast<int>(in_vect.size()); ind++) {
|
||||
for (auto ind = 0; ind < static_cast<int>(in_vect.size()); ind++) {
|
||||
in_vect[ind] = (*itr_input);
|
||||
input_cnt = (input_cnt + 1) % (input_shape[2] * input_shape[3]);
|
||||
itr_input++;
|
||||
|
@ -199,7 +199,7 @@ Status Phase(const std::shared_ptr<Tensor> &angle_0, const std::shared_ptr<Tenso
|
|||
}
|
||||
|
||||
// concat phase time 0
|
||||
int64_t ind = 0;
|
||||
size_t ind = 0;
|
||||
auto itr_p0 = phase_time0->begin<T>();
|
||||
(void)phase.insert(phase.begin(), (*itr_p0));
|
||||
itr_p0++;
|
||||
|
@ -235,7 +235,7 @@ Status Mag(const std::shared_ptr<Tensor> &abs_0, const std::shared_ptr<Tensor> &
|
|||
std::vector<T> mag(mag_shape[0] * mag_shape[1] * mag_shape[2]);
|
||||
auto itr_abs_0 = abs_0->begin<T>();
|
||||
auto itr_abs_1 = abs_1->begin<T>();
|
||||
for (int ind = 0; itr_abs_0 != abs_0->end<T>(); itr_abs_0++, itr_abs_1++, ind++) {
|
||||
for (auto ind = 0; itr_abs_0 != abs_0->end<T>(); itr_abs_0++, itr_abs_1++, ind++) {
|
||||
mag[ind] = alphas[ind % mag_shape[2]] * (*itr_abs_1) + (1 - alphas[ind % mag_shape[2]]) * (*itr_abs_0);
|
||||
}
|
||||
std::shared_ptr<Tensor> mag_tensor;
|
||||
|
@ -246,7 +246,7 @@ Status Mag(const std::shared_ptr<Tensor> &abs_0, const std::shared_ptr<Tensor> &
|
|||
|
||||
template <typename T>
|
||||
Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, float rate,
|
||||
std::shared_ptr<Tensor> phase_advance) {
|
||||
const std::shared_ptr<Tensor> &phase_advance) {
|
||||
// pack <..., freq, time, complex>
|
||||
TensorShape input_shape = input->shape();
|
||||
TensorShape toShape({input->Size() / (input_shape[-1] * input_shape[-2] * input_shape[-3]), input_shape[-3],
|
||||
|
@ -260,11 +260,11 @@ Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *outpu
|
|||
std::vector<dsize_t> time_steps_0, time_steps_1;
|
||||
std::vector<T> alphas;
|
||||
for (int ind = 0;; ind++) {
|
||||
auto val = ind * rate;
|
||||
T val = static_cast<float>(ind) * rate;
|
||||
if (val >= input_shape[-2]) {
|
||||
break;
|
||||
}
|
||||
int val_int = static_cast<int>(val);
|
||||
auto val_int = static_cast<dsize_t>(val);
|
||||
time_steps_0.push_back(val_int);
|
||||
time_steps_1.push_back(val_int + 1);
|
||||
alphas.push_back(fmod(val, 1));
|
||||
|
@ -319,7 +319,7 @@ Status TimeStretch(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *outpu
|
|||
}
|
||||
|
||||
Status TimeStretch(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float rate, float hop_length,
|
||||
float n_freq) {
|
||||
int32_t n_freq) {
|
||||
std::shared_ptr<Tensor> phase_advance;
|
||||
switch (input->type().value()) {
|
||||
case DataType::DE_FLOAT32:
|
||||
|
@ -340,15 +340,15 @@ Status Dct(std::shared_ptr<Tensor> *output, int n_mfcc, int n_mels, NormMode nor
|
|||
TensorShape dct_shape({n_mels, n_mfcc});
|
||||
Tensor::CreateEmpty(dct_shape, DataType(DataType::DE_FLOAT32), output);
|
||||
auto iter = (*output)->begin<float>();
|
||||
const float sqrt_2 = 1 / sqrt(2);
|
||||
float sqrt_2_n_mels = sqrt(2.0 / n_mels);
|
||||
const float sqrt_2 = 1 / sqrt(2.0f);
|
||||
auto sqrt_2_n_mels = static_cast<float>(sqrt(2.0 / n_mels));
|
||||
for (int i = 0; i < n_mels; i++) {
|
||||
for (int j = 0; j < n_mfcc; j++) {
|
||||
// calculate temp:
|
||||
// 1. while norm = None, use 2*cos(PI*(i+0.5)*j/n_mels)
|
||||
// 2. while norm = Ortho, divide the first row by sqrt(2),
|
||||
// then using sqrt(2.0 / n_mels)*cos(PI*(i+0.5)*j/n_mels)
|
||||
float temp = PI / n_mels * (i + 0.5) * j;
|
||||
auto temp = static_cast<float>(PI / n_mels * (i + 0.5) * j);
|
||||
temp = cos(temp);
|
||||
if (norm == NormMode::kOrtho) {
|
||||
if (j == 0) {
|
||||
|
@ -400,11 +400,11 @@ Status MaskAlongAxis(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
|||
std::to_string(mask_start) + ", 'mask_width' " + std::to_string(mask_width) + " and length " +
|
||||
std::to_string(input_shape[check_dim_ind]));
|
||||
|
||||
int32_t cell_size = input->type().SizeInBytes();
|
||||
size_t cell_size = input->type().SizeInBytes();
|
||||
|
||||
if (axis == 1) {
|
||||
// freq
|
||||
for (int ind = 0; ind < input->Size() / input_shape[-2] * mask_width; ind++) {
|
||||
for (auto ind = 0; ind < input->Size() / input_shape[-2] * mask_width; ind++) {
|
||||
int block_num = ind / (mask_width * input_shape[-1]);
|
||||
auto start_pos = ind % (mask_width * input_shape[-1]) + mask_start * input_shape[-1] +
|
||||
input_shape[-1] * input_shape[-2] * block_num;
|
||||
|
@ -448,7 +448,7 @@ template <typename T>
|
|||
Status Norm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float power) {
|
||||
// calculate the output dimension
|
||||
auto input_size = input->shape().AsVector();
|
||||
int32_t dim_back = static_cast<int32_t>(input_size.back());
|
||||
auto dim_back = input_size.back();
|
||||
RETURN_IF_NOT_OK(
|
||||
ValidateTensorShape("ComplexNorm", input->IsComplex(), "<..., complex=2>", std::to_string(dim_back)));
|
||||
input_size.pop_back();
|
||||
|
@ -489,8 +489,8 @@ Status ComplexNorm(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor>
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
float sgn(T val) {
|
||||
return static_cast<float>(static_cast<T>(0) < val) - static_cast<float>(val < static_cast<T>(0));
|
||||
inline float sgn(T val) {
|
||||
return (val > 0) ? 1 : ((val < 0) ? -1 : 0);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -642,10 +642,10 @@ Status Fade(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *outpu
|
|||
int32_t fade_out_len, FadeShape fade_shape) {
|
||||
RETURN_IF_NOT_OK(Tensor::CreateFromTensor(input, output));
|
||||
const TensorShape input_shape = input->shape();
|
||||
int32_t waveform_length = static_cast<int32_t>(input_shape[-1]);
|
||||
auto waveform_length = static_cast<int32_t>(input_shape[-1]);
|
||||
RETURN_IF_NOT_OK(ValidateNoGreaterThan("Fade", "fade_in_len", fade_in_len, "length of waveform", waveform_length));
|
||||
RETURN_IF_NOT_OK(ValidateNoGreaterThan("Fade", "fade_out_len", fade_out_len, "length of waveform", waveform_length));
|
||||
int32_t num_waveform = static_cast<int32_t>(input->Size() / waveform_length);
|
||||
auto num_waveform = static_cast<int32_t>(input->Size() / waveform_length);
|
||||
TensorShape toShape = TensorShape({num_waveform, waveform_length});
|
||||
RETURN_IF_NOT_OK((*output)->Reshape(toShape));
|
||||
TensorPtr fade_in;
|
||||
|
@ -792,7 +792,7 @@ Status GenerateWaveTable(std::shared_ptr<Tensor> *output, const DataType &type,
|
|||
RETURN_UNEXPECTED_IF_NULL(output);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(table_size > 0,
|
||||
"table_size must be more than 0, but got: " + std::to_string(table_size));
|
||||
int32_t phase_offset = static_cast<int32_t>(phase / PI / 2 * table_size + 0.5);
|
||||
auto phase_offset = static_cast<int32_t>(phase / PI / 2 * table_size + 0.5);
|
||||
// get the offset of the i-th
|
||||
std::vector<int32_t> point;
|
||||
for (auto i = 0; i < table_size; i++) {
|
||||
|
@ -814,7 +814,7 @@ Status GenerateWaveTable(std::shared_ptr<Tensor> *output, const DataType &type,
|
|||
// change phase
|
||||
*iter = point[i] * 2.0 / table_size;
|
||||
// get complete offset
|
||||
int32_t value = static_cast<int>(4 * point[i] / table_size);
|
||||
auto value = static_cast<int>(4 * point[i] / table_size);
|
||||
// change the value of the square wave according to the number of complete offsets
|
||||
if (value == 0) {
|
||||
*iter = *iter + 0.5;
|
||||
|
@ -859,7 +859,7 @@ Status ReadWaveFile(const std::string &wav_file_dir, std::vector<float> *wavefor
|
|||
std::ifstream in(file_path.ToString(), std::ios::in | std::ios::binary);
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(in.is_open(), "Invalid file, failed to open waveform file: " + file_path.ToString() +
|
||||
", make sure the file not damaged or permission denied.");
|
||||
WavHeader *header = new WavHeader();
|
||||
auto *header = new WavHeader();
|
||||
in.read(reinterpret_cast<char *>(header), sizeof(WavHeader));
|
||||
*sample_rate = header->sample_rate;
|
||||
float bytes_per_sample = header->bits_per_sample / 8;
|
||||
|
@ -1012,7 +1012,7 @@ Status SlidingWindowCmnHelper(const std::shared_ptr<Tensor> &input, std::shared_
|
|||
RETURN_IF_NOT_OK(
|
||||
input->Reshape(TensorShape({static_cast<int>(first_index / (num_frames * num_feats)), num_frames, num_feats})));
|
||||
|
||||
int32_t num_channels = static_cast<int32_t>(input->shape()[0]);
|
||||
auto num_channels = static_cast<int32_t>(input->shape()[0]);
|
||||
TensorPtr cmn_waveform;
|
||||
RETURN_IF_NOT_OK(
|
||||
Tensor::CreateEmpty(TensorShape({num_channels, num_frames, num_feats}), input->type(), &cmn_waveform));
|
||||
|
@ -1059,7 +1059,7 @@ Status Pad(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output
|
|||
RETURN_IF_NOT_OK(ValidateNonNegative("Pad", "pad_right", pad_right));
|
||||
TensorShape input_shape = input->shape();
|
||||
int32_t wave_length = input_shape[-1];
|
||||
int32_t num_wavs = static_cast<int32_t>(input->Size() / wave_length);
|
||||
auto num_wavs = static_cast<int32_t>(input->Size() / wave_length);
|
||||
TensorShape to_shape = TensorShape({num_wavs, wave_length});
|
||||
RETURN_IF_NOT_OK(input->Reshape(to_shape));
|
||||
int32_t pad_length = wave_length + pad_left + pad_right;
|
||||
|
@ -1553,7 +1553,7 @@ Status SpectralCentroidImpl(const std::shared_ptr<Tensor> &input, std::shared_pt
|
|||
}
|
||||
}
|
||||
specgram.push_back(tmp);
|
||||
specgram_sum.push_back(specgram[k].colwise().sum());
|
||||
specgram_sum.emplace_back(specgram[k].colwise().sum());
|
||||
}
|
||||
for (int k = 0; k < k_num; k++) {
|
||||
for (int i = 0; i < channals; ++i) {
|
||||
|
@ -1561,7 +1561,7 @@ Status SpectralCentroidImpl(const std::shared_ptr<Tensor> &input, std::shared_pt
|
|||
tmp(i, j) = freqs_r(i, 0) * specgram[k](i, j);
|
||||
}
|
||||
}
|
||||
specgram_result.push_back((tmp).colwise().sum());
|
||||
specgram_result.emplace_back((tmp).colwise().sum());
|
||||
}
|
||||
auto itr_output = output_tensor->begin<T>();
|
||||
for (int k = 0; k < k_num; k++) {
|
||||
|
|
|
@ -56,8 +56,8 @@ Status AmplitudeToDB(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
|||
RETURN_IF_NOT_OK(input->Reshape(to_shape));
|
||||
|
||||
std::vector<T> max_val;
|
||||
int step = to_shape[-3] * input_shape[-2] * input_shape[-1];
|
||||
int cnt = 0;
|
||||
uint64_t step = to_shape[-3] * input_shape[-2] * input_shape[-1];
|
||||
uint64_t cnt = 0;
|
||||
T temp_max = std::numeric_limits<T>::lowest();
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++) {
|
||||
// do clamp
|
||||
|
@ -73,10 +73,10 @@ Status AmplitudeToDB(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tenso
|
|||
}
|
||||
|
||||
if (!std::isnan(top_db)) {
|
||||
int ind = 0;
|
||||
uint64_t ind = 0;
|
||||
for (auto itr = input->begin<T>(); itr != input->end<T>(); itr++, ind++) {
|
||||
float lower_bound = max_val[ind / step] - top_db;
|
||||
*itr = std::max((*itr), static_cast<T>(lower_bound));
|
||||
T lower_bound = max_val[ind / step] - top_db;
|
||||
*itr = std::max((*itr), lower_bound);
|
||||
}
|
||||
}
|
||||
RETURN_IF_NOT_OK(input->Reshape(input_shape));
|
||||
|
@ -147,10 +147,9 @@ Status Contrast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *o
|
|||
RETURN_IF_NOT_OK(Tensor::CreateEmpty(output_shape, input->type(), &out));
|
||||
auto itr_out = out->begin<T>();
|
||||
for (auto itr_in = input->begin<T>(); itr_in != input->end<T>(); itr_in++) {
|
||||
T temp1, temp2 = 0;
|
||||
// PI / 2 is half of the constant PI
|
||||
temp1 = static_cast<T>(*itr_in) * (PI / TWO);
|
||||
temp2 = enhancement_amount_value * std::sin(temp1 * 4);
|
||||
T temp1 = static_cast<T>(*itr_in) * (PI / TWO);
|
||||
T temp2 = enhancement_amount_value * std::sin(temp1 * 4);
|
||||
*itr_out = std::sin(temp1 + temp2);
|
||||
itr_out++;
|
||||
}
|
||||
|
@ -242,8 +241,8 @@ Status LFilter(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *ou
|
|||
TensorShape input_shape = input->shape();
|
||||
TensorShape toShape({input->Size() / input_shape[-1], input_shape[-1]});
|
||||
input->Reshape(toShape);
|
||||
auto shape_0 = input->shape()[0];
|
||||
auto shape_1 = input->shape()[1];
|
||||
auto shape_0 = static_cast<size_t>(input->shape()[0]);
|
||||
auto shape_1 = static_cast<size_t>(input->shape()[1]);
|
||||
std::vector<T> signal;
|
||||
std::shared_ptr<Tensor> out;
|
||||
std::vector<T> out_vect(shape_0 * shape_1);
|
||||
|
@ -355,7 +354,7 @@ Status SpectralCentroid(const std::shared_ptr<Tensor> &input, std::shared_ptr<Te
|
|||
/// \param output: Tensor after stretch in time domain.
|
||||
/// \return Status code.
|
||||
Status TimeStretch(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, float rate, float hop_length,
|
||||
float n_freq);
|
||||
int32_t n_freq);
|
||||
|
||||
/// \brief Apply a mask along axis.
|
||||
/// \param input: Tensor of shape <..., freq, time>.
|
||||
|
|
|
@ -29,20 +29,20 @@ Status LFilterOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<
|
|||
} else if (input->type() == DataType(DataType::DE_FLOAT64)) {
|
||||
std::vector<double> a_coeffs_double;
|
||||
std::vector<double> b_coeffs_double;
|
||||
for (int i = 0; i < a_coeffs_.size(); i++) {
|
||||
for (auto i = 0; i < a_coeffs_.size(); i++) {
|
||||
a_coeffs_double.push_back(static_cast<double>(a_coeffs_[i]));
|
||||
}
|
||||
for (int i = 0; i < b_coeffs_.size(); i++) {
|
||||
for (auto i = 0; i < b_coeffs_.size(); i++) {
|
||||
b_coeffs_double.push_back(static_cast<double>(b_coeffs_[i]));
|
||||
}
|
||||
return LFilter(input, output, a_coeffs_double, b_coeffs_double, clamp_);
|
||||
} else {
|
||||
std::vector<float16> a_coeffs_float16;
|
||||
std::vector<float16> b_coeffs_float16;
|
||||
for (int i = 0; i < a_coeffs_.size(); i++) {
|
||||
for (auto i = 0; i < a_coeffs_.size(); i++) {
|
||||
a_coeffs_float16.push_back(static_cast<float16>(a_coeffs_[i]));
|
||||
}
|
||||
for (int i = 0; i < b_coeffs_.size(); i++) {
|
||||
for (auto i = 0; i < b_coeffs_.size(); i++) {
|
||||
b_coeffs_float16.push_back(static_cast<float16>(b_coeffs_[i]));
|
||||
}
|
||||
return LFilter(input, output, a_coeffs_float16, b_coeffs_float16, clamp_);
|
||||
|
|
|
@ -29,18 +29,18 @@ namespace dataset {
|
|||
|
||||
class LFilterOp : public TensorOp {
|
||||
public:
|
||||
LFilterOp(std::vector<float> a_coeffs, std::vector<float> b_coeffs, bool clamp)
|
||||
LFilterOp(const std::vector<float> &a_coeffs, const std::vector<float> &b_coeffs, bool clamp)
|
||||
: a_coeffs_(a_coeffs), b_coeffs_(b_coeffs), clamp_(clamp) {}
|
||||
|
||||
~LFilterOp() override = default;
|
||||
|
||||
void Print(std::ostream &out) const override {
|
||||
out << Name() << ": a_coeffs: ";
|
||||
for (int i = 0; i < a_coeffs_.size(); i++) {
|
||||
for (auto i = 0; i < a_coeffs_.size(); i++) {
|
||||
out << a_coeffs_[i] << " ";
|
||||
}
|
||||
out << "b_coeffs: ";
|
||||
for (int i = 0; i < b_coeffs_.size(); i++) {
|
||||
for (auto i = 0; i < b_coeffs_.size(); i++) {
|
||||
out << b_coeffs_[i] << " ";
|
||||
}
|
||||
out << "clamp: " << clamp_ << std::endl;
|
||||
|
|
|
@ -68,7 +68,7 @@ Status TimeStretchOp::OutputShape(const std::vector<TensorShape> &inputs, std::v
|
|||
std::vector<dsize_t> s_vec = s.AsVector();
|
||||
s_vec.pop_back();
|
||||
s_vec.pop_back();
|
||||
s_vec.push_back(std::ceil(s[-2] / fixed_rate_));
|
||||
s_vec.push_back(std::ceil(s[-2] / static_cast<dsize_t>(fixed_rate_)));
|
||||
// push back complex
|
||||
s_vec.push_back(2);
|
||||
outputs.emplace_back(TensorShape(s_vec));
|
||||
|
|
|
@ -75,7 +75,7 @@ class CallbackManager {
|
|||
|
||||
private:
|
||||
bool enabled_; // flag to enable callback, if false, all functions would return immediately
|
||||
DatasetOp *op_; // back pointer to DatasetOp, raw pointer to avoid circular ownership
|
||||
DatasetOp *op_ = nullptr; // back pointer to DatasetOp, raw pointer to avoid circular ownership
|
||||
std::vector<std::shared_ptr<DSCallback>> callbacks_; // list of callbacks the DatasetOp needs to call
|
||||
std::vector<size_t> begin_indices_;
|
||||
std::vector<size_t> end_indices_;
|
||||
|
|
|
@ -34,7 +34,7 @@ class DSCallback {
|
|||
explicit DSCallback(int32_t step_size = 1) : step_size_(step_size) {}
|
||||
|
||||
/// \brief Destructor
|
||||
~DSCallback() = default;
|
||||
virtual ~DSCallback() = default;
|
||||
|
||||
/// \brief actual callback function for begin, needs to be overridden in the derived class
|
||||
/// \param cb_param, callback parameter passed in from DatasetOp when calling the callback
|
||||
|
|
|
@ -42,7 +42,7 @@ class PyDSCallback : public DSCallback {
|
|||
epoch_end_needed_(false),
|
||||
step_end_needed_(false) {}
|
||||
|
||||
~PyDSCallback() = default;
|
||||
~PyDSCallback() override = default;
|
||||
|
||||
void SetBegin(const py::function &f);
|
||||
void SetEnd(const py::function &f);
|
||||
|
@ -128,5 +128,4 @@ class PyDSCallback : public DSCallback {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_PY_DS_CALLBACK_H
|
||||
|
|
|
@ -63,7 +63,7 @@ class ConfigManager {
|
|||
|
||||
// Another debug print helper. Converts the print info to a string for you.
|
||||
// @return The string version of the debug print
|
||||
std::string ToString() {
|
||||
std::string ToString() const {
|
||||
std::stringstream ss;
|
||||
ss << *this;
|
||||
return ss.str();
|
||||
|
@ -178,7 +178,7 @@ class ConfigManager {
|
|||
|
||||
// getter function
|
||||
// @return The interval of monitor sampling
|
||||
int32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }
|
||||
uint32_t monitor_sampling_interval() const { return monitor_sampling_interval_; }
|
||||
|
||||
// setter function
|
||||
// @param auto_num_workers - whether assign threads to each op automatically
|
||||
|
@ -200,13 +200,13 @@ class ConfigManager {
|
|||
|
||||
// getter function
|
||||
// @return The timeout DSWaitedCallback would wait for before raising an error
|
||||
int32_t callback_timeout() const { return callback_timout_; }
|
||||
uint32_t callback_timeout() const { return callback_timout_; }
|
||||
|
||||
// getter function
|
||||
// E.g. 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map.
|
||||
// please refer to AutoWorkerPass for detail on what each option is.
|
||||
// @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration
|
||||
uint8_t get_auto_worker_config() { return auto_worker_config_; }
|
||||
uint8_t get_auto_worker_config() const { return auto_worker_config_; }
|
||||
|
||||
// setter function
|
||||
// E.g. set the value of 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map.
|
||||
|
@ -220,7 +220,7 @@ class ConfigManager {
|
|||
|
||||
// getter function
|
||||
// @return - Flag to indicate whether shared memory for multi-processing is enabled
|
||||
bool enable_shared_mem() { return enable_shared_mem_; }
|
||||
bool enable_shared_mem() const { return enable_shared_mem_; }
|
||||
|
||||
// setter function
|
||||
// @param offload - To enable automatic offloading of dataset ops
|
||||
|
@ -228,7 +228,7 @@ class ConfigManager {
|
|||
|
||||
// getter function
|
||||
// @return - Flag to indicate whether automatic offloading is enabled for the dataset
|
||||
bool get_auto_offload() { return auto_offload_; }
|
||||
bool get_auto_offload() const { return auto_offload_; }
|
||||
|
||||
// setter function
|
||||
// @param enable - To enable autotune
|
||||
|
@ -236,11 +236,11 @@ class ConfigManager {
|
|||
|
||||
// getter function
|
||||
// @return - Flag to indicate whether autotune is enabled
|
||||
bool enable_autotune() { return enable_autotune_; }
|
||||
bool enable_autotune() const { return enable_autotune_; }
|
||||
|
||||
// getter function
|
||||
// @return - autotune interval in steps
|
||||
int64_t autotune_interval() { return autotune_interval_; }
|
||||
int64_t autotune_interval() const { return autotune_interval_; }
|
||||
|
||||
// setter function
|
||||
// @param interval - autotune interval in steps
|
||||
|
@ -277,5 +277,4 @@ class ConfigManager {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_CONFIG_MANAGER_H_
|
||||
|
|
|
@ -32,7 +32,7 @@ class DeviceTensor : public Tensor {
|
|||
public:
|
||||
DeviceTensor(const TensorShape &shape, const DataType &type);
|
||||
|
||||
~DeviceTensor() {}
|
||||
~DeviceTensor() override = default;
|
||||
|
||||
Status SetAttributes(uint8_t *data_ptr, const uint32_t &dataSize, const uint32_t &width, const uint32_t &widthStride,
|
||||
const uint32_t &height, const uint32_t &heightStride);
|
||||
|
|
|
@ -85,10 +85,12 @@ Tensor &Tensor::operator=(Tensor &&other) noexcept {
|
|||
data_ = other.GetMutableBuffer();
|
||||
data_end_ = other.data_end_;
|
||||
data_allocator_ = std::move(other.data_allocator_);
|
||||
yuv_shape_ = other.yuv_shape_;
|
||||
other.Invalidate();
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
Status Tensor::CreateEmpty(const TensorShape &shape, const DataType &type, TensorPtr *out) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(shape.known(), "Invalid shape.");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(type != DataType::DE_UNKNOWN, "Invalid data type.");
|
||||
|
@ -111,6 +113,7 @@ Status Tensor::CreateEmpty(const TensorShape &shape, const DataType &type, Tenso
|
|||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Tensor::CreateFromMemory(const TensorShape &shape, const DataType &type, const uchar *src, TensorPtr *out) {
|
||||
RETURN_IF_NOT_OK(CreateEmpty(shape, type, out));
|
||||
if (src != nullptr && out != nullptr) {
|
||||
|
@ -680,7 +683,7 @@ Status Tensor::to_json(nlohmann::json *out_json) {
|
|||
RETURN_IF_NOT_OK(to_json_convert<double>(&args));
|
||||
} else if (type_ == DataType::DE_STRING) {
|
||||
std::vector<std::string> data_out;
|
||||
for (auto it = this->begin<std::string_view>(); it != this->end<std::string_view>(); it++) {
|
||||
for (auto it = this->begin<std::string_view>(); it != this->end<std::string_view>(); ++it) {
|
||||
data_out.emplace_back(*it);
|
||||
}
|
||||
args["data"] = data_out;
|
||||
|
@ -739,7 +742,8 @@ Status Tensor::from_json(nlohmann::json op_params, std::shared_ptr<Tensor> *tens
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
Status Tensor::from_json_convert(nlohmann::json json_data, TensorShape shape, std::shared_ptr<Tensor> *tensor) {
|
||||
Status Tensor::from_json_convert(const nlohmann::json &json_data, const TensorShape &shape,
|
||||
std::shared_ptr<Tensor> *tensor) {
|
||||
std::vector<T> data = json_data;
|
||||
RETURN_IF_NOT_OK(CreateFromVector(data, shape, tensor));
|
||||
return Status::OK();
|
||||
|
|
|
@ -225,7 +225,8 @@ class Tensor {
|
|||
static Status from_json(nlohmann::json op_params, std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
template <typename T>
|
||||
static Status from_json_convert(nlohmann::json json_data, TensorShape shape, std::shared_ptr<Tensor> *tensor);
|
||||
static Status from_json_convert(const nlohmann::json &json_data, const TensorShape &shape,
|
||||
std::shared_ptr<Tensor> *tensor);
|
||||
|
||||
/// Get item located at `index`, caller needs to provide the type.
|
||||
/// \tparam T
|
||||
|
@ -481,6 +482,9 @@ class Tensor {
|
|||
~TensorIterator() = default;
|
||||
|
||||
TensorIterator<T> &operator=(const TensorIterator<T> &rhs) {
|
||||
if (this == &rhs) {
|
||||
return *this;
|
||||
}
|
||||
ptr_ = rhs.ptr_;
|
||||
return *this;
|
||||
}
|
||||
|
@ -565,7 +569,7 @@ class Tensor {
|
|||
using pointer = std::string_view *;
|
||||
using reference = std::string_view &;
|
||||
|
||||
explicit TensorIterator(uchar *data = nullptr, dsize_t index = 0) {
|
||||
explicit TensorIterator(const uchar *data = nullptr, dsize_t index = 0) {
|
||||
data_ = reinterpret_cast<const char *>(data);
|
||||
index_ = index;
|
||||
}
|
||||
|
@ -795,7 +799,7 @@ inline Status Tensor::CreateFromVector<std::string>(const std::vector<std::strin
|
|||
*out = std::allocate_shared<Tensor>(*alloc, TensorShape({static_cast<dsize_t>(items.size())}),
|
||||
DataType(DataType::DE_STRING));
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(out != nullptr, "Allocate memory failed.");
|
||||
if (items.size() == 0) {
|
||||
if (items.empty()) {
|
||||
if (shape.known()) {
|
||||
return (*out)->Reshape(shape);
|
||||
}
|
||||
|
|
|
@ -195,7 +195,7 @@ class Connector {
|
|||
int32_t expect_consumer_;
|
||||
|
||||
// The index to the queues_ where the next data should be popped.
|
||||
int32_t pop_from_;
|
||||
size_t pop_from_;
|
||||
|
||||
int32_t num_producers_;
|
||||
int32_t num_consumers_;
|
||||
|
|
|
@ -64,23 +64,28 @@ class ExecutionTree {
|
|||
++ind_;
|
||||
return *this;
|
||||
} // prefix ++ overload
|
||||
|
||||
Iterator operator++(int) {
|
||||
Iterator it = *this;
|
||||
it.ind_ = ind_;
|
||||
ind_++;
|
||||
return it;
|
||||
} // post-fix ++ overload
|
||||
|
||||
Iterator &operator--() {
|
||||
--ind_;
|
||||
return *this;
|
||||
} // prefix -- overload
|
||||
|
||||
Iterator operator--(int) {
|
||||
Iterator it = *this;
|
||||
it.ind_ = ind_;
|
||||
ind_--;
|
||||
return it;
|
||||
} // post-fix -- overload
|
||||
|
||||
DatasetOp &operator*() { return *nodes_[ind_]; } // dereference operator
|
||||
|
||||
std::shared_ptr<DatasetOp> operator->() { return nodes_[ind_]; }
|
||||
|
||||
// getter function
|
||||
|
@ -91,10 +96,10 @@ class ExecutionTree {
|
|||
|
||||
bool operator!=(const Iterator &rhs) { return nodes_[ind_] != rhs.nodes_[rhs.ind_]; }
|
||||
|
||||
int32_t NumNodes() { return nodes_.size(); }
|
||||
size_t NumNodes() const { return nodes_.size(); }
|
||||
|
||||
private:
|
||||
int32_t ind_; // the cur node our Iterator points to
|
||||
size_t ind_; // the cur node our Iterator points to
|
||||
std::vector<std::shared_ptr<DatasetOp>> nodes_; // store the nodes in post order
|
||||
void PostOrderTraverse(const std::shared_ptr<DatasetOp> &);
|
||||
};
|
||||
|
@ -140,7 +145,7 @@ class ExecutionTree {
|
|||
/// \param out - reference to the output stream being overloaded
|
||||
/// \param exe_tree - reference to the execution tree to display
|
||||
/// \return - the output stream must be returned
|
||||
friend std::ostream &operator<<(std::ostream &out, ExecutionTree &exe_tree) {
|
||||
friend std::ostream &operator<<(std::ostream &out, const ExecutionTree &exe_tree) {
|
||||
exe_tree.Print(out);
|
||||
return out;
|
||||
}
|
||||
|
@ -164,8 +169,10 @@ class ExecutionTree {
|
|||
/// \return Status The status code returned
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::vector<Task *> *worker_tasks,
|
||||
std::string name = "", int32_t operator_id = -1);
|
||||
|
||||
Status LaunchWorkers(int32_t num_workers, std::function<Status(uint32_t)> func, std::string name = "",
|
||||
int32_t operator_id = -1);
|
||||
|
||||
/// \brief Getter method
|
||||
/// \return shared_ptr to the root operator
|
||||
std::shared_ptr<DatasetOp> root() const { return root_; }
|
||||
|
@ -233,5 +240,4 @@ class ExecutionTree {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_EXECUTION_TREE_H_
|
||||
|
|
|
@ -28,7 +28,6 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
using EdgeType = int8_t;
|
||||
using EdgeIdType = int32_t;
|
||||
|
||||
class Edge {
|
||||
public:
|
||||
|
@ -38,7 +37,8 @@ class Edge {
|
|||
// @param WeightType weight - edge weight
|
||||
// @param std::shared_ptr<Node> src_node - source node
|
||||
// @param std::shared_ptr<Node> dst_node - destination node
|
||||
Edge(EdgeIdType id, EdgeType type, WeightType weight, std::shared_ptr<Node> src_node, std::shared_ptr<Node> dst_node)
|
||||
Edge(EdgeIdType id, EdgeType type, WeightType weight, const std::shared_ptr<Node> &src_node,
|
||||
const std::shared_ptr<Node> &dst_node)
|
||||
: id_(id), type_(type), weight_(weight), src_node_(src_node), dst_node_(dst_node) {}
|
||||
|
||||
virtual ~Edge() = default;
|
||||
|
|
|
@ -44,6 +44,9 @@ struct MetaInfo {
|
|||
|
||||
class GraphData {
|
||||
public:
|
||||
// Destructor
|
||||
virtual ~GraphData() = default;
|
||||
|
||||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
// @param std::shared_ptr<Tensor> *out - Returned nodes id
|
||||
|
|
|
@ -17,13 +17,13 @@
|
|||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_GNN_GRAPH_DATA_CLIENT_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
#include "proto/gnn_graph_data.grpc.pb.h"
|
||||
|
@ -48,7 +48,7 @@ class GraphDataClient : public GraphData {
|
|||
// @param int32_t num_workers - number of parallel threads
|
||||
GraphDataClient(const std::string &dataset_file, const std::string &hostname, int32_t port);
|
||||
|
||||
~GraphDataClient();
|
||||
~GraphDataClient() override;
|
||||
|
||||
Status Init() override;
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace mindspore {
|
|||
namespace dataset {
|
||||
namespace gnn {
|
||||
|
||||
GraphDataImpl::GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode)
|
||||
GraphDataImpl::GraphDataImpl(const std::string &dataset_file, int32_t num_workers, bool server_mode)
|
||||
: dataset_file_(dataset_file),
|
||||
num_workers_(num_workers),
|
||||
rnd_(GetRandomDevice()),
|
||||
|
@ -38,7 +38,7 @@ GraphDataImpl::GraphDataImpl(std::string dataset_file, int32_t num_workers, bool
|
|||
MS_LOG(INFO) << "num_workers:" << num_workers;
|
||||
}
|
||||
|
||||
GraphDataImpl::~GraphDataImpl() {}
|
||||
GraphDataImpl::~GraphDataImpl() = default;
|
||||
|
||||
Status GraphDataImpl::GetAllNodes(NodeType node_type, std::shared_ptr<Tensor> *out) {
|
||||
RETURN_UNEXPECTED_IF_NULL(out);
|
||||
|
|
|
@ -44,9 +44,9 @@ class GraphDataImpl : public GraphData {
|
|||
// Constructor
|
||||
// @param std::string dataset_file -
|
||||
// @param int32_t num_workers - number of parallel threads
|
||||
GraphDataImpl(std::string dataset_file, int32_t num_workers, bool server_mode = false);
|
||||
GraphDataImpl(const std::string &dataset_file, int32_t num_workers, bool server_mode = false);
|
||||
|
||||
~GraphDataImpl();
|
||||
~GraphDataImpl() override;
|
||||
|
||||
// Get all nodes from the graph.
|
||||
// @param NodeType node_type - type of node
|
||||
|
@ -150,11 +150,11 @@ class GraphDataImpl : public GraphData {
|
|||
Status GraphInfo(py::dict *out) override;
|
||||
#endif
|
||||
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() {
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultNodeFeatures() const {
|
||||
return &default_node_feature_map_;
|
||||
}
|
||||
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() {
|
||||
const std::unordered_map<FeatureType, std::shared_ptr<Feature>> *GetAllDefaultEdgeFeatures() const {
|
||||
return &default_edge_feature_map_;
|
||||
}
|
||||
|
||||
|
|
|
@ -48,9 +48,9 @@ class GraphDataServer {
|
|||
Status ClientRegister(int32_t pid);
|
||||
Status ClientUnRegister(int32_t pid);
|
||||
|
||||
enum ServerState state() { return state_; }
|
||||
enum ServerState state() const { return state_; }
|
||||
|
||||
bool IsStopped() {
|
||||
bool IsStopped() const {
|
||||
if (state_ == kGdsStopped) {
|
||||
return true;
|
||||
} else {
|
||||
|
@ -86,7 +86,7 @@ class GraphDataServer {
|
|||
#if !defined(_WIN32) && !defined(_WIN64)
|
||||
class UntypedCall {
|
||||
public:
|
||||
virtual ~UntypedCall() {}
|
||||
virtual ~UntypedCall() = default;
|
||||
|
||||
virtual Status operator()() = 0;
|
||||
|
||||
|
@ -112,7 +112,7 @@ class CallData : public UntypedCall {
|
|||
handle_request_function_(handle_request_function),
|
||||
responder_(&ctx_) {}
|
||||
|
||||
~CallData() = default;
|
||||
~CallData() override = default;
|
||||
|
||||
static Status EnqueueRequest(ServiceImpl *service_impl, AsyncService *async_service, grpc::ServerCompletionQueue *cq,
|
||||
EnqueueFunction enqueue_function, HandleRequestFunction handle_request_function) {
|
||||
|
|
|
@ -50,9 +50,9 @@ class GraphSharedMemory {
|
|||
|
||||
Status GetData(uint8_t *data, int64_t data_len, int64_t offset, int64_t get_data_len);
|
||||
|
||||
key_t memory_key() { return memory_key_; }
|
||||
key_t memory_key() const { return memory_key_; }
|
||||
|
||||
int64_t memory_size() { return memory_size_; }
|
||||
int64_t memory_size() const { return memory_size_; }
|
||||
|
||||
private:
|
||||
Status SharedMemoryImpl(const int &shmflg);
|
||||
|
|
|
@ -36,7 +36,7 @@ class LocalNode : public Node {
|
|||
// @param NodeType type - node type
|
||||
LocalNode(NodeIdType id, NodeType type, WeightType weight);
|
||||
|
||||
~LocalNode() = default;
|
||||
~LocalNode() override = default;
|
||||
|
||||
// Get the feature of a node
|
||||
// @param FeatureType feature_type - type of feature
|
||||
|
|
|
@ -23,17 +23,23 @@
|
|||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore::dataset {
|
||||
|
||||
class DatasetCache {
|
||||
public:
|
||||
virtual ~DatasetCache() = default;
|
||||
|
||||
virtual Status Build() = 0;
|
||||
|
||||
virtual Status ValidateParams() = 0;
|
||||
|
||||
virtual Status CreateCacheOp(int32_t num_workers, int32_t connector_queue_size, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
|
||||
virtual Status CreateCacheLookupOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
|
||||
virtual Status CreateCacheMergeOp(int32_t num_workers, int32_t connector_queue_size,
|
||||
std::shared_ptr<DatasetOp> *ds) = 0;
|
||||
|
||||
virtual Status to_json(nlohmann::json *out_json) { return Status::OK(); }
|
||||
|
||||
#ifndef ENABLE_ANDROID
|
||||
|
@ -41,5 +47,4 @@ class DatasetCache {
|
|||
#endif
|
||||
};
|
||||
} // namespace mindspore::dataset
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_DATASET_CACHE_H_
|
||||
|
|
|
@ -75,7 +75,7 @@ class DatasetCacheImpl : public DatasetCache {
|
|||
|
||||
Status to_json(nlohmann::json *out_json) override;
|
||||
|
||||
~DatasetCacheImpl() = default;
|
||||
~DatasetCacheImpl() override = default;
|
||||
|
||||
private:
|
||||
std::shared_ptr<CacheClient> cache_client_;
|
||||
|
|
|
@ -37,7 +37,7 @@ class PreBuiltDatasetCache : public DatasetCacheImpl {
|
|||
cache_client_ = std::move(cc);
|
||||
}
|
||||
|
||||
~PreBuiltDatasetCache() = default;
|
||||
~PreBuiltDatasetCache() override = default;
|
||||
|
||||
/// Method to initialize the DatasetCache by creating an instance of a CacheClient
|
||||
/// \return Status Error code
|
||||
|
@ -45,5 +45,4 @@ class PreBuiltDatasetCache : public DatasetCacheImpl {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_CACHE_PRE_BUILT_DATASET_CACHE_H_
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class BatchNode : public DatasetNode {
|
||||
public:
|
||||
#ifdef ENABLE_PYTHON
|
||||
|
@ -43,7 +42,7 @@ class BatchNode : public DatasetNode {
|
|||
BatchNode(std::shared_ptr<DatasetNode> child, int32_t batch_size, bool drop_remainder);
|
||||
|
||||
/// \brief Destructor
|
||||
~BatchNode() = default;
|
||||
~BatchNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -126,7 +125,6 @@ class BatchNode : public DatasetNode {
|
|||
#endif
|
||||
std::map<std::string, std::pair<TensorShape, std::shared_ptr<Tensor>>> pad_map_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BATCH_NODE_H_
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class BucketBatchByLengthNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -38,7 +37,7 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
bool pad_to_bucket_boundary = false, bool drop_remainder = false);
|
||||
|
||||
/// \brief Destructor
|
||||
~BucketBatchByLengthNode() = default;
|
||||
~BucketBatchByLengthNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -81,7 +80,6 @@ class BucketBatchByLengthNode : public DatasetNode {
|
|||
bool pad_to_bucket_boundary_;
|
||||
bool drop_remainder_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUCKET_BATCH_BY_LENGTH_NODE_H_
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class BuildSentenceVocabNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -37,7 +36,7 @@ class BuildSentenceVocabNode : public DatasetNode {
|
|||
SentencePieceModel model_type, const std::unordered_map<std::string, std::string> ¶ms);
|
||||
|
||||
/// \brief Destructor
|
||||
~BuildSentenceVocabNode() = default;
|
||||
~BuildSentenceVocabNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -88,7 +87,6 @@ class BuildSentenceVocabNode : public DatasetNode {
|
|||
SentencePieceModel model_type_;
|
||||
std::unordered_map<std::string, std::string> params_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_SENTENCE_PIECE_VOCAB_NODE_H_
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class BuildVocabNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -35,7 +34,7 @@ class BuildVocabNode : public DatasetNode {
|
|||
const std::vector<std::string> &special_tokens, bool special_first);
|
||||
|
||||
/// \brief Destructor
|
||||
~BuildVocabNode() = default;
|
||||
~BuildVocabNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -86,7 +85,6 @@ class BuildVocabNode : public DatasetNode {
|
|||
std::vector<std::string> special_tokens_;
|
||||
bool special_first_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_BUILD_VOCAB_NODE_H_
|
||||
|
|
|
@ -32,7 +32,7 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
|
|||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheLookupNode() = default;
|
||||
~CacheLookupNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -87,5 +87,4 @@ class CacheLookupNode : public DatasetNode, public SamplerObj {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_LOOKUP_NODE_H_
|
||||
|
|
|
@ -30,7 +30,7 @@ class CacheMergeNode : public DatasetNode {
|
|||
CacheMergeNode(std::shared_ptr<DatasetNode> child, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheMergeNode() = default;
|
||||
~CacheMergeNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -67,5 +67,4 @@ class CacheMergeNode : public DatasetNode {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_MERGE_NODE_H_
|
||||
|
|
|
@ -31,7 +31,7 @@ class CacheNode : public DatasetNode {
|
|||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CacheNode() = default;
|
||||
~CacheNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -71,5 +71,4 @@ class CacheNode : public DatasetNode {
|
|||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CACHE_NODE_H_
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class ConcatNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -36,7 +35,7 @@ class ConcatNode : public DatasetNode {
|
|||
const std::vector<std::pair<int, int>> &children_start_end_index = {});
|
||||
|
||||
/// \brief Destructor
|
||||
~ConcatNode() = default;
|
||||
~ConcatNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -89,11 +88,6 @@ class ConcatNode : public DatasetNode {
|
|||
const std::vector<std::pair<int, int>> &ChildrenFlagAndNums() const { return children_flag_and_nums_; }
|
||||
const std::vector<std::pair<int, int>> &ChildrenStartEndIndex() const { return children_start_end_index_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
|
||||
/// \brief Base-class override for accepting IRNodePass visitor
|
||||
/// \param[in] p The node to visit
|
||||
/// \param[out] modified Indicator if the node was modified
|
||||
|
@ -105,8 +99,12 @@ class ConcatNode : public DatasetNode {
|
|||
/// \param[out] modified Indicator if the node was modified
|
||||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
};
|
||||
|
||||
private:
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
std::vector<std::pair<int, int>> children_flag_and_nums_;
|
||||
std::vector<std::pair<int, int>> children_start_end_index_;
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_CONCAT_NODE_H_
|
||||
|
|
|
@ -34,7 +34,6 @@ Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_ro
|
|||
RETURN_UNEXPECTED_IF_NULL(shuffle_size);
|
||||
const int64_t average_files_multiplier = 4;
|
||||
const int64_t shuffle_max = 10000;
|
||||
int64_t avg_rows_per_file = 0;
|
||||
|
||||
// Adjust the num rows per shard if sharding was given
|
||||
if (num_devices > 0) {
|
||||
|
@ -52,7 +51,7 @@ Status ComputeShuffleSize(int64_t num_files, int64_t num_devices, int64_t num_ro
|
|||
|
||||
// get the average per file
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(num_files != 0, "The size of dataset_files must be greater than 0.");
|
||||
avg_rows_per_file = num_rows / num_files;
|
||||
int64_t avg_rows_per_file = num_rows / num_files;
|
||||
|
||||
*shuffle_size = std::max(avg_rows_per_file * average_files_multiplier, shuffle_max);
|
||||
return Status::OK();
|
||||
|
|
|
@ -181,7 +181,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
explicit DatasetNode(const std::shared_ptr<DatasetCache> &dataset_cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~DatasetNode() = default;
|
||||
virtual ~DatasetNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -286,7 +286,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
|
||||
/// \brief Check if this node is an orphan node
|
||||
/// \return True if this node isn't nullptr nor does it have any children and a parent
|
||||
static bool IsOrphanNode(std::shared_ptr<DatasetNode> node) {
|
||||
static bool IsOrphanNode(const std::shared_ptr<DatasetNode> &node) {
|
||||
return node != nullptr && node->parent_ == nullptr && node->Children().empty();
|
||||
}
|
||||
|
||||
|
@ -294,7 +294,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
void HasCacheAbove() { descendant_of_cache_ = true; }
|
||||
|
||||
/// \brief Getter of the number of workers
|
||||
int32_t NumWorkers() { return num_workers_; }
|
||||
int32_t NumWorkers() const { return num_workers_; }
|
||||
|
||||
/// \brief Getter of dataset cache
|
||||
std::shared_ptr<DatasetCache> GetDatasetCache() { return cache_; }
|
||||
|
@ -346,7 +346,7 @@ class DatasetNode : public std::enable_shared_from_this<DatasetNode> {
|
|||
void SetTotalRepeats(int32_t total_repeats) { total_repeats_ = total_repeats; }
|
||||
|
||||
/// \brief Setter function, set the number of epochs for the operator
|
||||
void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
|
||||
virtual void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
|
||||
|
||||
/// \brief Getter function
|
||||
/// \return The number of required repeats for the operator
|
||||
|
@ -399,7 +399,7 @@ class MappableSourceNode : public DatasetNode {
|
|||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Destructor
|
||||
~MappableSourceNode() = default;
|
||||
virtual ~MappableSourceNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -430,7 +430,7 @@ class NonMappableSourceNode : public DatasetNode {
|
|||
Status Accept(IRNodePass *const p, bool *const modified) override;
|
||||
|
||||
/// \brief Destructor
|
||||
~NonMappableSourceNode() = default;
|
||||
virtual ~NonMappableSourceNode() = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class EpochCtrlNode : public RepeatNode {
|
||||
// Allow GeneratorNode to access internal members
|
||||
friend class GeneratorNode;
|
||||
|
@ -40,7 +39,7 @@ class EpochCtrlNode : public RepeatNode {
|
|||
EpochCtrlNode(std::shared_ptr<DatasetNode> child, int32_t num_epochs);
|
||||
|
||||
/// \brief Destructor
|
||||
~EpochCtrlNode() = default;
|
||||
~EpochCtrlNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -75,7 +74,6 @@ class EpochCtrlNode : public RepeatNode {
|
|||
/// \return Status of the node visit
|
||||
Status AcceptAfter(IRNodePass *const p, bool *const modified) override;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_EPOCH_CTRL_NODE_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class FilterNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -33,7 +32,7 @@ class FilterNode : public DatasetNode {
|
|||
std::vector<std::string> input_columns = {});
|
||||
|
||||
/// \brief Destructor
|
||||
~FilterNode() = default;
|
||||
~FilterNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -83,7 +82,6 @@ class FilterNode : public DatasetNode {
|
|||
std::shared_ptr<TensorOp> predicate_;
|
||||
std::vector<std::string> input_columns_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_FILTER_NODE_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class MapNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -36,7 +35,7 @@ class MapNode : public DatasetNode {
|
|||
ManualOffloadMode offload = ManualOffloadMode::kUnspecified);
|
||||
|
||||
/// \brief Destructor
|
||||
~MapNode() = default;
|
||||
~MapNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -127,7 +126,6 @@ class MapNode : public DatasetNode {
|
|||
/// \brief ManualOffloadMode to indicate manual_offload status
|
||||
ManualOffloadMode offload_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_MAP_NODE_H_
|
||||
|
|
|
@ -25,14 +25,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class ProjectNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit ProjectNode(std::shared_ptr<DatasetNode> child, const std::vector<std::string> &columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~ProjectNode() = default;
|
||||
~ProjectNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -74,7 +73,6 @@ class ProjectNode : public DatasetNode {
|
|||
private:
|
||||
std::vector<std::string> columns_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_PROJECT_NODE_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class RenameNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -33,7 +32,7 @@ class RenameNode : public DatasetNode {
|
|||
const std::vector<std::string> &output_columns);
|
||||
|
||||
/// \brief Destructor
|
||||
~RenameNode() = default;
|
||||
~RenameNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -77,7 +76,6 @@ class RenameNode : public DatasetNode {
|
|||
std::vector<std::string> input_columns_;
|
||||
std::vector<std::string> output_columns_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_RENAME_NODE_H_
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class RepeatOp;
|
||||
|
||||
class RepeatNode : public DatasetNode {
|
||||
|
@ -42,7 +41,7 @@ class RepeatNode : public DatasetNode {
|
|||
RepeatNode(std::shared_ptr<DatasetNode> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~RepeatNode() = default;
|
||||
~RepeatNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -67,7 +66,7 @@ class RepeatNode : public DatasetNode {
|
|||
|
||||
/// \brief Getter
|
||||
/// \return Number of cycles to repeat the execution
|
||||
const int32_t Count() const { return repeat_count_; }
|
||||
int32_t Count() const { return repeat_count_; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
|
@ -136,7 +135,6 @@ class RepeatNode : public DatasetNode {
|
|||
std::shared_ptr<RepeatNode> reset_ancestor_; // updated its immediate Repeat/EpochCtrl ancestor in GeneratorNodePass
|
||||
int32_t repeat_count_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_REPEAT_NODE_H_
|
||||
|
|
|
@ -35,7 +35,7 @@ class RootNode : public DatasetNode {
|
|||
explicit RootNode(std::shared_ptr<DatasetNode> child);
|
||||
|
||||
/// \brief Destructor
|
||||
~RootNode() = default;
|
||||
~RootNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -55,10 +55,10 @@ class RootNode : public DatasetNode {
|
|||
Status Build(std::vector<std::shared_ptr<DatasetOp>> *const node_ops) override;
|
||||
|
||||
/// \brief Getter of number of epochs
|
||||
int32_t num_epochs() { return num_epochs_; }
|
||||
int32_t num_epochs() const { return num_epochs_; }
|
||||
|
||||
/// \brief Setter of number of epochs
|
||||
void SetNumEpochs(int32_t num_epochs) { num_epochs_ = num_epochs; }
|
||||
void SetNumEpochs(int32_t num_epochs) override { num_epochs_ = num_epochs; }
|
||||
|
||||
/// \brief Parameters validation
|
||||
/// \return Status Status::OK() if all the parameters are valid
|
||||
|
@ -79,7 +79,6 @@ class RootNode : public DatasetNode {
|
|||
private:
|
||||
int32_t num_epochs_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_ROOT_NODE_H_
|
||||
|
|
|
@ -27,12 +27,11 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class ShuffleNode : public DatasetNode {
|
||||
public:
|
||||
ShuffleNode(std::shared_ptr<DatasetNode> child, int32_t shuffle_size, bool reset_every_epoch);
|
||||
|
||||
~ShuffleNode() = default;
|
||||
~ShuffleNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -76,7 +75,6 @@ class ShuffleNode : public DatasetNode {
|
|||
uint32_t shuffle_seed_;
|
||||
bool reset_every_epoch_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SHUFFLE_NODE_H_
|
||||
|
|
|
@ -25,14 +25,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class SkipNode : public DatasetNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
explicit SkipNode(std::shared_ptr<DatasetNode> child, int32_t count);
|
||||
|
||||
/// \brief Destructor
|
||||
~SkipNode() = default;
|
||||
~SkipNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -57,7 +56,7 @@ class SkipNode : public DatasetNode {
|
|||
|
||||
/// \brief Getter
|
||||
/// \return Number of rows to skip
|
||||
const int32_t Count() const { return skip_count_; }
|
||||
int32_t Count() const { return skip_count_; }
|
||||
|
||||
/// \brief Base-class override for GetDatasetSize
|
||||
/// \param[in] size_getter Shared pointer to DatasetSizeGetter
|
||||
|
@ -99,7 +98,6 @@ class SkipNode : public DatasetNode {
|
|||
private:
|
||||
int32_t skip_count_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SKIP_NODE_H_
|
||||
|
|
|
@ -34,7 +34,7 @@ class AGNewsNode : public NonMappableSourceNode {
|
|||
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~AGNewsNode() = default;
|
||||
~AGNewsNode() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
|
|
@ -99,7 +99,6 @@ Status AlbumNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
|
|||
*dataset_size = dataset_size_;
|
||||
return Status::OK();
|
||||
}
|
||||
int64_t sample_size = -1;
|
||||
int64_t num_rows = 0;
|
||||
// iterate over the files in the directory and count files to initiate num_rows
|
||||
Path folder(dataset_dir_);
|
||||
|
@ -118,7 +117,7 @@ Status AlbumNode::GetDatasetSize(const std::shared_ptr<DatasetSizeGetter> &size_
|
|||
// give sampler the total number of files and check if num_samples is smaller
|
||||
std::shared_ptr<SamplerRT> sampler_rt = nullptr;
|
||||
RETURN_IF_NOT_OK(sampler_->SamplerBuild(&sampler_rt));
|
||||
sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
int64_t sample_size = sampler_rt->CalculateNumSamples(num_rows);
|
||||
if (sample_size == -1) {
|
||||
RETURN_IF_NOT_OK(size_getter->DryRun(shared_from_this(), &sample_size));
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class AlbumNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -34,7 +33,7 @@ class AlbumNode : public MappableSourceNode {
|
|||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~AlbumNode() = default;
|
||||
~AlbumNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -102,7 +101,6 @@ class AlbumNode : public MappableSourceNode {
|
|||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_ALBUM_NODE_H_
|
||||
|
|
|
@ -32,7 +32,7 @@ class AmazonReviewNode : public NonMappableSourceNode {
|
|||
int32_t num_shards, int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~AmazonReviewNode() = default;
|
||||
~AmazonReviewNode() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
|
|
@ -37,7 +37,7 @@ class Caltech256Node : public MappableSourceNode {
|
|||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~Caltech256Node() = default;
|
||||
~Caltech256Node() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CelebANode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -35,7 +34,7 @@ class CelebANode : public MappableSourceNode {
|
|||
const bool &decode, const std::set<std::string> &extensions, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CelebANode() = default;
|
||||
~CelebANode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -104,7 +103,6 @@ class CelebANode : public MappableSourceNode {
|
|||
std::set<std::string> extensions_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CELEBA_NODE_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class Cifar100Node : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -33,7 +32,7 @@ class Cifar100Node : public MappableSourceNode {
|
|||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar100Node() = default;
|
||||
~Cifar100Node() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -98,7 +97,6 @@ class Cifar100Node : public MappableSourceNode {
|
|||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR100_NODE_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class Cifar10Node : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -33,7 +32,7 @@ class Cifar10Node : public MappableSourceNode {
|
|||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~Cifar10Node() = default;
|
||||
~Cifar10Node() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -98,7 +97,6 @@ class Cifar10Node : public MappableSourceNode {
|
|||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CIFAR10_NODE_H_
|
||||
|
|
|
@ -28,14 +28,14 @@ namespace dataset {
|
|||
// Constructor for CityscapesNode
|
||||
CityscapesNode::CityscapesNode(const std::string &dataset_dir, const std::string &usage,
|
||||
const std::string &quality_mode, const std::string &task, bool decode,
|
||||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache)
|
||||
: MappableSourceNode(std::move(cache)),
|
||||
const std::shared_ptr<SamplerObj> &sampler, const std::shared_ptr<DatasetCache> &cache)
|
||||
: MappableSourceNode(cache),
|
||||
dataset_dir_(dataset_dir),
|
||||
usage_(usage),
|
||||
quality_mode_(quality_mode),
|
||||
task_(task),
|
||||
sampler_(sampler),
|
||||
decode_(decode) {}
|
||||
decode_(decode),
|
||||
sampler_(sampler) {}
|
||||
|
||||
std::shared_ptr<DatasetNode> CityscapesNode::Copy() {
|
||||
std::shared_ptr<SamplerObj> sampler = (sampler_ == nullptr) ? nullptr : sampler_->SamplerCopy();
|
||||
|
|
|
@ -25,16 +25,15 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CityscapesNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
CityscapesNode(const std::string &dataset_dir, const std::string &usage, const std::string &quality_mode,
|
||||
const std::string &task, bool decode, std::shared_ptr<SamplerObj> sampler,
|
||||
std::shared_ptr<DatasetCache> cache);
|
||||
const std::string &task, bool decode, const std::shared_ptr<SamplerObj> &sampler,
|
||||
const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~CityscapesNode() = default;
|
||||
~CityscapesNode() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
@ -105,7 +104,6 @@ class CityscapesNode : public MappableSourceNode {
|
|||
bool decode_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CITYSCAPES_NODE_H_
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
/// \class CLUENode
|
||||
/// \brief A Dataset derived class to represent CLUE dataset
|
||||
class CLUENode : public NonMappableSourceNode {
|
||||
|
@ -36,7 +35,7 @@ class CLUENode : public NonMappableSourceNode {
|
|||
ShuffleMode shuffle, int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CLUENode() = default;
|
||||
~CLUENode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -145,7 +144,6 @@ class CLUENode : public NonMappableSourceNode {
|
|||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CLUE_NODE_H_
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class CocoNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor
|
||||
|
@ -34,7 +33,7 @@ class CocoNode : public MappableSourceNode {
|
|||
const bool &extra_metadata);
|
||||
|
||||
/// \brief Destructor
|
||||
~CocoNode() = default;
|
||||
~CocoNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -104,7 +103,6 @@ class CocoNode : public MappableSourceNode {
|
|||
std::shared_ptr<SamplerObj> sampler_;
|
||||
bool extra_metadata_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_COCO_NODE_H_
|
||||
|
|
|
@ -34,7 +34,7 @@ class CoNLL2000Node : public NonMappableSourceNode {
|
|||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~CoNLL2000Node() = default;
|
||||
~CoNLL2000Node() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
|
|
@ -32,9 +32,12 @@ enum CsvType : uint8_t { INT = 0, FLOAT, STRING };
|
|||
class CsvBase {
|
||||
public:
|
||||
CsvBase() = default;
|
||||
|
||||
explicit CsvBase(CsvType t) : type(t) {}
|
||||
virtual ~CsvBase() {}
|
||||
CsvType type;
|
||||
|
||||
virtual ~CsvBase() = default;
|
||||
|
||||
CsvType type{INT};
|
||||
};
|
||||
|
||||
/// \brief CSV Record that can represent integer, float and string.
|
||||
|
@ -42,8 +45,11 @@ template <typename T>
|
|||
class CsvRecord : public CsvBase {
|
||||
public:
|
||||
CsvRecord() = default;
|
||||
|
||||
CsvRecord(CsvType t, T v) : CsvBase(t), value(v) {}
|
||||
~CsvRecord() {}
|
||||
|
||||
~CsvRecord() override = default;
|
||||
|
||||
T value;
|
||||
};
|
||||
|
||||
|
@ -56,7 +62,7 @@ class CSVNode : public NonMappableSourceNode {
|
|||
std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor
|
||||
~CSVNode() = default;
|
||||
~CSVNode() override = default;
|
||||
|
||||
/// \brief Node name getter
|
||||
/// \return Name of the current node
|
||||
|
@ -139,7 +145,6 @@ class CSVNode : public NonMappableSourceNode {
|
|||
int32_t num_shards_;
|
||||
int32_t shard_id_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_CSV_NODE_H_
|
||||
|
|
|
@ -33,7 +33,7 @@ class DBpediaNode : public NonMappableSourceNode {
|
|||
int32_t num_shards, int32_t shard_id, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~DBpediaNode() = default;
|
||||
~DBpediaNode() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class DIV2KNode : public MappableSourceNode {
|
||||
public:
|
||||
/// \brief Constructor.
|
||||
|
@ -33,7 +32,7 @@ class DIV2KNode : public MappableSourceNode {
|
|||
bool decode, const std::shared_ptr<SamplerObj> &sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~DIV2KNode() = default;
|
||||
~DIV2KNode() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
|
|
@ -37,7 +37,7 @@ class EMnistNode : public MappableSourceNode {
|
|||
std::shared_ptr<SamplerObj> sampler, std::shared_ptr<DatasetCache> cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~EMnistNode() = default;
|
||||
~EMnistNode() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
@ -105,7 +105,6 @@ class EMnistNode : public MappableSourceNode {
|
|||
std::string usage_;
|
||||
std::shared_ptr<SamplerObj> sampler_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_ENGINE_IR_DATASETOPS_SOURCE_EMNIST_NODE_H_
|
||||
|
|
|
@ -40,7 +40,7 @@ class EnWik9Node : public NonMappableSourceNode {
|
|||
int32_t shard_id, const std::shared_ptr<DatasetCache> &cache);
|
||||
|
||||
/// \brief Destructor.
|
||||
~EnWik9Node() = default;
|
||||
~EnWik9Node() override = default;
|
||||
|
||||
/// \brief Node name getter.
|
||||
/// \return Name of the current node.
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue