forked from mindspore-Ecosystem/mindspore
!16042 Fixes RNG_seed bug in StandardNormal operator
From: @huangbo77 Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @wuxuejian
This commit is contained in:
commit
d632a34b16
mindspore/ccsrc/backend/kernel_compiler/cpu
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -15,6 +15,7 @@
|
|||
*/
|
||||
#include <random>
|
||||
#include <thread>
|
||||
#include "common/thread_pool.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#include "backend/kernel_compiler/cpu/random_cpu_kernel.h"
|
||||
|
||||
|
@ -41,11 +42,23 @@ void LaunchStandardNormal(int seed, int seed2, const std::vector<AddressPtr> &ou
|
|||
auto output = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
size_t lens = outputs[0]->size / sizeof(float);
|
||||
std::normal_distribution<float> distribution;
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
auto max_thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
|
||||
const float block_size = 128.0;
|
||||
size_t thread_num = lens < block_size * max_thread_num ? std::ceil(lens / block_size) : max_thread_num;
|
||||
std::vector<common::Task> tasks;
|
||||
size_t start = 0;
|
||||
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
|
||||
while (start < lens) {
|
||||
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
|
||||
std::default_random_engine random_generator(++RNG_seed);
|
||||
StandardNormal(output, distribution, random_generator, start, end);
|
||||
};
|
||||
CPUKernelUtils::ParallelFor(task, lens);
|
||||
auto block = [&, start, end]() {
|
||||
StandardNormal(output, distribution, random_generator, start, end);
|
||||
return common::SUCCESS;
|
||||
};
|
||||
tasks.emplace_back(block);
|
||||
start += once_compute_size;
|
||||
}
|
||||
common::ThreadPool::GetInstance().SyncRun(tasks);
|
||||
}
|
||||
|
||||
void RandomCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
|
|
Loading…
Reference in New Issue