fix random cpu ops

This commit is contained in:
fangzehua 2022-01-11 16:00:34 +08:00
parent 6d99de6d5a
commit 5db0bf553b
2 changed files with 9 additions and 3 deletions

View File

@ -26,6 +26,7 @@ constexpr size_t kUniformRealInputsNum = 1;
constexpr size_t kUniformIntOutputsNum = 1;
constexpr size_t kUniformRealOutputsNum = 1;
constexpr size_t kStandardNormalOutputsNum = 1;
constexpr float kRandomBlockSize = 128.0;
constexpr char kKernelName[] = "Random";
} // namespace
void StandardNormal(float *output, std::normal_distribution<float> distribution,
@ -39,12 +40,12 @@ void LaunchStandardNormal(RandomCPUKernel *content, unsigned int seed, const std
auto output = reinterpret_cast<float *>(outputs[0]->addr);
// multithreading
size_t lens = outputs[0]->size / sizeof(float);
auto task = [&seed, &output](size_t start, size_t end) {
std::default_random_engine random_generator(++seed);
auto task = [&seed, &output, &random_generator](size_t start, size_t end) {
std::normal_distribution<float> distribution;
std::default_random_engine random_generator(++seed);
StandardNormal(output, distribution, random_generator, start, end);
};
ParallelLaunchAutoSearch(task, lens, content, &content->parallel_search_info_);
ParallelLaunch(task, lens, kRandomBlockSize, content);
}
void LaunchUniformInt(unsigned int seed, const std::vector<AddressPtr> &inputs,

View File

@ -26,6 +26,7 @@ from collections import Counter
import numpy as np
from mindspore import log as logger
from mindspore import context
from mindspore.common.initializer import Zero
from .. import signature as sig
from .._utils import get_broadcast_shape, is_shape_unknown
@ -3304,6 +3305,10 @@ class StridedSlice(PrimitiveWithInfer):
shrink_axis_mask=0):
"""Initialize StridedSlice"""
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
# auto parallel haven't support begin_mask and end_mask
if context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]:
begin_mask = 0
end_mask = 0
validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
validator.check_non_negative_int(end_mask, 'end_mask', self.name)
validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)