forked from mindspore-Ecosystem/mindspore
add a process of make type and shape to be abstract
This commit is contained in:
parent
109770b743
commit
68f3b5784f
|
@ -505,12 +505,33 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi
|
|||
return py_args;
|
||||
}
|
||||
|
||||
void CheckCustomPrimOutputInferResult(const PrimitivePtr &prim, const AbstractBasePtr &res_spec) {
|
||||
const string kOutputNum = "output_num";
|
||||
if (prim->IsCustomPrim()) {
|
||||
// Raise error if output_num is not match the infer result.
|
||||
auto output_num_value = prim->GetAttr(kOutputNum);
|
||||
if (output_num_value == nullptr) {
|
||||
MS_LOG(DEBUG) << "The output num may no need to check";
|
||||
return;
|
||||
}
|
||||
int64_t output_num = GetValue<int64_t>(output_num_value);
|
||||
if (res_spec->isa<AbstractTensor>() && output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Custom primitive " << prim->ToString() << " output_num " << output_num
|
||||
<< " not matches the infer result.";
|
||||
} else if (res_spec->isa<AbstractTuple>() &&
|
||||
(res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
|
||||
MS_LOG(EXCEPTION) << "Custom primitive " << prim->ToString() << " output_num " << output_num
|
||||
<< " not matches the infer result.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dict &output) {
|
||||
// Convert to AbstractValue based on type and shape
|
||||
auto out_dtype = output[ATTR_DTYPE];
|
||||
if (output[ATTR_VALUE].is_none()) {
|
||||
auto out_shape = output[ATTR_SHAPE];
|
||||
return PyListDtype2AbstractTensor(out_shape, out_dtype, output);
|
||||
return MakePyInferRes2Abstract(out_shape, out_dtype, output);
|
||||
}
|
||||
// Convert pyobject to Value, then to AbstractValue
|
||||
ValuePtr converted_ret = nullptr;
|
||||
|
@ -527,18 +548,7 @@ AbstractBasePtr PyInferRes2Abstract(const PrimitivePyPtr &prim_py, const py::dic
|
|||
res_tensor->set_value(converted_ret);
|
||||
SetValueRange(res_tensor, output);
|
||||
}
|
||||
if (prim_py->IsCustomPrim()) {
|
||||
// Raise error if output_num is not match the infer result.
|
||||
int64_t output_num = GetValue<int64_t>(prim_py->GetAttr("output_num"));
|
||||
if (res_spec->isa<AbstractTensor>() && output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
|
||||
<< " not matches the infer result.";
|
||||
} else if (res_spec->isa<AbstractTuple>() &&
|
||||
(res_spec->cast<AbstractTuplePtr>()->size() != LongToSize(output_num))) {
|
||||
MS_LOG(EXCEPTION) << "Custom primitive " << prim_py->ToString() << " output_num " << output_num
|
||||
<< " not matches the infer result.";
|
||||
}
|
||||
}
|
||||
CheckCustomPrimOutputInferResult(prim_py, res_spec);
|
||||
return res_spec;
|
||||
}
|
||||
} // end anonymous namespace
|
||||
|
|
|
@ -25,6 +25,7 @@
|
|||
#include <cfloat>
|
||||
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "pipeline/jit/parse/parse.h"
|
||||
#include "pipeline/jit/parse/parse_base.h"
|
||||
#include "ir/value.h"
|
||||
|
@ -357,9 +358,8 @@ void SetValueRange(const AbstractBasePtr &tensor, const py::object &output) {
|
|||
}
|
||||
}
|
||||
|
||||
AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py::object &type_obj,
|
||||
AbstractBasePtr MakePyInferRes2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
|
||||
const py::object &output) {
|
||||
AbstractBasePtr tensor = nullptr;
|
||||
auto ret_vec = shape_obj.cast<ShapeVector>();
|
||||
auto ret_dtype = type_obj.cast<TypePtr>();
|
||||
ShapeVector min_shape_vec;
|
||||
|
@ -379,15 +379,7 @@ AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py:
|
|||
}
|
||||
|
||||
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
|
||||
if (ret_dtype->isa<TensorType>()) {
|
||||
auto tensor_type = type_obj.cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
|
||||
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
|
||||
} else {
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
|
||||
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
|
||||
}
|
||||
AbstractBasePtr tensor = MakeAbstractTensor(ret_shape, ret_dtype);
|
||||
|
||||
SetValueRange(tensor, output);
|
||||
return tensor;
|
||||
|
@ -404,17 +396,15 @@ static bool IsMonadType(const py::object &type_obj) {
|
|||
static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) {
|
||||
if (py::isinstance<Type>(type_obj)) {
|
||||
auto type = type_obj.cast<Type *>();
|
||||
if (type->isa<UMonadType>()) {
|
||||
return kUMonad->ToAbstract();
|
||||
}
|
||||
if (type->isa<IOMonadType>()) {
|
||||
return kIOMonad->ToAbstract();
|
||||
}
|
||||
}
|
||||
if (!type->isa<MonadType>()) {
|
||||
MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj);
|
||||
}
|
||||
return abstract::MakeMonadAbstract(type->cast<MonadTypePtr>());
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "Not a type object: " << py::str(type_obj);
|
||||
}
|
||||
|
||||
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
|
||||
AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
|
||||
const py::object &output) {
|
||||
if ((py::isinstance<py::list>(shape_obj) || py::isinstance<py::tuple>(shape_obj)) && py::isinstance<Type>(type_obj)) {
|
||||
auto ret_vec = shape_obj.cast<ShapeVector>();
|
||||
|
@ -425,13 +415,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
|
|||
abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, ret_dtype);
|
||||
return abs_scalar;
|
||||
}
|
||||
return PyList2DynamicShapeTensor(shape_obj, type_obj, output);
|
||||
return MakePyInferRes2AbstractTensor(shape_obj, type_obj, output);
|
||||
} else if (py::isinstance<py::tuple>(shape_obj) && py::isinstance<py::tuple>(type_obj)) {
|
||||
auto shape_tuple = shape_obj.cast<py::tuple>();
|
||||
auto typeid_tuple = type_obj.cast<py::tuple>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_tuple.size(); ++it) {
|
||||
auto tensor_it = PyListDtype2AbstractTensor(shape_tuple[it], typeid_tuple[it]);
|
||||
auto tensor_it = MakePyInferRes2Abstract(shape_tuple[it], typeid_tuple[it]);
|
||||
ptr_list.push_back(tensor_it);
|
||||
}
|
||||
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
|
||||
|
@ -441,7 +431,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
|
|||
auto typeid_list = type_obj.cast<py::list>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_list.size(); ++it) {
|
||||
auto tensor_it = PyListDtype2AbstractTensor(shape_list[it], typeid_list[it]);
|
||||
auto tensor_it = MakePyInferRes2Abstract(shape_list[it], typeid_list[it]);
|
||||
ptr_list.push_back(tensor_it);
|
||||
}
|
||||
auto list = std::make_shared<abstract::AbstractList>(ptr_list);
|
||||
|
|
|
@ -36,7 +36,7 @@ py::object ValuePtrToPyData(const ValuePtr &value);
|
|||
bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple &args,
|
||||
const std::shared_ptr<py::object> &ret_val);
|
||||
|
||||
AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj,
|
||||
AbstractBasePtr MakePyInferRes2Abstract(const py::object &shape_obj, const py::object &type_obj,
|
||||
const py::object &output = py::none());
|
||||
void SetValueRange(const AbstractBasePtr &tensor, const py::object &output);
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
#include "utils/ms_context.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "abstract/param_validator.h"
|
||||
|
||||
|
@ -348,5 +349,96 @@ int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list,
|
|||
}
|
||||
return num_segments_value;
|
||||
}
|
||||
|
||||
AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type) {
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
AbstractBasePtr tensor = nullptr;
|
||||
auto ret_vec = shape->shape();
|
||||
ShapeVector min_shape_vec;
|
||||
ShapeVector max_shape_vec;
|
||||
|
||||
if (shape->IsDynamic()) {
|
||||
if (!shape->min_shape().empty()) {
|
||||
min_shape_vec = shape->min_shape();
|
||||
}
|
||||
if (!shape->max_shape().empty()) {
|
||||
max_shape_vec = shape->max_shape();
|
||||
}
|
||||
}
|
||||
|
||||
auto ret_shape = std::make_shared<abstract::Shape>(ret_vec, min_shape_vec, max_shape_vec);
|
||||
if (type->isa<TensorType>()) {
|
||||
auto tensor_type = type->cast<TensorTypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(tensor_type);
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, tensor_type->element());
|
||||
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
|
||||
} else {
|
||||
auto element = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
|
||||
tensor = std::make_shared<abstract::AbstractTensor>(element, ret_shape);
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type) {
|
||||
if (type->isa<UMonadType>()) {
|
||||
return kUMonad->ToAbstract();
|
||||
} else if (type->isa<IOMonadType>()) {
|
||||
return kIOMonad->ToAbstract();
|
||||
}
|
||||
MS_EXCEPTION(UnknownError) << "Unsupported to convert type " << type->ToString() << " to monad abstract";
|
||||
}
|
||||
|
||||
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type) {
|
||||
MS_EXCEPTION_IF_NULL(base_shape);
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
if ((base_shape->isa<Shape>())) {
|
||||
auto shape = base_shape->cast<ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape);
|
||||
auto shape_vec = shape->shape();
|
||||
// if the size of shape list is empty, return an scalar abstract
|
||||
if (shape_vec.empty() && (!type->isa<TensorType>())) {
|
||||
abstract::AbstractScalarPtr abs_scalar = std::make_shared<abstract::AbstractScalar>(kAnyValue, type);
|
||||
return abs_scalar;
|
||||
}
|
||||
return MakeAbstractTensor(shape, type);
|
||||
} else if (base_shape->isa<TupleShape>() && type->isa<Tuple>()) {
|
||||
auto shape_tuple = base_shape->cast<TupleShapePtr>();
|
||||
auto type_tuple = type->cast<TuplePtr>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_tuple->size(); ++it) {
|
||||
auto tensor_it = MakeAbstract((*shape_tuple)[it], (*type_tuple)[it]);
|
||||
ptr_list.push_back(tensor_it);
|
||||
}
|
||||
auto tuple = std::make_shared<abstract::AbstractTuple>(ptr_list);
|
||||
return tuple;
|
||||
} else if (base_shape->isa<ListShape>() && type->isa<List>()) {
|
||||
auto shape_list = base_shape->cast<ListShapePtr>();
|
||||
auto type_list = type->cast<ListPtr>();
|
||||
AbstractBasePtrList ptr_list;
|
||||
for (size_t it = 0; it < shape_list->size(); ++it) {
|
||||
auto tensor_it = MakeAbstract((*shape_list)[it], (*type_list)[it]);
|
||||
ptr_list.push_back(tensor_it);
|
||||
}
|
||||
auto list = std::make_shared<abstract::AbstractList>(ptr_list);
|
||||
return list;
|
||||
} else if (base_shape->isa<NoShape>() && type->isa<TypeNone>()) {
|
||||
// AbstractNone indicates there is no output for this CNode node.
|
||||
auto abstract_none = std::make_shared<abstract::AbstractNone>();
|
||||
return abstract_none;
|
||||
} else if (type->isa<Monad>()) {
|
||||
// Return monad abstract if it is monad type.
|
||||
return MakeMonadAbstract(type->cast<MonadTypePtr>());
|
||||
} else {
|
||||
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool enable_sparse = context->get_param<bool>(MS_CTX_ENABLE_SPARSE);
|
||||
if (enable_sparse) {
|
||||
return std::make_shared<abstract::AbstractUndetermined>();
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "evaluator return invalid shape " << base_shape->ToString() << "or type. " << type->ToString();
|
||||
}
|
||||
}
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -62,7 +62,9 @@ void CheckMinMaxShape(const ShapeVector &shape, ShapeVector *min_shape, ShapeVec
|
|||
|
||||
// Get 3rd argument for UnsortedSegmentOps' inferImpl function
|
||||
int64_t GetUnsortedSegmentOpScalarArg(const AbstractBasePtrList &args_spec_list, const std::string &op_name);
|
||||
|
||||
AbstractBasePtr MakeAbstract(const BaseShapePtr &base_shape, const TypePtr &type);
|
||||
AbstractBasePtr MakeMonadAbstract(const MonadTypePtr &type);
|
||||
AbstractBasePtr MakeAbstractTensor(const ShapePtr &shape, const TypePtr &type);
|
||||
} // namespace abstract
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_ABSTRACT_UTILS_H_
|
||||
|
|
|
@ -50,8 +50,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr AddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameAdd, Add);
|
||||
} // namespace ops
|
||||
|
|
|
@ -72,7 +72,8 @@ TypePtr AddNInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
}
|
||||
std::set<TypePtr> valid_types = common_valid_types;
|
||||
valid_types.insert(kBool);
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
|
||||
return elements[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -83,11 +84,7 @@ AbstractBasePtr AddNInfer(const abstract::AnalysisEnginePtr &, const PrimitivePt
|
|||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto abs_type = AddNInferType(primitive, input_args);
|
||||
if (abs_type->type_id() == kObjectTypeUndeterminedType) {
|
||||
return std::make_shared<abstract::AbstractUndetermined>();
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(abs_type, AddNInferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(AddNInferShape(primitive, input_args), AddNInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AddN, prim::kPrimAddN, AddNInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -172,8 +172,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
|
||||
AbstractBasePtr AvgPool3DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool3D, prim::kPrimAvgPool3D, AvgPool3DInfer, nullptr, true);
|
||||
|
|
|
@ -145,9 +145,8 @@ AbstractBasePtr BatchMatmulInfer(const abstract::AnalysisEnginePtr &, const Prim
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInteger("BatchMatmul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(BatchMatmulInferType(primitive, input_args),
|
||||
BatchMatmulInferShape(primitive, input_args));
|
||||
return abs;
|
||||
return abstract::MakeAbstract(BatchMatmulInferShape(primitive, input_args),
|
||||
BatchMatmulInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BatchMatmul, prim::kPrimBatchMatMul, BatchMatmulInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -84,8 +84,8 @@ std::vector<int64_t> BroadcastTo::get_shape() const {
|
|||
}
|
||||
AbstractBasePtr BroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(BroadcastToInferType(primitive, input_args),
|
||||
BroadcastToInferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(BroadcastToInferShape(primitive, input_args),
|
||||
BroadcastToInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BroadcastTo, prim::kPrimBroadcastTo, BroadcastToInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -313,8 +313,7 @@ Format Conv2D::get_format() const {
|
|||
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("Conv2d infer", input_args.size(), kGreaterEqual, 2, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
|
||||
Conv2dInferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(Conv2dInferShape(primitive, input_args), Conv2dInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -56,8 +56,7 @@ void CheckCTCLossInputs(const std::vector<AbstractBasePtr> &input_args, const st
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckCTCLossInputs(input_args, op_name);
|
||||
auto input_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
|
@ -71,36 +70,33 @@ std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
|
|||
if (min_shape.empty() || max_shape.empty()) {
|
||||
loss_shape = std::make_shared<abstract::Shape>(batch);
|
||||
gradient_shape = std::make_shared<abstract::Shape>(shape);
|
||||
return std::vector<abstract::ShapePtr>{loss_shape, gradient_shape};
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{loss_shape, gradient_shape});
|
||||
}
|
||||
|
||||
ShapeVector batch_min = {min_shape[1]};
|
||||
ShapeVector batch_max = {max_shape[1]};
|
||||
loss_shape = std::make_shared<abstract::Shape>(batch, batch_min, batch_max);
|
||||
gradient_shape = std::make_shared<abstract::Shape>(shape, min_shape, max_shape);
|
||||
return std::vector<abstract::ShapePtr>{loss_shape, gradient_shape};
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{loss_shape, gradient_shape});
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
TuplePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("labels_indices", input_args[1]->BuildType(), {kInt64}, op_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("labels_values", input_args[2]->BuildType(), {kInt32}, op_name);
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("sequence_length", input_args[3]->BuildType(), {kInt32}, op_name);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name);
|
||||
auto type = CheckAndConvertUtils::CheckTensorTypeValid("inputs", input_args[0]->BuildType(), valid_types, op_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr CTCLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto types = InferType(primitive, input_args);
|
||||
auto shapes = InferShape(primitive, input_args);
|
||||
auto loss = std::make_shared<abstract::AbstractTensor>(type, shapes[0]);
|
||||
auto gradient = std::make_shared<abstract::AbstractTensor>(type, shapes[1]);
|
||||
AbstractBasePtrList outputs = {loss, gradient};
|
||||
return std::make_shared<abstract::AbstractTuple>(outputs);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(CTCLoss, prim::kPrimCTCLoss, CTCLossInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -116,8 +116,7 @@ AbstractBasePtr DropoutDoMaskInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
CheckAndConvertUtils::CheckInteger("infer shape", input_args.size(), kGreaterEqual, 3, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutDoMask, prim::kPrimDropoutDoMask, DropoutDoMaskInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -147,8 +147,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr DropoutGenMaskInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DropoutGenMask, prim::kPrimDropoutGenMask, DropoutGenMaskInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -44,8 +44,7 @@ AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto type = value->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto abstract = std::make_shared<abstract::AbstractType>(type);
|
||||
return abstract;
|
||||
return abstract::MakeAbstract(std::make_shared<abstract::NoShape>(), type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false);
|
||||
|
|
|
@ -84,10 +84,10 @@ AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
if (ind_dyn || param_dyn) {
|
||||
ShapeVector min_shape = calc_shape(indices_shp_min, param_shp_min);
|
||||
ShapeVector max_shape = calc_shape(indices_shp_max, param_shp_max);
|
||||
return std::make_shared<abstract::AbstractTensor>(
|
||||
params->element(), std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape));
|
||||
return abstract::MakeAbstract(std::make_shared<abstract::Shape>(out_shape, min_shape, max_shape),
|
||||
params->BuildType());
|
||||
}
|
||||
return std::make_shared<abstract::AbstractTensor>(params->element(), std::make_shared<abstract::Shape>(out_shape));
|
||||
return abstract::MakeAbstract(std::make_shared<abstract::Shape>(out_shape), params->BuildType());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Gather, prim::kPrimGather, GatherInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -68,9 +68,7 @@ AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const Primitiv
|
|||
std::set<TypePtr> valid_types = {kInt32, kInt64};
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_types, prim_name);
|
||||
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_types, prim_name);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(GatherDInferType(primitive, input_args),
|
||||
GatherDInferShape(primitive, input_args));
|
||||
return abs;
|
||||
return abstract::MakeAbstract(GatherDInferShape(primitive, input_args), GatherDInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -54,8 +54,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr GeLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(GeLU, prim::kPrimGeLU, GeLUInfer, nullptr, true);
|
||||
|
||||
|
|
|
@ -72,8 +72,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr BiasAddGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(BiasAddGrad, prim::kPrimBiasAddGrad, BiasAddGradInfer, nullptr, true);
|
||||
|
||||
|
|
|
@ -30,11 +30,12 @@ AbstractBasePtr LayerNormGradInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
auto x_backprob = input_args[0]->Broaden();
|
||||
auto gamma_backprob = input_args[4]->Broaden();
|
||||
auto beta_backprob = input_args[4]->Broaden();
|
||||
|
||||
AbstractBasePtrList args_list({x_backprob, gamma_backprob, beta_backprob});
|
||||
return std::make_shared<abstract::AbstractTuple>(args_list);
|
||||
auto shapes = std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{
|
||||
x_backprob->BuildShape(), gamma_backprob->BuildShape(), beta_backprob->BuildShape()});
|
||||
auto types = std::make_shared<Tuple>(
|
||||
std::vector<TypePtr>{x_backprob->BuildType(), gamma_backprob->BuildType(), beta_backprob->BuildType()});
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
|
||||
void LayerNormGrad::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis) {
|
||||
this->set_begin_norm_axis(begin_norm_axis);
|
||||
this->set_begin_params_axis(begin_params_axis);
|
||||
|
|
|
@ -62,8 +62,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr ReLUGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGrad, prim::kPrimReluGrad, ReLUGradInfer, nullptr, true);
|
||||
|
||||
|
|
|
@ -55,8 +55,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr ReLUGradV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUGradV2, prim::kPrimReluGradV2, ReLUGradV2Infer, nullptr, true);
|
||||
|
||||
|
|
|
@ -93,23 +93,22 @@ AbstractBasePtr LayerNormInfer(const abstract::AnalysisEnginePtr &, const Primit
|
|||
}
|
||||
}
|
||||
|
||||
auto mean = input_x->Broaden();
|
||||
auto var = input_x->Broaden();
|
||||
|
||||
std::vector<BaseShapePtr> shapes_list = {input_x->BuildShape()};
|
||||
std::vector<TypePtr> types_list = {input_x->BuildType(), input_x->BuildType(), input_x->BuildType()};
|
||||
auto mean_var_shape = CalLayerNormMeanAndVarShape(begin_norm_axis, input_shape->shape());
|
||||
auto input_min_shape = input_shape->min_shape();
|
||||
auto input_max_shape = input_shape->max_shape();
|
||||
if (input_min_shape.empty() || input_max_shape.empty()) {
|
||||
mean->set_shape(std::make_shared<abstract::Shape>(mean_var_shape));
|
||||
var->set_shape(std::make_shared<abstract::Shape>(mean_var_shape));
|
||||
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape));
|
||||
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape));
|
||||
} else {
|
||||
auto mean_var_shape_min = CalLayerNormMeanAndVarShape(begin_norm_axis, input_min_shape);
|
||||
auto mean_var_shape_max = CalLayerNormMeanAndVarShape(begin_norm_axis, input_min_shape);
|
||||
mean->set_shape(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
|
||||
var->set_shape(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
|
||||
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
|
||||
shapes_list.emplace_back(std::make_shared<abstract::Shape>(mean_var_shape, mean_var_shape_min, mean_var_shape_max));
|
||||
}
|
||||
AbstractBasePtrList args_list({input_x->Broaden(), mean, var});
|
||||
return std::make_shared<abstract::AbstractTuple>(args_list);
|
||||
return abstract::MakeAbstract(std::make_shared<abstract::TupleShape>(shapes_list),
|
||||
std::make_shared<Tuple>(types_list));
|
||||
}
|
||||
|
||||
void LayerNorm::Init(const int64_t begin_norm_axis, const int64_t begin_params_axis, const float epsilon) {
|
||||
|
|
|
@ -40,8 +40,8 @@ TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
auto prim_name = prim->name();
|
||||
// check
|
||||
std::set<TypePtr> valid_index_types = {kFloat16, kFloat32};
|
||||
auto x_type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", input_args[0]->BuildType(), valid_index_types, prim_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
|
@ -49,9 +49,7 @@ TypePtr Log1pInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr Log1pInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(Log1pInferType(primitive, input_args),
|
||||
Log1pInferShape(primitive, input_args));
|
||||
return abs;
|
||||
return abstract::MakeAbstract(Log1pInferShape(primitive, input_args), Log1pInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Log1p, prim::kPrimLog1p, Log1pInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -65,8 +65,8 @@ TypePtr LogSoftmaxInferType(const PrimitivePtr &prim, const std::vector<Abstract
|
|||
|
||||
AbstractBasePtr LogSoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(LogSoftmaxInferType(primitive, input_args),
|
||||
LogSoftmaxInferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(LogSoftmaxInferShape(primitive, input_args),
|
||||
LogSoftmaxInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(LogSoftmax, prim::kPrimLogSoftmax, LogSoftmaxInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -126,8 +126,7 @@ bool MatMul::get_transpose_b() const {
|
|||
AbstractBasePtr MatMulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInteger("MatMul infer", input_args.size(), kGreaterEqual, 2, primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(MatMulInferType(primitive, input_args),
|
||||
MatMulInferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(MatMulInferShape(primitive, input_args), MatMulInferType(primitive, input_args));
|
||||
}
|
||||
// Add
|
||||
REGISTER_PRIMITIVE_C(kNameMatMul, MatMul);
|
||||
|
|
|
@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr MulInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameMul, Mul);
|
||||
} // namespace ops
|
||||
|
|
|
@ -73,8 +73,7 @@ TypePtr OneHotInferType(const PrimitivePtr &prim, const std::vector<AbstractBase
|
|||
} // namespace
|
||||
AbstractBasePtr OneHotInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(OneHotInferType(primitive, input_args),
|
||||
OneHotInferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(OneHotInferShape(primitive, input_args), OneHotInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(OneHot, prim::kPrimOneHot, OneHotInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -45,8 +45,7 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr OnesLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(OnesLike, prim::kPrimOnesLike, OnesLikeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr RealDivInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameRealDiv, RealDiv);
|
||||
} // namespace ops
|
||||
|
|
|
@ -44,14 +44,15 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
CheckAndConvertUtils::CheckInteger("ReLU infer", input_args.size(), kEqual, 1, prim_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto x_type = input_args[0]->BuildType();
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, common_valid_types, prim_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, common_valid_types, prim_name);
|
||||
return x_type;
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr ReLUInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto type = InferType(primitive, input_args);
|
||||
auto shape = InferShape(primitive, input_args);
|
||||
return std::make_shared<abstract::AbstractTensor>(type, shape);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLU, prim::kPrimRelu, ReLUInfer, nullptr, true);
|
||||
|
||||
|
|
|
@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
} // namespace
|
||||
AbstractBasePtr ReLU6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLU6, prim::kPrimRelu6, ReLU6Infer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -51,8 +51,7 @@ std::vector<int64_t> GetOutputMaskShape(const std::vector<int64_t> &input_shape,
|
|||
}
|
||||
return mask_shape;
|
||||
}
|
||||
std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
abstract::TupleShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", input_args.size(), kEqual, 1, prim_name);
|
||||
|
@ -74,13 +73,13 @@ std::vector<abstract::ShapePtr> InferShape(const PrimitivePtr &primitive,
|
|||
if (min_shape.empty() || max_shape.empty()) {
|
||||
inputs_shape = std::make_shared<abstract::Shape>(input_shape);
|
||||
masks_shape = std::make_shared<abstract::Shape>(mask_shape);
|
||||
return std::vector<abstract::ShapePtr>{inputs_shape, masks_shape};
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{inputs_shape, masks_shape});
|
||||
}
|
||||
auto min_mask_shape = GetOutputMaskShape(min_shape, x_dtype);
|
||||
auto max_mask_shape = GetOutputMaskShape(max_shape, x_dtype);
|
||||
inputs_shape = std::make_shared<abstract::Shape>(input_shape, min_shape, max_shape);
|
||||
masks_shape = std::make_shared<abstract::Shape>(mask_shape, min_mask_shape, max_mask_shape);
|
||||
return std::vector<abstract::ShapePtr>{inputs_shape, masks_shape};
|
||||
return std::make_shared<abstract::TupleShape>(std::vector<abstract::BaseShapePtr>{inputs_shape, masks_shape});
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
@ -92,7 +91,8 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
MS_EXCEPTION_IF_NULL(x_type);
|
||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||
CheckAndConvertUtils::CheckSubClass("input_x", x_type, valid_x_type, prim_name);
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name);
|
||||
auto type = CheckAndConvertUtils::CheckTensorTypeValid("input_x", x_type, valid_x_type, prim_name);
|
||||
return std::make_shared<Tuple>(std::vector<TypePtr>{type, type});
|
||||
}
|
||||
} // namespace
|
||||
AbstractBasePtr ReLUV2Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
@ -100,10 +100,7 @@ AbstractBasePtr ReLUV2Infer(const abstract::AnalysisEnginePtr &, const Primitive
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto types = InferType(primitive, input_args);
|
||||
auto shapes = InferShape(primitive, input_args);
|
||||
auto input_dtype = std::make_shared<abstract::AbstractTensor>(types, shapes[0]);
|
||||
auto mask_dtype = std::make_shared<abstract::AbstractTensor>(kUInt8, shapes[1]);
|
||||
AbstractBasePtrList outputs = {input_dtype, mask_dtype};
|
||||
return std::make_shared<abstract::AbstractTuple>(outputs);
|
||||
return abstract::MakeAbstract(shapes, types);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ReLUV2, prim::kPrimReluV2, ReLUV2Infer, nullptr, true);
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// check
|
||||
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(kInt32, ScalarSummaryInferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(ScalarSummaryInferShape(primitive, input_args), kInt32);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -79,8 +79,7 @@ TypePtr SoftMaxInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
|
||||
AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(SoftMaxInferType(primitive, input_args),
|
||||
SoftMaxInferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(SoftMaxInferShape(primitive, input_args), SoftMaxInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -49,9 +49,7 @@ TypePtr SoftplusInferType(const PrimitivePtr &prim, const std::vector<AbstractBa
|
|||
AbstractBasePtr SoftplusInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(SoftplusInferType(primitive, input_args),
|
||||
SoftplusInferShape(primitive, input_args));
|
||||
return abs;
|
||||
return abstract::MakeAbstract(SoftplusInferShape(primitive, input_args), SoftplusInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softplus, prim::kPrimSoftplus, SoftplusInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -52,8 +52,7 @@ TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &
|
|||
|
||||
AbstractBasePtr SubInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameSub, Sub);
|
||||
} // namespace ops
|
||||
|
|
|
@ -48,7 +48,7 @@ AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const Pr
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
// check
|
||||
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
|
||||
return std::make_shared<abstract::AbstractTensor>(kInt32, TensorSummaryInferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(TensorSummaryInferShape(primitive, input_args), kInt32);
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
|
@ -91,8 +91,7 @@ TypePtr TileInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePt
|
|||
|
||||
AbstractBasePtr TileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(TileInferType(primitive, input_args),
|
||||
TileInferShape(primitive, input_args)->shape());
|
||||
return abstract::MakeAbstract(TileInferShape(primitive, input_args), TileInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameTile, Tile);
|
||||
|
||||
|
|
|
@ -52,9 +52,7 @@ TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBaseP
|
|||
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args),
|
||||
ZerosInferShape(primitive, input_args));
|
||||
return abs;
|
||||
return abstract::MakeAbstract(ZerosInferShape(primitive, input_args), ZerosInferType(primitive, input_args));
|
||||
}
|
||||
|
||||
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -38,15 +38,15 @@ TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBaseP
|
|||
auto infer_type = input_args[0]->BuildType();
|
||||
auto valid_type = common_valid_types;
|
||||
valid_type.insert(kBool);
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, op_name);
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("infer_type", infer_type, valid_type, op_name);
|
||||
return infer_type;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
AbstractBasePtr ZerosLikeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args));
|
||||
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(ZerosLike, prim::kPrimZerosLike, ZerosLikeInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
|
|
Loading…
Reference in New Issue