forked from mindspore-Ecosystem/mindspore
[Lite] add standard normal mapper and fix bugs
This commit is contained in:
parent
7cddb2c437
commit
c583a82b00
|
@ -651,9 +651,9 @@ REG_ADPT_DESC(Erfinv, prim::kPrimErfinv->name(), ADPT_DESC(Erfinv))
|
|||
// ArgMin
|
||||
INPUT_MAP(ArgMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(dimension)}};
|
||||
ATTR_INPUT_MAP(ArgMin) = {{"axis", "dimension"}};
|
||||
ATTR_MAP(ArgMin) = {{"output_dtype", ATTR_DESC(dtype, AnyTraits<GEType>())}};
|
||||
ATTR_MAP(ArgMin) = {{"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}};
|
||||
OUTPUT_MAP(ArgMin) = {{0, OUTPUT_DESC(y)}};
|
||||
REG_ADPT_DESC(ArgMin, kArgMinOpName, ADPT_DESC(ArgMin))
|
||||
REG_ADPT_DESC(ArgMin, kArgminOpName, ADPT_DESC(ArgMin))
|
||||
REG_ADPT_DESC(ArgMinD, kArgMinDOpName, ADPT_DESC(ArgMin))
|
||||
|
||||
// Threshold
|
||||
|
|
|
@ -868,7 +868,7 @@ int BenchmarkUnifiedApi::MarkPerformance() {
|
|||
|
||||
MS_LOG(INFO) << "Running benchmark loops...";
|
||||
std::cout << "Running benchmark loops..." << std::endl;
|
||||
uint64_t time_min = 1000000;
|
||||
uint64_t time_min = UINT64_MAX;
|
||||
uint64_t time_max = 0;
|
||||
uint64_t time_avg = 0;
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/adapter/acl/mapper/standard_normal_mapper.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
|
||||
#include "src/common/log_util.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "tools/converter/adapter/acl/common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
namespace {
|
||||
constexpr size_t kNameInputNum = 2;
|
||||
constexpr size_t kNumFlagThree = 3;
|
||||
} // namespace
|
||||
|
||||
STATUS StandardNormalMapper::Mapper(const CNodePtr &cnode) {
|
||||
ValueNodePtr value_node = nullptr;
|
||||
PrimitivePtr src_prim = nullptr;
|
||||
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
|
||||
MS_LOG(ERROR) << "Get primitive from cnode failed.";
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
if (cnode->size() != kNameInputNum) {
|
||||
MS_LOG(ERROR) << "Input size of StandardNormal must be " << (kNameInputNum - 1)
|
||||
<< " real size: " << (cnode->size() - 1);
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
ops::StandardNormal std_norm;
|
||||
auto dst_prim = std_norm.GetPrim();
|
||||
MSLITE_CHECK_PTR(dst_prim);
|
||||
dst_prim->AddAttr("dtype", TypeIdToType(acl::GetTypeFromNode(cnode)));
|
||||
dst_prim->SetAttrs(src_prim->attrs());
|
||||
value_node->set_value(dst_prim);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_MAPPER(kNameStandardNormal, StandardNormalMapper)
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2023 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_STANDARD_NORMAL_MAPPER_H_
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_STANDARD_NORMAL_MAPPER_H_
|
||||
|
||||
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
|
||||
#include "ops/standard_normal.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
using mindspore::ops::kNameStandardNormal;
|
||||
|
||||
class StandardNormalMapper : public PrimitiveMapper {
|
||||
public:
|
||||
StandardNormalMapper() : PrimitiveMapper(kNameStandardNormal) {}
|
||||
|
||||
~StandardNormalMapper() override = default;
|
||||
|
||||
STATUS Mapper(const CNodePtr &cnode) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_STANDARD_NORMAL_MAPPER_H_
|
|
@ -30,6 +30,7 @@
|
|||
#include "ops/custom.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "ops/transpose.h"
|
||||
#include "ops/standard_normal.h"
|
||||
#include "ops/tuple_get_item.h"
|
||||
#include "mindspore/core/ops/core_ops.h"
|
||||
#include "cxx_api/model/acl/model_converter.h"
|
||||
|
@ -274,7 +275,7 @@ STATUS AclPassImpl::MapperForOrgMindIR(const FuncGraphPtr &func_graph) {
|
|||
std::set<FuncGraphPtr> all_func_graphs = {};
|
||||
lite::GetAllFuncGraph(func_graph, &all_func_graphs);
|
||||
|
||||
std::set<std::string> mindir_mapper = {ops::kNameTranspose};
|
||||
std::set<std::string> mindir_mapper = {ops::kNameTranspose, ops::kNameStandardNormal};
|
||||
for (auto graph : all_func_graphs) {
|
||||
auto node_list = TopoSort(graph->get_return());
|
||||
for (auto &node : node_list) {
|
||||
|
|
Loading…
Reference in New Issue