forked from mindspore-Ecosystem/mindspore
Fix bugs of quantile, matrixsolvel and instancenormv2
This commit is contained in:
parent
9054d923ea
commit
3a6597c7ff
|
@ -60,6 +60,7 @@ int TridiagonalMatMulCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||||
auto input1_shape = inputs[kIndex1]->GetShapeVector();
|
auto input1_shape = inputs[kIndex1]->GetShapeVector();
|
||||||
auto input2_shape = inputs[kIndex2]->GetShapeVector();
|
auto input2_shape = inputs[kIndex2]->GetShapeVector();
|
||||||
auto input3_shape = inputs[kIndex3]->GetShapeVector();
|
auto input3_shape = inputs[kIndex3]->GetShapeVector();
|
||||||
|
rhs_shape_ = input3_shape;
|
||||||
if ((input0_shape.size() < is_matrix) || (input1_shape.size() < is_matrix) || (input2_shape.size() < is_matrix) ||
|
if ((input0_shape.size() < is_matrix) || (input1_shape.size() < is_matrix) || (input2_shape.size() < is_matrix) ||
|
||||||
(input3_shape.size() < is_matrix)) {
|
(input3_shape.size() < is_matrix)) {
|
||||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||||
|
@ -140,10 +141,6 @@ bool TridiagonalMatMulCpuKernelMod::Launch(const std::vector<kernel::AddressPtr>
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs,
|
void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs,
|
||||||
const std::vector<AddressPtr> &outputs) {
|
const std::vector<AddressPtr> &outputs) {
|
||||||
auto node_ = node_wpt_.lock();
|
|
||||||
if (!node_) {
|
|
||||||
MS_LOG(EXCEPTION) << "node_wpt_ is expired.";
|
|
||||||
}
|
|
||||||
T *superdiag_ptr = reinterpret_cast<T *>(inputs[0]->addr);
|
T *superdiag_ptr = reinterpret_cast<T *>(inputs[0]->addr);
|
||||||
MS_EXCEPTION_IF_NULL(superdiag_ptr);
|
MS_EXCEPTION_IF_NULL(superdiag_ptr);
|
||||||
T *maindiag_ptr = reinterpret_cast<T *>(inputs[1]->addr);
|
T *maindiag_ptr = reinterpret_cast<T *>(inputs[1]->addr);
|
||||||
|
@ -154,15 +151,14 @@ void TridiagonalMatMulCpuKernelMod::LaunchTridiagonalMatMul(const std::vector<Ad
|
||||||
MS_EXCEPTION_IF_NULL(rhs_ptr);
|
MS_EXCEPTION_IF_NULL(rhs_ptr);
|
||||||
T *y_ptr = reinterpret_cast<T *>(outputs[0]->addr);
|
T *y_ptr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||||
MS_EXCEPTION_IF_NULL(y_ptr);
|
MS_EXCEPTION_IF_NULL(y_ptr);
|
||||||
auto rhs_shape = AnfAlgo::GetInputDeviceShape(node_, kInputShapeIndex3);
|
size_t m = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - row]);
|
||||||
size_t m = static_cast<size_t>(rhs_shape[rhs_shape.size() - row]);
|
size_t n = static_cast<size_t>(rhs_shape_[rhs_shape_.size() - col]);
|
||||||
size_t n = static_cast<size_t>(rhs_shape[rhs_shape.size() - col]);
|
|
||||||
size_t size_mn = m * n;
|
size_t size_mn = m * n;
|
||||||
using VectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
using VectorMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, 1>>;
|
||||||
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
using MatrixMap = Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
|
||||||
size_t rhs_num = 1;
|
size_t rhs_num = 1;
|
||||||
for (size_t i = 0; i < rhs_shape.size(); i++) {
|
for (size_t i = 0; i < rhs_shape_.size(); i++) {
|
||||||
rhs_num *= static_cast<size_t>(rhs_shape[i]);
|
rhs_num *= static_cast<size_t>(rhs_shape_[i]);
|
||||||
}
|
}
|
||||||
size_t rhs_matrix_num = rhs_num / size_mn;
|
size_t rhs_matrix_num = rhs_num / size_mn;
|
||||||
for (size_t i = 0; i < rhs_matrix_num; i++) {
|
for (size_t i = 0; i < rhs_matrix_num; i++) {
|
||||||
|
|
|
@ -39,7 +39,7 @@ class TridiagonalMatMulCpuKernelMod : public NativeCpuKernelMod {
|
||||||
std::vector<KernelAttr> GetOpSupport() override;
|
std::vector<KernelAttr> GetOpSupport() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
CNodeWeakPtr node_wpt_;
|
ShapeVector rhs_shape_;
|
||||||
TypeId dtype_{kTypeUnknown};
|
TypeId dtype_{kTypeUnknown};
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
void LaunchTridiagonalMatMul(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
|
||||||
|
|
|
@ -135,7 +135,7 @@ void InstanceNormV2CpuKernelMod::CollectLinearAndConstant(const typename TTypes<
|
||||||
} else {
|
} else {
|
||||||
mean = running_mean(idx);
|
mean = running_mean(idx);
|
||||||
float _std_ = std::sqrt(running_var(idx) + static_cast<float>(epsilon_));
|
float _std_ = std::sqrt(running_var(idx) + static_cast<float>(epsilon_));
|
||||||
MS_EXCEPTION_IF_ZERO("_std_", static_cast<int64_t>(_std_));
|
MS_EXCEPTION_IF_ZERO("_std_", _std_);
|
||||||
invstd = float_init_one / _std_;
|
invstd = float_init_one / _std_;
|
||||||
}
|
}
|
||||||
_alpha_[idx] = invstd * gamma(idx);
|
_alpha_[idx] = invstd * gamma(idx);
|
||||||
|
|
|
@ -108,7 +108,7 @@ bool InstanceNormV2GradCpuKernelMod::LaunchKernel(const std::vector<kernel::Addr
|
||||||
float mean = float_init_zero, invstd = float_init_zero;
|
float mean = float_init_zero, invstd = float_init_zero;
|
||||||
mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx);
|
mean = is_training_ ? save_mean_matrix(idx, c_idx) : running_mean_matrix(idx, c_idx);
|
||||||
float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_);
|
float _invstd_ = std::sqrt(running_var_matrix(idx, c_idx) + epsilon_);
|
||||||
MS_EXCEPTION_IF_ZERO("_invstd_", static_cast<int64_t>(_invstd_));
|
MS_EXCEPTION_IF_ZERO("_invstd_", _invstd_);
|
||||||
invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_;
|
invstd = is_training_ ? save_invstd_matrix(idx, c_idx) : float_init_one / _invstd_;
|
||||||
|
|
||||||
double sum = double_init_zero, dotp = double_init_zero;
|
double sum = double_init_zero, dotp = double_init_zero;
|
||||||
|
|
|
@ -34,8 +34,6 @@ bool QuantileCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std:
|
||||||
MS_LOG(ERROR) << "cast Quantile ops failed!";
|
MS_LOG(ERROR) << "cast Quantile ops failed!";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
input_shape_ = inputs.at(0)->GetShapeVector();
|
|
||||||
q_shape_ = inputs.at(1)->GetShapeVector();
|
|
||||||
kernel_name_ = kernel_ptr->name();
|
kernel_name_ = kernel_ptr->name();
|
||||||
dim_ = GetValue<int64_t>(base_operator->GetAttr("dim"));
|
dim_ = GetValue<int64_t>(base_operator->GetAttr("dim"));
|
||||||
keep_dims_ = GetValue<bool>(base_operator->GetAttr("keep_dims"));
|
keep_dims_ = GetValue<bool>(base_operator->GetAttr("keep_dims"));
|
||||||
|
@ -47,11 +45,13 @@ int QuantileCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
|
||||||
const std::vector<KernelTensorPtr> &outputs,
|
const std::vector<KernelTensorPtr> &outputs,
|
||||||
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||||
int ret = 0;
|
int ret = 0;
|
||||||
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != 0) {
|
if ((ret = NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) != KRET_OK) {
|
||||||
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
|
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
return 0;
|
input_shape_ = inputs.at(0)->GetShapeVector();
|
||||||
|
q_shape_ = inputs.at(1)->GetShapeVector();
|
||||||
|
return KRET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t QuantileCpuKernelMod::MaybeWrapDim(int dim, int dim_post_expr) {
|
uint32_t QuantileCpuKernelMod::MaybeWrapDim(int dim, int dim_post_expr) {
|
||||||
|
|
|
@ -144,6 +144,7 @@ class MIND_API AGTridiagonalMatMulInfer : public abstract::OpInferBase {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
MIND_API_OPERATOR_IMPL(TridiagonalMatMul, BaseOperator);
|
||||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(TridiagonalMatMul, prim::kPrimTridiagonalMatMul, AGTridiagonalMatMulInfer, false);
|
REGISTER_PRIMITIVE_OP_INFER_IMPL(TridiagonalMatMul, prim::kPrimTridiagonalMatMul, AGTridiagonalMatMulInfer, false);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,7 +30,7 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameTridiagonalMatMul = "TridiagonalMatMul";
|
constexpr auto kNameTridiagonalMatMul = "TridiagonalMatMul";
|
||||||
class MS_CORE_API TridiagonalMatMul : public BaseOperator {
|
class MIND_API TridiagonalMatMul : public BaseOperator {
|
||||||
public:
|
public:
|
||||||
MIND_API_BASE_MEMBER(TridiagonalMatMul);
|
MIND_API_BASE_MEMBER(TridiagonalMatMul);
|
||||||
TridiagonalMatMul() : BaseOperator(kNameTridiagonalMatMul) {
|
TridiagonalMatMul() : BaseOperator(kNameTridiagonalMatMul) {
|
||||||
|
|
|
@ -601,22 +601,6 @@ def get_bprop_matrix_exp(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.MatrixInverse)
|
|
||||||
def get_bprop_matrix_inverse(self):
|
|
||||||
"""Generate bprop for MatrixInverse"""
|
|
||||||
matmul_x1 = nn.MatMul(transpose_x1=True)
|
|
||||||
matmul_x2 = nn.MatMul(transpose_x2=True)
|
|
||||||
neg = P.Neg()
|
|
||||||
|
|
||||||
def bprop(x, out, dout):
|
|
||||||
dx = matmul_x2(dout, out)
|
|
||||||
dx = matmul_x1(out, dx)
|
|
||||||
dx = neg(dx)
|
|
||||||
return (dx,)
|
|
||||||
|
|
||||||
return bprop
|
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(MatrixPower)
|
@bprop_getters.register(MatrixPower)
|
||||||
def get_bprop_matrix_power(self):
|
def get_bprop_matrix_power(self):
|
||||||
"""Generate bprop for MatrixPower"""
|
"""Generate bprop for MatrixPower"""
|
||||||
|
@ -786,7 +770,7 @@ def get_bprop_matrix_solve_ls(self):
|
||||||
grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
|
grad_a = add(neg(matmul(matrix, zx_sym)), matmul(rhs, z_temp))
|
||||||
grad_b = matmul(matrix, z)
|
grad_b = matmul(matrix, z)
|
||||||
|
|
||||||
return (grad_a, grad_b, None)
|
return (grad_a, grad_b, zeros_like(l2))
|
||||||
|
|
||||||
def under_determined(matrix, rhs, l2, dout):
|
def under_determined(matrix, rhs, l2, dout):
|
||||||
if matrix.dtype == mstype.complex64:
|
if matrix.dtype == mstype.complex64:
|
||||||
|
@ -833,7 +817,7 @@ def get_bprop_matrix_solve_ls(self):
|
||||||
a2 = matmul(tmp, a2_temp)
|
a2 = matmul(tmp, a2_temp)
|
||||||
|
|
||||||
grad_a = add(a1, a2)
|
grad_a = add(a1, a2)
|
||||||
return (grad_a, grad_b, None)
|
return (grad_a, grad_b, zeros_like(l2))
|
||||||
|
|
||||||
if fast is False:
|
if fast is False:
|
||||||
raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")
|
raise ValueError("For MatrixSolveLs, gradient not defined for fast=False")
|
||||||
|
|
Loading…
Reference in New Issue