forked from mindspore-Ecosystem/mindspore
fix random
This commit is contained in:
parent
8539de0e01
commit
66d11d44ab
|
@ -54,6 +54,15 @@ message(PATCH_EXECUTABLE = ${Patch_EXECUTABLE})
|
||||||
|
|
||||||
find_required_package(Threads)
|
find_required_package(Threads)
|
||||||
|
|
||||||
|
# add openmp if the onednn use ms threadpool
|
||||||
|
if(USE_MS_THREADPOOL_FOR_DNNL)
|
||||||
|
find_package(OpenMP)
|
||||||
|
if(OPENMP_FOUND)
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
|
||||||
|
else()
|
||||||
|
message(WARNING "OpenMP not found")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
## packages used on Linux
|
## packages used on Linux
|
||||||
if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
if(NOT CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
|
|
|
@ -30,7 +30,7 @@ option(ENABLE_FAST_HASH_TABLE "Enable use fast hash table instead of std ones" O
|
||||||
option(USE_LLVM "use llvm" OFF)
|
option(USE_LLVM "use llvm" OFF)
|
||||||
option(USE_MS_THREADPOOL_FOR_DNNL "use ms threadpool for onednn ops" ON)
|
option(USE_MS_THREADPOOL_FOR_DNNL "use ms threadpool for onednn ops" ON)
|
||||||
|
|
||||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
if(NOT CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||||
set(USE_MS_THREADPOOL_FOR_DNNL OFF)
|
set(USE_MS_THREADPOOL_FOR_DNNL OFF)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
|
@ -18,14 +18,6 @@ if(ENABLE_D)
|
||||||
add_subdirectory(backend/kernel_compiler/aicpu/aicpu_ops)
|
add_subdirectory(backend/kernel_compiler/aicpu/aicpu_ops)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# add openmp
|
|
||||||
find_package(OpenMP)
|
|
||||||
if(OPENMP_FOUND)
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
|
|
||||||
else()
|
|
||||||
message(FATAL_ERROR "OpenMP not found")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if(ENABLE_CPU)
|
if(ENABLE_CPU)
|
||||||
if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
if(${CMAKE_HOST_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||||
set(PLATFORM_ARM64 "on")
|
set(PLATFORM_ARM64 "on")
|
||||||
|
|
|
@ -40,9 +40,14 @@ void LaunchStandardNormal(RandomCPUKernel *content, unsigned int seed, const std
|
||||||
auto output = reinterpret_cast<float *>(outputs[0]->addr);
|
auto output = reinterpret_cast<float *>(outputs[0]->addr);
|
||||||
// multithreading
|
// multithreading
|
||||||
size_t lens = outputs[0]->size / sizeof(float);
|
size_t lens = outputs[0]->size / sizeof(float);
|
||||||
std::default_random_engine random_generator(++seed);
|
auto thread_pool = GetActorMgrInnerThreadPool();
|
||||||
auto task = [&seed, &output, &random_generator](size_t start, size_t end) {
|
size_t max_thread_num = thread_pool->GetKernelThreadNum();
|
||||||
std::normal_distribution<float> distribution;
|
size_t thread_num = lens < kRandomBlockSize * max_thread_num ? std::ceil(lens / kRandomBlockSize) : max_thread_num;
|
||||||
|
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
|
||||||
|
std::normal_distribution<float> distribution;
|
||||||
|
auto task = [&](size_t start, size_t end) {
|
||||||
|
auto task_id = start / once_compute_size;
|
||||||
|
std::default_random_engine random_generator(seed + task_id);
|
||||||
StandardNormal(output, distribution, random_generator, start, end);
|
StandardNormal(output, distribution, random_generator, start, end);
|
||||||
};
|
};
|
||||||
ParallelLaunch(task, lens, kRandomBlockSize, content);
|
ParallelLaunch(task, lens, kRandomBlockSize, content);
|
||||||
|
|
|
@ -832,7 +832,8 @@ def set_context(**kwargs):
|
||||||
Through context.set_context(backend_policy="ms")
|
Through context.set_context(backend_policy="ms")
|
||||||
Default: The value must be in ['ge', 'vm', 'ms'].
|
Default: The value must be in ['ge', 'vm', 'ms'].
|
||||||
runtime_num_threads(int): The thread pool number of cpu kernel and actor used in runtime,
|
runtime_num_threads(int): The thread pool number of cpu kernel and actor used in runtime,
|
||||||
which must bigger than 0.
|
which must bigger than 0. Default value if 0.6 times of the machine threads, if you run many processes at
|
||||||
|
the same time, you should set the value smaller to avoid thread contention.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If input key is not an attribute in context.
|
ValueError: If input key is not an attribute in context.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue