Fix bugs of quantile, matrixsolvel and instancenormv2

This commit is contained in:
panzhihui 2022-12-28 21:25:04 +08:00
parent 9054d923ea
commit 3a6597c7ff
8 changed files with 16 additions and 35 deletions

View File

@ -60,6 +60,7 @@ int TridiagonalMatMulCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
auto input1_shape = inputs[kIndex1]->GetShapeVector(); auto input1_shape = inputs[kIndex1]->GetShapeVector();
auto input2_shape = inputs[kIndex2]->GetShapeVector(); auto input2_shape = inputs[kIndex2]->GetShapeVector();
auto input3_shape = inputs[kIndex3]->GetShapeVector(); auto input3_shape = inputs[kIndex3]->GetShapeVector();
rhs_shape_ = input3_shape;
if ((input0_shape.size() < is_matrix) || (input1_shape.size() < is_matrix) || (input2_shape.size() < is_matrix) || if ((input0_shape.size() < is_matrix) || (input1_shape.size() < is_matrix) || (input2_shape.size() < is_matrix) ||
(input3_shape.size() < is_matrix)) { (input3_shape.size() < is_matrix)) {
MS_LOG(ERROR) << "For '" << kernel_name_ MS_LOG(ERROR) << "For '" << kernel_name_
@ -140,10 +141,6 @@ bool TridiagonalMatMulCpuKernelMod::Launch(const std::vector<kernel::AddressPtr>
template <typename T> template <typename T>
void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs, void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) { const std::vector<AddressPtr> &outputs) {
auto node_ = node_wpt_.lock();
if (!node_) {
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
}
T *superdiag_ptr = reinterpret_cast<T *>(inputs[0]->addr); T *superdiag_ptr = reinterpret_cast<T *>(inputs[0]->addr);
MS_EXCEPTION_IF_NULL(superdiag_ptr); MS_EXCEPTION_IF_NULL(superdiag_ptr);
T *maindiag_ptr = reinterpret_cast<T *>(inputs[1]->addr); T *maindiag_ptr = reinterpret_cast<T *>(inputs[1]->addr);
@ -154,15 +151,14 @@ void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<Ad
MS_EXCEPTION_IF_NULL(rhs_ptr); MS_EXCEPTION_IF_NULL(rhs_ptr);
T *y_ptr = reinterpret_cast<T *>(outputs[0]->addr); T *y_ptr = reinterpret_cast<T *>(outputs[0]->addr);
MS_EXCEPTION_IF_NULL(y_ptr); MS_EXCEPTION_IF_NULL(y_ptr);
auto rhs_shape = AnfAlgo::GetInputDeviceShape(node_, kInputShapeIndex3); size_t m = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - row]);
size_t m = static_cast<size_t>(rhs_shape[rhs_shape.size() - row]); size_t n = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - col]);
size_t n = static_cast<size_t>(rhs_shape[rhs_shape.size() - col]);
size_t size_mn = m * n; size_t size_mn = m * n;
using VectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>; using VectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>; using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
size_t rhs_num = 1; size_t rhs_num = 1;
for (size_t i = 0; i < rhs_shape.size(); i++) { for (size_t i = 0; i < rhs_shape_.size(); i++) {
rhs_num *= static_cast<size_t>(rhs_shape[i]); rhs_num *= static_cast<size_t>(rhs_shape_[i]);
} }
size_t rhs_matrix_num = rhs_num / size_mn; size_t rhs_matrix_num = rhs_num / size_mn;
for (size_t i = 0; i < rhs_matrix_num; i++) { for (size_t i = 0; i < rhs_matrix_num; i++) {

View File

@ -39,7 +39,7 @@ class TridiagonalMatMulCpuKernelMod : public NativeCpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override; std::vector<KernelAttr> GetOpSupport() override;
private: private:
CNodeWeakPtr node_wpt_; ShapeVector rhs_shape_;
TypeId dtype_{kTypeUnknown}; TypeId dtype_{kTypeUnknown};
template <typename T> template <typename T>
void LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); void LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

View File

@ -135,7 +135,7 @@ void InstanceNormV2CpuKernelMod::CollectLinearAndConstant(const typename TTypes<
} else { } else {
mean = running_mean(idx); mean = running_mean(idx);
float _std_ = std::sqrt(running_var(idx) + static_cast<float>(epsilon_)); float _std_ = std::sqrt(running_var(idx) + static_cast<float>(epsilon_));
MS_EXCEPTION_IF_ZERO("_std_", static_cast<int64_t>(_std_)); MS_EXCEPTION_IF_ZERO("_std_", _std_);
invstd = float_init_one / _std_; invstd = float_init_one / _std_;
} }
_alpha_[idx] = invstd * gamma(idx); _alpha_[idx] = invstd * gamma(idx);

View File

@ -108,7 +108,7 @@ bool InstanceNormV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::Addr
float mean = float_init_zero, invstd = float_init_zero; float mean = float_init_zero, invstd = float_init_zero;
mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx); mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx);
float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_); float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_);
MS_EXCEPTION_IF_ZERO("_invstd_", static_cast<int64_t>(_invstd_)); MS_EXCEPTION_IF_ZERO("_invstd_", _invstd_);
invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_; invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_;
double sum = double_init_zero, dotp = double_init_zero; double sum = double_init_zero, dotp = double_init_zero;

View File

@ -34,8 +34,6 @@ bool QuantileCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std:
MS_LOG(ERROR) << "cast Quantile ops failed!"; MS_LOG(ERROR) << "cast Quantile ops failed!";
return false; return false;
} }
input_shape_ = inputs.at(0)->GetShapeVector();
q_shape_ = inputs.at(1)->GetShapeVector();
kernel_name_ = kernel_ptr->name(); kernel_name_ = kernel_ptr->name();
dim_ = GetValue<int64_t>(base_operator->GetAttr("dim")); dim_ = GetValue<int64_t>(base_operator->GetAttr("dim"));
keep_dims_ = GetValue<bool>(base_operator->GetAttr("keep_dims")); keep_dims_ = GetValue<bool>(base_operator->GetAttr("keep_dims"));
@ -47,11 +45,13 @@ int QuantileCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
const std::vector<KernelTensorPtr> &outputs, const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) { const std::map<uint32_t, tensor::TensorPtr> &others) {
int ret = 0; int ret = 0;
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) { if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != KRET_OK) {
MS_LOG(WARNING) << kernel_name_ << " reinit failed."; MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
return ret; return ret;
} }
return 0; input_shape_ = inputs.at(0)->GetShapeVector();
q_shape_ = inputs.at(1)->GetShapeVector();
return KRET_OK;
} }
uint32_t QuantileCpuKernelMod::MaybeWrapDim(int dim, int dim_post_expr) { uint32_t QuantileCpuKernelMod::MaybeWrapDim(int dim, int dim_post_expr) {

View File

@ -144,6 +144,7 @@ class MIND_API AGTridiagonalMatMulInfer : public abstract::OpInferBase {
} }
}; };
MIND_API_OPERATOR_IMPL(TridiagonalMatMul, BaseOperator);
REGISTER_PRIMITIVE_OP_INFER_IMPL(TridiagonalMatMul, prim::kPrimTridiagonalMatMul, AGTridiagonalMatMulInfer, false); REGISTER_PRIMITIVE_OP_INFER_IMPL(TridiagonalMatMul, prim::kPrimTridiagonalMatMul, AGTridiagonalMatMulInfer, false);
} // namespace ops } // namespace ops
} // namespace mindspore } // namespace mindspore

View File

@ -30,7 +30,7 @@
namespace mindspore { namespace mindspore {
namespace ops { namespace ops {
constexpr auto kNameTridiagonalMatMul = "TridiagonalMatMul"; constexpr auto kNameTridiagonalMatMul = "TridiagonalMatMul";
class MS_CORE_API TridiagonalMatMul : public BaseOperator { class MIND_API TridiagonalMatMul : public BaseOperator {
public: public:
MIND_API_BASE_MEMBER(TridiagonalMatMul); MIND_API_BASE_MEMBER(TridiagonalMatMul);
TridiagonalMatMul() : BaseOperator(kNameTridiagonalMatMul) { TridiagonalMatMul() : BaseOperator(kNameTridiagonalMatMul) {

View File

@ -601,22 +601,6 @@ def get_bprop_matrix_exp(self):
return bprop return bprop
@bprop_getters.register(P.MatrixInverse)
def get_bprop_matrix_inverse(self):
"""Generate bprop for MatrixInverse"""
matmul_x1 = nn.MatMul(transpose_x1=True)
matmul_x2 = nn.MatMul(transpose_x2=True)
neg = P.Neg()
def bprop(x, out, dout):
dx = matmul_x2(dout, out)
dx = matmul_x1(out, dx)
dx = neg(dx)
return (dx,)
return bprop
@bprop_getters.register(MatrixPower) @bprop_getters.register(MatrixPower)
def get_bprop_matrix_power(self): def get_bprop_matrix_power(self):
"""Generate bprop for MatrixPower""" """Generate bprop for MatrixPower"""
@ -786,7 +770,7 @@ def get_bprop_matrix_solve_ls(self):
grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp)) grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
grad_b = matmul(matrix, z) grad_b = matmul(matrix, z)
return (grad_a, grad_b, None) return (grad_a, grad_b, zeros_like(l2))
def under_determined(matrix, rhs, l2, dout): def under_determined(matrix, rhs, l2, dout):
if matrix.dtype == mstype.complex64: if matrix.dtype == mstype.complex64:
@ -833,7 +817,7 @@ def get_bprop_matrix_solve_ls(self):
a2 = matmul(tmp, a2_temp) a2 = matmul(tmp, a2_temp)
grad_a = add(a1, a2) grad_a = add(a1, a2)
return (grad_a, grad_b, None) return (grad_a, grad_b, zeros_like(l2))
if fast is False: if fast is False:
raise ValueError("For MatrixSolveLs, gradient not defined for fast=False") raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")