!14948 Add some infer in C++

From: @liangzhibo
Reviewed-by: @zh_qh,@ginfung
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-04-15 17:39:30 +08:00 committed by Gitee
commit 19ab917aa5
6 changed files with 115 additions and 26 deletions

View File

@ -0,0 +1,54 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/dtype.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/abstract_value.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const AbstractBasePtr &infer) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
const std::set<TypePtr> valid_types = {kTensorType};
auto type =
CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name);
return type;
}
AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto value = DTypeInferValue(primitive, input_args, nullptr);
MS_EXCEPTION_IF_NULL(value);
auto type = value->cast<TypePtr>();
MS_EXCEPTION_IF_NULL(type);
auto abstract = std::make_shared<abstract::AbstractType>(type);
return abstract;
}
REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,40 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_DTYPE_H_
#define MINDSPORE_CORE_OPS_DTYPE_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
class DType : public PrimitiveC {
public:
DType() : PrimitiveC(prim::kPrimDType->name()) { InitIOName({"x"}, {"output"}); }
~DType() = default;
MS_DECLARE_PARENT(DType, PrimitiveC);
void Init() {}
};
using PrimDTypePtr = std::shared_ptr<DType>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_DTYPE_H_

View File

@ -31,7 +31,10 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
// infer shape
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto in_shape = shape_map[kShape];
// infer type
AbstractBasePtrList abs_list;
std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
@ -39,9 +42,20 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
return std::make_shared<abstract::AbstractScalar>(item);
});
auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
abs->set_value(MakeValue(in_shape));
return abs;
}
REGISTER_PRIMITIVE_C(kNameShape, Shape);
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
const AbstractBasePtr &infer) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
auto inshape = shape_map[kShape];
auto value = MakeValue(inshape);
return value;
}
REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, false);
} // namespace ops
} // namespace mindspore

View File

@ -26,17 +26,13 @@
namespace mindspore {
namespace ops {
constexpr auto kNameShape = "Shape";
class Shape : public PrimitiveC {
public:
Shape() : PrimitiveC(kNameShape) {}
Shape() : PrimitiveC(prim::kPrimShape->name()) {}
~Shape() = default;
MS_DECLARE_PARENT(Shape, PrimitiveC);
void Init() {}
};
AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimShapePtr = std::shared_ptr<Shape>;
} // namespace ops
} // namespace mindspore

View File

@ -197,7 +197,7 @@ class ExpandDims(PrimitiveWithInfer):
return out
class DType(PrimitiveWithInfer):
class DType(Primitive):
"""
Returns the data type of the input tensor as mindspore.dtype.
@ -224,14 +224,6 @@ class DType(PrimitiveWithInfer):
def __init__(self):
"""Initialize DType"""
def __infer__(self, x):
addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.'
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info)
out = {'shape': (),
'dtype': mstype.type_type,
'value': x['dtype'].element_type()}
return out
class SameTypeShape(PrimitiveWithInfer):
"""
@ -549,7 +541,7 @@ class Reshape(PrimitiveWithInfer):
return out
class Shape(PrimitiveWithInfer):
class Shape(Primitive):
"""
Returns the shape of the input tensor.
@ -578,13 +570,6 @@ class Shape(PrimitiveWithInfer):
def __init__(self):
"""Initialize Shape"""
def __infer__(self, x):
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
out = {'shape': (),
'dtype': mstype.tuple_,
'value': tuple(x['shape'])}
return out
class DynamicShape(Primitive):
"""

View File

@ -129,7 +129,7 @@ class Flatten(PrimitiveWithInfer):
return input_x
class Softmax(PrimitiveWithInfer):
class Softmax(Primitive):
r"""
Softmax operation.