forked from OSSInnovation/mindspore
!3198 synchronize latest Ascend software suite 18 Jul 2020, and merging branches
Merge pull request !3198 from yanghaoran/code_sync_0718
This commit is contained in:
commit
6f8863b65d
|
@ -10,9 +10,9 @@
|
||||||
[submodule "third_party/protobuf"]
|
[submodule "third_party/protobuf"]
|
||||||
path = third_party/protobuf
|
path = third_party/protobuf
|
||||||
url = https://github.com/protocolbuffers/protobuf.git
|
url = https://github.com/protocolbuffers/protobuf.git
|
||||||
[submodule "graphengine"]
|
|
||||||
path = graphengine
|
|
||||||
url = https://gitee.com/mindspore/graphengine.git
|
|
||||||
[submodule "akg"]
|
[submodule "akg"]
|
||||||
path = akg
|
path = akg
|
||||||
url = https://gitee.com/mindspore/akg.git
|
url = https://gitee.com/mindspore/akg.git
|
||||||
|
[submodule "graphengine"]
|
||||||
|
path = graphengine
|
||||||
|
url = https://gitee.com/mindspore/graphengine.git
|
||||||
|
|
|
@ -15,6 +15,7 @@ include(${GE_SOURCE_DIR}/cmake/external_libs/securec.cmake)
|
||||||
if (NOT ENABLE_D)
|
if (NOT ENABLE_D)
|
||||||
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
|
set(GE_PREBUILD_PATH ${GE_SOURCE_DIR}/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||||
find_library(slog libslog.so ${GE_PREBUILD_PATH})
|
find_library(slog libslog.so ${GE_PREBUILD_PATH})
|
||||||
|
find_library(error_manager liberror_manager.so ${GE_PREBUILD_PATH})
|
||||||
elseif (DEFINED ENV{D_LINK_PATH})
|
elseif (DEFINED ENV{D_LINK_PATH})
|
||||||
set(GE_LIB_PATH $ENV{D_LINK_PATH})
|
set(GE_LIB_PATH $ENV{D_LINK_PATH})
|
||||||
set(GE_SYS_ARCH "")
|
set(GE_SYS_ARCH "")
|
||||||
|
|
|
@ -156,6 +156,7 @@ if (NOT ENABLE_GE)
|
||||||
set(ASCEND_PATH /usr/local/Ascend)
|
set(ASCEND_PATH /usr/local/Ascend)
|
||||||
endif ()
|
endif ()
|
||||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||||
|
set(ASCEND_FWK_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||||
|
|
||||||
install(
|
install(
|
||||||
FILES
|
FILES
|
||||||
|
@ -164,6 +165,7 @@ if (NOT ENABLE_GE)
|
||||||
${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so
|
${CMAKE_BINARY_DIR}/graphengine/src/ge/ge_runtime/libge_runtime.so
|
||||||
${ASCEND_DRIVER_PATH}/libslog.so
|
${ASCEND_DRIVER_PATH}/libslog.so
|
||||||
${ASCEND_DRIVER_PATH}/libc_sec.so
|
${ASCEND_DRIVER_PATH}/libc_sec.so
|
||||||
|
${ASCEND_FWK_PATH}/liberror_manager.so
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
COMPONENT mindspore
|
COMPONENT mindspore
|
||||||
)
|
)
|
||||||
|
@ -172,6 +174,7 @@ if (NOT ENABLE_GE)
|
||||||
FILES
|
FILES
|
||||||
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
|
${CMAKE_BINARY_DIR}/graphengine/src/common/graph/libgraph.so
|
||||||
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so
|
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/libslog.so
|
||||||
|
${CMAKE_SOURCE_DIR}/graphengine/third_party/prebuild/${CMAKE_HOST_SYSTEM_PROCESSOR}/liberror_manager.so
|
||||||
${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
|
${CMAKE_SOURCE_DIR}/build/graphengine/libc_sec.so
|
||||||
DESTINATION ${INSTALL_LIB_DIR}
|
DESTINATION ${INSTALL_LIB_DIR}
|
||||||
COMPONENT mindspore
|
COMPONENT mindspore
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 18cf690152add623ffbddfbbb4674d1b34484ca7
|
Subproject commit 103f2d1019dc50d781d7a964551d9f1f50b3b009
|
|
@ -40,7 +40,7 @@ def get_ddk_version():
|
||||||
with open(backup_ddk_info_file, "r") as fp:
|
with open(backup_ddk_info_file, "r") as fp:
|
||||||
ddk_version = json.load(fp)["VERSION"]
|
ddk_version = json.load(fp)["VERSION"]
|
||||||
else:
|
else:
|
||||||
ddk_version = "1.60.T17.B830"
|
ddk_version = "Ascend910"
|
||||||
return ddk_version
|
return ddk_version
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -185,7 +185,7 @@ if (ENABLE_GE)
|
||||||
else ()
|
else ()
|
||||||
target_link_libraries(mindspore ge_client)
|
target_link_libraries(mindspore ge_client)
|
||||||
endif ()
|
endif ()
|
||||||
target_link_libraries(mindspore graph tsdclient)
|
target_link_libraries(mindspore graph tsdclient datatransfer)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ENABLE_D)
|
if (ENABLE_D)
|
||||||
|
@ -216,8 +216,9 @@ if (ENABLE_D)
|
||||||
find_library(CCE_LIB cce ${ASCEND_RUNTIME_PATH})
|
find_library(CCE_LIB cce ${ASCEND_RUNTIME_PATH})
|
||||||
find_library(RUNTIME_LIB runtime ${ASCEND_RUNTIME_PATH})
|
find_library(RUNTIME_LIB runtime ${ASCEND_RUNTIME_PATH})
|
||||||
find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
||||||
|
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
||||||
find_library(PROFILING msprof ${ASCEND_DRIVER_PATH})
|
find_library(PROFILING msprof ${ASCEND_DRIVER_PATH})
|
||||||
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${PROFILING} ${HCCL} ${TSDCLIENT})
|
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${PROFILING} ${HCCL} ${DATATRANSFER})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# link protobuf
|
# link protobuf
|
||||||
|
|
|
@ -292,7 +292,6 @@ bool TbeKernelSelect::TbeCheckSupported(
|
||||||
parallel::TOPK,
|
parallel::TOPK,
|
||||||
parallel::IN_TOPK,
|
parallel::IN_TOPK,
|
||||||
parallel::PACK,
|
parallel::PACK,
|
||||||
parallel::GATHER_ND,
|
|
||||||
parallel::UNSORTEF_SEGMENT_MIND,
|
parallel::UNSORTEF_SEGMENT_MIND,
|
||||||
parallel::UNSORTEF_SEGMENT_PRODD,
|
parallel::UNSORTEF_SEGMENT_PRODD,
|
||||||
parallel::CAST};
|
parallel::CAST};
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h"
|
#include "backend/optimizer/ascend/ir_fission/batch_norm_grad_split.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h"
|
#include "backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h"
|
#include "backend/optimizer/ascend/ir_fission/single_batch_norm_fission.h"
|
||||||
|
#include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h"
|
||||||
#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h"
|
#include "backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h"
|
||||||
#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h"
|
#include "backend/optimizer/ascend/ir_fission/layer_norm_grad_split.h"
|
||||||
#include "backend/optimizer/pass/communication_op_fusion.h"
|
#include "backend/optimizer/pass/communication_op_fusion.h"
|
||||||
|
@ -154,6 +155,7 @@ void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) {
|
||||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
|
ir_fusion_pm->AddPass(std::make_shared<BatchNormGrad2BNInferGrad>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
|
ir_fusion_pm->AddPass(std::make_shared<BatchNormGradInferFission>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
|
ir_fusion_pm->AddPass(std::make_shared<SplitFission>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
|
ir_fusion_pm->AddPass(std::make_shared<PackFission>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<ConcatFission>());
|
ir_fusion_pm->AddPass(std::make_shared<ConcatFission>());
|
||||||
|
@ -303,6 +305,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
|
||||||
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
ir_fusion_pm->AddPass(std::make_shared<TopKSplit>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
|
||||||
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
|
||||||
|
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
|
||||||
|
|
||||||
optimizer->AddPassManager(ir_fusion_pm);
|
optimizer->AddPassManager(ir_fusion_pm);
|
||||||
(void)optimizer->Optimize(kernel_graph);
|
(void)optimizer->Optimize(kernel_graph);
|
||||||
|
|
|
@ -94,7 +94,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
|
||||||
origin_pair = FindRefOriginNode(input_node);
|
origin_pair = FindRefOriginNode(input_node);
|
||||||
MS_EXCEPTION_IF_NULL(origin_pair.first);
|
MS_EXCEPTION_IF_NULL(origin_pair.first);
|
||||||
if (!origin_pair.first->isa<Parameter>()) {
|
if (!origin_pair.first->isa<Parameter>()) {
|
||||||
MS_LOG(EXCEPTION) << "ref op origin node is not parameter";
|
MS_LOG(WARNING) << "ref op origin node is not parameter";
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is "
|
MS_LOG(DEBUG) << "DealRefTransAndCast the node input index " << input_index << ", find origin op is "
|
||||||
<< origin_pair.first->DebugString() << ", index is " << origin_pair.second;
|
<< origin_pair.first->DebugString() << ", index is " << origin_pair.second;
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/optimizer/common/helper.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace {
|
||||||
|
CNodePtr CreateTensorMove(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
|
||||||
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kTensorMoveOpName)),
|
||||||
|
tensor_scatter_update->input(1)};
|
||||||
|
auto tensor_move = graph->NewCNode(inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||||
|
tensor_move->set_scope(tensor_scatter_update->scope());
|
||||||
|
tensor_move->set_abstract(tensor_scatter_update->abstract());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrUseLocking, MakeValue(false), tensor_move);
|
||||||
|
return tensor_move;
|
||||||
|
}
|
||||||
|
|
||||||
|
CNodePtr CreateScatterNdUpdate(const FuncGraphPtr &graph, const CNodePtr &tensor_scatter_update,
|
||||||
|
const CNodePtr &tensor_move) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_scatter_update);
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_move);
|
||||||
|
std::vector<AnfNodePtr> inputs = {NewValueNode(std::make_shared<Primitive>(kScatterNdUpdateOpName)), tensor_move,
|
||||||
|
tensor_scatter_update->input(2), tensor_scatter_update->input(3)};
|
||||||
|
auto scatter_nd_update = graph->NewCNode(inputs);
|
||||||
|
MS_EXCEPTION_IF_NULL(scatter_nd_update);
|
||||||
|
scatter_nd_update->set_scope(tensor_scatter_update->scope());
|
||||||
|
scatter_nd_update->set_abstract(tensor_scatter_update->abstract());
|
||||||
|
return scatter_nd_update;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
const BaseRef TensorScatterUpdateFission::DefinePattern() const {
|
||||||
|
VarPtr Xs = std::make_shared<SeqVar>();
|
||||||
|
auto prim = std::make_shared<Primitive>(kTensorScatterUpdateOpName);
|
||||||
|
return VectorRef({prim, Xs});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr TensorScatterUpdateFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
auto tensor_scatter_update = node->cast<CNodePtr>();
|
||||||
|
if (tensor_scatter_update == nullptr || tensor_scatter_update->size() != 4) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto tensor_move = CreateTensorMove(func_graph, tensor_scatter_update);
|
||||||
|
return CreateScatterNdUpdate(func_graph, tensor_scatter_update, tensor_move);
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,33 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_
|
||||||
|
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_
|
||||||
|
|
||||||
|
#include "backend/optimizer/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class TensorScatterUpdateFission : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit TensorScatterUpdateFission(bool multigraph = true)
|
||||||
|
: PatternProcessPass("tensor_scatter_update_fission", multigraph) {}
|
||||||
|
~TensorScatterUpdateFission() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FISSION_TENSOR_SCATTER_UPDATE_FISSION_H_
|
|
@ -0,0 +1,707 @@
|
||||||
|
/**
|
||||||
|
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||||
|
*
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "pipeline/static_analysis/prim.h"
|
||||||
|
#include "pipeline/static_analysis/utils.h"
|
||||||
|
#include "pipeline/static_analysis/param_validator.h"
|
||||||
|
#include "operator/ops.h"
|
||||||
|
#include "utils/convert_utils.h"
|
||||||
|
#include "ir/tensor_py.h"
|
||||||
|
|
||||||
|
using mindspore::tensor::TensorPy;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace abstract {
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplStringEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two scalars whose value is a string.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||||
|
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr value_x = scalar_x->BuildValue();
|
||||||
|
ValuePtr value_y = scalar_y->BuildValue();
|
||||||
|
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||||
|
<< ", param1: " << value_y->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ret = (value_x->cast<StringImmPtr>()->value() == value_y->cast<StringImmPtr>()->value());
|
||||||
|
return std::make_shared<AbstractScalar>(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplStringConcat(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two scalars whose value is a string.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractScalarPtr scalar_x = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||||
|
AbstractScalarPtr scalar_y = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr value_x = scalar_x->BuildValue();
|
||||||
|
ValuePtr value_y = scalar_y->BuildValue();
|
||||||
|
if (!value_x->isa<StringImm>() || !value_y->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " requires 2 parameters are string, but got param0: " << value_x->ToString()
|
||||||
|
<< ", param1: " << value_y->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ret = (value_x->cast<StringImmPtr>()->value() + value_y->cast<StringImmPtr>()->value());
|
||||||
|
return std::make_shared<AbstractScalar>(ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeTuple(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return std::make_shared<AbstractTuple>(args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeList(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return std::make_shared<AbstractList>(args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeDict(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tuples.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractTuplePtr keys = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
AbstractTuplePtr values = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
size_t keys_size = keys->size();
|
||||||
|
if (values->size() != keys_size) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator keys' size is not equal with values' size";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<AbstractAttribute> key_value;
|
||||||
|
AbstractScalarPtr key;
|
||||||
|
AbstractBasePtrList key_list = keys->elements();
|
||||||
|
AbstractBasePtrList value_list = values->elements();
|
||||||
|
for (size_t index = 0; index < keys_size; index++) {
|
||||||
|
key = CheckArg<AbstractScalar>(op_name + "key", key_list, index);
|
||||||
|
ValuePtr keyPtr = key->BuildValue();
|
||||||
|
MS_EXCEPTION_IF_NULL(keyPtr);
|
||||||
|
if (!keyPtr->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator keys should be string, but got " << keyPtr->ToString();
|
||||||
|
}
|
||||||
|
std::string key_string = GetValue<std::string>(keyPtr);
|
||||||
|
key_value.emplace_back(key_string, value_list[index]);
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractDictionary>(key_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a string and an object of a subclass of AbstractBase.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||||
|
|
||||||
|
ValuePtr keyPtr = key->BuildValue();
|
||||||
|
if (!keyPtr->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << keyPtr->ToString();
|
||||||
|
}
|
||||||
|
std::string key_string = GetValue<std::string>(keyPtr);
|
||||||
|
return std::make_shared<AbstractKeywordArg>(key_string, args_spec_list[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplExtractKwarg(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a string and a keyword.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 0);
|
||||||
|
AbstractKeywordArgPtr kwarg = CheckArg<AbstractKeywordArg>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr key_value = key->BuildValue();
|
||||||
|
if (!key_value->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||||
|
}
|
||||||
|
std::string key_input = GetValue<std::string>(key_value);
|
||||||
|
std::string key_actual = kwarg->get_key();
|
||||||
|
if (key_actual != key_input) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator input key should be same as AbstractKeywordArg' key, but input is "
|
||||||
|
<< key_input << ", AbstractKeywordArg' key is " << key_actual;
|
||||||
|
}
|
||||||
|
return kwarg->get_arg();
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: three scalars whose value is an int32 number.
|
||||||
|
CheckArgsSize(primitive->name(), args_spec_list, 3);
|
||||||
|
size_t args_size = args_spec_list.size();
|
||||||
|
for (size_t index = 0; index < args_size; index++) {
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||||
|
if (!args_spec_list[index]->isa<AbstractScalar>() && !args_spec_list[index]->isa<AbstractNone>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is neither AbstractScalar nor AbstractNone.";
|
||||||
|
}
|
||||||
|
if (args_spec_list[index]->isa<AbstractScalar>() &&
|
||||||
|
!dyn_cast<AbstractScalar>(args_spec_list[index])->BuildValue()->isa<Int32Imm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "MakeSlice eval " << index << " parameter is an AbstractScalar, but is not an int32 number.";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Slice: start, end, step
|
||||||
|
return std::make_shared<AbstractSlice>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Eval the return type of make_record
|
||||||
|
AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: at lease two objects of a subclass of AbstractBase.
|
||||||
|
if (args_spec_list.size() < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << "Typeof evaluator requires more than 1 parameter, while the input size is "
|
||||||
|
<< args_spec_list.size() << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
// args_spec_list[0] maybe AbstractScalarPtr or AbstractTypePtr
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||||
|
TypePtr type = args_spec_list[0]->GetTypeTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(type);
|
||||||
|
if (type->type_id() != kMetaTypeTypeType) {
|
||||||
|
MS_LOG(EXCEPTION) << "Can not make type(" << type->ToString() << ")not TypeType";
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr value_track = args_spec_list[0]->GetValueTrack();
|
||||||
|
MS_EXCEPTION_IF_NULL(value_track);
|
||||||
|
TypePtr type_ptr = value_track->cast<TypePtr>();
|
||||||
|
if (type_ptr == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Value type error, not Me type:" << value_track->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cls = dyn_cast<Class>(type_ptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(cls);
|
||||||
|
ClassAttrVector attributes = cls->GetAttributes();
|
||||||
|
CheckArgsSize(primitive->name(), args_spec_list, attributes.size() + 1);
|
||||||
|
|
||||||
|
std::vector<AbstractAttribute> abs_attributes;
|
||||||
|
for (size_t i = 0; i < attributes.size(); i++) {
|
||||||
|
AbstractAttribute elem(attributes[i].first, args_spec_list[i + 1]);
|
||||||
|
abs_attributes.push_back(elem);
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<AbstractClass>(cls->tag(), abs_attributes, cls->methods());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
AbstractBasePtr InferTupleOrListGetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tuple or list and a scalar whose value is an int32 number.
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||||
|
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr index_value = index->BuildValue();
|
||||||
|
if (!index_value->isa<Int32Imm>()) {
|
||||||
|
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||||
|
<< index_value->ToString();
|
||||||
|
}
|
||||||
|
int idx_v = GetValue<int>(index_value);
|
||||||
|
std::size_t nelems = queue->elements().size();
|
||||||
|
if (idx_v >= SizeToInt(nelems) || idx_v < -SizeToInt(nelems)) {
|
||||||
|
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be in range[-" << SizeToInt(nelems) << ", "
|
||||||
|
<< SizeToInt(nelems) << "), but got " << idx_v << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::size_t uidx_v = 0;
|
||||||
|
if (idx_v >= 0) {
|
||||||
|
uidx_v = IntToSize(idx_v);
|
||||||
|
} else {
|
||||||
|
uidx_v = IntToSize(idx_v + SizeToInt(nelems));
|
||||||
|
}
|
||||||
|
return queue->elements()[uidx_v];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
AbstractBasePtr InferTupleOrListSetItem(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tuple or list, a scalar whose value is an int32 number and an object of a subclass of AbstractBase.
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 3);
|
||||||
|
auto queue = CheckArg<T>(op_name, args_spec_list, 0);
|
||||||
|
AbstractScalarPtr index = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr index_value = index->BuildValue();
|
||||||
|
if (!index_value->isa<Int32Imm>()) {
|
||||||
|
MS_EXCEPTION(IndexError) << op_name << " evaluator index should be an int32 number, but got "
|
||||||
|
<< index_value->ToString();
|
||||||
|
}
|
||||||
|
int idx_v = GetValue<int>(index_value);
|
||||||
|
if (idx_v < 0) {
|
||||||
|
MS_EXCEPTION(IndexError) << "The index of " << typeid(T).name() << " should be positive number, but got " << idx_v
|
||||||
|
<< ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t uidx_v = IntToSize(idx_v);
|
||||||
|
AbstractBasePtrList elements = queue->elements();
|
||||||
|
std::size_t nelems = elements.size();
|
||||||
|
if (uidx_v >= nelems) {
|
||||||
|
MS_EXCEPTION(IndexError) << op_name << " evaluator the index: " << uidx_v << " to set out of range: " << nelems - 1
|
||||||
|
<< ".";
|
||||||
|
}
|
||||||
|
elements[uidx_v] = args_spec_list[2];
|
||||||
|
return std::make_shared<T>(elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplTupleGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferTupleOrListGetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplListGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferTupleOrListGetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplTupleSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferTupleOrListSetItem<AbstractTuple>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplListSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferTupleOrListSetItem<AbstractList>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplDictGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a dict and a scalar whose value is a string.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||||
|
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr key_value = key->BuildValue();
|
||||||
|
if (!key_value->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||||
|
}
|
||||||
|
auto key_str = GetValue<std::string>(key_value);
|
||||||
|
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||||
|
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||||
|
[key_str](const AbstractAttribute &item) { return item.first == key_str; });
|
||||||
|
|
||||||
|
if (it == dict_elems.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "The key " << key_str << " does not exist in the dict:" << args_spec_list[0]->ToString();
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplDictSetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a dict and a scalar whose value is a string and an object of a subclass of AbstractBase.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 3);
|
||||||
|
AbstractDictionaryPtr dict = CheckArg<AbstractDictionary>(op_name, args_spec_list, 0);
|
||||||
|
AbstractScalarPtr key = CheckArg<AbstractScalar>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr key_value = key->BuildValue();
|
||||||
|
if (!key_value->isa<StringImm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator key should be string, but got " << key_value->ToString();
|
||||||
|
}
|
||||||
|
std::string key_str = GetValue<std::string>(key_value);
|
||||||
|
std::vector<AbstractAttribute> dict_elems = dict->elements();
|
||||||
|
auto it = std::find_if(dict_elems.begin(), dict_elems.end(),
|
||||||
|
[key_str](AbstractAttribute &item) { return item.first == key_str; });
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||||
|
auto new_ele = std::make_pair(key_str, args_spec_list[2]);
|
||||||
|
if (it != dict_elems.end()) {
|
||||||
|
int index = it - dict_elems.begin();
|
||||||
|
dict_elems[IntToSize(index)] = new_ele;
|
||||||
|
} else {
|
||||||
|
dict_elems.push_back(new_ele);
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractDictionary>(dict_elems);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplListAppend(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a list and an object of a subclass of AbstractBase.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractListPtr list = CheckArg<AbstractList>(op_name, args_spec_list, 0);
|
||||||
|
(void)AbstractJoin(list->elements());
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
AbstractBasePtr InferTupleOrListOrDictLen(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tuple or list or dict.
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
auto arg = CheckArg<T>(op_name, args_spec_list, 0);
|
||||||
|
return std::make_shared<AbstractScalar>(SizeToInt(arg->size()));
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplTupleLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferTupleOrListOrDictLen<AbstractTuple>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplListLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferTupleOrListOrDictLen<AbstractList>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferTupleOrListOrDictLen<AbstractDictionary>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplArrayLen(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return std::make_shared<AbstractScalar>(kAnyValue, kInt32);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplListMap(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: fn, list1, list2, ...
|
||||||
|
MS_EXCEPTION_IF_NULL(engine);
|
||||||
|
if (args_spec_list.size() <= 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "List_map requires at least 1 list. while the input size is " << args_spec_list.size() << ".";
|
||||||
|
}
|
||||||
|
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(primitive->name(), args_spec_list, 0);
|
||||||
|
// check args from 1.
|
||||||
|
CheckArgsSpec<AbstractList>(AbstractBasePtrList(args_spec_list.begin() + 1, args_spec_list.end()));
|
||||||
|
|
||||||
|
AbstractBasePtrList subargs;
|
||||||
|
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
|
||||||
|
AbstractListPtr l_ptr = dyn_cast<AbstractList>(args_spec_list[i]);
|
||||||
|
if (l_ptr == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Argument[" << i << "] of list_map should be a list.";
|
||||||
|
}
|
||||||
|
subargs.push_back(AbstractJoin(l_ptr->elements()));
|
||||||
|
}
|
||||||
|
EvalResultPtr engin_exc = engine->Execute(fn, subargs);
|
||||||
|
AbstractBasePtrList result;
|
||||||
|
for (std::size_t i = 1; i < args_spec_list.size(); i++) {
|
||||||
|
result.push_back(engin_exc->abstract());
|
||||||
|
}
|
||||||
|
return std::make_shared<AbstractList>(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a fn, a list and an object of a subclass of a AbstractBase.
|
||||||
|
MS_EXCEPTION_IF_NULL(engine);
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 3);
|
||||||
|
AbstractFunctionPtr fn = CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
||||||
|
AbstractListPtr lst = CheckArg<AbstractList>(op_name, args_spec_list, 1);
|
||||||
|
AbstractBasePtr dflt = args_spec_list[2];
|
||||||
|
|
||||||
|
AbstractBasePtr list_type = AbstractJoin(lst->elements());
|
||||||
|
auto result1 = engine->Execute(fn, lst->elements());
|
||||||
|
auto result2 = engine->Execute(fn, {dflt, list_type});
|
||||||
|
MS_EXCEPTION_IF_NULL(result1->abstract());
|
||||||
|
MS_EXCEPTION_IF_NULL(result2->abstract());
|
||||||
|
return result1->abstract()->Join(result2->abstract());
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tuple
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
|
||||||
|
auto tuple_elements = input->elements();
|
||||||
|
AbstractBasePtrList elem_list;
|
||||||
|
(void)std::transform(tuple_elements.rbegin(), tuple_elements.rend(), std::back_inserter(elem_list),
|
||||||
|
[](const AbstractBasePtr &elem) { return elem->Clone(); });
|
||||||
|
return std::make_shared<AbstractTuple>(elem_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr DoInferReduceShape(const AbstractTuplePtr &x_shape, const ValuePtr &x_shp_value,
|
||||||
|
const ValueTuplePtr &axis_value_ptr, const PrimitivePtr &primitive) {
|
||||||
|
size_t x_rank = x_shape->size();
|
||||||
|
std::set<int> axis_set;
|
||||||
|
auto axis_data = axis_value_ptr->value();
|
||||||
|
if (axis_data.empty()) {
|
||||||
|
int size = 1;
|
||||||
|
AbstractBasePtrList values(x_rank, std::make_shared<AbstractScalar>(size));
|
||||||
|
return std::make_shared<AbstractTuple>(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto &elem : axis_data) {
|
||||||
|
int e_value = CheckAxis(primitive->name(), elem, -SizeToInt(x_rank), SizeToInt(x_rank) - 1);
|
||||||
|
(void)axis_set.insert(e_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto x_shp_data = x_shp_value->cast<ValueTuplePtr>()->value();
|
||||||
|
if (x_shp_data.size() < x_rank) {
|
||||||
|
MS_LOG(EXCEPTION) << "x_shape_data.size() " << x_shp_data.size() << " less than x_shape.size() " << x_rank;
|
||||||
|
}
|
||||||
|
AbstractBasePtrList values;
|
||||||
|
for (size_t i = 0; i < x_rank; i++) {
|
||||||
|
if (axis_set.count(SizeToInt(i)) || axis_set.count(SizeToInt(i) - SizeToInt(x_rank))) {
|
||||||
|
auto axis_v = MakeValue(1);
|
||||||
|
values.push_back(std::make_shared<AbstractScalar>(axis_v, axis_v->type()));
|
||||||
|
} else {
|
||||||
|
int dim_value = x_shp_data[i]->cast<Int32ImmPtr>()->value();
|
||||||
|
auto dim = MakeValue(dim_value);
|
||||||
|
values.push_back(std::make_shared<AbstractScalar>(dim, dim->type()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<AbstractTuple>(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: x_shape, axis
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||||
|
|
||||||
|
auto x_shp_value = shape_x->BuildValue();
|
||||||
|
if (x_shp_value->isa<AnyValue>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name
|
||||||
|
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Axis can be scalar, tuple or None
|
||||||
|
AbstractTuplePtr axis = nullptr;
|
||||||
|
if (args_spec_list[1]->isa<AbstractScalar>()) {
|
||||||
|
MS_LOG(DEBUG) << op_name << " evaluator second parameter is scalar";
|
||||||
|
AbstractBasePtrList axis_list = {dyn_cast<AbstractScalar>(args_spec_list[1])};
|
||||||
|
axis = std::make_shared<AbstractTuple>(axis_list);
|
||||||
|
} else if (args_spec_list[1]->isa<AbstractTuple>()) {
|
||||||
|
MS_LOG(DEBUG) << op_name << " evaluator second parameter is tuple";
|
||||||
|
axis = args_spec_list[1]->cast<AbstractTuplePtr>();
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << op_name << " evaluator second parameter should be a scalar or tuple, but got "
|
||||||
|
<< args_spec_list[1]->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto axis_value = axis->BuildValue();
|
||||||
|
if (axis_value->isa<AnyValue>()) {
|
||||||
|
MS_LOG(EXCEPTION) << op_name
|
||||||
|
<< " evaluator shape's data field can't be anything: " << args_spec_list[1]->ToString();
|
||||||
|
}
|
||||||
|
auto axis_value_ptr = axis_value->cast<ValueTuplePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(axis_value_ptr);
|
||||||
|
|
||||||
|
return DoInferReduceShape(shape_x, x_shp_value, axis_value_ptr, primitive);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tuples.
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
AbstractTuplePtr div_shp = CheckArg<AbstractTuple>(op_name, args_spec_list, 1);
|
||||||
|
MS_LOG(INFO) << "DivShape input:" << shape_x->ToString() << ", div:" << div_shp->ToString();
|
||||||
|
|
||||||
|
auto div_shp_value = div_shp->BuildValue();
|
||||||
|
if (div_shp_value->isa<AnyValue>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[0]->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shpx_value = shape_x->BuildValue();
|
||||||
|
if (shpx_value->isa<AnyValue>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << args_spec_list[1]->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (div_shp->size() != shape_x->size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "tileshape elems shape must the same div_shp: " << div_shp->size()
|
||||||
|
<< ", shapex: " << shape_x->size() << ".";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
|
||||||
|
auto div_shp_data = div_shp_value->cast<ValueTuplePtr>()->value();
|
||||||
|
AbstractBasePtrList values;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < div_shp_data.size(); i++) {
|
||||||
|
if (div_shp_data[i]->cast<Int32ImmPtr>() == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "div_shp_shape data should be an int32 number, but it's " << args_spec_list[1]->ToString();
|
||||||
|
}
|
||||||
|
int shapex_value = GetValue<int>(shpx_data[i]);
|
||||||
|
int div_value = GetValue<int>(div_shp_data[i]);
|
||||||
|
MS_LOG(DEBUG) << "div_shp_shape data shapex_value :" << shapex_value << " div_value: " << div_value;
|
||||||
|
if (div_value == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "error: division value should not be 0!";
|
||||||
|
}
|
||||||
|
if ((shapex_value % div_value) != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "div_shp_shape data shapex must div int:" << shapex_value << " div_value: " << div_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
int result = shapex_value / div_value;
|
||||||
|
auto result_v = MakeValue(result);
|
||||||
|
values.push_back(std::make_shared<AbstractScalar>(result_v, result_v->type()));
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<AbstractTuple>(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tuple
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
AbstractTuplePtr input = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
|
||||||
|
py::tuple data_tuple = ValuePtrToPyData(input->BuildValue());
|
||||||
|
py::array data = py::array(data_tuple);
|
||||||
|
auto tensor = TensorPy::MakeTensor(data);
|
||||||
|
auto ret = tensor->ToAbstract();
|
||||||
|
ret->set_value(tensor);
|
||||||
|
MS_LOG(DEBUG) << "Tuple2arry result AbstractTensor: " << ret->ToString();
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplShapeMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tuple
|
||||||
|
// example: tuple = (1, 2, 3), shape_mul(tuple) = 1*2*3 = 6
|
||||||
|
const std::string op_name = primitive->name();
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 1);
|
||||||
|
AbstractTuplePtr shape_x = CheckArg<AbstractTuple>(op_name, args_spec_list, 0);
|
||||||
|
|
||||||
|
auto shpx_value = shape_x->BuildValue();
|
||||||
|
if (shpx_value->isa<AnyValue>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "shape's data field can't be anythin: " << shape_x->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shpx_data = shpx_value->cast<ValueTuplePtr>()->value();
|
||||||
|
|
||||||
|
int result = 1;
|
||||||
|
for (size_t i = 0; i < shpx_data.size(); i++) {
|
||||||
|
int value = GetValue<int>(shpx_data[i]);
|
||||||
|
result = IntMulWithOverflowCheck(result, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result_v = MakeValue(result);
|
||||||
|
MS_LOG(DEBUG) << "shape mul result:" << result_v->ToString();
|
||||||
|
return std::make_shared<AbstractScalar>(result_v, result_v->type());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
AbstractBasePtr InferImplTupleOrListEqual(const std::string &op_name, const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: two tuples or two lists.
|
||||||
|
CheckArgsSize(op_name, args_spec_list, 2);
|
||||||
|
auto input_x = CheckArg<T>(op_name, args_spec_list, 0);
|
||||||
|
auto input_y = CheckArg<T>(op_name, args_spec_list, 1);
|
||||||
|
|
||||||
|
ValuePtr x_value = input_x->BuildValue();
|
||||||
|
ValuePtr y_value = input_y->BuildValue();
|
||||||
|
return std::make_shared<AbstractScalar>(*x_value == *y_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplTupleEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferImplTupleOrListEqual<AbstractTuple>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplListEqual(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
return InferImplTupleOrListEqual<AbstractList>(primitive->name(), args_spec_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SlideInfo {
|
||||||
|
int start;
|
||||||
|
int step;
|
||||||
|
int stop;
|
||||||
|
};
|
||||||
|
|
||||||
|
void CalcSlidePara(const AbstractBasePtrList &args_spec_list, SlideInfo *slide) {
|
||||||
|
int arg1 = 0;
|
||||||
|
int arg2 = 0;
|
||||||
|
if (!args_spec_list.empty()) {
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[0]);
|
||||||
|
auto arg_value = args_spec_list[0]->BuildValue();
|
||||||
|
if (!arg_value->isa<Int32Imm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||||
|
}
|
||||||
|
arg1 = GetValue<int>(arg_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args_spec_list.size() >= 2) {
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[1]);
|
||||||
|
auto arg_value = args_spec_list[1]->BuildValue();
|
||||||
|
if (!arg_value->isa<Int32Imm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||||
|
}
|
||||||
|
arg2 = GetValue<int>(arg_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args_spec_list.size() == 3) {
|
||||||
|
MS_EXCEPTION_IF_NULL(args_spec_list[2]);
|
||||||
|
auto arg_value = args_spec_list[2]->BuildValue();
|
||||||
|
if (!arg_value->isa<Int32Imm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Only supported input an int32 number.";
|
||||||
|
}
|
||||||
|
slide->step = GetValue<int>(arg_value);
|
||||||
|
slide->start = arg1;
|
||||||
|
slide->stop = arg2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args_spec_list.size() == 2) {
|
||||||
|
slide->start = arg1;
|
||||||
|
slide->stop = arg2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args_spec_list.size() == 1) {
|
||||||
|
slide->stop = arg1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr &,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
if (args_spec_list.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Cannot make range from empty input.";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args_spec_list.size() > 3) {
|
||||||
|
MS_LOG(EXCEPTION) << "Error args size of make range operational.";
|
||||||
|
}
|
||||||
|
|
||||||
|
SlideInfo slide = {0, 1, 0};
|
||||||
|
CalcSlidePara(args_spec_list, &slide);
|
||||||
|
|
||||||
|
if (slide.step == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Error, step value is 0.";
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtrList args;
|
||||||
|
if (slide.start <= slide.stop) {
|
||||||
|
if (slide.step <= 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
|
||||||
|
}
|
||||||
|
for (int i = slide.start; i < slide.stop; i += slide.step) {
|
||||||
|
args.push_back(abstract::FromValue(i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (slide.step >= 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Error slice[" << slide.start << ", " << slide.stop << ", " << slide.step << "]";
|
||||||
|
}
|
||||||
|
for (int i = slide.start; i > slide.stop; i += slide.step) {
|
||||||
|
args.push_back(abstract::FromValue(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<AbstractTuple>(args);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
|
const AbstractBasePtrList &args_spec_list) {
|
||||||
|
// Inputs: a tensor
|
||||||
|
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||||
|
return args_spec_list[0]->Clone();
|
||||||
|
}
|
||||||
|
} // namespace abstract
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,93 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
|
||||||
|
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "ir/pattern_matcher.h"
|
||||||
|
#include "optimizer/irpass.h"
|
||||||
|
#include "optimizer/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
namespace irpass {
|
||||||
|
// {prim::kPrimMakeRef, X, Y, Z} -> Y
|
||||||
|
class MakeRefEliminater : public OptimizerCaller {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
PatternNode<AnfNodePtr> x, y, z;
|
||||||
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimMakeRef, x, y, z), y);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// {prim::kPrimGetRefValue, Parameter} -> Parameter
|
||||||
|
// {prim::kPrimGetRefOrigin, Parameter} -> Parameter
|
||||||
|
class GetRefParamEliminater : public OptimizerCaller {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
PatternNode<AnfNodePtr> x;
|
||||||
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefValue, x), x, x.CheckFunc(IsParam, node));
|
||||||
|
MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimGetRefOrigin, x), x, x.CheckFunc(IsParam, node));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// {prim::kPrimGetRefKey, {prim::kPrimMakeRef, X, Y, Z}} -> X
|
||||||
|
// {prim::kPrimGetRefValue, {prim::kPrimMakeRef, X, Y, Z}} -> Y
|
||||||
|
// {prim::kPrimGetRefOrigin, {prim::kPrimMakeRef, X, Y, Z}} -> Z
|
||||||
|
class GetMakeRefEliminater : public OptimizerCaller {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
PatternNode<AnfNodePtr> x, y, z;
|
||||||
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
||||||
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
||||||
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// IsValueNode<RefKey>
|
||||||
|
class ReplaceRefkeyByParam : public OptimizerCaller {
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
|
||||||
|
auto RefKeyLambda = [&node, &optimizer]() -> AnfNodePtr {
|
||||||
|
auto refkey = GetValueNode<RefKeyPtr>(node);
|
||||||
|
auto resource = std::dynamic_pointer_cast<pipeline::Resource>(optimizer->resource());
|
||||||
|
MS_EXCEPTION_IF_NULL(resource);
|
||||||
|
|
||||||
|
auto top_graph = resource->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(top_graph);
|
||||||
|
|
||||||
|
for (const auto &tnode : top_graph->parameters()) {
|
||||||
|
auto para = tnode->cast<ParameterPtr>();
|
||||||
|
if (para != nullptr && para->name() == refkey->tag()) {
|
||||||
|
return para;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
PatternNode<AnfNodePtr> x;
|
||||||
|
MATCH_REPLACE_LAMBDA_IF(node, x, RefKeyLambda, x.CheckFunc(IsValueNode<RefKey>, node));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace irpass
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_REF_ELIMINATE_H_
|
|
@ -0,0 +1,175 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "parallel/graph_util/generate_graph.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
using mindspore::tensor::Tensor;
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace parallel {
|
||||||
|
std::string GetOpPythonPath(const OperatorName &op_name) {
|
||||||
|
// almost all ops are defined in two main paths
|
||||||
|
const std::string ops_module = OP_PATH;
|
||||||
|
const std::string inner_ops_module = INNER_OP_PATH;
|
||||||
|
py::module mod = py::module::import(common::SafeCStr(ops_module));
|
||||||
|
py::module inner_mod = py::module::import(common::SafeCStr(inner_ops_module));
|
||||||
|
if (!py::hasattr(inner_mod, common::SafeCStr(op_name))) {
|
||||||
|
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||||
|
MS_LOG(EXCEPTION) << ops_module << " or " << inner_ops_module << " don't have op:" << op_name;
|
||||||
|
}
|
||||||
|
return ops_module;
|
||||||
|
}
|
||||||
|
return inner_ops_module;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) {
|
||||||
|
std::string op_path = GetOpPythonPath(op_name);
|
||||||
|
py::module mod = py::module::import(common::SafeCStr(op_path));
|
||||||
|
if (!py::hasattr(mod, common::SafeCStr(op_name))) {
|
||||||
|
MS_LOG(ERROR) << "Failure: op_path:" << op_path << " don't have attr " << op_name;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
std::vector<py::object> arg_list;
|
||||||
|
(void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list),
|
||||||
|
[](const Attr &attr) { return ValuePtrToPyData(attr.second); });
|
||||||
|
py::object obj =
|
||||||
|
parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list);
|
||||||
|
ValuePtr op_instance = nullptr;
|
||||||
|
bool succ = parse::ConvertData(obj, &op_instance);
|
||||||
|
if (!succ) {
|
||||||
|
MS_LOG(ERROR) << "Failure:get Python op " << op_path << " from " << op_name << " fail";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return op_instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) {
|
||||||
|
auto value_node = NewValueNode(value_ptr);
|
||||||
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
|
return value_node->cast<AnfNodePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::unordered_map<int32_t, AnfNodePtr> int_tensor_map = {};
|
||||||
|
AnfNodePtr CreateInt32Tensor(int32_t value) {
|
||||||
|
auto it = int_tensor_map.find(value);
|
||||||
|
if (it != int_tensor_map.end()) {
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<tensor::Tensor>(py::int_(value), kInt32);
|
||||||
|
ValuePtr value_ptr = MakeValue(tensor_ptr);
|
||||||
|
auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr);
|
||||||
|
int_tensor_map[value] = anf_node_ptr;
|
||||||
|
return anf_node_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr CreatTypeInt(int32_t value) {
|
||||||
|
ValuePtr value_ptr = MakeValue(std::make_shared<Int>(value));
|
||||||
|
return ValuePtrToAnfNodePtr(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr CreatInt32Imm(int32_t value) {
|
||||||
|
ValuePtr value_ptr = MakeValue(std::make_shared<Int32Imm>(value));
|
||||||
|
return ValuePtrToAnfNodePtr(value_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetInstanceNameByCNode(const CNodePtr &cnode) {
|
||||||
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
||||||
|
if (!prim) {
|
||||||
|
MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr.";
|
||||||
|
}
|
||||||
|
std::string instance_name = prim->instance_name();
|
||||||
|
return HashInstanceName(instance_name);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string HashInstanceName(const std::string &name) {
|
||||||
|
auto using_hash_name = common::GetEnv(USING_HASH_NAME);
|
||||||
|
std::string instance_name;
|
||||||
|
if ((using_hash_name.empty()) || (using_hash_name == "on")) {
|
||||||
|
instance_name = HashName(name);
|
||||||
|
} else {
|
||||||
|
instance_name = name;
|
||||||
|
}
|
||||||
|
return instance_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status GenerateGraph::Init(const CNodePtr &cnode) {
|
||||||
|
if (!cnode) {
|
||||||
|
MS_LOG(ERROR) << "Init:cnode is nullptr";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
cnode_ = cnode;
|
||||||
|
func_graph_ = cnode->func_graph();
|
||||||
|
if (!func_graph_) {
|
||||||
|
MS_LOG(ERROR) << "Init:func_graph_ is nullptr";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
manager_ = func_graph_->manager();
|
||||||
|
if (!manager_) {
|
||||||
|
MS_LOG(ERROR) << "Init:manager_ is nullptr";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
scope_ = cnode_->scope();
|
||||||
|
if (!scope_) {
|
||||||
|
MS_LOG(ERROR) << "Init:scope_ is nullptr";
|
||||||
|
return FAILED;
|
||||||
|
}
|
||||||
|
virtual_input_node_ = std::make_shared<AnfNode>(nullptr);
|
||||||
|
virtual_input_node_->set_scope(scope_);
|
||||||
|
instance_name_base_ = GetInstanceNameByCNode(cnode_);
|
||||||
|
name_idx_ = 0;
|
||||||
|
return SUCCESS;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr GenerateGraph::PushBack(const std::vector<AnfNodePtr> &inputs) {
|
||||||
|
CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
cnode->set_scope(scope_);
|
||||||
|
if (inputs.size() < 2) {
|
||||||
|
MS_LOG(EXCEPTION) << "inputs.size() must be more than 1";
|
||||||
|
}
|
||||||
|
(void)manager_->Replace(inputs.at(1), cnode); // using Replace function to insert cnode after inputs[0]
|
||||||
|
auto new_anf_node_ptr = cnode->cast<AnfNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(new_anf_node_ptr);
|
||||||
|
return new_anf_node_ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) {
|
||||||
|
name_idx_++;
|
||||||
|
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_));
|
||||||
|
if (pyop_instance == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
|
||||||
|
}
|
||||||
|
auto value_node = NewValueNode(pyop_instance);
|
||||||
|
return value_node->cast<AnfNodePtr>();
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) {
|
||||||
|
name_idx_++;
|
||||||
|
OperatorAttrs attrs;
|
||||||
|
ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_));
|
||||||
|
if (pyop_instance == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "Failure:" << op_name << " CreatOpInstance failed";
|
||||||
|
}
|
||||||
|
auto value_node = NewValueNode(pyop_instance);
|
||||||
|
return value_node->cast<AnfNodePtr>();
|
||||||
|
}
|
||||||
|
} // namespace parallel
|
||||||
|
} // namespace mindspore
|
File diff suppressed because it is too large
Load Diff
|
@ -192,21 +192,18 @@ bool MsContext::OpenTsd() {
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
|
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
|
||||||
|
|
||||||
TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size);
|
|
||||||
if (status != TDT_OK) {
|
|
||||||
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
tsd_ref_++;
|
|
||||||
#ifdef ENABLE_TDTQUE
|
|
||||||
int32_t initStatus = tdt::TdtHostInit(device_id);
|
int32_t initStatus = tdt::TdtHostInit(device_id);
|
||||||
if (initStatus != TDT_OK_CODE) {
|
if (initStatus != TDT_OK_CODE) {
|
||||||
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
tdt_print_ = std::thread(TensorPrint());
|
tdt_print_ = std::thread(TensorPrint());
|
||||||
#endif
|
TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size);
|
||||||
|
if (status != TDT_OK) {
|
||||||
|
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tsd_ref_++;
|
||||||
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << tsd_ref_ << ".";
|
MS_LOG(INFO) << "Open and init tsd successful, tsd reference = " << tsd_ref_ << ".";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -173,6 +173,9 @@ constexpr auto kSparseApplyProximalAdagradOpName = "SparseApplyProximalAdagrad";
|
||||||
constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp";
|
constexpr auto kSparseApplyRMSPropOpName = "SparseApplyRMSProp";
|
||||||
constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta";
|
constexpr auto kSparseApplyAdadeltaOpName = "SparseApplyAdadelta";
|
||||||
constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad";
|
constexpr auto kApplyAdamWithAmsgradOpName = "ApplyAdamWithAmsgrad";
|
||||||
|
constexpr auto kTensorMoveOpName = "TensorMove";
|
||||||
|
constexpr auto kTensorScatterUpdateOpName = "TensorScatterUpdate";
|
||||||
|
constexpr auto kScatterNdUpdateOpName = "ScatterNdUpdate";
|
||||||
constexpr auto kPushOpName = "Push";
|
constexpr auto kPushOpName = "Push";
|
||||||
constexpr auto kPullOpName = "Pull";
|
constexpr auto kPullOpName = "Pull";
|
||||||
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
|
constexpr auto kEmbeddingLookupOpName = "EmbeddingLookup";
|
||||||
|
@ -236,6 +239,8 @@ constexpr auto kAttrNumSplit = "num_split";
|
||||||
constexpr auto kAttrOutputNum = "output_num";
|
constexpr auto kAttrOutputNum = "output_num";
|
||||||
constexpr auto kAttrSizeSplits = "size_splits";
|
constexpr auto kAttrSizeSplits = "size_splits";
|
||||||
constexpr auto kAttrOutputDefault = "output_default";
|
constexpr auto kAttrOutputDefault = "output_default";
|
||||||
|
constexpr auto kAttrPrimitiveTarget = "primitive_target";
|
||||||
|
constexpr auto kAttrUseLocking = "use_locking";
|
||||||
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
|
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
|
||||||
constexpr auto kAttrOffset = "offset";
|
constexpr auto kAttrOffset = "offset";
|
||||||
constexpr auto kAttrPsKey = "ps_key";
|
constexpr auto kAttrPsKey = "ps_key";
|
||||||
|
|
|
@ -18,11 +18,12 @@ from mindspore.ops import operations as P
|
||||||
from mindspore.common.parameter import Parameter
|
from mindspore.common.parameter import Parameter
|
||||||
from mindspore.common.initializer import initializer
|
from mindspore.common.initializer import initializer
|
||||||
from mindspore._checkparam import ParamValidator as validator, Rel
|
from mindspore._checkparam import ParamValidator as validator, Rel
|
||||||
from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative
|
from mindspore._checkparam import Validator
|
||||||
|
from mindspore._checkparam import check_bool, twice, check_int_positive
|
||||||
from mindspore._extends import cell_attr_register
|
from mindspore._extends import cell_attr_register
|
||||||
from ..cell import Cell
|
from ..cell import Cell
|
||||||
|
|
||||||
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d']
|
__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d', 'Conv1d', 'Conv1dTranspose']
|
||||||
|
|
||||||
class _Conv(Cell):
|
class _Conv(Cell):
|
||||||
"""
|
"""
|
||||||
|
@ -47,7 +48,16 @@ class _Conv(Cell):
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.pad_mode = pad_mode
|
self.pad_mode = pad_mode
|
||||||
self.padding = check_int_non_negative(padding)
|
if isinstance(padding, int):
|
||||||
|
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
|
||||||
|
self.padding = padding
|
||||||
|
elif isinstance(padding, tuple):
|
||||||
|
for pad in padding:
|
||||||
|
Validator.check_integer('padding item', pad, 0, Rel.GE, self.cls_name)
|
||||||
|
self.padding = padding
|
||||||
|
else:
|
||||||
|
raise TypeError("padding type must be int/tuple(int) cannot be {}!".format(type(padding)))
|
||||||
|
|
||||||
self.dilation = dilation
|
self.dilation = dilation
|
||||||
self.group = check_int_positive(group)
|
self.group = check_int_positive(group)
|
||||||
self.has_bias = has_bias
|
self.has_bias = has_bias
|
||||||
|
@ -141,7 +151,10 @@ class Conv2d(_Conv):
|
||||||
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
|
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
|
||||||
Tensor borders. `padding` should be greater than or equal to 0.
|
Tensor borders. `padding` should be greater than or equal to 0.
|
||||||
|
|
||||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input. If `padding` is one integer,
|
||||||
|
the padding of top, bottom, left and right is same, equal to padding. If `padding` is tuple with
|
||||||
|
four integer, the padding of top, bottom, left and right equal to padding[0], padding[1],
|
||||||
|
padding[2], padding[3] with corresponding. Default: 0.
|
||||||
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
|
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
|
||||||
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||||
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||||
|
@ -241,6 +254,174 @@ class Conv2d(_Conv):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1d(_Conv):
|
||||||
|
r"""
|
||||||
|
1D convolution layer.
|
||||||
|
|
||||||
|
Applies a 1D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, W_{in})`,
|
||||||
|
where :math:`N` is batch size and :math:`C_{in}` is channel number. For each batch of shape
|
||||||
|
:math:`(C_{in}, W_{in})`, the formula is defined as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j,
|
||||||
|
|
||||||
|
where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges
|
||||||
|
from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th
|
||||||
|
filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice
|
||||||
|
of kernel and it has shape :math:`(\text{ks_w})`, where :math:`\text{ks_w}` are width of the convolution kernel.
|
||||||
|
The full kernel has shape :math:`(C_{out}, C_{in} // \text{group}, \text{ks_w})`, where group is the group number
|
||||||
|
to split the input in the channel dimension.
|
||||||
|
|
||||||
|
If the 'pad_mode' is set to be "valid", the output width will be
|
||||||
|
:math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} -
|
||||||
|
(\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively.
|
||||||
|
|
||||||
|
The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition
|
||||||
|
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||||
|
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||||
|
kernel_size (int): The data type is int. Specifies the
|
||||||
|
width of the 1D convolution window.
|
||||||
|
stride (int): The distance of kernel moving, an int number that represents
|
||||||
|
the width of movement. Default: 1.
|
||||||
|
pad_mode (str): Specifies padding mode. The optional values are
|
||||||
|
"same", "valid", "pad". Default: "same".
|
||||||
|
|
||||||
|
- same: Adopts the way of completion. Output width will be the same as the input.
|
||||||
|
Total number of padding will be calculated for horizontal
|
||||||
|
direction and evenly distributed to left and right if possible. Otherwise, the
|
||||||
|
last extra padding will be done from the bottom and the right side. If this mode is set, `padding`
|
||||||
|
must be 0.
|
||||||
|
|
||||||
|
- valid: Adopts the way of discarding. The possibly largest width of output will be return
|
||||||
|
without padding. Extra pixels will be discarded. If this mode is set, `padding`
|
||||||
|
must be 0.
|
||||||
|
|
||||||
|
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
|
||||||
|
Tensor borders. `padding` should be greater than or equal to 0.
|
||||||
|
|
||||||
|
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||||
|
dilation (int): The data type is int. Specifies the dilation rate
|
||||||
|
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||||
|
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||||
|
be greater or equal to 1 and bounded by the height and width of the
|
||||||
|
input. Default: 1.
|
||||||
|
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||||
|
divisible by the number of groups. Default: 1.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||||
|
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||||
|
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||||
|
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||||
|
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'normal'.
|
||||||
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||||
|
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'zeros'.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, C_{out}, W_{out})`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal')
|
||||||
|
>>> input = Tensor(np.ones([1, 120, 640]), mindspore.float32)
|
||||||
|
>>> net(input).shape
|
||||||
|
(1, 240, 640)
|
||||||
|
"""
|
||||||
|
@cell_attr_register
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
has_bias=False,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros'):
|
||||||
|
|
||||||
|
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
|
||||||
|
Validator.check_value_type("stride", stride, [int], self.cls_name)
|
||||||
|
Validator.check_value_type("padding", padding, [int], self.cls_name)
|
||||||
|
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
|
||||||
|
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
|
||||||
|
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
|
||||||
|
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
|
||||||
|
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
|
||||||
|
kernel_size = (1, kernel_size)
|
||||||
|
stride = (1, stride)
|
||||||
|
dilation = (1, dilation)
|
||||||
|
|
||||||
|
super(Conv1d, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
pad_mode,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
group,
|
||||||
|
has_bias,
|
||||||
|
weight_init,
|
||||||
|
bias_init)
|
||||||
|
self.padding = (0, 0, padding, padding)
|
||||||
|
self.conv2d = P.Conv2D(out_channel=self.out_channels,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
mode=1,
|
||||||
|
pad_mode=self.pad_mode,
|
||||||
|
pad=self.padding,
|
||||||
|
stride=self.stride,
|
||||||
|
dilation=self.dilation,
|
||||||
|
group=self.group)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
if pad_mode not in ('valid', 'same', 'pad'):
|
||||||
|
raise ValueError('Attr \'pad_mode\' of \'Conv1d\' Op passed '
|
||||||
|
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze = P.Squeeze(2)
|
||||||
|
self.shape = P.Shape()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x_shape = self.shape(x)
|
||||||
|
if len(x_shape) == 3:
|
||||||
|
x = self.expand_dims(x, 2)
|
||||||
|
output = self.conv2d(x, self.weight)
|
||||||
|
if self.has_bias:
|
||||||
|
output = self.bias_add(output, self.bias)
|
||||||
|
if len(x_shape) == 3:
|
||||||
|
output = self.squeeze(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
s = 'input_channels={}, output_channels={}, kernel_size={},' \
|
||||||
|
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
|
||||||
|
'group={}, has_bias={},' \
|
||||||
|
'weight_init={}, bias_init={}'.format(
|
||||||
|
self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
self.kernel_size,
|
||||||
|
self.stride,
|
||||||
|
self.pad_mode,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.group,
|
||||||
|
self.has_bias,
|
||||||
|
self.weight,
|
||||||
|
self.bias)
|
||||||
|
|
||||||
|
if self.has_bias:
|
||||||
|
s += ', bias={}'.format(self.bias)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
class Conv2dTranspose(_Conv):
|
class Conv2dTranspose(_Conv):
|
||||||
r"""
|
r"""
|
||||||
2D transposed convolution layer.
|
2D transposed convolution layer.
|
||||||
|
@ -268,7 +449,10 @@ class Conv2dTranspose(_Conv):
|
||||||
- same: Adopted the way of completion.
|
- same: Adopted the way of completion.
|
||||||
|
|
||||||
- valid: Adopted the way of discarding.
|
- valid: Adopted the way of discarding.
|
||||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
padding (Union[int, tuple[int]]): Implicit paddings on both sides of the input. If `padding` is one integer,
|
||||||
|
the padding of top, bottom, left and right is same, equal to padding. If `padding` is tuple with
|
||||||
|
four integer, the padding of top, bottom, left and right equal to padding[0], padding[1],
|
||||||
|
padding[2], padding[3] with corresponding. Default: 0.
|
||||||
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
|
dilation (Union[int, tuple[int]]): The data type is int or tuple with 2 integers. Specifies the dilation rate
|
||||||
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||||
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||||
|
@ -313,6 +497,9 @@ class Conv2dTranspose(_Conv):
|
||||||
kernel_size = twice(kernel_size)
|
kernel_size = twice(kernel_size)
|
||||||
stride = twice(stride)
|
stride = twice(stride)
|
||||||
dilation = twice(dilation)
|
dilation = twice(dilation)
|
||||||
|
Validator.check_value_type('padding', padding, (int, tuple), self.cls_name)
|
||||||
|
if isinstance(padding, tuple):
|
||||||
|
Validator.check_integer('padding size', len(padding), 4, Rel.EQ, self.cls_name)
|
||||||
# out_channels and in_channels swap.
|
# out_channels and in_channels swap.
|
||||||
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
|
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
|
||||||
# then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
|
# then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
|
||||||
|
@ -352,12 +539,16 @@ class Conv2dTranspose(_Conv):
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
group=group)
|
group=group)
|
||||||
self.bias_add = P.BiasAdd()
|
self.bias_add = P.BiasAdd()
|
||||||
|
if isinstance(self.padding, int):
|
||||||
|
self.padding_top, self.padding_bottom, self.padding_left, self.padding_right = (self.padding,) * 4
|
||||||
|
else:
|
||||||
|
self.padding_top, self.padding_bottom, self.padding_left, self.padding_right = self.padding
|
||||||
|
|
||||||
def set_strategy(self, strategy):
|
def set_strategy(self, strategy):
|
||||||
self.conv2d_transpose.set_strategy(strategy)
|
self.conv2d_transpose.set_strategy(strategy)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size):
|
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding):
|
||||||
"""Calculate the width and height of output."""
|
"""Calculate the width and height of output."""
|
||||||
length = 0
|
length = 0
|
||||||
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
|
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
|
||||||
|
@ -369,14 +560,16 @@ class Conv2dTranspose(_Conv):
|
||||||
elif self.is_same:
|
elif self.is_same:
|
||||||
length = input_length * stride_size
|
length = input_length * stride_size
|
||||||
elif self.is_pad:
|
elif self.is_pad:
|
||||||
length = input_length * stride_size - 2 * self.padding + filter_size - stride_size
|
length = input_length * stride_size - padding + filter_size - stride_size
|
||||||
|
|
||||||
return length
|
return length
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
n, _, h, w = self.shape(x)
|
n, _, h, w = self.shape(x)
|
||||||
h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0])
|
h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0],
|
||||||
w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1])
|
self.padding_top + self.padding_bottom)
|
||||||
|
w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1],
|
||||||
|
self.padding_left + self.padding_right)
|
||||||
if self.has_bias:
|
if self.has_bias:
|
||||||
return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)),
|
return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)),
|
||||||
self.bias)
|
self.bias)
|
||||||
|
@ -400,6 +593,181 @@ class Conv2dTranspose(_Conv):
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class Conv1dTranspose(_Conv):
|
||||||
|
r"""
|
||||||
|
1D transposed convolution layer.
|
||||||
|
|
||||||
|
Compute a 1D transposed convolution, which is also know as a deconvolution
|
||||||
|
(although it is not actual deconvolution).
|
||||||
|
|
||||||
|
Input is typically of shape :math:`(N, C, W)`, where :math:`N` is batch size and :math:`C` is channel number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): The number of channels in the input space.
|
||||||
|
out_channels (int): The number of channels in the output space.
|
||||||
|
kernel_size (int): int, which specifies the width of the 1D convolution window.
|
||||||
|
stride (int): The distance of kernel moving, an int number that represents
|
||||||
|
the width of movement. Default: 1.
|
||||||
|
pad_mode (str): Select the mode of the pad. The optional values are
|
||||||
|
"pad", "same", "valid". Default: "same".
|
||||||
|
|
||||||
|
- pad: Implicit paddings on both sides of the input.
|
||||||
|
|
||||||
|
- same: Adopted the way of completion.
|
||||||
|
|
||||||
|
- valid: Adopted the way of discarding.
|
||||||
|
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||||
|
dilation (int): The data type is int. Specifies the dilation rate
|
||||||
|
to use for dilated convolution. If set to be :math:`k > 1`, there will
|
||||||
|
be :math:`k - 1` pixels skipped for each sampling location. Its value should
|
||||||
|
be greater or equal to 1 and bounded by the width of the
|
||||||
|
input. Default: 1.
|
||||||
|
group (int): Split filter into groups, `in_channels` and `out_channels` should be
|
||||||
|
divisible by the number of groups. This is not support for Davinci devices when group > 1. Default: 1.
|
||||||
|
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||||
|
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||||
|
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||||
|
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||||
|
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||||
|
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'normal'.
|
||||||
|
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||||
|
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||||
|
Initializer for more details. Default: 'zeros'.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor of shape :math:`(N, C_{out}, W_{out})`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> net = nn.Conv1dTranspose(3, 64, 4, has_bias=False, weight_init='normal')
|
||||||
|
>>> input = Tensor(np.ones([1, 3, 50]), mindspore.float32)
|
||||||
|
>>> net(input)
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
pad_mode='same',
|
||||||
|
padding=0,
|
||||||
|
dilation=1,
|
||||||
|
group=1,
|
||||||
|
has_bias=False,
|
||||||
|
weight_init='normal',
|
||||||
|
bias_init='zeros'):
|
||||||
|
Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name)
|
||||||
|
Validator.check_value_type("stride", stride, [int], self.cls_name)
|
||||||
|
Validator.check_value_type("padding", padding, [int], self.cls_name)
|
||||||
|
Validator.check_value_type("dilation", dilation, [int], self.cls_name)
|
||||||
|
Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name)
|
||||||
|
Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name)
|
||||||
|
Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name)
|
||||||
|
Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name)
|
||||||
|
kernel_size = (1, kernel_size)
|
||||||
|
stride = (1, stride)
|
||||||
|
dilation = (1, dilation)
|
||||||
|
# out_channels and in_channels swap.
|
||||||
|
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
|
||||||
|
# then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
|
||||||
|
super(Conv1dTranspose, self).__init__(
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
pad_mode,
|
||||||
|
padding,
|
||||||
|
dilation,
|
||||||
|
group,
|
||||||
|
has_bias,
|
||||||
|
weight_init,
|
||||||
|
bias_init,
|
||||||
|
transposed=True)
|
||||||
|
self.padding = (0, 0, padding, padding)
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = out_channels
|
||||||
|
self.shape = P.Shape()
|
||||||
|
if pad_mode not in ('valid', 'same', 'pad'):
|
||||||
|
raise ValueError('Attr \'pad_mode\' of \'Conv1dTranspose\' Op passed '
|
||||||
|
+ str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.')
|
||||||
|
self.is_valid = self.pad_mode == 'valid'
|
||||||
|
self.is_same = self.pad_mode == 'same'
|
||||||
|
self.is_pad = self.pad_mode == 'pad'
|
||||||
|
if check_bool(has_bias):
|
||||||
|
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
|
||||||
|
|
||||||
|
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel.
|
||||||
|
self.conv2d_transpose = P.Conv2DBackpropInput(out_channel=in_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
mode=1,
|
||||||
|
pad_mode=pad_mode,
|
||||||
|
pad=self.padding,
|
||||||
|
stride=stride,
|
||||||
|
dilation=dilation,
|
||||||
|
group=group)
|
||||||
|
self.bias_add = P.BiasAdd()
|
||||||
|
self.expand_dims = P.ExpandDims()
|
||||||
|
self.squeeze = P.Squeeze(2)
|
||||||
|
|
||||||
|
def set_strategy(self, strategy):
|
||||||
|
self.conv2d_transpose.set_strategy(strategy)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding):
|
||||||
|
"""Calculate the width and height of output."""
|
||||||
|
length = 0
|
||||||
|
filter_size = filter_size + (filter_size - 1) * (dilation_size - 1)
|
||||||
|
if self.is_valid:
|
||||||
|
if filter_size - stride_size > 0:
|
||||||
|
length = input_length * stride_size + filter_size - stride_size
|
||||||
|
else:
|
||||||
|
length = input_length * stride_size
|
||||||
|
elif self.is_same:
|
||||||
|
length = input_length * stride_size
|
||||||
|
elif self.is_pad:
|
||||||
|
length = input_length * stride_size - padding + filter_size - stride_size
|
||||||
|
|
||||||
|
return length
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x_shape = self.shape(x)
|
||||||
|
if len(x_shape) == 3:
|
||||||
|
x = self.expand_dims(x, 2)
|
||||||
|
|
||||||
|
n, _, h, w = self.shape(x)
|
||||||
|
|
||||||
|
h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0],
|
||||||
|
self.padding[0] + self.padding[1])
|
||||||
|
w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1],
|
||||||
|
self.padding[2] + self.padding[3])
|
||||||
|
output = self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out))
|
||||||
|
if self.has_bias:
|
||||||
|
output = self.bias_add(output, self.bias)
|
||||||
|
|
||||||
|
if len(x_shape) == 3:
|
||||||
|
output = self.squeeze(output)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def extend_repr(self):
|
||||||
|
s = 'input_channels={}, output_channels={}, kernel_size={},' \
|
||||||
|
'stride={}, pad_mode={}, padding={}, dilation={}, ' \
|
||||||
|
'group={}, has_bias={},' \
|
||||||
|
'weight_init={}, bias_init={}'.format(self.in_channels,
|
||||||
|
self.out_channels,
|
||||||
|
self.kernel_size,
|
||||||
|
self.stride,
|
||||||
|
self.pad_mode,
|
||||||
|
self.padding,
|
||||||
|
self.dilation,
|
||||||
|
self.group,
|
||||||
|
self.has_bias,
|
||||||
|
self.weight,
|
||||||
|
self.bias)
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
class DepthwiseConv2d(Cell):
|
class DepthwiseConv2d(Cell):
|
||||||
r"""
|
r"""
|
||||||
2D depthwise convolution layer.
|
2D depthwise convolution layer.
|
||||||
|
|
|
@ -283,6 +283,7 @@ class AvgPool1d(_PoolNd):
|
||||||
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
||||||
self.slice = P.Slice()
|
self.slice = P.Slice()
|
||||||
self.expand = P.ExpandDims()
|
self.expand = P.ExpandDims()
|
||||||
|
self.squeeze = P.Squeeze(2)
|
||||||
|
|
||||||
def construct(self, x):
|
def construct(self, x):
|
||||||
_shape_check(self.shape(x))
|
_shape_check(self.shape(x))
|
||||||
|
@ -295,4 +296,5 @@ class AvgPool1d(_PoolNd):
|
||||||
else:
|
else:
|
||||||
x = self.expand(x, 2)
|
x = self.expand(x, 2)
|
||||||
x = self.avg_pool(x)
|
x = self.avg_pool(x)
|
||||||
|
x = self.squeeze(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -393,7 +393,6 @@ class Optimizer(Cell):
|
||||||
current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0)
|
current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0)
|
||||||
lr += (current_dynamic_lr,)
|
lr += (current_dynamic_lr,)
|
||||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
if self.dynamic_lr:
|
if self.dynamic_lr:
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
"""grad impl."""
|
"""grad impl."""
|
||||||
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
from . import grad_array_ops, grad_comm_ops, grad_debug_ops, grad_implementations, \
|
||||||
grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
|
grad_inner_ops, grad_math_ops, grad_nn_ops, grad_other_ops, grad_quant_ops
|
||||||
from .grad_base import get_bprop_fn
|
from .grad_base import get_bprop_fn
|
||||||
|
|
||||||
__all__ = ['get_bprop_fn']
|
__all__ = ['get_bprop_fn']
|
||||||
|
|
|
@ -211,6 +211,25 @@ def get_bprop_embedding_lookup(self):
|
||||||
return bprop_sparse
|
return bprop_sparse
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.EmbeddingLookup)
|
||||||
|
def get_bprop_embedding_look_up(self):
|
||||||
|
"""Generate bprop for EmbeddingLookup"""
|
||||||
|
sub_op = P.Sub()
|
||||||
|
reshape_op = P.Reshape()
|
||||||
|
def bprop(x, indices, offset, out, dout):
|
||||||
|
x_shp = shape_op(x)
|
||||||
|
new_indices = sub_op(indices, offset)
|
||||||
|
# Reshape the 'new_indices'
|
||||||
|
new_indices_shape_changed = (size_op(new_indices),)
|
||||||
|
new_indices = reshape_op(new_indices, new_indices_shape_changed)
|
||||||
|
actual_dout_shape_changed = new_indices_shape_changed
|
||||||
|
if len(x_shp) > 1:
|
||||||
|
actual_dout_shape_changed += x_shp[1:]
|
||||||
|
actual_dout = reshape_op(dout, actual_dout_shape_changed)
|
||||||
|
return (new_indices, actual_dout, x_shp), zeros_like(indices), zeros_like(offset)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.Transpose)
|
@bprop_getters.register(P.Transpose)
|
||||||
def get_bprop_transpose(self):
|
def get_bprop_transpose(self):
|
||||||
"""Generate bprop for Transpose"""
|
"""Generate bprop for Transpose"""
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""array_ops"""
|
||||||
|
|
||||||
|
from .. import operations as P
|
||||||
|
from ..operations import _grad_ops as G
|
||||||
|
from ..operations import _inner_ops as inner
|
||||||
|
from ..composite.multitype_ops.zeros_like_impl import zeros_like
|
||||||
|
from .grad_base import bprop_getters
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(inner.StridedSliceAICPU)
|
||||||
|
def get_bprop_strided_slice_aicpu(self):
|
||||||
|
"""Generate bprop for StridedSlice"""
|
||||||
|
shape_op = P.Shape()
|
||||||
|
input_grad = G.StridedSliceGradAICPU(self.begin_mask,
|
||||||
|
self.end_mask,
|
||||||
|
self.ellipsis_mask,
|
||||||
|
self.new_axis_mask,
|
||||||
|
self.shrink_axis_mask)
|
||||||
|
|
||||||
|
def bprop(x, begin, end, strides, out, dout):
|
||||||
|
dx = input_grad(dout, shape_op(x), begin, end, strides)
|
||||||
|
return dx, zeros_like(begin), zeros_like(end), zeros_like(strides)
|
||||||
|
|
||||||
|
return bprop
|
|
@ -673,7 +673,7 @@ def get_bprop_mirror_pad(self):
|
||||||
mirror_pad_grad = G.MirrorPadGrad(self.mode)
|
mirror_pad_grad = G.MirrorPadGrad(self.mode)
|
||||||
|
|
||||||
def bprop(x, paddings, out, dout):
|
def bprop(x, paddings, out, dout):
|
||||||
dx = mirror_pad_grad(dout, paddings, x)
|
dx = mirror_pad_grad(dout, paddings)
|
||||||
return (dx, zeros_like(paddings))
|
return (dx, zeros_like(paddings))
|
||||||
|
|
||||||
return bprop
|
return bprop
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
"""aicpu ops"""
|
"""aicpu ops"""
|
||||||
from .init_data_set_queue import _init_data_set_queue_aicpu
|
from .init_data_set_queue import _init_data_set_queue_aicpu
|
||||||
|
from .embedding_lookup import _embedding_lookup_aicpu
|
||||||
from .dropout_genmask import _dropout_genmask_aicpu
|
from .dropout_genmask import _dropout_genmask_aicpu
|
||||||
from .get_next import _get_next_aicpu
|
from .get_next import _get_next_aicpu
|
||||||
from .print_tensor import _print_aicpu
|
from .print_tensor import _print_aicpu
|
||||||
|
@ -25,10 +26,20 @@ from .squeeze import _squeeze_aicpu
|
||||||
from .expand_dims import _expand_dims_aicpu
|
from .expand_dims import _expand_dims_aicpu
|
||||||
from .random_choice_with_mask import _random_choice_with_mask_aicpu
|
from .random_choice_with_mask import _random_choice_with_mask_aicpu
|
||||||
from .pack import _pack_aicpu
|
from .pack import _pack_aicpu
|
||||||
from .normal import _normal_aicpu
|
|
||||||
from .ctcloss import _ctcloss_aicpu
|
from .ctcloss import _ctcloss_aicpu
|
||||||
from .reverse_sequence import _reverse_sequence_aicpu
|
from .reverse_sequence import _reverse_sequence_aicpu
|
||||||
from .crop_and_resize import _crop_and_resize_aicpu
|
from .crop_and_resize import _crop_and_resize_aicpu
|
||||||
from .end_of_sequence import _end_of_sequence_aicpu
|
|
||||||
from .rnnt_loss import _rnnt_loss_aicpu
|
from .rnnt_loss import _rnnt_loss_aicpu
|
||||||
from .random_categorical import _random_categorical_aicpu
|
from .random_categorical import _random_categorical_aicpu
|
||||||
|
from .cast import _cast_aicpu
|
||||||
|
from .mirror_pad import _mirror_pad_aicpu
|
||||||
|
from .mirror_pad_grad import _mirror_pad_grad_aicpu
|
||||||
|
from .standard_normal import _standard_normal_aicpu
|
||||||
|
from .gamma import _gamma_aicpu
|
||||||
|
from .poisson import _poisson_aicpu
|
||||||
|
from .uniform_int import _uniform_int_aicpu
|
||||||
|
from .uniform_real import _uniform_real_aicpu
|
||||||
|
from .laplace import _laplace_aicpu
|
||||||
|
from .strided_slice import _strided_slice_aicpu
|
||||||
|
from .strided_slice_grad import _strided_slice_grad_aicpu
|
||||||
|
from .end_of_sequence import _end_of_sequence_aicpu
|
||||||
|
|
|
@ -0,0 +1,172 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Cast op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
cast_op_info = AiCPURegOp("Cast") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(cast_op_info)
|
||||||
|
def _cast_aicpu():
|
||||||
|
"""Cast AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""EmbeddingLookup op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
embeddingLookup_op_info = AiCPURegOp("EmbeddingLookup") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "params", "required") \
|
||||||
|
.input(1, "indices", "required") \
|
||||||
|
.input(2, "offset", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(embeddingLookup_op_info)
|
||||||
|
def _embedding_lookup_aicpu():
|
||||||
|
"""EmbeddingLookup AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RandomGamma op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
gamma_op_info = AiCPURegOp("Gamma") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "shape", "required") \
|
||||||
|
.input(1, "alpha", "required") \
|
||||||
|
.input(2, "beta", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.attr("seed", "int") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(gamma_op_info)
|
||||||
|
def _gamma_aicpu():
|
||||||
|
"""RandomGamma AiCPU register"""
|
||||||
|
return
|
|
@ -13,21 +13,21 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
"""Normal op"""
|
"""RandomLaplace op"""
|
||||||
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
normal_op_info = AiCPURegOp("Normal") \
|
laplace_op_info = AiCPURegOp("Laplace") \
|
||||||
.fusion_type("OPAQUE") \
|
.fusion_type("OPAQUE") \
|
||||||
.input(0, "shape", "required") \
|
.input(0, "shape", "required") \
|
||||||
.input(1, "mean", "required") \
|
.input(1, "mean", "required") \
|
||||||
.input(2, "stddev", "required") \
|
.input(2, "lambda_param", "required") \
|
||||||
.output(0, "y", "required") \
|
.output(0, "output", "required") \
|
||||||
.attr("seed", "int") \
|
.attr("seed", "int") \
|
||||||
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
@op_info_register(normal_op_info)
|
@op_info_register(laplace_op_info)
|
||||||
def _normal_aicpu():
|
def _laplace_aicpu():
|
||||||
"""Normal AiCPU register"""
|
"""RandomLaplace AiCPU register"""
|
||||||
return
|
return
|
|
@ -0,0 +1,52 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MirrorPad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
mirror_pad_op_info = AiCPURegOp("MirrorPad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "paddings", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.attr("mode", "str") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(mirror_pad_op_info)
|
||||||
|
def _mirror_pad_aicpu():
|
||||||
|
"""MirrorPad AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,52 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""MirrorPadGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
mirror_pad_grad_op_info = AiCPURegOp("MirrorPadGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.input(1, "paddings", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.attr("mode", "str") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(mirror_pad_grad_op_info)
|
||||||
|
def _mirror_pad_grad_aicpu():
|
||||||
|
"""MirrorPadGrad AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RandomPoisson op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
poisson_op_info = AiCPURegOp("Poisson") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "shape", "required") \
|
||||||
|
.input(1, "mean", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.attr("seed", "int") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(poisson_op_info)
|
||||||
|
def _poisson_aicpu():
|
||||||
|
"""RandomPoisson AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,32 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RandomNormal op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
normal_op_info = AiCPURegOp("StandardNormal") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "shape", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.attr("seed", "int") \
|
||||||
|
.attr("seed2", "int") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(normal_op_info)
|
||||||
|
def _standard_normal_aicpu():
|
||||||
|
"""RandomNormal AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""StridedSlice op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
strided_slice_op_info = AiCPURegOp("StridedSliceAICPU") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "input", "required") \
|
||||||
|
.input(1, "begin", "required") \
|
||||||
|
.input(2, "end", "required") \
|
||||||
|
.input(3, "stride", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.attr("begin_mask", "int") \
|
||||||
|
.attr("end_mask", "int") \
|
||||||
|
.attr("ellipsis_mask", "int") \
|
||||||
|
.attr("new_axis_mask", "int") \
|
||||||
|
.attr("shrink_axis_mask", "int") \
|
||||||
|
.dtype_format(DataType.F32_Default,
|
||||||
|
DataType.I32_Default,
|
||||||
|
DataType.I32_Default,
|
||||||
|
DataType.I32_Default,
|
||||||
|
DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(strided_slice_op_info)
|
||||||
|
def _strided_slice_aicpu():
|
||||||
|
"""StridedSlice AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""StridedSliceGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
strided_slice_grad_op_info = AiCPURegOp("StridedSliceGradAICPU") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "dy", "required") \
|
||||||
|
.input(1, "shape", "required") \
|
||||||
|
.input(2, "begin", "required") \
|
||||||
|
.input(3, "end", "required") \
|
||||||
|
.input(4, "stride", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.attr("begin_mask", "int") \
|
||||||
|
.attr("end_mask", "int") \
|
||||||
|
.attr("ellipsis_mask", "int") \
|
||||||
|
.attr("new_axis_mask", "int") \
|
||||||
|
.attr("shrink_axis_mask", "int") \
|
||||||
|
.dtype_format(DataType.F32_Default,
|
||||||
|
DataType.I32_Default,
|
||||||
|
DataType.I32_Default,
|
||||||
|
DataType.I32_Default,
|
||||||
|
DataType.I32_Default,
|
||||||
|
DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(strided_slice_grad_op_info)
|
||||||
|
def _strided_slice_grad_aicpu():
|
||||||
|
"""StridedSliceGrad AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RandomUniformInt op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
uniform_int_op_info = AiCPURegOp("UniformInt") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "shape", "required") \
|
||||||
|
.input(1, "a", "required") \
|
||||||
|
.input(2, "b", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.attr("seed", "int") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(uniform_int_op_info)
|
||||||
|
def _uniform_int_aicpu():
|
||||||
|
"""RandomUniformInt AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RandomUniformReal op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
uniform_real_op_info = AiCPURegOp("UniformReal") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "shape", "required") \
|
||||||
|
.input(1, "a", "required") \
|
||||||
|
.input(2, "b", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.attr("seed", "int") \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(uniform_real_op_info)
|
||||||
|
def _uniform_real_aicpu():
|
||||||
|
"""RandomUniformReal AiCPU register"""
|
||||||
|
return
|
|
@ -288,5 +288,6 @@ from .scatter_div import _scatter_div_tbe
|
||||||
from .mod import _mod_tbe
|
from .mod import _mod_tbe
|
||||||
from .max_pool_grad_grad import _max_pool_grad_grad_tbe
|
from .max_pool_grad_grad import _max_pool_grad_grad_tbe
|
||||||
from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe
|
from .max_pool_grad_grad_with_argmax import _max_pool_grad_grad_with_argmax_tbe
|
||||||
|
from .tensor_move import _tensor_move_tbe
|
||||||
from .population_count import _population_count_tbe
|
from .population_count import _population_count_tbe
|
||||||
from .parallel_concat import _parallel_concat_tbe
|
from .parallel_concat import _parallel_concat_tbe
|
||||||
|
|
|
@ -28,8 +28,11 @@ avg_pool_op_info = TBERegOp("AvgPool") \
|
||||||
.attr("padding", "required", "str", "all") \
|
.attr("padding", "required", "str", "all") \
|
||||||
.attr("data_format", "optional", "str", "all") \
|
.attr("data_format", "optional", "str", "all") \
|
||||||
.input(0, "x", False, "required", "all") \
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.input(1, "filter", False, "optional", "all") \
|
||||||
|
.input(2, "bias", False, "optional", "all") \
|
||||||
.output(0, "y", False, "required", "all") \
|
.output(0, "y", False, "required", "all") \
|
||||||
.dtype_format(DataType.F16_5HD, DataType.F16_5HD) \
|
.dtype_format(DataType.F16_5HD, DataType.F16_FracZ, DataType.F16_Default, DataType.F16_5HD) \
|
||||||
|
.dtype_format(DataType.I8_5HD, DataType.I8_C1HWNCoC0, DataType.I32_Default, DataType.I32_5HD) \
|
||||||
.get_op_info()
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,41 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""TensorMove op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||||
|
|
||||||
|
tensor_move_op_info = TBERegOp("TensorMove") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.async_flag(False) \
|
||||||
|
.binfile_name("tensor_move.so") \
|
||||||
|
.compute_cost(10) \
|
||||||
|
.kernel_name("tensor_move") \
|
||||||
|
.partial_flag(True) \
|
||||||
|
.input(0, "x", False, "required", "all") \
|
||||||
|
.output(0, "y", False, "required", "all") \
|
||||||
|
.op_pattern("formatAgnostic") \
|
||||||
|
.dtype_format(DataType.I32_None, DataType.I32_None) \
|
||||||
|
.dtype_format(DataType.F16_None, DataType.F16_None) \
|
||||||
|
.dtype_format(DataType.F32_None, DataType.F32_None) \
|
||||||
|
.dtype_format(DataType.I8_None, DataType.I8_None) \
|
||||||
|
.dtype_format(DataType.U8_None, DataType.U8_None) \
|
||||||
|
.dtype_format(DataType.BOOL_None, DataType.BOOL_None) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(tensor_move_op_info)
|
||||||
|
def _tensor_move_tbe():
|
||||||
|
"""TensorMove TBE register"""
|
||||||
|
return
|
|
@ -27,8 +27,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
|
||||||
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
|
||||||
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
SameTypeShape, ScatterAdd, ScatterSub, ScatterMul, ScatterDiv, ScatterMax, ScatterMin,
|
||||||
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
ScatterUpdate, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
|
||||||
Shape, Size, Slice, Split, TransShape,
|
Shape, Size, Slice, Split, TransShape, ParallelConcat,
|
||||||
ParallelConcat,
|
|
||||||
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
Squeeze, StridedSlice, Tile, TensorScatterUpdate,
|
||||||
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
|
Transpose, TruncatedNormal, TupleToArray, UnsortedSegmentMin, UnsortedSegmentProd,
|
||||||
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
UnsortedSegmentSum, SpaceToDepth, DepthToSpace, SpaceToBatch, BatchToSpace,
|
||||||
|
@ -55,7 +54,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
||||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e, TruncateDiv, TruncateMod,
|
||||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
|
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps, Tan)
|
||||||
|
|
||||||
from .random_ops import (RandomChoiceWithMask, Normal, RandomCategorical)
|
from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal,
|
||||||
|
RandomCategorical, Laplace)
|
||||||
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
|
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
|
||||||
BiasAdd, Conv2D,
|
BiasAdd, Conv2D,
|
||||||
DepthwiseConv2dNative,
|
DepthwiseConv2dNative,
|
||||||
|
@ -69,8 +69,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
||||||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||||
ResizeBilinear, Sigmoid,
|
ResizeBilinear, Sigmoid,
|
||||||
SigmoidCrossEntropyWithLogits,
|
SigmoidCrossEntropyWithLogits,
|
||||||
SmoothL1Loss, Softmax, Softsign, Softplus, LRN,
|
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss,
|
||||||
RNNTLoss,
|
|
||||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||||
|
@ -78,6 +77,8 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
||||||
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
|
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
|
||||||
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
|
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
|
||||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
||||||
|
from . import _quant_ops
|
||||||
|
from ._quant_ops import *
|
||||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, PopulationCount,
|
||||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
|
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, Push, Pull)
|
||||||
from .thor_ops import *
|
from .thor_ops import *
|
||||||
|
@ -135,6 +136,7 @@ __all__ = [
|
||||||
'OneHot',
|
'OneHot',
|
||||||
'GatherV2',
|
'GatherV2',
|
||||||
'SparseGatherV2',
|
'SparseGatherV2',
|
||||||
|
'EmbeddingLookup',
|
||||||
'Concat',
|
'Concat',
|
||||||
'Pack',
|
'Pack',
|
||||||
'Unpack',
|
'Unpack',
|
||||||
|
@ -172,6 +174,11 @@ __all__ = [
|
||||||
'Tanh',
|
'Tanh',
|
||||||
'RandomChoiceWithMask',
|
'RandomChoiceWithMask',
|
||||||
'Normal',
|
'Normal',
|
||||||
|
'Gamma',
|
||||||
|
'Poisson',
|
||||||
|
'UniformInt',
|
||||||
|
'UniformReal',
|
||||||
|
'Laplace',
|
||||||
'RandomCategorical',
|
'RandomCategorical',
|
||||||
'ResizeBilinear',
|
'ResizeBilinear',
|
||||||
'ScalarSummary',
|
'ScalarSummary',
|
||||||
|
@ -320,6 +327,7 @@ __all__ = [
|
||||||
"ApplyCenteredRMSProp",
|
"ApplyCenteredRMSProp",
|
||||||
"SpaceToBatchND",
|
"SpaceToBatchND",
|
||||||
"BatchToSpaceND",
|
"BatchToSpaceND",
|
||||||
|
"ReverseSequence",
|
||||||
"SquareSumAll",
|
"SquareSumAll",
|
||||||
"BitwiseAnd",
|
"BitwiseAnd",
|
||||||
"BitwiseOr",
|
"BitwiseOr",
|
||||||
|
@ -335,6 +343,7 @@ __all__ = [
|
||||||
"ApproximateEqual",
|
"ApproximateEqual",
|
||||||
"InplaceUpdate",
|
"InplaceUpdate",
|
||||||
"InTopK",
|
"InTopK",
|
||||||
|
"CropAndResize",
|
||||||
"LRN",
|
"LRN",
|
||||||
"Mod",
|
"Mod",
|
||||||
"PopulationCount",
|
"PopulationCount",
|
||||||
|
|
|
@ -1204,6 +1204,54 @@ class StridedSliceGrad(PrimitiveWithInfer):
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
||||||
|
|
||||||
|
class StridedSliceGradAICPU(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Performs grad of StridedSlice operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_mask (int): Start indexing the slice. Default: 0.
|
||||||
|
end_mask (int): End indexing the slice. Default: 0.
|
||||||
|
ellipsis_mask (int): An int32 mask. Default: 0.
|
||||||
|
new_axis_mask (int): An int32 mask. Default: 0.
|
||||||
|
shrink_axis_mask (int): An int32 mask. Default: 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor, has the same shape of input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self,
|
||||||
|
begin_mask=0,
|
||||||
|
end_mask=0,
|
||||||
|
ellipsis_mask=0,
|
||||||
|
new_axis_mask=0,
|
||||||
|
shrink_axis_mask=0):
|
||||||
|
"""init StrideSliceGrad"""
|
||||||
|
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
|
||||||
|
validator.check_value_type('end_mask', end_mask, [int], self.name)
|
||||||
|
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
|
||||||
|
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
|
||||||
|
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
|
||||||
|
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
||||||
|
|
||||||
|
def __infer__(self, dy, shapex, begin, end, strides):
|
||||||
|
args = {"dy": dy['dtype']}
|
||||||
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||||
|
|
||||||
|
for idx, item in enumerate(shapex['value']):
|
||||||
|
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(begin['value']):
|
||||||
|
validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(end['value']):
|
||||||
|
validator.check_value_type("end[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(strides['value']):
|
||||||
|
validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
|
||||||
|
|
||||||
|
return {'shape': shapex['value'],
|
||||||
|
'dtype': dy['dtype'],
|
||||||
|
'value': None}
|
||||||
|
|
||||||
|
|
||||||
class SoftplusGrad(PrimitiveWithInfer):
|
class SoftplusGrad(PrimitiveWithInfer):
|
||||||
"""Computes gradient for the Log Softmax activation."""
|
"""Computes gradient for the Log Softmax activation."""
|
||||||
|
|
||||||
|
@ -1246,11 +1294,20 @@ class MirrorPadGrad(PrimitiveWithInfer):
|
||||||
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
|
validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name)
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
|
|
||||||
def __infer__(self, dout, paddings, x):
|
def __infer__(self, dout, paddings):
|
||||||
validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
|
validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name)
|
||||||
validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
|
validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name)
|
||||||
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name)
|
validator.check("paddings rank", len(paddings['shape']), "expected", 2, Rel.EQ, self.name)
|
||||||
return {'shape': x['shape'],
|
validator.check("paddings dim_1", paddings['shape'][1], "expected", 2, Rel.EQ, self.name)
|
||||||
|
|
||||||
|
if paddings['value'] is None:
|
||||||
|
raise ValueError(f"For {self.name}, paddings must be const.")
|
||||||
|
paddings_value = paddings['value'].asnumpy()
|
||||||
|
y_shape = ()
|
||||||
|
dout_shape = dout['shape']
|
||||||
|
for i, val in enumerate(dout_shape):
|
||||||
|
y_shape += (val - paddings_value[i][0] - paddings_value[i][1],)
|
||||||
|
return {'shape': y_shape,
|
||||||
'dtype': dout['dtype'],
|
'dtype': dout['dtype'],
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,137 @@ from ..._c_expression import signature_dtype as sig_dtype
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||||
|
|
||||||
|
|
||||||
|
class StridedSliceAICPU(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
|
||||||
|
Extracts a strided slice of a tensor.
|
||||||
|
|
||||||
|
Given an input tensor, this operation inserts a dimension of length 1 at the dimension.
|
||||||
|
This operation extracts a fragment of size (end-begin)/stride from the given
|
||||||
|
'input_tensor'. Starting from the position specified by the begin, the fragment
|
||||||
|
continues adding stride to the index until all dimensions are not less than end.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The stride may be negative value, which causes reverse slicing.
|
||||||
|
The shape of `begin`, `end` and `strides` should be the same.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
begin_mask (int): Starting index of the slice. Default: 0.
|
||||||
|
end_mask (int): Ending index of the slice. Default: 0.
|
||||||
|
ellipsis_mask (int): An int mask. Default: 0.
|
||||||
|
new_axis_mask (int): An int mask. Default: 0.
|
||||||
|
shrink_axis_mask (int): An int mask. Default: 0.
|
||||||
|
Currently all the masks are not in used. Use default 0 only.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **input_x** (Tensor) - The input Tensor.
|
||||||
|
- **begin** (tuple[int]) - A tuple which represents the location where to start. Only
|
||||||
|
constant value is allowed.
|
||||||
|
- **end** (tuple[int]) - A tuple or which represents the maximum location where to stop.
|
||||||
|
Only constant value is allowed.
|
||||||
|
- **strides** (tuple[int]) - A tuple which represents the stride continuously added
|
||||||
|
before reach the maximum location. Only constant value is allowed.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor.
|
||||||
|
Explain with the following example.
|
||||||
|
- In the 0th dim, begin is 1, end is 2, and strides is 1,
|
||||||
|
because :math:`1+1=2\geq2`, the interval is :math:`[1,2)`.
|
||||||
|
Thus, return the element with :math:`index = 1` in 0th dim, i.e., [[3, 3, 3], [4, 4, 4]].
|
||||||
|
- In the 1st dim, similarly, the interval is :math:`[0,1)`.
|
||||||
|
Based on the return value of the 0th dim, return the element with :math:`index = 0`,
|
||||||
|
i.e., [3, 3, 3].
|
||||||
|
- In the 2nd dim, similarly, the interval is :math:`[0,3)`.
|
||||||
|
Based on the return value of the 1st dim, return the element with :math:`index = 0,1,2`,
|
||||||
|
i.e., [3, 3, 3].
|
||||||
|
- Finally, the output is [3, 3, 3].
|
||||||
|
|
||||||
|
Examples
|
||||||
|
>>> input_x = Tensor([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]],
|
||||||
|
>>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
|
||||||
|
>>> slice = P.StridedSliceAICPU()
|
||||||
|
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 2))
|
||||||
|
>>> output.shape
|
||||||
|
(1, 1, 2)
|
||||||
|
>>> output
|
||||||
|
[[[3, 3]]]
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self,
|
||||||
|
begin_mask=0,
|
||||||
|
end_mask=0,
|
||||||
|
ellipsis_mask=0,
|
||||||
|
new_axis_mask=0,
|
||||||
|
shrink_axis_mask=0):
|
||||||
|
"""init StrideSlice"""
|
||||||
|
self.init_prim_io_names(inputs=['x', 'begin', 'end', 'strides'], outputs=['output'])
|
||||||
|
validator.check_value_type('begin_mask', begin_mask, [int], self.name)
|
||||||
|
validator.check_value_type('end_mask', end_mask, [int], self.name)
|
||||||
|
validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name)
|
||||||
|
validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name)
|
||||||
|
validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name)
|
||||||
|
|
||||||
|
def __infer__(self, x, begin, end, strides):
|
||||||
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
||||||
|
validator.check_value_type("begin", begin_v, [tuple], self.name)
|
||||||
|
validator.check_value_type("end", end_v, [tuple], self.name)
|
||||||
|
validator.check_value_type("strides", strides_v, [tuple], self.name)
|
||||||
|
|
||||||
|
x_shape = x['shape']
|
||||||
|
x_shp_len = len(x_shape)
|
||||||
|
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
|
||||||
|
raise ValueError(f"For \'{self.name}\' the length of begin index{begin_v}, end index{end_v} and "
|
||||||
|
f"strides{strides_v} must be equal to the dims({x_shp_len}) of input.")
|
||||||
|
|
||||||
|
ret_shape = []
|
||||||
|
append_dimensions = []
|
||||||
|
shrink_pos = bin(self.shrink_axis_mask)[::-1]
|
||||||
|
new_pos = bin(self.new_axis_mask)[::-1]
|
||||||
|
for i in range(x_shp_len):
|
||||||
|
# After the integer is converted to binary, it is a str and the first two chars are the flag char '0b'
|
||||||
|
if i < (len(new_pos) - 2) and new_pos[i] == '1':
|
||||||
|
ret_shape.append(1)
|
||||||
|
append_dimensions.append(x_shape[x_shp_len - 1 - len(append_dimensions)])
|
||||||
|
continue
|
||||||
|
if i < (len(shrink_pos) - 2) and shrink_pos[i] == '1':
|
||||||
|
validator.check_integer(f'begin[{i}]', begin_v[i], -x_shape[i], Rel.GE, self.name)
|
||||||
|
validator.check_integer(f'begin[{i}]', begin_v[i], x_shape[i], Rel.LT, self.name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
begin_idx = begin_v[i]
|
||||||
|
end_idx = end_v[i]
|
||||||
|
strides_idx = strides_v[i]
|
||||||
|
if self.begin_mask:
|
||||||
|
begin_idx = 0
|
||||||
|
if self.end_mask:
|
||||||
|
end_idx = x_shape[i]
|
||||||
|
validator.check_integer(f'begin[{i}]', begin_idx, x_shape[i], Rel.LE, self.name)
|
||||||
|
validator.check_integer(f'end[{i}]', end_idx, x_shape[i], Rel.LE, self.name)
|
||||||
|
validator.check_integer(f'strides[{i}]', strides_idx, 0, Rel.NE, self.name)
|
||||||
|
if strides_idx > 0:
|
||||||
|
# If sliced forward , end_idx >= begin_idx
|
||||||
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.LE)
|
||||||
|
if begin_idx < 0 < end_idx:
|
||||||
|
# Turn negative begin_idx into positive values
|
||||||
|
begin_idx = x_shape[i] + begin_idx
|
||||||
|
num_elems = (end_idx - begin_idx + strides_idx - 1) // strides_idx
|
||||||
|
else:
|
||||||
|
# If sliced backwards, end_idx <= begin_idx
|
||||||
|
validator.check(f'begin[{i}]', begin_idx, f'end[{i}]', end_idx, Rel.GE)
|
||||||
|
if end_idx < 0 < begin_idx:
|
||||||
|
# Turn negative end_idx into positive values
|
||||||
|
end_idx = x_shape[i] + end_idx
|
||||||
|
num_elems = (end_idx - begin_idx + strides_idx + 1) // strides_idx
|
||||||
|
|
||||||
|
ret_shape.append(num_elems)
|
||||||
|
if append_dimensions:
|
||||||
|
ret_shape += append_dimensions[::-1]
|
||||||
|
return {'shape': ret_shape,
|
||||||
|
'dtype': x['dtype'],
|
||||||
|
'value': None}
|
||||||
|
|
||||||
|
|
||||||
class ExtractImagePatches(PrimitiveWithInfer):
|
class ExtractImagePatches(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Extract patches from images.
|
Extract patches from images.
|
||||||
|
|
|
@ -780,7 +780,9 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
|
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
|
||||||
2 deconvolution, 3 depthwise convolution. Default: 1.
|
2 deconvolution, 3 depthwise convolution. Default: 1.
|
||||||
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
|
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
|
||||||
pad (int): The pad value to fill. Default: 0.
|
pad (Union(int, tuple[int])): The pad value to fill. Default: 0. If `pad` is one integer, the padding of
|
||||||
|
top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding
|
||||||
|
of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding.
|
||||||
stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1.
|
stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1.
|
||||||
dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1.
|
dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1.
|
||||||
group (int): Split input into groups. Default: 1.
|
group (int): Split input into groups. Default: 1.
|
||||||
|
@ -820,11 +822,19 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('stride', self.stride)
|
self.add_prim_attr('stride', self.stride)
|
||||||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||||
self.add_prim_attr('dilation', self.dilation)
|
self.add_prim_attr('dilation', self.dilation)
|
||||||
validator.check_value_type('pad', pad, (int,), self.name)
|
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||||
|
if isinstance(pad, int):
|
||||||
|
pad = (pad,) * 4
|
||||||
|
else:
|
||||||
|
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
||||||
|
self.padding = pad
|
||||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||||
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
|
|
||||||
|
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||||
|
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||||
if self.pad_mode == 'pad':
|
if self.pad_mode == 'pad':
|
||||||
validator.check_integer('pad', self.pad, 0, Rel.GE, self.name)
|
for item in pad:
|
||||||
|
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
||||||
self.add_prim_attr('data_format', "NCHW")
|
self.add_prim_attr('data_format', "NCHW")
|
||||||
|
@ -862,11 +872,11 @@ class Conv2D(PrimitiveWithInfer):
|
||||||
pad_left = math.floor(pad_needed_w / 2)
|
pad_left = math.floor(pad_needed_w / 2)
|
||||||
pad_right = pad_needed_w - pad_left
|
pad_right = pad_needed_w - pad_left
|
||||||
elif self.pad_mode == 'pad':
|
elif self.pad_mode == 'pad':
|
||||||
pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad
|
pad_top, pad_bottom, pad_left, pad_right = self.padding
|
||||||
|
|
||||||
h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
|
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \
|
||||||
/ stride_h
|
/ stride_h
|
||||||
w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
|
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \
|
||||||
/ stride_w
|
/ stride_w
|
||||||
h_out = math.floor(h_out)
|
h_out = math.floor(h_out)
|
||||||
w_out = math.floor(w_out)
|
w_out = math.floor(w_out)
|
||||||
|
@ -1279,7 +1289,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
||||||
out_channel (int): The dimensionality of the output space.
|
out_channel (int): The dimensionality of the output space.
|
||||||
kernel_size (Union[int, tuple[int]]): The size of the convolution window.
|
kernel_size (Union[int, tuple[int]]): The size of the convolution window.
|
||||||
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
|
pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid".
|
||||||
pad (int): The pad value to fill. Default: 0.
|
pad (Union[int, tuple[int]]): The pad value to fill. Default: 0. If `pad` is one integer, the padding of
|
||||||
|
top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding
|
||||||
|
of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding.
|
||||||
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
|
mode (int): 0 Math convolutiuon, 1 cross-correlation convolution ,
|
||||||
2 deconvolution, 3 depthwise convolution. Default: 1.
|
2 deconvolution, 3 depthwise convolution. Default: 1.
|
||||||
stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1.
|
stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1.
|
||||||
|
@ -1316,9 +1328,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
||||||
self.add_prim_attr('stride', self.stride)
|
self.add_prim_attr('stride', self.stride)
|
||||||
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
||||||
self.add_prim_attr('dilation', self.dilation)
|
self.add_prim_attr('dilation', self.dilation)
|
||||||
validator.check_value_type('pad', pad, (int,), self.name)
|
|
||||||
|
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
||||||
|
if isinstance(pad, int):
|
||||||
|
pad = (pad,) * 4
|
||||||
|
self.padding = pad
|
||||||
|
else:
|
||||||
|
validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name)
|
||||||
|
|
||||||
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name)
|
||||||
self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name)
|
if pad_mode != 'pad' and pad != (0, 0, 0, 0):
|
||||||
|
raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.")
|
||||||
|
if self.pad_mode == 'pad':
|
||||||
|
for item in pad:
|
||||||
|
validator.check_integer('pad item', item, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
pad_mode = pad_mode.upper()
|
pad_mode = pad_mode.upper()
|
||||||
self.add_prim_attr('pad_mode', pad_mode)
|
self.add_prim_attr('pad_mode', pad_mode)
|
||||||
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name)
|
||||||
|
@ -1360,7 +1384,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
||||||
pad_right = pad_needed_w - pad_left
|
pad_right = pad_needed_w - pad_left
|
||||||
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
|
pad_list = (pad_top, pad_bottom, pad_left, pad_right)
|
||||||
elif self.pad_mode == 'PAD':
|
elif self.pad_mode == 'PAD':
|
||||||
pad_list = (self.pad,) * 4
|
pad_list = self.padding
|
||||||
self.add_prim_attr('pad_list', pad_list)
|
self.add_prim_attr('pad_list', pad_list)
|
||||||
out = {
|
out = {
|
||||||
'value': None,
|
'value': None,
|
||||||
|
@ -1735,7 +1759,6 @@ class DataFormatDimMap(PrimitiveWithInfer):
|
||||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||||
return x_type
|
return x_type
|
||||||
|
|
||||||
|
|
||||||
class RNNTLoss(PrimitiveWithInfer):
|
class RNNTLoss(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Computes the RNNTLoss and its gradient with respect to the softmax outputs.
|
Computes the RNNTLoss and its gradient with respect to the softmax outputs.
|
||||||
|
|
|
@ -19,6 +19,277 @@ from ..._checkparam import Validator as validator
|
||||||
from ..._checkparam import Rel
|
from ..._checkparam import Rel
|
||||||
from ...common import dtype as mstype
|
from ...common import dtype as mstype
|
||||||
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
from ..primitive import PrimitiveWithInfer, prim_attr_register
|
||||||
|
from .._utils import get_broadcast_shape
|
||||||
|
|
||||||
|
|
||||||
|
class Laplace(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Generates random numbers according to the Laplace random number distribution.
|
||||||
|
It is defined as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{f}(x;μ,λ) = \frac{1}{2λ}\exp(-\frac{|x-μ|}{λ}),
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||||
|
Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
|
||||||
|
- **mean** (Tensor) - The mean μ distribution parameter, which specifies the location of the peak.
|
||||||
|
With float32 data type.
|
||||||
|
- **lambda_param** (Tensor) - The parameter used for controling the variance of this random distribution. The
|
||||||
|
variance of Laplace distribution is equal to twice the square of lambda_param. With float32 data type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, has the shape 'shape' input and dtype as float32.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = (4, 16)
|
||||||
|
>>> mean = Tensor(1.0, mstype.float32)
|
||||||
|
>>> lambda_param = Tensor(1.0, mstype.float32)
|
||||||
|
>>> laplace = P.Laplace(seed=2)
|
||||||
|
>>> output = laplace(shape, mean, lambda_param)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, seed=0):
|
||||||
|
"""Init Laplace"""
|
||||||
|
self.init_prim_io_names(inputs=['shape', 'mean', 'lambda_param'], outputs=['output'])
|
||||||
|
validator.check_value_type('seed', seed, [int], self.name)
|
||||||
|
|
||||||
|
def __infer__(self, shape, mean, lambda_param):
|
||||||
|
shape_v = shape["value"]
|
||||||
|
if shape_v is None:
|
||||||
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
|
for i, shape_i in enumerate(shape_v):
|
||||||
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
||||||
|
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
||||||
|
validator.check_tensor_type_same({"lambda_param": lambda_param["dtype"]}, [mstype.float32], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(mean['shape'], lambda_param['shape'], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
||||||
|
out = {
|
||||||
|
'shape': broadcast_shape,
|
||||||
|
'dtype': mstype.float32,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Gamma(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Produces random positive floating-point values x, distributed according to probability density function:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{P}(x|α,β) = \frac{\exp(-x/β)}{{β^α}\cdot{\Gamma(α)}}\cdot{x^{α-1}},
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||||
|
Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
|
||||||
|
- **alpha** (Tensor) - The α distribution parameter.
|
||||||
|
It is also known as the shape parameter. With float32 data type.
|
||||||
|
- **beta** (Tensor) - The β distribution parameter.
|
||||||
|
It is also known as the scale parameter. With float32 data type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of alpha and beta.
|
||||||
|
The dtype is float32.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = (4, 16)
|
||||||
|
>>> alpha = Tensor(1.0, mstype.float32)
|
||||||
|
>>> beta = Tensor(1.0, mstype.float32)
|
||||||
|
>>> gamma = P.Gamma(seed=3)
|
||||||
|
>>> output = Gamma(shape, alpha, beta)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, seed=0):
|
||||||
|
"""Init Gamma"""
|
||||||
|
self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output'])
|
||||||
|
validator.check_value_type('seed', seed, [int], self.name)
|
||||||
|
|
||||||
|
def __infer__(self, shape, alpha, beta):
|
||||||
|
shape_v = shape["value"]
|
||||||
|
if shape_v is None:
|
||||||
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
|
for i, shape_i in enumerate(shape_v):
|
||||||
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
||||||
|
validator.check_tensor_type_same({"alpha": alpha["dtype"]}, [mstype.float32], self.name)
|
||||||
|
validator.check_tensor_type_same({"beta": beta["dtype"]}, [mstype.float32], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(alpha['shape'], beta['shape'], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
||||||
|
out = {
|
||||||
|
'shape': broadcast_shape,
|
||||||
|
'dtype': mstype.float32,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Poisson(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Produces random non-negative integer values i, distributed according to discrete probability function:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{P}(i|μ) = \frac{\exp(-μ)μ^{i}}{i!},
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||||
|
Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
|
||||||
|
- **mean** (Tensor) - μ parameter the distribution was constructed with.
|
||||||
|
The parameter defines mean number of occurrences of the event. With float32 data type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor. The shape should be the broadcasted shape of Input "shape" and shape of mean.
|
||||||
|
The dtype is int32.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = (4, 16)
|
||||||
|
>>> mean = Tensor(5.0, mstype.float32)
|
||||||
|
>>> poisson = P.Poisson(seed=5)
|
||||||
|
>>> output = poisson(shape, mean)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, seed=0):
|
||||||
|
"""Init Poisson"""
|
||||||
|
self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output'])
|
||||||
|
validator.check_value_type('seed', seed, [int], self.name)
|
||||||
|
|
||||||
|
def __infer__(self, shape, mean):
|
||||||
|
shape_v = shape["value"]
|
||||||
|
if shape_v is None:
|
||||||
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
|
for i, shape_i in enumerate(shape_v):
|
||||||
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
||||||
|
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(mean['shape'], shape_v, self.name)
|
||||||
|
out = {
|
||||||
|
'shape': broadcast_shape,
|
||||||
|
'dtype': mstype.int32,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UniformInt(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Produces random integer values i, uniformly distributed on the closed interval [a, b], that is,
|
||||||
|
distributed according to the discrete probability function:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{P}(i|a,b) = \frac{1}{b-a+1},
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The number in tensor a should be strictly less than b at any position after broadcasting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||||
|
Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
|
||||||
|
- **a** (Tensor) - The a distribution parameter.
|
||||||
|
It defines the minimum possibly generated value. With int32 data type.
|
||||||
|
- **b** (Tensor) - The b distribution parameter.
|
||||||
|
It defines the maximum possibly generated value. With int32 data type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
|
||||||
|
The dtype is int32.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = (4, 16)
|
||||||
|
>>> a = Tensor(1, mstype.int32)
|
||||||
|
>>> b = Tensor(5, mstype.int32)
|
||||||
|
>>> uniform_int = P.UniformInt(seed=10)
|
||||||
|
>>> output = uniform_int(shape, a, b)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, seed=0):
|
||||||
|
"""Init UniformInt"""
|
||||||
|
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output'])
|
||||||
|
validator.check_value_type('seed', seed, [int], self.name)
|
||||||
|
|
||||||
|
def __infer__(self, shape, a, b):
|
||||||
|
shape_v = shape["value"]
|
||||||
|
if shape_v is None:
|
||||||
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
|
for i, shape_i in enumerate(shape_v):
|
||||||
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
||||||
|
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.int32], self.name)
|
||||||
|
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.int32], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
||||||
|
out = {
|
||||||
|
'shape': broadcast_shape,
|
||||||
|
'dtype': mstype.int32,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UniformReal(PrimitiveWithInfer):
|
||||||
|
r"""
|
||||||
|
Produces random floating-point values i, uniformly distributed on the interval [min(a, b), max(a, b)), that is,\
|
||||||
|
distributed according to the probability density function:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{P}(i|a,b) = \frac{1}{b-a},
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Seed data is used as entropy source for Random number engines generating pseudo-random numbers.
|
||||||
|
Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **shape** (tuple) - The shape of random tensor to be generated. Only constant value is allowed.
|
||||||
|
- **a** (Tensor) - The a distribution parameter.
|
||||||
|
It defines the minimum possibly generated value. With float32 data type.
|
||||||
|
- **b** (Tensor) - The b distribution parameter.
|
||||||
|
It defines the maximum possibly generated value. With float32 data type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor. The shape should be the broadcasted shape of Input "shape" and shapes of a and b.
|
||||||
|
The dtype is float32.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> shape = (4, 16)
|
||||||
|
>>> a = Tensor(1.0, mstype.float32)
|
||||||
|
>>> b = Tensor(5.0, mstype.float32)
|
||||||
|
>>> uniform_real = P.UniformReal(seed=10)
|
||||||
|
>>> output = uniform_real(shape, a, b)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, seed=0):
|
||||||
|
"""Init UniformReal"""
|
||||||
|
self.init_prim_io_names(inputs=['shape', 'a', 'b'], outputs=['output'])
|
||||||
|
validator.check_value_type('seed', seed, [int], self.name)
|
||||||
|
|
||||||
|
def __infer__(self, shape, a, b):
|
||||||
|
shape_v = shape["value"]
|
||||||
|
if shape_v is None:
|
||||||
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
|
validator.check_value_type("shape", shape_v, [tuple], self.name)
|
||||||
|
for i, shape_i in enumerate(shape_v):
|
||||||
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GT, self.name)
|
||||||
|
validator.check_tensor_type_same({"a": a["dtype"]}, [mstype.float32], self.name)
|
||||||
|
validator.check_tensor_type_same({"b": b["dtype"]}, [mstype.float32], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(a['shape'], b['shape'], self.name)
|
||||||
|
broadcast_shape = get_broadcast_shape(broadcast_shape, shape_v, self.name)
|
||||||
|
out = {
|
||||||
|
'shape': broadcast_shape,
|
||||||
|
'dtype': mstype.float32,
|
||||||
|
'value': None}
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class RandomChoiceWithMask(PrimitiveWithInfer):
|
class RandomChoiceWithMask(PrimitiveWithInfer):
|
||||||
|
@ -66,50 +337,6 @@ class RandomChoiceWithMask(PrimitiveWithInfer):
|
||||||
return (mstype.int32, mstype.bool_)
|
return (mstype.int32, mstype.bool_)
|
||||||
|
|
||||||
|
|
||||||
class Normal(PrimitiveWithInfer):
|
|
||||||
"""
|
|
||||||
Generates random samples from a normal(Gaussian) distribution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed (int): Random seed. Default: 0.
|
|
||||||
|
|
||||||
Inputs:
|
|
||||||
- **shape** (tuple[int]) - The shape of output tensor. Only constant value is allowed.
|
|
||||||
- **mean** (Tensor) - The mean of the distribution, with float32 data type.
|
|
||||||
- **stddev** (Tensor) - The standard deviation of the distribution, with float32 data type.
|
|
||||||
|
|
||||||
Outputs:
|
|
||||||
Tensor, with the given shape from the specific distribution and float32 data type.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> normal = P.Normal()
|
|
||||||
>>> mean = Tensor(0., mstype.float32)
|
|
||||||
>>> stddev = Tensor(1., mstype.float32)
|
|
||||||
>>> out = normal((32, 3, 3), mean, stddev)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@prim_attr_register
|
|
||||||
def __init__(self, seed=0):
|
|
||||||
"""Init Normal"""
|
|
||||||
validator.check_value_type("seed", seed, [int], self.name)
|
|
||||||
|
|
||||||
def __infer__(self, shape, mean, stddev):
|
|
||||||
shape_value = shape["value"]
|
|
||||||
if shape_value is None:
|
|
||||||
raise ValueError(f"For {self.name}, shape must be const.")
|
|
||||||
validator.check_value_type("shape", shape_value, [tuple], self.name)
|
|
||||||
for i, shape_i in enumerate(shape_value):
|
|
||||||
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GE, self.name)
|
|
||||||
|
|
||||||
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
|
||||||
validator.check_tensor_type_same({"stddev": stddev["dtype"]}, [mstype.float32], self.name)
|
|
||||||
|
|
||||||
out = {"shape": shape_value,
|
|
||||||
"dtype": mstype.float32,
|
|
||||||
"value": None}
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
class RandomCategorical(PrimitiveWithInfer):
|
class RandomCategorical(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
Generates random samples from a given categorical distribution tensor.
|
Generates random samples from a given categorical distribution tensor.
|
||||||
|
@ -166,3 +393,46 @@ class RandomCategorical(PrimitiveWithInfer):
|
||||||
return {'shape': (x_shape),
|
return {'shape': (x_shape),
|
||||||
'dtype': (self.dtype),
|
'dtype': (self.dtype),
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
||||||
|
class Normal(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Generates random samples from a normal(Gaussian) distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed (int): Random seed. Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **shape** (tuple[int]) - The shape of output tensor. Only constant value is allowed.
|
||||||
|
- **mean** (Tensor) - The mean of the distribution, with float32 data type.
|
||||||
|
- **stddev** (Tensor) - The standard deviation of the distribution, with float32 data type.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
Tensor, with the given shape from the specific distribution and float32 data type.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> normal = P.Normal()
|
||||||
|
>>> mean = Tensor(0., mstype.float32)
|
||||||
|
>>> stddev = Tensor(1., mstype.float32)
|
||||||
|
>>> out = normal((32, 3, 3), mean, stddev)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, seed=0):
|
||||||
|
"""Init Normal"""
|
||||||
|
validator.check_value_type("seed", seed, [int], self.name)
|
||||||
|
|
||||||
|
def __infer__(self, shape, mean, stddev):
|
||||||
|
shape_value = shape["value"]
|
||||||
|
if shape_value is None:
|
||||||
|
raise ValueError(f"For {self.name}, shape must be const.")
|
||||||
|
validator.check_value_type("shape", shape_value, [tuple], self.name)
|
||||||
|
for i, shape_i in enumerate(shape_value):
|
||||||
|
validator.check_integer("shape[%d]" % i, shape_i, 0, Rel.GE, self.name)
|
||||||
|
|
||||||
|
validator.check_tensor_type_same({"mean": mean["dtype"]}, [mstype.float32], self.name)
|
||||||
|
validator.check_tensor_type_same({"stddev": stddev["dtype"]}, [mstype.float32], self.name)
|
||||||
|
|
||||||
|
out = {"shape": shape_value,
|
||||||
|
"dtype": mstype.float32,
|
||||||
|
"value": None}
|
||||||
|
return out
|
||||||
|
|
|
@ -126,10 +126,6 @@ class ModelCallback(Callback):
|
||||||
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_arm_ascend_training
|
|
||||||
@pytest.mark.platform_x86_ascend_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_bert_tdt():
|
def test_bert_tdt():
|
||||||
"""test bert tdt"""
|
"""test bert tdt"""
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
|
@ -154,10 +154,6 @@ class TimeMonitor(Callback):
|
||||||
self.epoch_mseconds_list.append(epoch_mseconds)
|
self.epoch_mseconds_list.append(epoch_mseconds)
|
||||||
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
|
self.per_step_mseconds_list.append(epoch_mseconds / self.data_size)
|
||||||
|
|
||||||
@pytest.mark.level0
|
|
||||||
@pytest.mark.platform_arm_ascend_training
|
|
||||||
@pytest.mark.platform_x86_ascend_training
|
|
||||||
@pytest.mark.env_onecard
|
|
||||||
def test_bert_percision():
|
def test_bert_percision():
|
||||||
"""test bert percision"""
|
"""test bert percision"""
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, x, dtype):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.x = x
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
return self.cast(self.x, self.dtype)
|
||||||
|
|
||||||
|
def test_net_f32_bool():
|
||||||
|
x = np.random.randn(3, 4).astype(np.float32)
|
||||||
|
x[:, 1] = 0
|
||||||
|
net = Net(Tensor(x), mstype.bool_)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_f16_bool():
|
||||||
|
x = np.random.randn(3, 4).astype(np.float16)
|
||||||
|
x[:, 1] = 0
|
||||||
|
net = Net(Tensor(x), mstype.bool_)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_f64_bool():
|
||||||
|
x = np.random.randn(3, 4).astype(np.float64)
|
||||||
|
x[:, 1] = 0
|
||||||
|
net = Net(Tensor(x), mstype.bool_)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_int16_float16():
|
||||||
|
x = np.random.randint(-512, 512, size=(3, 4)).astype(np.int16)
|
||||||
|
net = Net(Tensor(x), mstype.float16)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_int64_float16():
|
||||||
|
x = np.random.randint(-512, 512, size=(3, 4)).astype(np.int64)
|
||||||
|
net = Net(Tensor(x), mstype.float16)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, shape, seed=0):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.gamma = P.Gamma(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, alpha, beta):
|
||||||
|
return self.gamma(self.shape, alpha, beta)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_1D():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 2, 4)
|
||||||
|
alpha = 1.0
|
||||||
|
beta = 1.0
|
||||||
|
net = Net(shape, seed)
|
||||||
|
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
|
||||||
|
output = net(talpha, tbeta)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_ND():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 1, 2)
|
||||||
|
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
|
||||||
|
beta = np.array([1.0]).astype(np.float32)
|
||||||
|
net = Net(shape, seed)
|
||||||
|
talpha, tbeta = Tensor(alpha), Tensor(beta)
|
||||||
|
output = net(talpha, tbeta)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 2)
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, shape, seed=0):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.laplace = P.Laplace(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, mean, lambda_param):
|
||||||
|
return self.laplace(self.shape, mean, lambda_param)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_1D():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 2, 4)
|
||||||
|
mean = 1.0
|
||||||
|
lambda_param = 1.0
|
||||||
|
net = Net(shape, seed)
|
||||||
|
tmean, tlambda_param = Tensor(mean, mstype.float32), Tensor(lambda_param, mstype.float32)
|
||||||
|
output = net(tmean, tlambda_param)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_ND():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 1, 2)
|
||||||
|
mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
|
||||||
|
lambda_param = np.array([1.0]).astype(np.float32)
|
||||||
|
net = Net(shape, seed)
|
||||||
|
tmean, tlambda_param = Tensor(mean), Tensor(lambda_param)
|
||||||
|
output = net(tmean, tlambda_param)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 2)
|
|
@ -12,32 +12,48 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore.context as context
|
import mindspore.context as context
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore.ops import operations as P
|
from mindspore import Tensor
|
||||||
from mindspore.common import Tensor
|
|
||||||
from mindspore.common import dtype as mstype
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
|
||||||
|
|
||||||
|
|
||||||
class Net(nn.Cell):
|
class Net(nn.Cell):
|
||||||
def __init__(self, shape=None, mean=0.0, stddev=1.0, seed=0):
|
def __init__(self, shape, seed=0):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self._mean = Tensor(mean, mstype.float32)
|
self.shape = shape
|
||||||
self._stddev = Tensor(stddev, mstype.float32)
|
self.seed = seed
|
||||||
self._normal = P.Normal(seed=seed)
|
|
||||||
self._shape = shape
|
|
||||||
|
|
||||||
def construct(self):
|
def construct(self, mean, stddev):
|
||||||
return self._normal(self._shape, self._mean, self._stddev)
|
return C.normal(self.shape, mean, stddev, self.seed)
|
||||||
|
|
||||||
|
|
||||||
def test_net_3x2x4():
|
def test_net_1D():
|
||||||
mean = 0.0
|
seed = 10
|
||||||
|
shape = (3, 2, 4)
|
||||||
|
mean = 1.0
|
||||||
stddev = 1.0
|
stddev = 1.0
|
||||||
seed = 0
|
net = Net(shape, seed)
|
||||||
net = Net((3, 2, 4), mean, stddev, seed)
|
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
|
||||||
out = net()
|
output = net(tmean, tstddev)
|
||||||
assert out.shape == (3, 2, 4)
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_ND():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 1, 2)
|
||||||
|
mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
|
||||||
|
stddev = np.array([1.0]).astype(np.float32)
|
||||||
|
net = Net(shape, seed)
|
||||||
|
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
|
||||||
|
output = net(tmean, tstddev)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 2)
|
||||||
|
|
|
@ -127,7 +127,6 @@ def test_net_int64():
|
||||||
print(output.asnumpy())
|
print(output.asnumpy())
|
||||||
assert np.array_equal(output.asnumpy(), np.stack([x, y], axis))
|
assert np.array_equal(output.asnumpy(), np.stack([x, y], axis))
|
||||||
|
|
||||||
|
|
||||||
def test_net_uint64():
|
def test_net_uint64():
|
||||||
x = np.random.randn(3, 5, 4).astype(np.uint64)
|
x = np.random.randn(3, 5, 4).astype(np.uint64)
|
||||||
y = np.random.randn(3, 5, 4).astype(np.uint64)
|
y = np.random.randn(3, 5, 4).astype(np.uint64)
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, shape):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.poisson = P.Poisson()
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, mean):
|
||||||
|
return self.poisson(self.shape, mean)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_1():
|
||||||
|
shape = (2, 16)
|
||||||
|
mean = np.array([5.0]).astype(np.float32)
|
||||||
|
net = Net(shape)
|
||||||
|
tmean = Tensor(mean)
|
||||||
|
output = net(tmean)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (2, 16)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_2():
|
||||||
|
shape = (4, 1)
|
||||||
|
mean = np.array([5.0, 10.0]).astype(np.float32)
|
||||||
|
net = Net(shape)
|
||||||
|
tmean = Tensor(mean)
|
||||||
|
output = net(tmean)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (4, 2)
|
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, shape, seed=0, seed2=0):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.shape = shape
|
||||||
|
self.seed = seed
|
||||||
|
self.seed2 = seed2
|
||||||
|
self.stdnormal = P.StandardNormal(seed, seed2)
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
return self.stdnormal(self.shape, self.seed, self.seed2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
seed = 10
|
||||||
|
seed2 = 10
|
||||||
|
shape = (3, 2, 4)
|
||||||
|
net = Net(shape, seed, seed2)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 4)
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.operations import _inner_ops as inner
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, begin, end, strides):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.strided_slice = inner.StridedSliceAICPU()
|
||||||
|
self.begin = begin
|
||||||
|
self.end = end
|
||||||
|
self.strides = strides
|
||||||
|
|
||||||
|
def construct(self, input):
|
||||||
|
return self.strided_slice(input, self.begin, self.end, self.strides)
|
||||||
|
|
||||||
|
|
||||||
|
input_x = np.array([[[0, 1, 2], [3, 4, 5]],
|
||||||
|
[[6, 7, 8], [9, 10, 11]],
|
||||||
|
[[12, 13, 14], [15, 16, 17]]
|
||||||
|
]).astype(np.float32)
|
||||||
|
begin = (1, 0, 0)
|
||||||
|
end = (2, 2, 3)
|
||||||
|
strides = (1, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
net = Net(begin, end, strides)
|
||||||
|
tinput = Tensor(input_x)
|
||||||
|
output = net(tinput)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert np.all([[[6, 8], [9, 11]]] == output.asnumpy())
|
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops.operations import _grad_ops as G
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, shape_x, begin, end, strides):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.strided_slice_grad = G.StridedSliceGradAICPU()
|
||||||
|
self.shape_x = shape_x
|
||||||
|
self.begin = begin
|
||||||
|
self.end = end
|
||||||
|
self.strides = strides
|
||||||
|
|
||||||
|
def construct(self, dy):
|
||||||
|
return self.strided_slice_grad(dy, self.shape_x, self.begin, self.end, self.strides)
|
||||||
|
|
||||||
|
|
||||||
|
dy = np.array([[[6, 8], [9, 11]]]).astype(np.float32)
|
||||||
|
shape_x = (3, 2, 3)
|
||||||
|
begin = (1, 0, 0)
|
||||||
|
end = (2, 2, 3)
|
||||||
|
strides = (1, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
net = Net(shape_x, begin, end, strides)
|
||||||
|
tdy = Tensor(dy)
|
||||||
|
output = net(tdy)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert np.all([[[0, 0, 0], [0, 0, 0]],
|
||||||
|
[[6, 0, 8], [9, 0, 11]],
|
||||||
|
[[0, 0, 0], [0, 0, 0]]
|
||||||
|
] == output.asnumpy())
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, shape, seed=0):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.uniformint = P.UniformInt(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, a, b):
|
||||||
|
return self.uniformint(self.shape, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_1D():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 2, 4)
|
||||||
|
a = 1
|
||||||
|
b = 5
|
||||||
|
net = Net(shape, seed)
|
||||||
|
ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32)
|
||||||
|
output = net(ta, tb)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_ND():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 2, 1)
|
||||||
|
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32)
|
||||||
|
b = np.array([10]).astype(np.int32)
|
||||||
|
net = Net(shape, seed)
|
||||||
|
ta, tb = Tensor(a), Tensor(b)
|
||||||
|
output = net(ta, tb)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 2)
|
|
@ -0,0 +1,57 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common import dtype as mstype
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, shape, seed=0):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.uniformreal = P.UniformReal(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, a, b):
|
||||||
|
return self.uniformreal(self.shape, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_1D():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 2, 4)
|
||||||
|
a = 1.0
|
||||||
|
b = 5.0
|
||||||
|
net = Net(shape, seed)
|
||||||
|
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
|
||||||
|
output = net(ta, tb)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net_ND():
|
||||||
|
seed = 10
|
||||||
|
shape = (3, 2, 1)
|
||||||
|
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.float32)
|
||||||
|
b = np.array([10]).astype(np.float32)
|
||||||
|
net = Net(shape, seed)
|
||||||
|
ta, tb = Tensor(a), Tensor(b)
|
||||||
|
output = net(ta, tb)
|
||||||
|
print(output.asnumpy())
|
||||||
|
assert output.shape == (3, 2, 2)
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, offset):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.embedding = P.EmbeddingLookup()
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
def construct(self, param, index):
|
||||||
|
return self.embedding(param, index, self.offset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_lookup_sparse():
|
||||||
|
params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.int32)
|
||||||
|
indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32)
|
||||||
|
offset = 4
|
||||||
|
embedding = Net(offset)
|
||||||
|
out = embedding(params, indices)
|
||||||
|
assert(out.asnumpy() == [[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).all()
|
|
@ -0,0 +1,56 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "common/backend_common_test.h"
|
||||||
|
#include "common/py_func_graph_fetcher.h"
|
||||||
|
#include "backend/optimizer/ascend/ir_fission/tensor_scatter_update_fission.h"
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class TestHWOptTensorScatterUpdateFission : public BackendCommon {
|
||||||
|
public:
|
||||||
|
TestHWOptTensorScatterUpdateFission()
|
||||||
|
: get_py_fun_("gtest_input.pre_activate.tensor_scatter_update_fission_test", true) {}
|
||||||
|
~TestHWOptTensorScatterUpdateFission() override = default;
|
||||||
|
|
||||||
|
UT::PyFuncGraphFetcher get_py_fun_;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TestHWOptTensorScatterUpdateFission, test_fission) {
|
||||||
|
FuncGraphPtr g = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "before");
|
||||||
|
EXPECT_NE(g, nullptr);
|
||||||
|
std::vector<int> shp1{2, 3};
|
||||||
|
std::vector<int> shp2{2, 2};
|
||||||
|
std::vector<int> shp3{2};
|
||||||
|
auto inputx = std::make_shared<abstract::AbstractTensor>(kFloat32, shp1);
|
||||||
|
auto indices = std::make_shared<abstract::AbstractTensor>(kInt32, shp2);
|
||||||
|
auto update = std::make_shared<abstract::AbstractTensor>(kFloat32, shp3);
|
||||||
|
AbstractBasePtrList args_spec_list{inputx, indices, update};
|
||||||
|
auto fg = GetKernelGraph(g, args_spec_list);
|
||||||
|
|
||||||
|
auto optimizer = std::make_shared<opt::GraphOptimizer>();
|
||||||
|
auto pm = std::make_shared<opt::PassManager>();
|
||||||
|
pm->AddPass(std::make_shared<opt::TensorScatterUpdateFission>());
|
||||||
|
optimizer->AddPassManager(pm);
|
||||||
|
FuncGraphPtr new_graph = optimizer->Optimize(fg);
|
||||||
|
|
||||||
|
FuncGraphPtr g_after = get_py_fun_.CallAndParseRet("test_tensor_scatter_update_fission", "after");
|
||||||
|
EXPECT_NE(g_after, nullptr);
|
||||||
|
EXPECT_TRUE(CheckEqualGraph(g_after, new_graph));
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
from mindspore.ops import Primitive
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
tensor_scatter_update = P.TensorScatterUpdate()
|
||||||
|
tensor_move = Primitive('TensorMove')
|
||||||
|
scatter_nd_update = Primitive('ScatterNdUpdate')
|
||||||
|
make_tuple = Primitive('make_tuple')
|
||||||
|
tuple_getitem = Primitive('tuple_getitem')
|
||||||
|
|
||||||
|
|
||||||
|
class FnDict:
|
||||||
|
def __init__(self):
|
||||||
|
self.fnDict = {}
|
||||||
|
|
||||||
|
def __call__(self, fn):
|
||||||
|
self.fnDict[fn.__name__] = fn
|
||||||
|
|
||||||
|
def __getitem__(self, name):
|
||||||
|
return self.fnDict[name]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tensor_scatter_update_fission(tag):
|
||||||
|
fns = FnDict()
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def before(x, indices, updates):
|
||||||
|
res = tensor_scatter_update(x, indices, updates)
|
||||||
|
return res
|
||||||
|
|
||||||
|
@fns
|
||||||
|
def after(x, indices, updates):
|
||||||
|
res = tensor_move(x)
|
||||||
|
res = scatter_nd_update(res, indices, updates)
|
||||||
|
return make_tuple(res)
|
||||||
|
|
||||||
|
return fns[tag]
|
|
@ -22,11 +22,22 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters
|
||||||
|
|
||||||
def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
|
def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
|
||||||
"""Rearranges an image to row vector"""
|
"""Rearranges an image to row vector"""
|
||||||
batch_num, channel, height, width = img.shape
|
if isinstance(pad, int):
|
||||||
out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1
|
pad_top = pad
|
||||||
out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1
|
pad_bottom = pad
|
||||||
|
pad_left = pad
|
||||||
|
pad_right = pad
|
||||||
|
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||||
|
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The \'pad\' should be an int number or "
|
||||||
|
f"a tuple of two or four int numbers, but got {pad}")
|
||||||
|
|
||||||
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
|
batch_num, channel, height, width = img.shape
|
||||||
|
out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1
|
||||||
|
out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1
|
||||||
|
|
||||||
|
img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
|
||||||
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
|
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
|
||||||
|
|
||||||
for y in range(filter_h):
|
for y in range(filter_h):
|
||||||
|
@ -43,10 +54,21 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
|
||||||
def conv2d(x, weight, bias=None, stride=1, pad=0,
|
def conv2d(x, weight, bias=None, stride=1, pad=0,
|
||||||
dilation=1, groups=1, padding_mode='zeros'):
|
dilation=1, groups=1, padding_mode='zeros'):
|
||||||
"""Convolution 2D"""
|
"""Convolution 2D"""
|
||||||
|
if isinstance(pad, int):
|
||||||
|
pad_top = pad
|
||||||
|
pad_bottom = pad
|
||||||
|
pad_left = pad
|
||||||
|
pad_right = pad
|
||||||
|
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||||
|
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The \'pad\' should be an int number or "
|
||||||
|
f"a tuple of two or four int numbers, but got {pad}")
|
||||||
|
|
||||||
batch_num, _, x_h, x_w = x.shape
|
batch_num, _, x_h, x_w = x.shape
|
||||||
filter_num, _, filter_h, filter_w = weight.shape
|
filter_num, _, filter_h, filter_w = weight.shape
|
||||||
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2])
|
out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2])
|
||||||
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3])
|
out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3])
|
||||||
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
|
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
|
||||||
col_w = np.reshape(weight, (filter_num, -1)).T
|
col_w = np.reshape(weight, (filter_num, -1)).T
|
||||||
out = np.dot(col, col_w)
|
out = np.dot(col, col_w)
|
||||||
|
|
|
@ -103,7 +103,8 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT
|
||||||
|
|
||||||
/* 获取梯度参数切分方案 */
|
/* 获取梯度参数切分方案 */
|
||||||
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum,
|
hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum,
|
||||||
u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force) {
|
u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force,
|
||||||
|
OriginalGraphShapeType shapeType) {
|
||||||
return HCCL_SUCCESS;
|
return HCCL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -540,6 +540,61 @@ class NormalNet(nn.Cell):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class LaplaceNet(nn.Cell):
|
||||||
|
def __init__(self, shape=None, seed=0):
|
||||||
|
super(LaplaceNet, self).__init__()
|
||||||
|
self.laplace = P.Laplace(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, mean, lambda_param):
|
||||||
|
out = self.laplace(self.shape, mean, lambda_param)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GammaNet(nn.Cell):
|
||||||
|
def __init__(self, shape=None, seed=0):
|
||||||
|
super(GammaNet, self).__init__()
|
||||||
|
self.gamma = P.Gamma(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, alpha, beta):
|
||||||
|
out = self.gamma(self.shape, alpha, beta)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PoissonNet(nn.Cell):
|
||||||
|
def __init__(self, shape=None, seed=0):
|
||||||
|
super(PoissonNet, self).__init__()
|
||||||
|
self.poisson = P.Poisson(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, mean):
|
||||||
|
out = self.poisson(self.shape, mean)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UniformIntNet(nn.Cell):
|
||||||
|
def __init__(self, shape=None, seed=0):
|
||||||
|
super(UniformIntNet, self).__init__()
|
||||||
|
self.uniformint = P.UniformInt(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, a, b):
|
||||||
|
out = self.uniformint(self.shape, a, b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class UniformRealNet(nn.Cell):
|
||||||
|
def __init__(self, shape=None, seed=0):
|
||||||
|
super(UniformRealNet, self).__init__()
|
||||||
|
self.uniformreal = P.UniformReal(seed=seed)
|
||||||
|
self.shape = shape
|
||||||
|
|
||||||
|
def construct(self, a, b):
|
||||||
|
out = self.uniformreal(self.shape, a, b)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class StridedSliceNet(nn.Cell):
|
class StridedSliceNet(nn.Cell):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(StridedSliceNet, self).__init__()
|
super(StridedSliceNet, self).__init__()
|
||||||
|
@ -819,6 +874,26 @@ test_case_math_ops = [
|
||||||
'block': NormalNet((3, 2, 4), 0),
|
'block': NormalNet((3, 2, 4), 0),
|
||||||
'desc_inputs': [Tensor(0.0, mstype.float32), Tensor(1.0, mstype.float32)],
|
'desc_inputs': [Tensor(0.0, mstype.float32), Tensor(1.0, mstype.float32)],
|
||||||
'skip': ['backward']}),
|
'skip': ['backward']}),
|
||||||
|
('Laplace', {
|
||||||
|
'block': LaplaceNet((3, 2, 4), 0),
|
||||||
|
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)],
|
||||||
|
'skip': ['backward']}),
|
||||||
|
('Gamma', {
|
||||||
|
'block': GammaNet((3, 2, 4), 0),
|
||||||
|
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)],
|
||||||
|
'skip': ['backward']}),
|
||||||
|
('Poisson', {
|
||||||
|
'block': PoissonNet((3, 2, 4), 0),
|
||||||
|
'desc_inputs': [Tensor(2.0, mstype.float32)],
|
||||||
|
'skip': ['backward']}),
|
||||||
|
('UniformInt', {
|
||||||
|
'block': UniformIntNet((3, 2, 4), 0),
|
||||||
|
'desc_inputs': [Tensor(1, mstype.int32), Tensor(15, mstype.int32)],
|
||||||
|
'skip': ['backward']}),
|
||||||
|
('UniformReal', {
|
||||||
|
'block': UniformRealNet((3, 2, 4), 0),
|
||||||
|
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(5.0, mstype.float32)],
|
||||||
|
'skip': ['backward']}),
|
||||||
('RandomChoiceWithMask', {
|
('RandomChoiceWithMask', {
|
||||||
'block': P.RandomChoiceWithMask(256),
|
'block': P.RandomChoiceWithMask(256),
|
||||||
'desc_inputs': [Tensor(np.random.rand(24000, 4).astype(np.bool_))],
|
'desc_inputs': [Tensor(np.random.rand(24000, 4).astype(np.bool_))],
|
||||||
|
|
|
@ -29,7 +29,6 @@ context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
class LeNet5(nn.Cell):
|
class LeNet5(nn.Cell):
|
||||||
""" LeNet5 definition """
|
""" LeNet5 definition """
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LeNet5, self).__init__()
|
super(LeNet5, self).__init__()
|
||||||
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
|
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
|
||||||
|
|
|
@ -169,16 +169,32 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
|
||||||
raise ValueError(f"The \'stride\' should be an int number or "
|
raise ValueError(f"The \'stride\' should be an int number or "
|
||||||
f"a tuple of two or four int numbers, but got {stride}")
|
f"a tuple of two or four int numbers, but got {stride}")
|
||||||
|
|
||||||
|
if isinstance(pad, int):
|
||||||
|
pad_top = pad
|
||||||
|
pad_bottom = pad
|
||||||
|
pad_left = pad
|
||||||
|
pad_right = pad
|
||||||
|
elif isinstance(pad, tuple) and len(pad) == 2:
|
||||||
|
pad_top = pad[0]
|
||||||
|
pad_bottom = pad[0]
|
||||||
|
pad_left = pad[1]
|
||||||
|
pad_right = pad[1]
|
||||||
|
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||||
|
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The \'pad\' should be an int number or "
|
||||||
|
f"a tuple of two or four int numbers, but got {pad}")
|
||||||
|
|
||||||
batch_num, channel, height, width = input_shape
|
batch_num, channel, height, width = input_shape
|
||||||
out_h = (height + 2 * pad - filter_h) // stride_h + 1
|
out_h = (height + pad_top + pad_bottom - filter_h) // stride_h + 1
|
||||||
out_w = (width + 2 * pad - filter_w) // stride_w + 1
|
out_w = (width + pad_left + pad_right - filter_w) // stride_w + 1
|
||||||
col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
|
col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \
|
||||||
.transpose(0, 3, 4, 5, 1, 2)
|
.transpose(0, 3, 4, 5, 1, 2)
|
||||||
|
|
||||||
img = np.zeros((batch_num,
|
img = np.zeros((batch_num,
|
||||||
channel,
|
channel,
|
||||||
height + 2 * pad + stride_h - 1,
|
height + pad_top + pad_bottom + stride_h - 1,
|
||||||
width + 2 * pad + stride_w - 1)) \
|
width + pad_left + pad_right + stride_w - 1)) \
|
||||||
.astype(col.dtype)
|
.astype(col.dtype)
|
||||||
for y in range(filter_h):
|
for y in range(filter_h):
|
||||||
y_max = y + stride_h * out_h
|
y_max = y + stride_h * out_h
|
||||||
|
@ -186,7 +202,7 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0):
|
||||||
x_max = x + stride_h * out_w
|
x_max = x + stride_h * out_w
|
||||||
img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
|
img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :]
|
||||||
|
|
||||||
return img[:, :, pad:height + pad, pad:width + pad]
|
return img[:, :, pad_top:height + pad_bottom, pad_left:width + pad_right]
|
||||||
|
|
||||||
|
|
||||||
def convolve(x, w, b=None, pad_mode="valid"):
|
def convolve(x, w, b=None, pad_mode="valid"):
|
||||||
|
@ -243,10 +259,21 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
|
||||||
dilation_h = dilation[0]
|
dilation_h = dilation[0]
|
||||||
dilation_w = dilation[1]
|
dilation_w = dilation[1]
|
||||||
|
|
||||||
|
if isinstance(pad, int):
|
||||||
|
pad_top = pad
|
||||||
|
pad_bottom = pad
|
||||||
|
pad_left = pad
|
||||||
|
pad_right = pad
|
||||||
|
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||||
|
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The \'pad\' should be an int number or "
|
||||||
|
f"a tuple of two or four int numbers, but got {pad}")
|
||||||
|
|
||||||
batch_num, _, x_h, x_w = x.shape
|
batch_num, _, x_h, x_w = x.shape
|
||||||
filter_num, _, filter_h, filter_w = weight.shape
|
filter_num, _, filter_h, filter_w = weight.shape
|
||||||
out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
|
out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h)
|
||||||
out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
|
out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w)
|
||||||
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
|
col = im2col(x, filter_h, filter_w, stride, pad, dilation)
|
||||||
col_w = np.reshape(weight, (filter_num, -1)).T
|
col_w = np.reshape(weight, (filter_num, -1)).T
|
||||||
out = np.dot(col, col_w)
|
out = np.dot(col, col_w)
|
||||||
|
@ -348,11 +375,22 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1):
|
||||||
raise ValueError(f"The \'dilation\' should be an int number or "
|
raise ValueError(f"The \'dilation\' should be an int number or "
|
||||||
f"a tuple of two or four int numbers, but got {dilation}")
|
f"a tuple of two or four int numbers, but got {dilation}")
|
||||||
|
|
||||||
batch_num, channel, height, width = img.shape
|
if isinstance(pad, int):
|
||||||
out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
|
pad_top = pad
|
||||||
out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
|
pad_bottom = pad
|
||||||
|
pad_left = pad
|
||||||
|
pad_right = pad
|
||||||
|
elif isinstance(pad, tuple) and len(pad) == 4:
|
||||||
|
pad_top, pad_bottom, pad_left, pad_right = pad
|
||||||
|
else:
|
||||||
|
raise ValueError(f"The \'pad\' should be an int number or "
|
||||||
|
f"a tuple of two or four int numbers, but got {pad}")
|
||||||
|
|
||||||
img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant')
|
batch_num, channel, height, width = img.shape
|
||||||
|
out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1
|
||||||
|
out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1
|
||||||
|
|
||||||
|
img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant')
|
||||||
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
|
col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype)
|
||||||
|
|
||||||
for y in range(filter_h):
|
for y in range(filter_h):
|
||||||
|
|
Loading…
Reference in New Issue