fixed reduce infer error info

This commit is contained in:
huoxinyou 2022-09-28 14:20:03 +08:00
parent 45ccd77e2d
commit fd55f299b0
12 changed files with 62 additions and 55 deletions

View File

@ -160,6 +160,7 @@ int EmbeddingLookUpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
return ret;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::EmbeddingLookup>(base_operator);
if ((inputs.size() != kEmbeddingLookupInputsNum && inputs.size() != kEmbeddingLookupDynamicShapeInputsNum) ||
outputs.size() != 1) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size must be " << kEmbeddingLookupInputsNum
@ -182,16 +183,7 @@ int EmbeddingLookUpCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
input_indices_lens_ = SizeOf(input_indices_shape);
input_indices_dtype_ = inputs[kIndex1]->GetDtype();
if (inputs.size() == kEmbeddingLookupInputsNum) {
PrimitivePtr prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
auto value_ptr = prim->GetAttr(kAttrOffset);
if (value_ptr->isa<tensor::Tensor>()) {
auto off_vec = CheckAndConvertUtils::CheckTensorIntValue("offset", value_ptr, kernel_name_);
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', offset must be int, bug got " << off_vec;
offset_ = off_vec[0];
} else {
offset_ = GetValue<int64_t>(value_ptr);
}
offset_ = kernel_ptr->get_offset();
}
return KRET_OK;
}

View File

@ -207,15 +207,9 @@ int ReduceCpuKernelFunc<T>::Resize(const BaseOperatorPtr &base_operator, const s
const std::vector<KernelTensorPtr> &,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
input_shape_ = inputs[0]->GetDeviceShapeAdaptively();
PrimitivePtr prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
if (prim->HasAttr(kAttrAxis)) {
auto value_ptr = prim->GetAttr(kAttrAxis);
if (value_ptr->isa<tensor::Tensor>()) {
axis_ = CheckAndConvertUtils::CheckTensorIntValue("axis", value_ptr, kernel_name_);
} else {
axis_ = CheckAndConvertUtils::CheckIntOrTupleInt("axis", value_ptr, kernel_name_);
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::Reduce>(base_operator);
if (kernel_ptr->HasAttr(kAttrAxis)) {
axis_ = kernel_ptr->get_axis();
}
(void)GetDynamicAttrIntValue(inputs, kAxisIndex_, inputsOnHost, kernel_name_, &axis_);
HandleInputAxis();

View File

@ -206,18 +206,7 @@ bool EmbeddingLookupGpuKernelMod::Init(const BaseOperatorPtr &base_operator, con
MS_LOG(INFO) << " EmbeddingLookup running in Dynamic Mode.";
} else if (inputs.size() == kEmbeddingLookupInputsNum) {
MS_LOG(INFO) << " EmbeddingLookup running in Normal Mode.";
PrimitivePtr prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
if (prim->HasAttr("offset")) {
auto value_ptr = prim->GetAttr("offset");
if (value_ptr->isa<tensor::Tensor>()) {
auto off_vec = CheckAndConvertUtils::CheckTensorIntValue("offset", value_ptr, kernel_name_);
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', offset must be int, bug got " << off_vec;
attr_ptr_->offset = off_vec[0];
} else {
attr_ptr_->offset = GetValue<int64_t>(value_ptr);
}
}
attr_ptr_->offset = kernel_ptr->get_offset();
} else {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 2 or 3, but got " << inputs.size();
}

View File

@ -17,6 +17,7 @@
#include "plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "utils/check_convert_utils.h"
#include "ops/transpose.h"
namespace mindspore {
namespace kernel {
@ -166,17 +167,9 @@ bool TransposeGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std
return true;
}
auto attr = base_operator->GetPrim()->GetAttr(kAttrPerm);
if (attr == nullptr) {
MS_LOG(ERROR) << "The attr \"perm\" is not found in kernel 'Transpose'.";
return false;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::Transpose>(base_operator);
std::vector<int64_t> perm;
if (attr->isa<tensor::Tensor>()) {
perm = CheckAndConvertUtils::CheckTensorIntValue("perm", attr, kernel_name_);
} else {
perm = CheckAndConvertUtils::CheckIntOrTupleInt("perm", attr, kernel_name_);
}
perm = kernel_ptr->get_perm();
GetPermValue(perm);
return true;
}

View File

@ -22,6 +22,7 @@
#include "mindapi/src/helper.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace ops {
@ -39,9 +40,21 @@ bool EmbeddingLookup::get_setattr_flag() const {
void EmbeddingLookup::set_offset(const int64_t offset) { (void)this->AddAttr(kAxis, api::MakeValue(offset)); }
int64_t EmbeddingLookup::get_offset() const {
auto value_ptr = this->GetAttr("offset");
return GetValue<int64_t>(value_ptr);
int64_t EmbeddingLookup::get_offset() {
PrimitivePtr prim = this->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
int64_t offset = 0;
if (prim->HasAttr(kAttrOffset)) {
auto value_ptr = prim->GetAttr(kAttrOffset);
if (value_ptr->isa<tensor::Tensor>()) {
auto off_vec = CheckAndConvertUtils::CheckTensorIntValue(kAttrOffset, value_ptr, prim->name());
MS_LOG(EXCEPTION) << "For '" << prim->name() << "', offset must be int, bug got " << off_vec;
offset = off_vec[0];
} else {
offset = GetValue<int64_t>(value_ptr);
}
}
return offset;
}
class EmbeddingLookupInfer : public abstract::OpInferBase {

View File

@ -49,7 +49,7 @@ class MIND_API EmbeddingLookup : public BaseOperator {
bool get_setattr_flag() const;
///
/// \return offset.
int64_t get_offset() const;
int64_t get_offset();
};
} // namespace ops
} // namespace mindspore

View File

@ -23,6 +23,7 @@
#include "abstract/ops/primitive_infer_map.h"
#include "utils/log_adapter.h"
#include "mindapi/src/helper.h"
#include "include/common/utils/utils.h"
#include "ops/primitive_c.h"
#include "ops/op_utils.h"

View File

@ -182,7 +182,7 @@ bool CheckAndGetAxisValue(const std::vector<abstract::AbstractBasePtr> &input_ar
}
}
} else {
MS_LOG(EXCEPTION) << "For '" << op_name
MS_EXCEPTION(ValueError) << "For '" << op_name
<< "', the second input type should be tensor or scalar, but got invalid abstract type:"
<< input_args[kInputIndex1]->type_name() << ".";
}
@ -199,7 +199,7 @@ abstract::ShapePtr ReduceBaseInferShape(const PrimitivePtr &primitive,
auto keep_dimis_value_ptr = primitive->GetAttr(kKeepDims);
MS_EXCEPTION_IF_NULL(keep_dimis_value_ptr);
if (!keep_dimis_value_ptr->isa<BoolImm>()) {
MS_LOG(EXCEPTION) << "For '" << primitive->name() << "', 'keep_dims' must be Bool.";
MS_EXCEPTION(ValueError) << "For '" << primitive->name() << "', 'keep_dims' must be Bool.";
}
bool keep_dims = GetValue<bool>(keep_dimis_value_ptr);
std::vector<int64_t> axis_value;

View File

@ -20,6 +20,7 @@
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace ops {
@ -31,12 +32,19 @@ void Reduce::Init(const bool keep_dims) { this->set_keep_dims(keep_dims); }
void Reduce::set_axis(const std::vector<int64_t> &axis) { (void)this->AddAttr(kAxis, api::MakeValue(axis)); }
std::vector<int64_t> Reduce::get_axis() const {
auto value_ptr = this->GetAttr(kAxis);
if (value_ptr->isa<api::Int64Imm>()) {
return {GetValue<int64_t>(value_ptr)};
std::vector<int64_t> Reduce::get_axis() {
PrimitivePtr prim = this->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
std::vector<int64_t> axis = {};
if (prim->HasAttr(kAttrAxis)) {
auto value_ptr = prim->GetAttr(kAttrAxis);
if (value_ptr->isa<tensor::Tensor>()) {
axis = CheckAndConvertUtils::CheckTensorIntValue(kAttrAxis, value_ptr, prim->name());
} else {
axis = CheckAndConvertUtils::CheckIntOrTupleInt(kAttrAxis, value_ptr, prim->name());
}
return GetValue<std::vector<int64_t>>(value_ptr);
}
return axis;
}
MIND_API_OPERATOR_IMPL(Reduce, BaseOperator);

View File

@ -53,7 +53,7 @@ class MIND_API Reduce : public BaseOperator {
void set_axis(const std::vector<int64_t> &axis);
std::vector<int64_t> get_axis() const;
std::vector<int64_t> get_axis();
};
abstract::AbstractBasePtr ReduceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);

View File

@ -25,11 +25,27 @@
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "include/common/utils/utils.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
std::vector<int64_t> Transpose::get_perm() {
PrimitivePtr prim = this->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
std::vector<int64_t> perm = {};
if (prim->HasAttr(kAttrPerm)) {
auto value_ptr = prim->GetAttr(kAttrPerm);
if (value_ptr->isa<tensor::Tensor>()) {
perm = CheckAndConvertUtils::CheckTensorIntValue(kAttrPerm, value_ptr, prim->name());
} else {
perm = CheckAndConvertUtils::CheckIntOrTupleInt(kAttrPerm, value_ptr, prim->name());
}
}
return perm;
}
bool CheckAndGetPermValue(const std::vector<AbstractBasePtr> &input_args, ShapeVector *perm_value,
const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(perm_value);

View File

@ -34,6 +34,7 @@ class MIND_API Transpose : public BaseOperator {
Transpose() : BaseOperator(kNameTranspose) { InitIOName({"x", "perm"}, {"output"}); }
/// \brief Init.
void Init() const {}
std::vector<int64_t> get_perm();
};
} // namespace ops
} // namespace mindspore