fix random cpu ops
This commit is contained in:
parent
6d99de6d5a
commit
5db0bf553b
|
@ -26,6 +26,7 @@ constexpr size_t kUniformRealInputsNum = 1;
|
|||
constexpr size_t kUniformIntOutputsNum = 1;
|
||||
constexpr size_t kUniformRealOutputsNum = 1;
|
||||
constexpr size_t kStandardNormalOutputsNum = 1;
|
||||
constexpr float kRandomBlockSize = 128.0;
|
||||
constexpr char kKernelName[] = "Random";
|
||||
} // namespace
|
||||
void StandardNormal(float *output, std::normal_distribution<float> distribution,
|
||||
|
@ -39,12 +40,12 @@ void LaunchStandardNormal(RandomCPUKernel *content, unsigned int seed, const std
|
|||
auto output = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
// multithreading
|
||||
size_t lens = outputs[0]->size / sizeof(float);
|
||||
auto task = [&seed, &output](size_t start, size_t end) {
|
||||
std::default_random_engine random_generator(++seed);
|
||||
auto task = [&seed, &output, &random_generator](size_t start, size_t end) {
|
||||
std::normal_distribution<float> distribution;
|
||||
std::default_random_engine random_generator(++seed);
|
||||
StandardNormal(output, distribution, random_generator, start, end);
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, lens, content, &content->parallel_search_info_);
|
||||
ParallelLaunch(task, lens, kRandomBlockSize, content);
|
||||
}
|
||||
|
||||
void LaunchUniformInt(unsigned int seed, const std::vector<AddressPtr> &inputs,
|
||||
|
|
|
@ -26,6 +26,7 @@ from collections import Counter
|
|||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore import context
|
||||
from mindspore.common.initializer import Zero
|
||||
from .. import signature as sig
|
||||
from .._utils import get_broadcast_shape, is_shape_unknown
|
||||
|
@ -3304,6 +3305,10 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
shrink_axis_mask=0):
|
||||
"""Initialize StridedSlice"""
|
||||
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
|
||||
# auto parallel haven't support begin_mask and end_mask
|
||||
if context.get_auto_parallel_context("parallel_mode") in ["semi_auto_parallel", "auto_parallel"]:
|
||||
begin_mask = 0
|
||||
end_mask = 0
|
||||
validator.check_non_negative_int(begin_mask, 'begin_mask', self.name)
|
||||
validator.check_non_negative_int(end_mask, 'end_mask', self.name)
|
||||
validator.check_non_negative_int(ellipsis_mask, 'ellipsis_mask', self.name)
|
||||
|
|
Loading…
Reference in New Issue