forked from mindspore-Ecosystem/mindspore
!15209 fix l2normalize ops parallel problem
From: @yuanwei66 Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
commit
6d2c357db8
|
@ -22,7 +22,7 @@ namespace kernel {
|
|||
template <typename T>
|
||||
void L2NormalizeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
epsilon_ = AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon");
|
||||
epsilon_ = static_cast<T>(AnfAlgo::GetNodeAttr<float>(kernel_node, "epsilon"));
|
||||
axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "axis"));
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
|
@ -35,9 +35,6 @@ void L2NormalizeCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
|||
template <typename T>
|
||||
void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t reduce_size, const int dims,
|
||||
std::unique_ptr<T[]> *denominator_addr) {
|
||||
T temp = (T)0.0;
|
||||
T epsilon = (T)epsilon_;
|
||||
T denominator = (T)0.0;
|
||||
// Calculate transpose axes and stride
|
||||
size_t stride = 1;
|
||||
std::vector<size_t> axes(input_shape_.size());
|
||||
|
@ -60,6 +57,8 @@ void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t
|
|||
TransposeIterator tran_base_iter(std::move(transpose_shape), std::move(axes), input_shape_);
|
||||
|
||||
auto task = [&](size_t start, size_t end) {
|
||||
T temp = (T)0.0;
|
||||
T denominator = (T)0.0;
|
||||
auto iter = tran_base_iter;
|
||||
iter.SetPos(start * stride);
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
|
@ -71,7 +70,7 @@ void L2NormalizeCPUKernel<T>::CalcDenominator(const T *input_addr, const size_t
|
|||
denominator += temp * temp;
|
||||
iter.GenNextPos();
|
||||
}
|
||||
denominator = (denominator > epsilon) ? denominator : epsilon;
|
||||
denominator = (denominator > epsilon_) ? denominator : epsilon_;
|
||||
(*denominator_addr)[i] = sqrt(denominator);
|
||||
}
|
||||
};
|
||||
|
@ -146,7 +145,7 @@ void L2NormalizeCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
|
|||
if (axis_ < -dims || axis_ >= dims) {
|
||||
MS_LOG(EXCEPTION) << "Attr axis_ " << axis_ << " must be in " << -dims << "~" << dims;
|
||||
}
|
||||
if (epsilon_ == 0.0) {
|
||||
if (epsilon_ == (T)0.0) {
|
||||
MS_LOG(EXCEPTION) << "Attr epsilon can not be zero.";
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ class L2NormalizeCPUKernel : public CPUKernel {
|
|||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
float epsilon_;
|
||||
T epsilon_;
|
||||
int axis_;
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
};
|
||||
|
|
|
@ -37,10 +37,10 @@ class Net(Cell):
|
|||
@pytest.mark.platform_x86_cpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_l2normalize_float32():
|
||||
x = np.arange(96).astype(np.float32).reshape(2, 3, 4, 4)
|
||||
x = np.arange(20*20*20*20).astype(np.float32).reshape(20, 20, 20, 20)
|
||||
expect = x / np.sqrt(np.sum(x**2, axis=0, keepdims=True))
|
||||
x = Tensor(x)
|
||||
error = np.ones(shape=[2, 3, 4, 4]) * 1.0e-5
|
||||
error = np.ones(shape=[20, 20, 20, 20]) * 1.0e-5
|
||||
|
||||
norm_op = Net(axis=0)
|
||||
output = norm_op(x)
|
||||
|
|
Loading…
Reference in New Issue