forked from mindspore-Ecosystem/mindspore
Add some infer
This commit is contained in:
parent
12e60a7a27
commit
f4a0ccc924
|
@ -0,0 +1,54 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/dtype.h"
|
||||||
|
#include <string>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include <vector>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
ValuePtr DTypeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||||
|
const AbstractBasePtr &infer) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("dtype infer", input_args.size(), kEqual, 1, op_name);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
|
const std::set<TypePtr> valid_types = {kTensorType};
|
||||||
|
auto type =
|
||||||
|
CheckAndConvertUtils::CheckTensorTypeValid("infer type", input_args[0]->BuildType(), valid_types, op_name);
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr DTypeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
auto value = DTypeInferValue(primitive, input_args, nullptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(value);
|
||||||
|
auto type = value->cast<TypePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
auto abstract = std::make_shared<abstract::AbstractType>(type);
|
||||||
|
return abstract;
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(DType, prim::kPrimDType, DTypeInfer, DTypeInferValue, false);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,40 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_DTYPE_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_DTYPE_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
class DType : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
DType() : PrimitiveC(prim::kPrimDType->name()) { InitIOName({"x"}, {"output"}); }
|
||||||
|
~DType() = default;
|
||||||
|
MS_DECLARE_PARENT(DType, PrimitiveC);
|
||||||
|
void Init() {}
|
||||||
|
};
|
||||||
|
using PrimDTypePtr = std::shared_ptr<DType>;
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_DTYPE_H_
|
|
@ -31,7 +31,10 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
||||||
// infer shape
|
// infer shape
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto op_name = primitive->name();
|
auto op_name = primitive->name();
|
||||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), op_name);
|
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||||
|
auto in_shape = shape_map[kShape];
|
||||||
// infer type
|
// infer type
|
||||||
AbstractBasePtrList abs_list;
|
AbstractBasePtrList abs_list;
|
||||||
std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
|
std::transform(in_shape.begin(), in_shape.end(), std::back_inserter(abs_list),
|
||||||
|
@ -39,9 +42,20 @@ AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP
|
||||||
return std::make_shared<abstract::AbstractScalar>(item);
|
return std::make_shared<abstract::AbstractScalar>(item);
|
||||||
});
|
});
|
||||||
auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
|
auto abs = std::make_shared<abstract::AbstractTuple>(abs_list);
|
||||||
abs->set_value(MakeValue(in_shape));
|
|
||||||
return abs;
|
return abs;
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_C(kNameShape, Shape);
|
|
||||||
|
ValuePtr ShapeInferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args,
|
||||||
|
const AbstractBasePtr &infer) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto op_name = primitive->name();
|
||||||
|
CheckAndConvertUtils::CheckInteger("shape infer", input_args.size(), kEqual, 1, op_name);
|
||||||
|
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||||
|
auto shape_map = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape());
|
||||||
|
auto inshape = shape_map[kShape];
|
||||||
|
auto value = MakeValue(inshape);
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(Shape, prim::kPrimShape, ShapeInfer, ShapeInferValue, false);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -26,17 +26,13 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
constexpr auto kNameShape = "Shape";
|
|
||||||
class Shape : public PrimitiveC {
|
class Shape : public PrimitiveC {
|
||||||
public:
|
public:
|
||||||
Shape() : PrimitiveC(kNameShape) {}
|
Shape() : PrimitiveC(prim::kPrimShape->name()) {}
|
||||||
~Shape() = default;
|
~Shape() = default;
|
||||||
MS_DECLARE_PARENT(Shape, PrimitiveC);
|
MS_DECLARE_PARENT(Shape, PrimitiveC);
|
||||||
void Init() {}
|
void Init() {}
|
||||||
};
|
};
|
||||||
|
|
||||||
AbstractBasePtr ShapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimShapePtr = std::shared_ptr<Shape>;
|
using PrimShapePtr = std::shared_ptr<Shape>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -197,7 +197,7 @@ class ExpandDims(PrimitiveWithInfer):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class DType(PrimitiveWithInfer):
|
class DType(Primitive):
|
||||||
"""
|
"""
|
||||||
Returns the data type of the input tensor as mindspore.dtype.
|
Returns the data type of the input tensor as mindspore.dtype.
|
||||||
|
|
||||||
|
@ -224,14 +224,6 @@ class DType(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize DType"""
|
"""Initialize DType"""
|
||||||
|
|
||||||
def __infer__(self, x):
|
|
||||||
addition_error_info = 'Perhaps you are using a mixture of tensors and scalars to operate.'
|
|
||||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name, addition_error_info)
|
|
||||||
out = {'shape': (),
|
|
||||||
'dtype': mstype.type_type,
|
|
||||||
'value': x['dtype'].element_type()}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class SameTypeShape(PrimitiveWithInfer):
|
class SameTypeShape(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
@ -549,7 +541,7 @@ class Reshape(PrimitiveWithInfer):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Shape(PrimitiveWithInfer):
|
class Shape(Primitive):
|
||||||
"""
|
"""
|
||||||
Returns the shape of the input tensor.
|
Returns the shape of the input tensor.
|
||||||
|
|
||||||
|
@ -578,13 +570,6 @@ class Shape(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize Shape"""
|
"""Initialize Shape"""
|
||||||
|
|
||||||
def __infer__(self, x):
|
|
||||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
|
||||||
out = {'shape': (),
|
|
||||||
'dtype': mstype.tuple_,
|
|
||||||
'value': tuple(x['shape'])}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class DynamicShape(Primitive):
|
class DynamicShape(Primitive):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -129,7 +129,7 @@ class Flatten(PrimitiveWithInfer):
|
||||||
return input_x
|
return input_x
|
||||||
|
|
||||||
|
|
||||||
class Softmax(PrimitiveWithInfer):
|
class Softmax(Primitive):
|
||||||
r"""
|
r"""
|
||||||
Softmax operation.
|
Softmax operation.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue