[Lite] add standard normal mapper and fix bugs

This commit is contained in:
xupan 2023-02-06 15:40:04 +08:00
parent 7cddb2c437
commit c583a82b00
5 changed files with 97 additions and 4 deletions

View File

@ -651,9 +651,9 @@ REG_ADPT_DESC(Erfinv, prim::kPrimErfinv->name(), ADPT_DESC(Erfinv))
// ArgMin
INPUT_MAP(ArgMin) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(dimension)}};
ATTR_INPUT_MAP(ArgMin) = {{"axis", "dimension"}};
ATTR_MAP(ArgMin) = {{"output_dtype", ATTR_DESC(dtype, AnyTraits<GEType>())}};
ATTR_MAP(ArgMin) = {{"output_type", ATTR_DESC(dtype, AnyTraits<GEType>())}};
OUTPUT_MAP(ArgMin) = {{0, OUTPUT_DESC(y)}};
REG_ADPT_DESC(ArgMin, kArgMinOpName, ADPT_DESC(ArgMin))
REG_ADPT_DESC(ArgMin, kArgminOpName, ADPT_DESC(ArgMin))
REG_ADPT_DESC(ArgMinD, kArgMinDOpName, ADPT_DESC(ArgMin))
// Threshold

View File

@ -868,7 +868,7 @@ int BenchmarkUnifiedApi::MarkPerformance() {
MS_LOG(INFO) << "Running benchmark loops...";
std::cout << "Running benchmark loops..." << std::endl;
uint64_t time_min = 1000000;
uint64_t time_min = UINT64_MAX;
uint64_t time_max = 0;
uint64_t time_avg = 0;

View File

@ -0,0 +1,55 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/adapter/acl/mapper/standard_normal_mapper.h"
#include <memory>
#include "tools/converter/adapter/acl/mapper/primitive_mapper_register.h"
#include "src/common/log_util.h"
#include "ops/op_utils.h"
#include "tools/converter/adapter/acl/common/utils.h"
namespace mindspore {
namespace lite {
namespace {
constexpr size_t kNameInputNum = 2;
constexpr size_t kNumFlagThree = 3;
} // namespace
STATUS StandardNormalMapper::Mapper(const CNodePtr &cnode) {
ValueNodePtr value_node = nullptr;
PrimitivePtr src_prim = nullptr;
if (GetValueNodeAndPrimFromCnode(cnode, &value_node, &src_prim) != lite::RET_OK) {
MS_LOG(ERROR) << "Get primitive from cnode failed.";
return lite::RET_ERROR;
}
if (cnode->size() != kNameInputNum) {
MS_LOG(ERROR) << "Input size of StandardNormal must be " << (kNameInputNum - 1)
<< " real size: " << (cnode->size() - 1);
return lite::RET_ERROR;
}
ops::StandardNormal std_norm;
auto dst_prim = std_norm.GetPrim();
MSLITE_CHECK_PTR(dst_prim);
dst_prim->AddAttr("dtype", TypeIdToType(acl::GetTypeFromNode(cnode)));
dst_prim->SetAttrs(src_prim->attrs());
value_node->set_value(dst_prim);
return lite::RET_OK;
}
REGISTER_PRIMITIVE_MAPPER(kNameStandardNormal, StandardNormalMapper)
} // namespace lite
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_STANDARD_NORMAL_MAPPER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_STANDARD_NORMAL_MAPPER_H_
#include "tools/converter/adapter/acl/mapper/primitive_mapper.h"
#include "ops/standard_normal.h"
namespace mindspore {
namespace lite {
using mindspore::ops::kNameStandardNormal;
class StandardNormalMapper : public PrimitiveMapper {
public:
StandardNormalMapper() : PrimitiveMapper(kNameStandardNormal) {}
~StandardNormalMapper() override = default;
STATUS Mapper(const CNodePtr &cnode) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_ADAPTER_ACL_MAPPER_STANDARD_NORMAL_MAPPER_H_

View File

@ -30,6 +30,7 @@
#include "ops/custom.h"
#include "ops/op_utils.h"
#include "ops/transpose.h"
#include "ops/standard_normal.h"
#include "ops/tuple_get_item.h"
#include "mindspore/core/ops/core_ops.h"
#include "cxx_api/model/acl/model_converter.h"
@ -274,7 +275,7 @@ STATUS AclPassImpl::MapperForOrgMindIR(const FuncGraphPtr &func_graph) {
std::set<FuncGraphPtr> all_func_graphs = {};
lite::GetAllFuncGraph(func_graph, &all_func_graphs);
std::set<std::string> mindir_mapper = {ops::kNameTranspose};
std::set<std::string> mindir_mapper = {ops::kNameTranspose, ops::kNameStandardNormal};
for (auto graph : all_func_graphs) {
auto node_list = TopoSort(graph->get_return());
for (auto &node : node_list) {