fix l2normalize ops parallel problem

This commit is contained in:
yuanwei66 2021-04-15 11:43:51 +08:00
parent c43deb5469
commit ee9b452b57
3 changed files with 8 additions and 9 deletions

View File

@ -22,7 +22,7 @@ namespace kernel {
template <typename T>
void L2NormalizeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
epsilon_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
epsilon_ = static_cast<T>(AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon"));
axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis"));
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
@ -35,9 +35,6 @@ void L2NormalizeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
template <typename T>
void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t reduce_size, const int dims,
std::unique_ptr<T[]> *denominator_addr) {
T temp = (T)0.0;
T epsilon = (T)epsilon_;
T denominator = (T)0.0;
// Calculate transpose axes and stride
size_t stride = 1;
std::vector<size_t> axes(input_shape_.size());
@ -60,6 +57,8 @@ void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t
TransposeIterator tran_base_iter(std::move(transpose_shape), std::move(axes), input_shape_);
auto task = [&](size_t start, size_t end) {
T temp = (T)0.0;
T denominator = (T)0.0;
auto iter = tran_base_iter;
iter.SetPos(start * stride);
for (size_t i = start; i < end; ++i) {
@ -71,7 +70,7 @@ void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t
denominator += temp * temp;
iter.GenNextPos();
}
denominator = (denominator > epsilon) ? denominator : epsilon;
denominator = (denominator > epsilon_) ? denominator : epsilon_;
(*denominator_addr)[i] = sqrt(denominator);
}
};
@ -146,7 +145,7 @@ void L2NormalizeCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "Attr axis_ " << axis_ << " must be in " << -dims << "~" << dims;
}
if (epsilon_ == 0.0) {
if (epsilon_ == (T)0.0) {
MS_LOG(EXCEPTION) << "Attr epsilon can not be zero.";
}
}

View File

@ -45,7 +45,7 @@ class L2NormalizeCPUKernel : public CPUKernel {
private:
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
float epsilon_;
T epsilon_;
int axis_;
void CheckParam(const CNodePtr &kernel_node);
};

View File

@ -37,10 +37,10 @@ class Net(Cell):
@pytest.mark.platform_x86_cpu_training
@pytest.mark.env_onecard
def test_l2normalize_float32():
x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4)
x = np.arange(20*20*20*20).astype(np.float32).reshape(20, 20, 20, 20)
expect = x / np.sqrt(np.sum(x**2, axis=0, keepdims=True))
x = Tensor(x)
error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5
error = np.ones(shape=[20, 20, 20, 20]) * 1.0e-5
norm_op = Net(axis=0)
output = norm_op(x)