forked from mindspore-Ecosystem/mindspore
!14948 Add some infer in C++
From: @liangzhibo Reviewed-by: @zh_qh,@ginfung Signed-off-by: @zh_qh
This commit is contained in:
commit
19ab917aa5
|
@ -0,0 +1,54 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "ops/dtype.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
const AbstractBasePtr &infer) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
const std::set<TypePtr> valid_types = {kTensorType};
|
||||
auto type =
|
||||
CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name);
|
||||
return type;
|
||||
}
|
||||
|
||||
AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto value = DTypeInferValue(primitive, input_args, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto type = value->cast<TypePtr>();
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
auto abstract = std::make_shared<abstract::AbstractType>(type);
|
||||
return abstract;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_DTYPE_H_
|
||||
#define MINDSPORE_CORE_OPS_DTYPE_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
class DType : public PrimitiveC {
|
||||
public:
|
||||
DType() : PrimitiveC(prim::kPrimDType->name()) { InitIOName({"x"}, {"output"}); }
|
||||
~DType() = default;
|
||||
MS_DECLARE_PARENT(DType, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
using PrimDTypePtr = std::shared_ptr<DType>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_DTYPE_H_
|
|
@ -31,7 +31,10 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
// infer shape
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
|
||||
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto in_shape = shape_map[kShape];
|
||||
// infer type
|
||||
AbstractBasePtrList abs_list;
|
||||
std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
|
||||
|
@ -39,9 +42,20 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
|||
return std::make_shared<abstract::AbstractScalar>(item);
|
||||
});
|
||||
auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
|
||||
abs->set_value(MakeValue(in_shape));
|
||||
return abs;
|
||||
}
|
||||
REGISTER_PRIMITIVE_C(kNameShape, Shape);
|
||||
|
||||
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||
const AbstractBasePtr &infer) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||
auto inshape = shape_map[kShape];
|
||||
auto value = MakeValue(inshape);
|
||||
return value;
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -26,17 +26,13 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameShape = "Shape";
|
||||
class Shape : public PrimitiveC {
|
||||
public:
|
||||
Shape() : PrimitiveC(kNameShape) {}
|
||||
Shape() : PrimitiveC(prim::kPrimShape->name()) {}
|
||||
~Shape() = default;
|
||||
MS_DECLARE_PARENT(Shape, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimShapePtr = std::shared_ptr<Shape>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -197,7 +197,7 @@ class ExpandDims(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class DType(PrimitiveWithInfer):
|
||||
class DType(Primitive):
|
||||
"""
|
||||
Returns the data type of the input tensor as mindspore.dtype.
|
||||
|
||||
|
@ -224,14 +224,6 @@ class DType(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""Initialize DType"""
|
||||
|
||||
def __infer__(self, x):
|
||||
addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.'
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info)
|
||||
out = {'shape': (),
|
||||
'dtype': mstype.type_type,
|
||||
'value': x['dtype'].element_type()}
|
||||
return out
|
||||
|
||||
|
||||
class SameTypeShape(PrimitiveWithInfer):
|
||||
"""
|
||||
|
@ -549,7 +541,7 @@ class Reshape(PrimitiveWithInfer):
|
|||
return out
|
||||
|
||||
|
||||
class Shape(PrimitiveWithInfer):
|
||||
class Shape(Primitive):
|
||||
"""
|
||||
Returns the shape of the input tensor.
|
||||
|
||||
|
@ -578,13 +570,6 @@ class Shape(PrimitiveWithInfer):
|
|||
def __init__(self):
|
||||
"""Initialize Shape"""
|
||||
|
||||
def __infer__(self, x):
|
||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
||||
out = {'shape': (),
|
||||
'dtype': mstype.tuple_,
|
||||
'value': tuple(x['shape'])}
|
||||
return out
|
||||
|
||||
|
||||
class DynamicShape(Primitive):
|
||||
"""
|
||||
|
|
|
@ -129,7 +129,7 @@ class Flatten(PrimitiveWithInfer):
|
|||
return input_x
|
||||
|
||||
|
||||
class Softmax(PrimitiveWithInfer):
|
||||
class Softmax(Primitive):
|
||||
r"""
|
||||
Softmax operation.
|
||||
|
||||
|
|
Loading…
Reference in New Issue