Fix bugs of quantile, matrixsolvel and instancenormv2

This commit is contained in:
panzhihui 2022-12-28 21:25:04 +08:00
parent 9054d923ea
commit 3a6597c7ff
8 changed files with 16 additions and 35 deletions

View File

@ -60,6 +60,7 @@ int TridiagonalMatMulCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
auto input1_shape = inputs[kIndex1]->GetShapeVector();
auto input2_shape = inputs[kIndex2]->GetShapeVector();
auto input3_shape = inputs[kIndex3]->GetShapeVector();
rhs_shape_ = input3_shape;
if ((input0_shape.size() < is_matrix) || (input1_shape.size() < is_matrix) || (input2_shape.size() < is_matrix) ||
(input3_shape.size() < is_matrix)) {
MS_LOG(ERROR) << "For '" << kernel_name_
@ -140,10 +141,6 @@ bool TridiagonalMatMulCpuKernelMod::Launch(const std::vector<kernel::AddressPtr>
template <typename T>
void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
T *superdiag_ptr = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(superdiag_ptr);
T *maindiag_ptr = reinterpret_cast<T *>(inputs[1]->addr);
@ -154,15 +151,14 @@ void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<Ad
MS_EXCEPTION_IF_NULL(rhs_ptr);
T *y_ptr = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(y_ptr);
auto rhs_shape = AnfAlgo::GetInputDeviceShape(node_, kInputShapeIndex3);
size_t m = static_cast<size_t>(rhs_shape[rhs_shape.size() - row]);
size_t n = static_cast<size_t>(rhs_shape[rhs_shape.size() - col]);
size_t m = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - row]);
size_t n = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - col]);
size_t size_mn = m * n;
using VectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
size_t rhs_num = 1;
for (size_t i = 0; i < rhs_shape.size(); i++) {
rhs_num *= static_cast<size_t>(rhs_shape[i]);
for (size_t i = 0; i < rhs_shape_.size(); i++) {
rhs_num *= static_cast<size_t>(rhs_shape_[i]);
}
size_t rhs_matrix_num = rhs_num / size_mn;
for (size_t i = 0; i < rhs_matrix_num; i++) {

View File

@ -39,7 +39,7 @@ class TridiagonalMatMulCpuKernelMod : public NativeCpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
private:
CNodeWeakPtr node_wpt_;
ShapeVector rhs_shape_;
TypeId dtype_{kTypeUnknown};
template <typename T>
void LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

View File

@ -135,7 +135,7 @@ void InstanceNormV2CpuKernelMod::CollectLinearAndConstant(const typename TTypes<
} else {
mean = running_mean(idx);
float _std_ = std::sqrt(running_var(idx) + static_cast<float>(epsilon_));
MS_EXCEPTION_IF_ZERO("_std_", static_cast<int64_t>(_std_));
MS_EXCEPTION_IF_ZERO("_std_", _std_);
invstd = float_init_one / _std_;
}
_alpha_[idx] = invstd * gamma(idx);

View File

@ -108,7 +108,7 @@ bool InstanceNormV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::Addr
float mean = float_init_zero, invstd = float_init_zero;
mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx);
float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_);
MS_EXCEPTION_IF_ZERO("_invstd_", static_cast<int64_t>(_invstd_));
MS_EXCEPTION_IF_ZERO("_invstd_", _invstd_);
invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_;
double sum = double_init_zero, dotp = double_init_zero;

View File

@ -34,8 +34,6 @@ bool QuantileCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std:
MS_LOG(ERROR) << "cast Quantile ops failed!";
return false;
}
input_shape_ = inputs.at(0)->GetShapeVector();
q_shape_ = inputs.at(1)->GetShapeVector();
kernel_name_ = kernel_ptr->name();
dim_ = GetValue<int64_t>(base_operator->GetAttr("dim"));
keep_dims_ = GetValue<bool>(base_operator->GetAttr("keep_dims"));
@ -47,11 +45,13 @@ int QuantileCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
int ret = 0;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != KRET_OK) {
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
return ret;
}
return 0;
input_shape_ = inputs.at(0)->GetShapeVector();
q_shape_ = inputs.at(1)->GetShapeVector();
return KRET_OK;
}
uint32_t QuantileCpuKernelMod::MaybeWrapDim(int dim, int dim_post_expr) {

View File

@ -144,6 +144,7 @@ class MIND_API AGTridiagonalMatMulInfer : public abstract::OpInferBase {
}
};
MIND_API_OPERATOR_IMPL(TridiagonalMatMul, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(TridiagonalMatMul, prim::kPrimTridiagonalMatMul, AGTridiagonalMatMulInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -30,7 +30,7 @@
namespace mindspore {
namespace ops {
constexpr auto kNameTridiagonalMatMul = "TridiagonalMatMul";
class MS_CORE_API TridiagonalMatMul : public BaseOperator {
class MIND_API TridiagonalMatMul : public BaseOperator {
public:
MIND_API_BASE_MEMBER(TridiagonalMatMul);
TridiagonalMatMul() : BaseOperator(kNameTridiagonalMatMul) {

View File

@ -601,22 +601,6 @@ def get_bprop_matrix_exp(self):
return bprop
@bprop_getters.register(P.MatrixInverse)
def get_bprop_matrix_inverse(self):
"""Generate bprop for MatrixInverse"""
matmul_x1 = nn.MatMul(transpose_x1=True)
matmul_x2 = nn.MatMul(transpose_x2=True)
neg = P.Neg()
def bprop(x, out, dout):
dx = matmul_x2(dout, out)
dx = matmul_x1(out, dx)
dx = neg(dx)
return (dx,)
return bprop
@bprop_getters.register(MatrixPower)
def get_bprop_matrix_power(self):
"""Generate bprop for MatrixPower"""
@ -786,7 +770,7 @@ def get_bprop_matrix_solve_ls(self):
grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
grad_b = matmul(matrix, z)
return (grad_a, grad_b, None)
return (grad_a, grad_b, zeros_like(l2))
def under_determined(matrix, rhs, l2, dout):
if matrix.dtype == mstype.complex64:
@ -833,7 +817,7 @@ def get_bprop_matrix_solve_ls(self):
a2 = matmul(tmp, a2_temp)
grad_a = add(a1, a2)
return (grad_a, grad_b, None)
return (grad_a, grad_b, zeros_like(l2))
if fast is False:
raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")