!41862 fix Image, log_matrix_determinant, lp_norm kernel bugs

Merge pull request !41862 from zhuzhongrui/pub_master
This commit is contained in:
i-robot 2022-09-16 02:34:10 +00:00 committed by Gitee
commit 002054321a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 60 additions and 60 deletions

View File

@ -55,61 +55,49 @@ void Ceil(const float *input, float *output, size_t start, size_t end) {
template <typename T, typename S> template <typename T, typename S>
void Real(const T *input, S *output, size_t start, size_t end) { void Real(const T *input, S *output, size_t start, size_t end) {
if (!std::is_same<S, complex64>::value && !std::is_same<S, complex128>::value) { for (size_t i = start; i < end; ++i) {
for (size_t i = start; i < end; ++i) { output[i] = static_cast<S>(std::real(input[i]));
output[i] = static_cast<S>(std::real(input[i]));
}
} else {
MS_LOG(EXCEPTION) << "For Real, it's output data type only support these types: float or double";
} }
} }
template <typename T, typename S> template <typename T, typename S>
void Imag(const T *input, S *output, size_t start, size_t end) { void Imag(const T *input, S *output, size_t start, size_t end) {
if constexpr (!std::is_same<S, std::complex<float>>::value && !std::is_same<S, std::complex<double>>::value) { for (size_t i = start; i < end; ++i) {
for (size_t i = start; i < end; ++i) { output[i] = static_cast<S>(std::imag(input[i]));
output[i] = static_cast<S>(std::imag(input[i]));
}
} else {
MS_LOG(EXCEPTION) << "For Imag, it's output data type only support these types: float or double";
} }
} }
template <typename T, typename S> template <typename T, typename S>
void Conj(const T *input, S *output, size_t start, size_t end) { void Conj(const T *input, S *output, size_t start, size_t end) {
if constexpr (std::is_same<T, S>::value && if constexpr (std::is_same<T, S>::value) {
(std::is_same<T, complex64>::value || std::is_same<T, complex128>::value)) { if constexpr ((std::is_same<T, complex64>::value || std::is_same<T, complex128>::value)) {
for (size_t i = start; i < end; ++i) { for (size_t i = start; i < end; ++i) {
output[i] = static_cast<S>(std::conj(input[i])); output[i] = static_cast<S>(std::conj(input[i]));
}
} else {
for (size_t i = start; i < end; ++i) {
output[i] = static_cast<S>(input[i]);
}
} }
} else { } else {
MS_LOG(EXCEPTION) << "For Conj, it's output data type only support these types: complex<float> or complex<double>"; MS_LOG(EXCEPTION) << "For Conj, it's output data type not equal to input data type.";
} }
} }
template <typename T, typename S> template <typename T, typename S>
void UnaryOpCpuKernelFunc<T, S>::GetUnaryOpFunc() { void UnaryOpCpuKernelFunc<T, S>::GetUnaryOpFunc() {
// only support float const std::map<std::string, UnaryOpFunc> kCommonSupportedMap = {{prim::kPrimReal->name(), &Real<T, S>},
{prim::kPrimImag->name(), &Imag<T, S>},
{prim::kPrimConj->name(), &Conj<T, S>}};
if constexpr (std::is_same<T, float>::value) { if constexpr (std::is_same<T, float>::value) {
static std::map<std::string, UnaryOpFunc> kFloatSupportedMap = {{prim::kPrimCeil->name(), &Ceil}}; if (kernel_name_ == prim::kPrimCeil->name()) {
auto iter = kFloatSupportedMap.find(kernel_name_); unary_op_func_ = &Ceil;
if (iter != kFloatSupportedMap.end()) {
unary_op_func_ = iter->second;
return; return;
} }
} }
auto iter = kCommonSupportedMap.find(kernel_name_);
if constexpr (std::is_same<T, complex64>::value || std::is_same<T, complex128>::value) { if (iter != kCommonSupportedMap.end()) {
static std::map<std::string, UnaryOpFunc> kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real<T, S>}, unary_op_func_ = iter->second;
{prim::kPrimImag->name(), &Imag<T, S>},
{prim::kPrimConj->name(), &Conj<T, S>}};
auto iter = kComplexSupportedTypeMap.find(kernel_name_);
if (iter != kComplexSupportedTypeMap.end()) {
unary_op_func_ = iter->second;
return;
}
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", only support these types: Real, Imag, Conj currently, but got "
<< kernel_name_;
} }
} }
@ -126,15 +114,12 @@ bool UnaryOpCpuKernelFunc<T, S>::RunFunc(const std::vector<AddressPtr> &inputs,
auto output = outputs.front(); auto output = outputs.front();
const auto input_addr = reinterpret_cast<T *>(input->addr); const auto input_addr = reinterpret_cast<T *>(input->addr);
auto output_addr = reinterpret_cast<S *>(output->addr); auto output_addr = reinterpret_cast<S *>(output->addr);
if (unary_op_func_ != nullptr) { if (unary_op_func_ == nullptr) {
ParallelLaunchAutoSearch( MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it has no cpu backend implements.";
std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2),
output->size / sizeof(S), this, &parallel_search_info_);
} else {
if (memcpy_s(output_addr, output->size, input_addr, input->size) != EOK) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does memory copy fail.";
}
} }
ParallelLaunchAutoSearch(
std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2),
output->size / sizeof(S), this, &parallel_search_info_);
return true; return true;
} }
@ -192,7 +177,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpCpuFuncCreator>>>
SpecializeUnaryFunc<float, float>}, SpecializeUnaryFunc<float, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
SpecializeUnaryFunc<double, double>}, SpecializeUnaryFunc<double, double>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc<bool, bool>}}}, {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc<bool, bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SpecializeUnaryFunc<uint8_t, uint8_t>}}},
{kConj, {kConj,
{{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), {{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
SpecializeUnaryFunc<complex128, complex128>}, SpecializeUnaryFunc<complex128, complex128>},

View File

@ -30,6 +30,10 @@ abstract::TupleShapePtr LogMatrixDeterminantInferShape(const PrimitivePtr &primi
const std::vector<AbstractBasePtr> &input_args) { const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name(); auto prim_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (IsDynamicRank(x_shape)) {
abstract::ShapePtr out_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape, out_shape});
}
auto x_rank = SizeToLong(x_shape.size()); auto x_rank = SizeToLong(x_shape.size());
constexpr int64_t number1 = 1; constexpr int64_t number1 = 1;
constexpr int64_t number2 = 2; constexpr int64_t number2 = 2;

View File

@ -34,6 +34,9 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve
MS_EXCEPTION_IF_NULL(item); MS_EXCEPTION_IF_NULL(item);
} }
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
if (IsDynamicRank(input_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
}
auto output_shape = input_shape; auto output_shape = input_shape;
auto input_rank = SizeToLong(input_shape.size()); auto input_rank = SizeToLong(input_shape.size());
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr("axis")); auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr("axis"));
@ -44,13 +47,18 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve
} else { } else {
CheckAndConvertUtils::CheckInRange("axis size", axis.size(), kIncludeNeither, {0, input_rank + 1}, prim_name); CheckAndConvertUtils::CheckInRange("axis size", axis.size(), kIncludeNeither, {0, input_rank + 1}, prim_name);
} }
if (axis.size() > 1) { for (int64_t &axi : axis) {
for (size_t i = 0; i < axis.size(); ++i) { CheckAndConvertUtils::CheckInRange("axis value", axi, kIncludeLeft, {-input_rank, input_rank}, prim_name);
CheckAndConvertUtils::CheckInRange("axis value", axis[i], kIncludeLeft, {-input_rank, input_rank}, prim_name); if (axi < 0) {
if (axis[i] < 0) { axi += input_rank;
axis[i] += input_rank;
}
} }
}
bool invalid_axis = std::any_of(axis.begin(), axis.end(), [&input_rank](int64_t axis) { return axis >= input_rank; });
if (invalid_axis) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of axis is out of range (-" << input_rank << ", "
<< input_rank << ").";
}
if (axis.size() > 1) {
constexpr int64_t place_holder = INT64_MAX; constexpr int64_t place_holder = INT64_MAX;
for (size_t i = 0; i < axis.size(); ++i) { for (size_t i = 0; i < axis.size(); ++i) {
auto temp = axis; auto temp = axis;
@ -77,9 +85,6 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve
} }
} }
} else { } else {
if (axis[0] < 0) {
axis[0] += input_rank;
}
if (!keep_dims) { if (!keep_dims) {
(void)output_shape.erase(output_shape.begin() + axis[0]); (void)output_shape.erase(output_shape.begin() + axis[0]);
} else { } else {

View File

@ -1813,7 +1813,7 @@ class Tensor(Tensor_):
ValueError: If either of `atol` and `rtol` is less than zero. ValueError: If either of `atol` and `rtol` is less than zero.
Supported Platforms: Supported Platforms:
``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16) >>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16)

View File

@ -2930,7 +2930,7 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
ValueError: If either of `atol` and `rtol` is less than zero. ValueError: If either of `atol` and `rtol` is less than zero.
Supported Platforms: Supported Platforms:
``CPU`` ``Ascend`` ``GPU`` ``CPU``
Examples: Examples:
>>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16) >>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16)

View File

@ -53,7 +53,6 @@ isinstance_ = P.IsInstance()
merge = P.Merge() merge = P.Merge()
geswitch = P.GeSwitch() geswitch = P.GeSwitch()
check_bprop = P.CheckBprop() check_bprop = P.CheckBprop()
sqrt = P.Sqrt()
reduce_sum = P.ReduceSum() reduce_sum = P.ReduceSum()
reduce_max = P.ReduceMax() reduce_max = P.ReduceMax()
reduce_min = P.ReduceMin() reduce_min = P.ReduceMin()

View File

@ -115,11 +115,12 @@ class GlobalNorm(nn.Cell):
super(GlobalNorm, self).__init__() super(GlobalNorm, self).__init__()
self.norm = nn.Norm() self.norm = nn.Norm()
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.sqrt = P.Sqrt()
def construct(self, grads): def construct(self, grads):
"""Calculate global norm construct""" """Calculate global norm construct"""
square_sum = self.hyper_map(get_square_sum, grads) square_sum = self.hyper_map(get_square_sum, grads)
global_norms = F.sqrt(F.addn(square_sum)) global_norms = self.sqrt(F.addn(square_sum))
return global_norms return global_norms

View File

@ -18,7 +18,7 @@ import mindspore as ms
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell from mindspore.nn import Cell
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P, functional as F from mindspore.ops import operations as P
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.api import _cell_graph_executor from mindspore.common.api import _cell_graph_executor
@ -63,11 +63,14 @@ class LayerNorm(nn.Cell):
self.mul = P.Mul() self.mul = P.Mul()
self.div = P.RealDiv() self.div = P.RealDiv()
self.square = P.Square() self.square = P.Square()
self.sqrt = P.Sqrt()
def construct(self, x): def construct(self, x):
mean = self.mean(x, -1) mean = self.mean(x, -1)
variance = self.mean(self.square(self.sub(x, mean))) variance = self.mean(self.square(self.sub(x, mean)))
output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps))) add_variance = self.add(variance, self.eps)
sqrt_variance = self.sqrt(add_variance)
output = self.div(self.sub(x, mean), sqrt_variance)
rescaled_output = self.add(self.mul(output, self.gamma), self.beta) rescaled_output = self.add(self.mul(output, self.gamma), self.beta)
return rescaled_output return rescaled_output

View File

@ -18,7 +18,7 @@ import mindspore as ms
from mindspore import context, Tensor, Parameter from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell from mindspore.nn import Cell
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P, functional as F from mindspore.ops import operations as P
from mindspore.common.initializer import initializer from mindspore.common.initializer import initializer
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.api import _cell_graph_executor from mindspore.common.api import _cell_graph_executor
@ -65,6 +65,7 @@ class LayerNorm(nn.Cell):
self.square = P.Square() self.square = P.Square()
self.reshape = P.Reshape() self.reshape = P.Reshape()
self.shape = P.Shape() self.shape = P.Shape()
self.sqrt = P.Sqrt()
def construct(self, x): def construct(self, x):
x_origin_shape = self.shape(x) x_origin_shape = self.shape(x)
@ -74,7 +75,7 @@ class LayerNorm(nn.Cell):
x = self.reshape(x, x_target_shape) x = self.reshape(x, x_target_shape)
mean = self.mean(x, -1) mean = self.mean(x, -1)
variance = self.mean(self.square(self.sub(x, mean))) variance = self.mean(self.square(self.sub(x, mean)))
output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps))) output = self.div(self.sub(x, mean), self.sqrt(self.add(variance, self.eps)))
rescaled_output = self.add(self.mul(output, self.gamma), self.beta) rescaled_output = self.add(self.mul(output, self.gamma), self.beta)
output_shape = self.shape(rescaled_output) + (1,) output_shape = self.shape(rescaled_output) + (1,)
rescaled_output = self.reshape(rescaled_output, output_shape) rescaled_output = self.reshape(rescaled_output, output_shape)