forked from mindspore-Ecosystem/mindspore
!14530 ops infer
From: @simson_wu Reviewed-by: @chujinjin,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
64d22b4c77
|
@ -363,7 +363,7 @@ set_target_properties(_c_expression PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
||||||
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
if(CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
target_link_libraries(mindspore mindspore::pybind11_module)
|
target_link_libraries(mindspore mindspore::pybind11_module)
|
||||||
target_link_libraries(mindspore mindspore_gvar)
|
target_link_libraries(mindspore mindspore_gvar)
|
||||||
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore -Wl,--no-whole-archive)
|
target_link_libraries(_c_expression PRIVATE -Wl,--whole-archive mindspore mindspore_core -Wl,--no-whole-archive)
|
||||||
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
elseif(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||||
target_link_libraries(mindspore mindspore::pybind11_module)
|
target_link_libraries(mindspore mindspore::pybind11_module)
|
||||||
target_link_libraries(mindspore mindspore_gvar)
|
target_link_libraries(mindspore mindspore_gvar)
|
||||||
|
|
|
@ -459,7 +459,7 @@ AnfNodePtr CreateValueNode(const FuncGraphPtr &func_graph, const CNodePtr &dynam
|
||||||
std::vector<size_t> shape = {t_size, IntToSize(1), n_size};
|
std::vector<size_t> shape = {t_size, IntToSize(1), n_size};
|
||||||
std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)};
|
std::vector<int64_t> output_shape = {SizeToLong(t_size), SizeToLong(1), SizeToLong(n_size)};
|
||||||
std::vector<int64_t> output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)};
|
std::vector<int64_t> output_tensor = {SizeToLong(t_size) * SizeToLong(n_size)};
|
||||||
auto tensor = TensorConstructUtils::CreateOnesTensor(kNumberTypeFloat32, output_tensor);
|
auto tensor = TensorConstructUtils::CreateOnesTensor(kFloat32, output_tensor);
|
||||||
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
|
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, output_shape);
|
||||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
|
auto value_node = kernel_graph->NewValueNode(x_abstract, tensor);
|
||||||
|
|
|
@ -287,6 +287,7 @@ inline const PrimitivePtr kPrimElu = std::make_shared<Primitive>("Elu");
|
||||||
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
inline const PrimitivePtr kPrimRelu6 = std::make_shared<Primitive>("ReLU6");
|
||||||
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
inline const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||||
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
|
inline const PrimitivePtr kPrimPRelu = std::make_shared<Primitive>("PReLU");
|
||||||
|
inline const PrimitivePtr kPrimZeros = std::make_shared<Primitive>("Zeros");
|
||||||
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
inline const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||||
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");
|
inline const PrimitivePtr kPrimOnesLike = std::make_shared<Primitive>("OnesLike");
|
||||||
inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
inline const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||||
|
|
|
@ -29,7 +29,7 @@ abstract::ShapePtr BiasAddInferShape(const PrimitivePtr &primitive, const std::v
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
// check
|
// check
|
||||||
CheckAndConvertUtils::CheckInteger("biasadd_infer", input_args.size(), kEqual, 2, prim_name);
|
CheckAndConvertUtils::CheckInteger("arg size", input_args.size(), kEqual, 2, prim_name);
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||||
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name);
|
auto b_shape = CheckAndConvertUtils::ConvertShapePtrToShape("b_shape", input_args[1]->BuildShape(), prim_name);
|
||||||
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
|
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kGreaterEqual, 2, prim_name);
|
||||||
|
@ -55,7 +55,7 @@ TypePtr BiasAddInferType(const PrimitivePtr &prim, const std::vector<AbstractBas
|
||||||
std::map<std::string, TypePtr> types;
|
std::map<std::string, TypePtr> types;
|
||||||
types.emplace("input_x", input_args[0]->BuildType());
|
types.emplace("input_x", input_args[0]->BuildType());
|
||||||
types.emplace("bias", input_args[1]->BuildType());
|
types.emplace("bias", input_args[1]->BuildType());
|
||||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
return CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim_name);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
void BiasAdd::set_format(const Format &format) {
|
void BiasAdd::set_format(const Format &format) {
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -20,28 +20,6 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto prim_name = primitive->name();
|
|
||||||
CheckAndConvertUtils::CheckInteger("gather_infer", input_args.size(), kEqual, 3, prim_name);
|
|
||||||
|
|
||||||
// Infer type
|
|
||||||
std::set<TypePtr> valid_x_type = {kTensorType};
|
|
||||||
auto x_type =
|
|
||||||
CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
|
|
||||||
std::set<TypePtr> valid_index_types = {kInt32, kInt64};
|
|
||||||
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name);
|
|
||||||
std::set<TypePtr> valid_dim_type = {kInt32, kInt64};
|
|
||||||
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name);
|
|
||||||
|
|
||||||
// Infer shape
|
|
||||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
|
||||||
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name);
|
|
||||||
CheckAndConvertUtils::Check("x_rank", x_shape.size(), kEqual, "index_rank", index_shape.size(), prim_name);
|
|
||||||
|
|
||||||
return std::make_shared<abstract::AbstractTensor>(x_type, index_shape);
|
|
||||||
}
|
|
||||||
REGISTER_PRIMITIVE_C(kNameGather, Gather);
|
REGISTER_PRIMITIVE_C(kNameGather, Gather);
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -34,8 +34,6 @@ class Gather : public PrimitiveC {
|
||||||
MS_DECLARE_PARENT(Gather, PrimitiveC);
|
MS_DECLARE_PARENT(Gather, PrimitiveC);
|
||||||
void Init() {}
|
void Init() {}
|
||||||
};
|
};
|
||||||
AbstractBasePtr GatherInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
||||||
const std::vector<AbstractBasePtr> &input_args);
|
|
||||||
using PrimGatherPtr = std::shared_ptr<Gather>;
|
using PrimGatherPtr = std::shared_ptr<Gather>;
|
||||||
} // namespace ops
|
} // namespace ops
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/gather_d.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
// gather_d
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr GatherDInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
// check
|
||||||
|
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->BuildShape(), prim_name);
|
||||||
|
auto index_shape = CheckAndConvertUtils::ConvertShapePtrToShape("dim_shape", input_args[2]->BuildShape(), prim_name);
|
||||||
|
int64_t x_rank = x_shape.size();
|
||||||
|
CheckAndConvertUtils::Check("x_rank", x_rank, kEqual, "index_rank", index_shape.size(), prim_name);
|
||||||
|
auto dim_v = GetValue<int64_t>(input_args[1]->BuildValue());
|
||||||
|
CheckAndConvertUtils::Check("dim value", dim_v, kGreaterEqual, "negative index_rank", -x_rank, prim_name);
|
||||||
|
CheckAndConvertUtils::Check("dim value", dim_v, kLessThan, "index_rank", x_rank, prim_name);
|
||||||
|
|
||||||
|
if (dim_v < 0) {
|
||||||
|
dim_v = dim_v + x_rank;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < x_rank; ++i) {
|
||||||
|
if (i == dim_v) continue;
|
||||||
|
MS_LOG(INFO) << "Check " << i << "th x shape";
|
||||||
|
CheckAndConvertUtils::Check("x shape", x_shape[i], kEqual, "index_rank", index_shape[i], prim_name);
|
||||||
|
}
|
||||||
|
return std::make_shared<abstract::Shape>(index_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr GatherDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
// check
|
||||||
|
std::set<TypePtr> valid_x_type = {kTensorType};
|
||||||
|
auto x_type =
|
||||||
|
CheckAndConvertUtils::CheckTensorTypeValid("x_type", input_args[0]->BuildType(), valid_x_type, prim_name);
|
||||||
|
std::set<TypePtr> valid_index_types = {kInt32, kInt64};
|
||||||
|
CheckAndConvertUtils::CheckTensorTypeValid("index_type", input_args[2]->BuildType(), valid_index_types, prim_name);
|
||||||
|
std::set<TypePtr> valid_dim_type = {kInt32, kInt64};
|
||||||
|
CheckAndConvertUtils::CheckSubClass("dim_type", input_args[1]->BuildType(), valid_dim_type, prim_name);
|
||||||
|
return x_type;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
AbstractBasePtr GatherDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto abs = std::make_shared<abstract::AbstractTensor>(GatherDInferType(primitive, input_args),
|
||||||
|
GatherDInferShape(primitive, input_args));
|
||||||
|
return abs;
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(GatherD, prim::kPrimGatherD, GatherDInfer, nullptr, false);
|
||||||
|
REGISTER_PRIMITIVE_C(kNameGatherD, GatherD);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_GATHER_D_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_GATHER_D_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameGatherD = "GatherD";
|
||||||
|
class GatherD : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
GatherD() : PrimitiveC(kNameGatherD) { InitIOName({"x", "dim", "index"}, {"output"}); }
|
||||||
|
~GatherD() = default;
|
||||||
|
MS_DECLARE_PARENT(GatherD, PrimitiveC);
|
||||||
|
void Init() {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_GATHER_D_H_
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,7 +22,18 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
// scalar_summary
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr ScalarSummaryInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
// check
|
||||||
|
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
|
||||||
|
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
|
||||||
|
return std::make_shared<abstract::Shape>(ShapeVector(1));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
void ScalarSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); }
|
void ScalarSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); }
|
||||||
|
|
||||||
bool ScalarSummary::get_side_effect_io() const {
|
bool ScalarSummary::get_side_effect_io() const {
|
||||||
|
@ -35,12 +46,9 @@ void ScalarSummary::Init() { this->set_side_effect_io(); }
|
||||||
AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr ScalarSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
|
||||||
// check
|
// check
|
||||||
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name);
|
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
|
||||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
|
return std::make_shared<abstract::AbstractTensor>(kInt32, ScalarSummaryInferShape(primitive, input_args));
|
||||||
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kLessEqual, 1, prim_name);
|
|
||||||
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
|
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(ScalarSummary, prim::kPrimScalarSummary, ScalarSummaryInfer, nullptr, true);
|
||||||
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
|
REGISTER_PRIMITIVE_C(kNameScalarSummary, ScalarSummary);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
@ -22,7 +22,18 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
// scalar_summary
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr TensorSummaryInferShape(const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
// check
|
||||||
|
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
|
||||||
|
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
|
||||||
|
return std::make_shared<abstract::Shape>(ShapeVector(1));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
void TensorSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); }
|
void TensorSummary::set_side_effect_io() { this->AddAttr(kSideEffectIO, MakeValue(true)); }
|
||||||
|
|
||||||
bool TensorSummary::get_side_effect_io() const {
|
bool TensorSummary::get_side_effect_io() const {
|
||||||
|
@ -35,12 +46,9 @@ void TensorSummary::Init() { this->set_side_effect_io(); }
|
||||||
AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr TensorSummaryInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
auto prim_name = primitive->name();
|
|
||||||
// check
|
// check
|
||||||
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], prim_name);
|
CheckAndConvertUtils::CheckSummaryParam(input_args[0], input_args[1], primitive->name());
|
||||||
auto v_shape = CheckAndConvertUtils::ConvertShapePtrToShape("v_shape", input_args[1]->BuildShape(), prim_name);
|
return std::make_shared<abstract::AbstractTensor>(kInt32, TensorSummaryInferShape(primitive, input_args));
|
||||||
CheckAndConvertUtils::CheckInteger("v rank", v_shape.size(), kGreaterEqual, 1, prim_name);
|
|
||||||
return std::make_shared<abstract::AbstractTensor>(kInt32, std::make_shared<abstract::Shape>(ShapeVector(1)));
|
|
||||||
}
|
}
|
||||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
|
REGISTER_PRIMITIVE_EVAL_IMPL(TensorSummary, prim::kPrimTensorSummary, TensorSummaryInfer, nullptr, true);
|
||||||
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
|
REGISTER_PRIMITIVE_C(kNameTensorSummary, TensorSummary);
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/**
|
/**
|
||||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
*
|
*
|
||||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
* you may not use this file except in compliance with the License.
|
* you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "ops/zeros.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <set>
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "utils/tensor_construct_utils.h"
|
||||||
|
#include "abstract/primitive_infer_map.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
// zeros
|
||||||
|
namespace {
|
||||||
|
abstract::ShapePtr ZerosInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
// check
|
||||||
|
auto shape_value = input_args[0]->BuildValue();
|
||||||
|
std::vector<int64_t> out_shape = CheckAndConvertUtils::CheckAttrIntOrTupleInt("shape", shape_value, prim_name);
|
||||||
|
CheckAndConvertUtils::CheckPositiveVector("shape", out_shape, prim_name);
|
||||||
|
return std::make_shared<abstract::Shape>(out_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
TypePtr ZerosInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
// check
|
||||||
|
auto dtype_value = input_args[1]->BuildValue();
|
||||||
|
if (!dtype_value->isa<Type>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "The dtype of Zeros is invalid!";
|
||||||
|
}
|
||||||
|
auto output_type = dtype_value->cast<TypePtr>();
|
||||||
|
const std::set<TypePtr> valid_types = {kBool, kInt8, kInt16, kInt32, kInt64, kUInt8,
|
||||||
|
kUInt16, kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||||
|
return CheckAndConvertUtils::CheckSubClass("dtype", output_type, valid_types, prim_name);
|
||||||
|
}
|
||||||
|
ValuePtr ZerosInferValue(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args,
|
||||||
|
const abstract::AbstractBasePtr &abs) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto prim_name = prim->name();
|
||||||
|
// check
|
||||||
|
auto out_shape = CheckAndConvertUtils::ConvertShapePtrToShape("output shape", abs->BuildShape(), prim_name);
|
||||||
|
auto out_type = abs->BuildType();
|
||||||
|
MS_EXCEPTION_IF_NULL(out_type);
|
||||||
|
return TensorConstructUtils::CreateZerosTensor(out_type, out_shape);
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
AbstractBasePtr ZerosInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto abs = std::make_shared<abstract::AbstractTensor>(ZerosInferType(primitive, input_args),
|
||||||
|
ZerosInferShape(primitive, input_args));
|
||||||
|
abs->set_value(ZerosInferValue(primitive, input_args, abs));
|
||||||
|
return abs;
|
||||||
|
}
|
||||||
|
REGISTER_PRIMITIVE_EVAL_IMPL(Zeros, prim::kPrimZeros, ZerosInfer, ZerosInferValue, false);
|
||||||
|
REGISTER_PRIMITIVE_C(kNameZeros, Zeros);
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CORE_OPS_ZEROS_H_
|
||||||
|
#define MINDSPORE_CORE_OPS_ZEROS_H_
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "ops/primitive_c.h"
|
||||||
|
#include "abstract/abstract_value.h"
|
||||||
|
#include "utils/check_convert_utils.h"
|
||||||
|
#include "ops/op_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace ops {
|
||||||
|
constexpr auto kNameZeros = "Zeros";
|
||||||
|
class Zeros : public PrimitiveC {
|
||||||
|
public:
|
||||||
|
Zeros() : PrimitiveC(kNameZeros) {}
|
||||||
|
~Zeros() = default;
|
||||||
|
MS_DECLARE_PARENT(Zeros, PrimitiveC);
|
||||||
|
void Init() {}
|
||||||
|
};
|
||||||
|
} // namespace ops
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CORE_OPS_ZEROS_H_
|
|
@ -442,8 +442,8 @@ TypePtr CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, Ty
|
||||||
auto type = types.begin()->second;
|
auto type = types.begin()->second;
|
||||||
MS_EXCEPTION_IF_NULL(type);
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
if (!type->isa<TensorType>()) {
|
if (!type->isa<TensorType>()) {
|
||||||
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << types.begin()->first << " input must be a tensor but got "
|
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s " << types.begin()->first
|
||||||
<< type->ToString();
|
<< " input must be a tensor but got " << type->ToString();
|
||||||
}
|
}
|
||||||
TypePtr check_type = _CheckTypeSame(types, prim_name, false);
|
TypePtr check_type = _CheckTypeSame(types, prim_name, false);
|
||||||
return CheckTypeValid(types.begin()->first, check_type, check_list, prim_name);
|
return CheckTypeValid(types.begin()->first, check_type, check_list, prim_name);
|
||||||
|
@ -599,4 +599,27 @@ void CheckAndConvertUtils::CheckMode(const std::string &class_name) {
|
||||||
MS_EXCEPTION(NotSupportError) << class_name << "operator does not support PyNative mode.";
|
MS_EXCEPTION(NotSupportError) << class_name << "operator does not support PyNative mode.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> CheckAndConvertUtils::CheckAttrIntOrTupleInt(const std::string &arg_name, const ValuePtr &attr,
|
||||||
|
const std::string &prim_name) {
|
||||||
|
std::vector<int64_t> result;
|
||||||
|
MS_EXCEPTION_IF_NULL(attr);
|
||||||
|
if (attr->isa<ValueTuple>()) {
|
||||||
|
std::vector<ValuePtr> attr_vec = attr->cast<ValueTuplePtr>()->value();
|
||||||
|
(void)std::transform(
|
||||||
|
attr_vec.begin(), attr_vec.end(), std::back_inserter(result), [=](const ValuePtr &e) -> int64_t {
|
||||||
|
if (!e->isa<Int64Imm>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Int64";
|
||||||
|
}
|
||||||
|
return GetValue<int64_t>(e);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
if (!attr->isa<Int64Imm>()) {
|
||||||
|
MS_EXCEPTION(TypeError) << "For " << prim_name << ", the type of" << arg_name << " must be Int64";
|
||||||
|
}
|
||||||
|
int64_t attr_val = attr->cast<Int64ImmPtr>()->value();
|
||||||
|
result.push_back(attr_val);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -321,6 +321,8 @@ class CheckAndConvertUtils {
|
||||||
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
|
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
|
||||||
const std::string &class_name);
|
const std::string &class_name);
|
||||||
static void CheckMode(const std::string &class_name);
|
static void CheckMode(const std::string &class_name);
|
||||||
|
static std::vector<int64_t> CheckAttrIntOrTupleInt(const std::string &prim_name, const ValuePtr &attr,
|
||||||
|
const std::string &arg_name);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);
|
static bool IsEqualVector(const std::vector<int64_t> &vec_1, const std::vector<int64_t> &vec_2);
|
||||||
|
|
|
@ -17,8 +17,10 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) {
|
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
|
||||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||||
|
auto type_id = ExtractTypeId(type_ptr);
|
||||||
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
|
||||||
size_t mem_size = IntToSize(tensor->ElementsNum());
|
size_t mem_size = IntToSize(tensor->ElementsNum());
|
||||||
auto tensor_data = tensor->data_c();
|
auto tensor_data = tensor->data_c();
|
||||||
char *data = reinterpret_cast<char *>(tensor_data);
|
char *data = reinterpret_cast<char *>(tensor_data);
|
||||||
|
@ -28,8 +30,10 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) {
|
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape) {
|
||||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||||
|
auto type_id = ExtractTypeId(type_ptr);
|
||||||
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape);
|
||||||
size_t mem_size = IntToSize(tensor->ElementsNum());
|
size_t mem_size = IntToSize(tensor->ElementsNum());
|
||||||
if (tensor->data_type() == kNumberTypeFloat32) {
|
if (tensor->data_type() == kNumberTypeFloat32) {
|
||||||
SetTensorData<float>(tensor->data_c(), 1.0, mem_size);
|
SetTensorData<float>(tensor->data_c(), 1.0, mem_size);
|
||||||
|
@ -39,8 +43,18 @@ tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std:
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor::TensorPtr TensorConstructUtils::CreateTensor(TypeId type, const std::vector<int64_t> &shape, void *data) {
|
tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, const std::vector<int64_t> &shape,
|
||||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape, data, type);
|
void *data) {
|
||||||
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||||
|
auto type_id = ExtractTypeId(type_ptr);
|
||||||
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, shape, data, type_id);
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||||
|
auto tensor_type = type_ptr->cast<TensorTypePtr>();
|
||||||
|
auto type_id = tensor_type->element()->type_id();
|
||||||
|
return type_id;
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -30,9 +30,10 @@ void SetTensorData(void *data, T num, size_t data_length) {
|
||||||
}
|
}
|
||||||
class TensorConstructUtils {
|
class TensorConstructUtils {
|
||||||
public:
|
public:
|
||||||
static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape);
|
static tensor::TensorPtr CreateZerosTensor(const TypePtr type, const std::vector<int64_t> &shape);
|
||||||
static tensor::TensorPtr CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape);
|
static tensor::TensorPtr CreateOnesTensor(const TypePtr type, const std::vector<int64_t> &shape);
|
||||||
static tensor::TensorPtr CreateTensor(TypeId type, const std::vector<int64_t> &shape, void *data);
|
static tensor::TensorPtr CreateTensor(const TypePtr type, const std::vector<int64_t> &shape, void *data);
|
||||||
|
static TypeId ExtractTypeId(const TypePtr type);
|
||||||
};
|
};
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_
|
||||||
|
|
|
@ -1342,27 +1342,6 @@ class Zeros(PrimitiveWithInfer):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize Zeros"""
|
"""Initialize Zeros"""
|
||||||
|
|
||||||
def __infer__(self, dims, dtype):
|
|
||||||
if isinstance(dims['value'], int):
|
|
||||||
shape = (dims['value'],)
|
|
||||||
else:
|
|
||||||
shape = dims['value']
|
|
||||||
validator.check_value_type("shape", shape, [tuple], self.name)
|
|
||||||
for i, item in enumerate(shape):
|
|
||||||
validator.check_non_negative_int(item, shape[i], self.name)
|
|
||||||
valid_types = [mstype.bool_, mstype.int8, mstype.int16, mstype.int32, mstype.int64,
|
|
||||||
mstype.uint8, mstype.uint16, mstype.uint32, mstype.uint64,
|
|
||||||
mstype.float16, mstype.float32, mstype.float64]
|
|
||||||
validator.check_types_same_and_valid({"value": dtype['value']}, valid_types, self.name)
|
|
||||||
x_nptype = mstype.dtype_to_nptype(dtype['value'])
|
|
||||||
ret = np.zeros(shape, x_nptype)
|
|
||||||
out = {
|
|
||||||
'value': Tensor(ret),
|
|
||||||
'shape': shape,
|
|
||||||
'dtype': x_nptype,
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class OnesLike(PrimitiveWithInfer):
|
class OnesLike(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
@ -5193,30 +5172,6 @@ class GatherD(PrimitiveWithInfer):
|
||||||
"""Initialize GatherD"""
|
"""Initialize GatherD"""
|
||||||
self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output'])
|
self.init_prim_io_names(inputs=['x', 'dim', 'index'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, x, dim, index):
|
|
||||||
validator.check_subclass("x", x['dtype'], mstype.tensor, self.name)
|
|
||||||
validator.check_tensor_dtype_valid("index", index['dtype'], [mstype.int32, mstype.int64], self.name)
|
|
||||||
validator.check_subclass("dim", dim['dtype'], [mstype.int32, mstype.int64], self.name)
|
|
||||||
x_shp = x['shape']
|
|
||||||
idx_shp = index['shape']
|
|
||||||
x_rank = len(x_shp)
|
|
||||||
idx_rank = len(idx_shp)
|
|
||||||
validator.check("x_rank, idx_rank", x_rank, "expected", idx_rank, Rel.EQ, self.name)
|
|
||||||
dim_v = dim['value']
|
|
||||||
validator.check("dim value", dim_v, "expected", -x_rank, Rel.GE, self.name)
|
|
||||||
validator.check("dim value", dim_v, "expected", x_rank, Rel.LT, self.name)
|
|
||||||
if dim_v < 0:
|
|
||||||
dim['value'] = dim_v + x_rank
|
|
||||||
for i in range(x_rank):
|
|
||||||
if i == dim['value']:
|
|
||||||
continue
|
|
||||||
validator.check("x_shp[{0}], idx_shp[{0}]".format(i), x_shp[i], "expected", idx_shp[i], Rel.EQ, self.name)
|
|
||||||
|
|
||||||
out = {'shape': index['shape'],
|
|
||||||
'dtype': x['dtype'],
|
|
||||||
'value': None}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class Identity(PrimitiveWithInfer):
|
class Identity(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -89,17 +89,6 @@ class ScalarSummary(PrimitiveWithInfer):
|
||||||
"""init"""
|
"""init"""
|
||||||
self.add_prim_attr("side_effect_io", True)
|
self.add_prim_attr("side_effect_io", True)
|
||||||
|
|
||||||
def __infer__(self, name, value):
|
|
||||||
_check_summary_param(name, value, self.__class__.__name__)
|
|
||||||
|
|
||||||
v_shape = value['shape']
|
|
||||||
# In the summary, the value whose shape is [1] is also considered as a scalar.
|
|
||||||
if v_shape and v_shape != [1]:
|
|
||||||
raise ValueError(f"For 'value' the type should be scalar, "
|
|
||||||
f"shape should be [] or [1] in {self.__class__.__name__}, but got {v_shape}.")
|
|
||||||
|
|
||||||
return SUMMARY_RETURN_VALUE
|
|
||||||
|
|
||||||
|
|
||||||
class ImageSummary(PrimitiveWithInfer):
|
class ImageSummary(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
@ -191,17 +180,6 @@ class TensorSummary(PrimitiveWithInfer):
|
||||||
"""init"""
|
"""init"""
|
||||||
self.add_prim_attr("side_effect_io", True)
|
self.add_prim_attr("side_effect_io", True)
|
||||||
|
|
||||||
def __infer__(self, name, value):
|
|
||||||
_check_summary_param(name, value, self.__class__.__name__)
|
|
||||||
|
|
||||||
v_shape = value['shape']
|
|
||||||
# In the summary, the value whose shape is [] is not considered as a tensor.
|
|
||||||
if not v_shape:
|
|
||||||
raise ValueError(f"For 'value' the type should be tensor in {self.__class__.__name__}, "
|
|
||||||
f"shape should not be [].")
|
|
||||||
|
|
||||||
return SUMMARY_RETURN_VALUE
|
|
||||||
|
|
||||||
|
|
||||||
class HistogramSummary(PrimitiveWithInfer):
|
class HistogramSummary(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue