forked from mindspore-Ecosystem/mindspore
!41862 fix Image, log_matrix_determinant, lp_norm kernel bugs
Merge pull request !41862 from zhuzhongrui/pub_master
This commit is contained in:
commit
002054321a
|
@ -55,61 +55,49 @@ void Ceil(const float *input, float *output, size_t start, size_t end) {
|
|||
|
||||
template <typename T, typename S>
|
||||
void Real(const T *input, S *output, size_t start, size_t end) {
|
||||
if (!std::is_same<S, complex64>::value && !std::is_same<S, complex128>::value) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::real(input[i]));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For Real, it's output data type only support these types: float or double";
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::real(input[i]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Imag(const T *input, S *output, size_t start, size_t end) {
|
||||
if constexpr (!std::is_same<S, std::complex<float>>::value && !std::is_same<S, std::complex<double>>::value) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::imag(input[i]));
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For Imag, it's output data type only support these types: float or double";
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::imag(input[i]));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void Conj(const T *input, S *output, size_t start, size_t end) {
|
||||
if constexpr (std::is_same<T, S>::value &&
|
||||
(std::is_same<T, complex64>::value || std::is_same<T, complex128>::value)) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::conj(input[i]));
|
||||
if constexpr (std::is_same<T, S>::value) {
|
||||
if constexpr ((std::is_same<T, complex64>::value || std::is_same<T, complex128>::value)) {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(std::conj(input[i]));
|
||||
}
|
||||
} else {
|
||||
for (size_t i = start; i < end; ++i) {
|
||||
output[i] = static_cast<S>(input[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For Conj, it's output data type only support these types: complex<float> or complex<double>";
|
||||
MS_LOG(EXCEPTION) << "For Conj, it's output data type not equal to input data type.";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void UnaryOpCpuKernelFunc<T, S>::GetUnaryOpFunc() {
|
||||
// only support float
|
||||
const std::map<std::string, UnaryOpFunc> kCommonSupportedMap = {{prim::kPrimReal->name(), &Real<T, S>},
|
||||
{prim::kPrimImag->name(), &Imag<T, S>},
|
||||
{prim::kPrimConj->name(), &Conj<T, S>}};
|
||||
if constexpr (std::is_same<T, float>::value) {
|
||||
static std::map<std::string, UnaryOpFunc> kFloatSupportedMap = {{prim::kPrimCeil->name(), &Ceil}};
|
||||
auto iter = kFloatSupportedMap.find(kernel_name_);
|
||||
if (iter != kFloatSupportedMap.end()) {
|
||||
unary_op_func_ = iter->second;
|
||||
if (kernel_name_ == prim::kPrimCeil->name()) {
|
||||
unary_op_func_ = &Ceil;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (std::is_same<T, complex64>::value || std::is_same<T, complex128>::value) {
|
||||
static std::map<std::string, UnaryOpFunc> kComplexSupportedTypeMap = {{prim::kPrimReal->name(), &Real<T, S>},
|
||||
{prim::kPrimImag->name(), &Imag<T, S>},
|
||||
{prim::kPrimConj->name(), &Conj<T, S>}};
|
||||
auto iter = kComplexSupportedTypeMap.find(kernel_name_);
|
||||
if (iter != kComplexSupportedTypeMap.end()) {
|
||||
unary_op_func_ = iter->second;
|
||||
return;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", only support these types: Real, Imag, Conj currently, but got "
|
||||
<< kernel_name_;
|
||||
auto iter = kCommonSupportedMap.find(kernel_name_);
|
||||
if (iter != kCommonSupportedMap.end()) {
|
||||
unary_op_func_ = iter->second;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,15 +114,12 @@ bool UnaryOpCpuKernelFunc<T, S>::RunFunc(const std::vector<AddressPtr> &inputs,
|
|||
auto output = outputs.front();
|
||||
const auto input_addr = reinterpret_cast<T *>(input->addr);
|
||||
auto output_addr = reinterpret_cast<S *>(output->addr);
|
||||
if (unary_op_func_ != nullptr) {
|
||||
ParallelLaunchAutoSearch(
|
||||
std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2),
|
||||
output->size / sizeof(S), this, ¶llel_search_info_);
|
||||
} else {
|
||||
if (memcpy_s(output_addr, output->size, input_addr, input->size) != EOK) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it does memory copy fail.";
|
||||
}
|
||||
if (unary_op_func_ == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', it has no cpu backend implements.";
|
||||
}
|
||||
ParallelLaunchAutoSearch(
|
||||
std::bind(unary_op_func_, input_addr, output_addr, std::placeholders::_1, std::placeholders::_2),
|
||||
output->size / sizeof(S), this, ¶llel_search_info_);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -192,7 +177,9 @@ std::map<std::string, std::vector<std::pair<KernelAttr, UnaryOpCpuFuncCreator>>>
|
|||
SpecializeUnaryFunc<float, float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
SpecializeUnaryFunc<double, double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc<bool, bool>}}},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), SpecializeUnaryFunc<bool, bool>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
SpecializeUnaryFunc<uint8_t, uint8_t>}}},
|
||||
{kConj,
|
||||
{{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
SpecializeUnaryFunc<complex128, complex128>},
|
||||
|
|
|
@ -30,6 +30,10 @@ abstract::TupleShapePtr LogMatrixDeterminantInferShape(const PrimitivePtr &primi
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (IsDynamicRank(x_shape)) {
|
||||
abstract::ShapePtr out_shape = std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{out_shape, out_shape});
|
||||
}
|
||||
auto x_rank = SizeToLong(x_shape.size());
|
||||
constexpr int64_t number1 = 1;
|
||||
constexpr int64_t number2 = 2;
|
||||
|
|
|
@ -34,6 +34,9 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (IsDynamicRank(input_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
}
|
||||
auto output_shape = input_shape;
|
||||
auto input_rank = SizeToLong(input_shape.size());
|
||||
auto axis = GetValue<std::vector<int64_t>>(primitive->GetAttr("axis"));
|
||||
|
@ -44,13 +47,18 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
} else {
|
||||
CheckAndConvertUtils::CheckInRange("axis size", axis.size(), kIncludeNeither, {0, input_rank + 1}, prim_name);
|
||||
}
|
||||
if (axis.size() > 1) {
|
||||
for (size_t i = 0; i < axis.size(); ++i) {
|
||||
CheckAndConvertUtils::CheckInRange("axis value", axis[i], kIncludeLeft, {-input_rank, input_rank}, prim_name);
|
||||
if (axis[i] < 0) {
|
||||
axis[i] += input_rank;
|
||||
}
|
||||
for (int64_t &axi : axis) {
|
||||
CheckAndConvertUtils::CheckInRange("axis value", axi, kIncludeLeft, {-input_rank, input_rank}, prim_name);
|
||||
if (axi < 0) {
|
||||
axi += input_rank;
|
||||
}
|
||||
}
|
||||
bool invalid_axis = std::any_of(axis.begin(), axis.end(), [&input_rank](int64_t axis) { return axis >= input_rank; });
|
||||
if (invalid_axis) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the value of axis is out of range (-" << input_rank << ", "
|
||||
<< input_rank << ").";
|
||||
}
|
||||
if (axis.size() > 1) {
|
||||
constexpr int64_t place_holder = INT64_MAX;
|
||||
for (size_t i = 0; i < axis.size(); ++i) {
|
||||
auto temp = axis;
|
||||
|
@ -77,9 +85,6 @@ abstract::ShapePtr LpNormInferShape(const PrimitivePtr &primitive, const std::ve
|
|||
}
|
||||
}
|
||||
} else {
|
||||
if (axis[0] < 0) {
|
||||
axis[0] += input_rank;
|
||||
}
|
||||
if (!keep_dims) {
|
||||
(void)output_shape.erase(output_shape.begin() + axis[0]);
|
||||
} else {
|
||||
|
|
|
@ -1813,7 +1813,7 @@ class Tensor(Tensor_):
|
|||
ValueError: If either of `atol` and `rtol` is less than zero.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16)
|
||||
|
|
|
@ -2930,7 +2930,7 @@ def isclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
|
|||
ValueError: If either of `atol` and `rtol` is less than zero.
|
||||
|
||||
Supported Platforms:
|
||||
``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> input = Tensor(np.array([1.3, 2.1, 3.2, 4.1, 5.1]), mindspore.float16)
|
||||
|
|
|
@ -53,7 +53,6 @@ isinstance_ = P.IsInstance()
|
|||
merge = P.Merge()
|
||||
geswitch = P.GeSwitch()
|
||||
check_bprop = P.CheckBprop()
|
||||
sqrt = P.Sqrt()
|
||||
reduce_sum = P.ReduceSum()
|
||||
reduce_max = P.ReduceMax()
|
||||
reduce_min = P.ReduceMin()
|
||||
|
|
|
@ -115,11 +115,12 @@ class GlobalNorm(nn.Cell):
|
|||
super(GlobalNorm, self).__init__()
|
||||
self.norm = nn.Norm()
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.sqrt = P.Sqrt()
|
||||
|
||||
def construct(self, grads):
|
||||
"""Calculate global norm construct"""
|
||||
square_sum = self.hyper_map(get_square_sum, grads)
|
||||
global_norms = F.sqrt(F.addn(square_sum))
|
||||
global_norms = self.sqrt(F.addn(square_sum))
|
||||
return global_norms
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import mindspore as ms
|
|||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P, functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
@ -63,11 +63,14 @@ class LayerNorm(nn.Cell):
|
|||
self.mul = P.Mul()
|
||||
self.div = P.RealDiv()
|
||||
self.square = P.Square()
|
||||
self.sqrt = P.Sqrt()
|
||||
|
||||
def construct(self, x):
|
||||
mean = self.mean(x, -1)
|
||||
variance = self.mean(self.square(self.sub(x, mean)))
|
||||
output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps)))
|
||||
add_variance = self.add(variance, self.eps)
|
||||
sqrt_variance = self.sqrt(add_variance)
|
||||
output = self.div(self.sub(x, mean), sqrt_variance)
|
||||
rescaled_output = self.add(self.mul(output, self.gamma), self.beta)
|
||||
return rescaled_output
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ import mindspore as ms
|
|||
from mindspore import context, Tensor, Parameter
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P, functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import initializer
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.api import _cell_graph_executor
|
||||
|
@ -65,6 +65,7 @@ class LayerNorm(nn.Cell):
|
|||
self.square = P.Square()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.sqrt = P.Sqrt()
|
||||
|
||||
def construct(self, x):
|
||||
x_origin_shape = self.shape(x)
|
||||
|
@ -74,7 +75,7 @@ class LayerNorm(nn.Cell):
|
|||
x = self.reshape(x, x_target_shape)
|
||||
mean = self.mean(x, -1)
|
||||
variance = self.mean(self.square(self.sub(x, mean)))
|
||||
output = self.div(self.sub(x, mean), F.sqrt(self.add(variance, self.eps)))
|
||||
output = self.div(self.sub(x, mean), self.sqrt(self.add(variance, self.eps)))
|
||||
rescaled_output = self.add(self.mul(output, self.gamma), self.beta)
|
||||
output_shape = self.shape(rescaled_output) + (1,)
|
||||
rescaled_output = self.reshape(rescaled_output, output_shape)
|
||||
|
|
Loading…
Reference in New Issue