forked from mindspore-Ecosystem/mindspore
Fix bugs of quantile, matrixsolvel and instancenormv2
This commit is contained in:
parent
9054d923ea
commit
3a6597c7ff
|
@ -60,6 +60,7 @@ int TridiagonalMatMulCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
|||
auto input1_shape = inputs[kIndex1]->GetShapeVector();
|
||||
auto input2_shape = inputs[kIndex2]->GetShapeVector();
|
||||
auto input3_shape = inputs[kIndex3]->GetShapeVector();
|
||||
rhs_shape_ = input3_shape;
|
||||
if ((input0_shape.size() < is_matrix) || (input1_shape.size() < is_matrix) || (input2_shape.size() < is_matrix) ||
|
||||
(input3_shape.size() < is_matrix)) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
|
@ -140,10 +141,6 @@ bool TridiagonalMatMulCpuKernelMod::Launch(const std::vector<kernel::AddressPtr>
|
|||
template <typename T>
|
||||
void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
auto node_ = node_wpt_.lock();
|
||||
if (!node_) {
|
||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
||||
}
|
||||
T *superdiag_ptr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(superdiag_ptr);
|
||||
T *maindiag_ptr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
|
@ -154,15 +151,14 @@ void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<Ad
|
|||
MS_EXCEPTION_IF_NULL(rhs_ptr);
|
||||
T *y_ptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
MS_EXCEPTION_IF_NULL(y_ptr);
|
||||
auto rhs_shape = AnfAlgo::GetInputDeviceShape(node_, kInputShapeIndex3);
|
||||
size_t m = static_cast<size_t>(rhs_shape[rhs_shape.size() - row]);
|
||||
size_t n = static_cast<size_t>(rhs_shape[rhs_shape.size() - col]);
|
||||
size_t m = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - row]);
|
||||
size_t n = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - col]);
|
||||
size_t size_mn = m * n;
|
||||
using VectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
||||
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||
size_t rhs_num = 1;
|
||||
for (size_t i = 0; i < rhs_shape.size(); i++) {
|
||||
rhs_num *= static_cast<size_t>(rhs_shape[i]);
|
||||
for (size_t i = 0; i < rhs_shape_.size(); i++) {
|
||||
rhs_num *= static_cast<size_t>(rhs_shape_[i]);
|
||||
}
|
||||
size_t rhs_matrix_num = rhs_num / size_mn;
|
||||
for (size_t i = 0; i < rhs_matrix_num; i++) {
|
||||
|
|
|
@ -39,7 +39,7 @@ class TridiagonalMatMulCpuKernelMod : public NativeCpuKernelMod {
|
|||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
CNodeWeakPtr node_wpt_;
|
||||
ShapeVector rhs_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
template <typename T>
|
||||
void LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||
|
|
|
@ -135,7 +135,7 @@ void InstanceNormV2CpuKernelMod::CollectLinearAndConstant(const typename TTypes<
|
|||
} else {
|
||||
mean = running_mean(idx);
|
||||
float _std_ = std::sqrt(running_var(idx) + static_cast<float>(epsilon_));
|
||||
MS_EXCEPTION_IF_ZERO("_std_", static_cast<int64_t>(_std_));
|
||||
MS_EXCEPTION_IF_ZERO("_std_", _std_);
|
||||
invstd = float_init_one / _std_;
|
||||
}
|
||||
_alpha_[idx] = invstd * gamma(idx);
|
||||
|
|
|
@ -108,7 +108,7 @@ bool InstanceNormV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::Addr
|
|||
float mean = float_init_zero, invstd = float_init_zero;
|
||||
mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx);
|
||||
float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_);
|
||||
MS_EXCEPTION_IF_ZERO("_invstd_", static_cast<int64_t>(_invstd_));
|
||||
MS_EXCEPTION_IF_ZERO("_invstd_", _invstd_);
|
||||
invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_;
|
||||
|
||||
double sum = double_init_zero, dotp = double_init_zero;
|
||||
|
|
|
@ -34,8 +34,6 @@ bool QuantileCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std:
|
|||
MS_LOG(ERROR) << "cast Quantile ops failed!";
|
||||
return false;
|
||||
}
|
||||
input_shape_ = inputs.at(0)->GetShapeVector();
|
||||
q_shape_ = inputs.at(1)->GetShapeVector();
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
dim_ = GetValue<int64_t>(base_operator->GetAttr("dim"));
|
||||
keep_dims_ = GetValue<bool>(base_operator->GetAttr("keep_dims"));
|
||||
|
@ -47,11 +45,13 @@ int QuantileCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
|
|||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||
int ret = 0;
|
||||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
|
||||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != KRET_OK) {
|
||||
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
|
||||
return ret;
|
||||
}
|
||||
return 0;
|
||||
input_shape_ = inputs.at(0)->GetShapeVector();
|
||||
q_shape_ = inputs.at(1)->GetShapeVector();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
uint32_t QuantileCpuKernelMod::MaybeWrapDim(int dim, int dim_post_expr) {
|
||||
|
|
|
@ -144,6 +144,7 @@ class MIND_API AGTridiagonalMatMulInfer : public abstract::OpInferBase {
|
|||
}
|
||||
};
|
||||
|
||||
MIND_API_OPERATOR_IMPL(TridiagonalMatMul, BaseOperator);
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(TridiagonalMatMul, prim::kPrimTridiagonalMatMul, AGTridiagonalMatMulInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameTridiagonalMatMul = "TridiagonalMatMul";
|
||||
class MS_CORE_API TridiagonalMatMul : public BaseOperator {
|
||||
class MIND_API TridiagonalMatMul : public BaseOperator {
|
||||
public:
|
||||
MIND_API_BASE_MEMBER(TridiagonalMatMul);
|
||||
TridiagonalMatMul() : BaseOperator(kNameTridiagonalMatMul) {
|
||||
|
|
|
@ -601,22 +601,6 @@ def get_bprop_matrix_exp(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.MatrixInverse)
|
||||
def get_bprop_matrix_inverse(self):
|
||||
"""Generate bprop for MatrixInverse"""
|
||||
matmul_x1 = nn.MatMul(transpose_x1=True)
|
||||
matmul_x2 = nn.MatMul(transpose_x2=True)
|
||||
neg = P.Neg()
|
||||
|
||||
def bprop(x, out, dout):
|
||||
dx = matmul_x2(dout, out)
|
||||
dx = matmul_x1(out, dx)
|
||||
dx = neg(dx)
|
||||
return (dx,)
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(MatrixPower)
|
||||
def get_bprop_matrix_power(self):
|
||||
"""Generate bprop for MatrixPower"""
|
||||
|
@ -786,7 +770,7 @@ def get_bprop_matrix_solve_ls(self):
|
|||
grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
|
||||
grad_b = matmul(matrix, z)
|
||||
|
||||
return (grad_a, grad_b, None)
|
||||
return (grad_a, grad_b, zeros_like(l2))
|
||||
|
||||
def under_determined(matrix, rhs, l2, dout):
|
||||
if matrix.dtype == mstype.complex64:
|
||||
|
@ -833,7 +817,7 @@ def get_bprop_matrix_solve_ls(self):
|
|||
a2 = matmul(tmp, a2_temp)
|
||||
|
||||
grad_a = add(a1, a2)
|
||||
return (grad_a, grad_b, None)
|
||||
return (grad_a, grad_b, zeros_like(l2))
|
||||
|
||||
if fast is False:
|
||||
raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")
|
||||
|
|
Loading…
Reference in New Issue