forked from mindspore-Ecosystem/mindspore
sync from mindspore to incubator
This commit is contained in:
commit
21c381b366
|
@ -10,9 +10,9 @@
|
||||||
[submodule "third_party/protobuf"]
|
[submodule "third_party/protobuf"]
|
||||||
path = third_party/protobuf
|
path = third_party/protobuf
|
||||||
url = https://github.com/protocolbuffers/protobuf.git
|
url = https://github.com/protocolbuffers/protobuf.git
|
||||||
[submodule "graphengine"]
|
|
||||||
path = graphengine
|
|
||||||
url = https://gitee.com/mindspore/graphengine.git
|
|
||||||
[submodule "akg"]
|
[submodule "akg"]
|
||||||
path = akg
|
path = akg
|
||||||
url = https://gitee.com/mindspore/akg.git
|
url = https://gitee.com/mindspore/akg.git
|
||||||
|
[submodule "graphengine"]
|
||||||
|
path = graphengine
|
||||||
|
url = https://gitee.com/ms-incubator/graphengine.git
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit dda72a48c7e0033389bd377c5804d485fdf3112d
|
Subproject commit 8891f0546c4a250095ff68e1262f58772b938fd9
|
|
@ -141,7 +141,7 @@ if (ENABLE_GE)
|
||||||
else ()
|
else ()
|
||||||
target_link_libraries(mindspore ge_client)
|
target_link_libraries(mindspore ge_client)
|
||||||
endif ()
|
endif ()
|
||||||
target_link_libraries(mindspore graph tsdclient)
|
target_link_libraries(mindspore graph tsdclient datatransfer)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (ENABLE_D)
|
if (ENABLE_D)
|
||||||
|
|
|
@ -29,6 +29,7 @@ constexpr auto kInitData = "InitData";
|
||||||
constexpr auto kGetNext = "GetNext";
|
constexpr auto kGetNext = "GetNext";
|
||||||
constexpr auto kPrint = "Print";
|
constexpr auto kPrint = "Print";
|
||||||
constexpr auto kPack = "Pack";
|
constexpr auto kPack = "Pack";
|
||||||
|
|
||||||
constexpr auto kOutputTypes = "output_types";
|
constexpr auto kOutputTypes = "output_types";
|
||||||
constexpr auto kOutputShapes = "output_shapes";
|
constexpr auto kOutputShapes = "output_shapes";
|
||||||
constexpr auto kChannelName = "channel_name";
|
constexpr auto kChannelName = "channel_name";
|
||||||
|
|
|
@ -58,7 +58,6 @@ class GetMakeRefEliminater : public OptimizerCaller {
|
||||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefKey, PPrimitive(prim::kPrimMakeRef, x, y, z)), x);
|
||||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefValue, PPrimitive(prim::kPrimMakeRef, x, y, z)), y);
|
||||||
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
|
MATCH_REPLACE(node, PPrimitive(prim::kPrimGetRefOrigin, PPrimitive(prim::kPrimMakeRef, x, y, z)), z);
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -1168,7 +1168,7 @@ INPUT_MAP(SparseApplyAdagradD) = {
|
||||||
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}};
|
{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(grad)}, {4, INPUT_DESC(indices)}};
|
||||||
ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
|
ATTR_MAP(SparseApplyAdagradD) = {{"lr", ATTR_DESC(lr, AnyTraits<float>())},
|
||||||
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
{"use_locking", ATTR_DESC(use_locking, AnyTraits<bool>())}};
|
||||||
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}};
|
OUTPUT_MAP(SparseApplyAdagradD) = {{0, OUTPUT_DESC(var)}, {1, OUTPUT_DESC(accum)}};
|
||||||
|
|
||||||
// ApplyProximalAdagradD
|
// ApplyProximalAdagradD
|
||||||
INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)},
|
INPUT_MAP(ApplyProximalAdagradD) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(accum)}, {3, INPUT_DESC(lr)},
|
||||||
|
|
|
@ -181,14 +181,21 @@ bool MsContext::OpenTsd() {
|
||||||
}
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
|
MS_LOG(INFO) << "Device id = " << device_id << ", rank size = " << rank_size << ".";
|
||||||
|
#if (defined(ENABLE_TDTQUE) && defined(ENABLE_GE))
|
||||||
|
int32_t initStatus = tdt::TdtHostInit(device_id);
|
||||||
|
if (initStatus != TDT_OK_CODE) {
|
||||||
|
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tdt_print_ = std::thread(TensorPrint());
|
||||||
|
#endif
|
||||||
TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size);
|
TDT_StatusT status = tdt::TsdClient::GetInstance()->Open(device_id, rank_size);
|
||||||
if (status != TDT_OK) {
|
if (status != TDT_OK) {
|
||||||
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
|
MS_LOG(EXCEPTION) << "Device " << device_id << " is occupied, open tsd failed, status = " << status << ".";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
tsd_ref_++;
|
tsd_ref_++;
|
||||||
#ifdef ENABLE_TDTQUE
|
#if (defined(ENABLE_TDTQUE) && !defined(ENABLE_GE))
|
||||||
int32_t initStatus = tdt::TdtHostInit(device_id);
|
int32_t initStatus = tdt::TdtHostInit(device_id);
|
||||||
if (initStatus != TDT_OK_CODE) {
|
if (initStatus != TDT_OK_CODE) {
|
||||||
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
MS_LOG(EXCEPTION) << "Init tsd failed, status = " << initStatus << ".";
|
||||||
|
|
|
@ -342,7 +342,6 @@ class Optimizer(Cell):
|
||||||
current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0)
|
current_dynamic_lr = self.gather(self.learning_rate[i], self.global_step, 0)
|
||||||
lr += (current_dynamic_lr,)
|
lr += (current_dynamic_lr,)
|
||||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
lr = self.learning_rate
|
lr = self.learning_rate
|
||||||
if self.dynamic_lr:
|
if self.dynamic_lr:
|
||||||
|
|
|
@ -518,6 +518,18 @@ def get_bprop_l2_loss(self):
|
||||||
return bprop
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
|
@bprop_getters.register(P.RNNTLoss)
|
||||||
|
def get_bprop_rnnt_loss(self):
|
||||||
|
"""Grad definition for `RNNTLoss` operation."""
|
||||||
|
expand = P.ExpandDims()
|
||||||
|
|
||||||
|
def bprop(acts, labels, act_lens, label_lens, out, dout):
|
||||||
|
grad_loss = out[1]
|
||||||
|
grad = grad_loss * expand(expand(expand(dout[0], -1), -1), -1)
|
||||||
|
return grad, zeros_like(labels), zeros_like(act_lens), zeros_like(label_lens)
|
||||||
|
return bprop
|
||||||
|
|
||||||
|
|
||||||
@bprop_getters.register(P.PReLU)
|
@bprop_getters.register(P.PReLU)
|
||||||
def get_bprop_prelu(self):
|
def get_bprop_prelu(self):
|
||||||
"""Grad definition for `PReLU` operation."""
|
"""Grad definition for `PReLU` operation."""
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
"""aicpu ops"""
|
"""aicpu ops"""
|
||||||
from .init_data_set_queue import _init_data_set_queue_aicpu
|
from .init_data_set_queue import _init_data_set_queue_aicpu
|
||||||
|
from .embedding_lookup import _embedding_lookup_aicpu
|
||||||
from .dropout_genmask import _dropout_genmask_aicpu
|
from .dropout_genmask import _dropout_genmask_aicpu
|
||||||
from .get_next import _get_next_aicpu
|
from .get_next import _get_next_aicpu
|
||||||
from .print_tensor import _print_aicpu
|
from .print_tensor import _print_aicpu
|
||||||
|
@ -29,3 +30,6 @@ from .normal import _normal_aicpu
|
||||||
from .ctcloss import _ctcloss_aicpu
|
from .ctcloss import _ctcloss_aicpu
|
||||||
from .reverse_sequence import _reverse_sequence_aicpu
|
from .reverse_sequence import _reverse_sequence_aicpu
|
||||||
from .crop_and_resize import _crop_and_resize_aicpu
|
from .crop_and_resize import _crop_and_resize_aicpu
|
||||||
|
from .rnnt_loss import _rnnt_loss_aicpu
|
||||||
|
from .random_categorical import _random_categorical_aicpu
|
||||||
|
from .cast import _cast_aicpu
|
||||||
|
|
|
@ -0,0 +1,172 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""Cast op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
cast_op_info = AiCPURegOp("Cast") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "x", "required") \
|
||||||
|
.output(0, "y", "required") \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(cast_op_info)
|
||||||
|
def _cast_aicpu():
|
||||||
|
"""Cast AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""EmbeddingLookup op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
embeddingLookup_op_info = AiCPURegOp("EmbeddingLookup") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "params", "required") \
|
||||||
|
.input(1, "indices", "required") \
|
||||||
|
.input(2, "offset", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I32_Default, \
|
||||||
|
DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, \
|
||||||
|
DataType.I64_Default, DataType.BOOL_Default) \
|
||||||
|
.dtype_format(DataType.I8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I8_Default) \
|
||||||
|
.dtype_format(DataType.I16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.I32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.I64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.U8_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U8_Default) \
|
||||||
|
.dtype_format(DataType.U16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U16_Default) \
|
||||||
|
.dtype_format(DataType.U32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U32_Default) \
|
||||||
|
.dtype_format(DataType.U64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.U64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.F64_Default) \
|
||||||
|
.dtype_format(DataType.BOOL_Default, DataType.I64_Default, \
|
||||||
|
DataType.I32_Default, DataType.BOOL_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(embeddingLookup_op_info)
|
||||||
|
def _embedding_lookup_aicpu():
|
||||||
|
"""EmbeddingLookup AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,48 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RandomCategorical op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
random_categorical_op_info = AiCPURegOp("RandomCategorical") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "logits", "required") \
|
||||||
|
.input(1, "num_sample", "required") \
|
||||||
|
.input(2, "seed", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I16_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I32_Default) \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(random_categorical_op_info)
|
||||||
|
def _random_categorical_aicpu():
|
||||||
|
"""RandomCategorical AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""RNNTLoss op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
rnnt_loss_op_info = AiCPURegOp("RNNTLoss") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.input(0, "acts", "required") \
|
||||||
|
.input(1, "labels", "required") \
|
||||||
|
.input(2, "input_lengths", "required") \
|
||||||
|
.input(3, "label_lengths", "required") \
|
||||||
|
.output(0, "costs", "required") \
|
||||||
|
.output(1, "grads", "required") \
|
||||||
|
.attr("blank_label", "int") \
|
||||||
|
.dtype_format(DataType.F32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.F32_NCHW,
|
||||||
|
DataType.F32_NCHW) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default,
|
||||||
|
DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
@op_info_register(rnnt_loss_op_info)
|
||||||
|
def _rnnt_loss_aicpu():
|
||||||
|
"""RNNTLoss AiCPU register"""
|
||||||
|
return
|
|
@ -54,7 +54,7 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
|
||||||
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
|
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
|
||||||
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps)
|
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps)
|
||||||
|
|
||||||
from .random_ops import (RandomChoiceWithMask, Normal)
|
from .random_ops import (RandomChoiceWithMask, Normal, RandomCategorical)
|
||||||
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
|
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
|
||||||
BiasAdd, Conv2D,
|
BiasAdd, Conv2D,
|
||||||
DepthwiseConv2dNative,
|
DepthwiseConv2dNative,
|
||||||
|
@ -69,6 +69,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
||||||
ResizeBilinear, Sigmoid,
|
ResizeBilinear, Sigmoid,
|
||||||
SigmoidCrossEntropyWithLogits,
|
SigmoidCrossEntropyWithLogits,
|
||||||
SmoothL1Loss, Softmax, Softplus,
|
SmoothL1Loss, Softmax, Softplus,
|
||||||
|
RNNTLoss,
|
||||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||||
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||||
|
@ -77,6 +78,8 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
|
||||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
|
||||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop)
|
CheckValid, MakeRefKey, Partial, Depend, CheckBprop)
|
||||||
|
from . import _quant_ops
|
||||||
|
from ._quant_ops import *
|
||||||
from .thor_ops import *
|
from .thor_ops import *
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -168,6 +171,7 @@ __all__ = [
|
||||||
'Tanh',
|
'Tanh',
|
||||||
'RandomChoiceWithMask',
|
'RandomChoiceWithMask',
|
||||||
'Normal',
|
'Normal',
|
||||||
|
'RandomCategorical',
|
||||||
'ResizeBilinear',
|
'ResizeBilinear',
|
||||||
'ScalarSummary',
|
'ScalarSummary',
|
||||||
'ImageSummary',
|
'ImageSummary',
|
||||||
|
@ -198,6 +202,7 @@ __all__ = [
|
||||||
'SmoothL1Loss',
|
'SmoothL1Loss',
|
||||||
'L2Loss',
|
'L2Loss',
|
||||||
'CTCLoss',
|
'CTCLoss',
|
||||||
|
'RNNTLoss',
|
||||||
'ReduceAll',
|
'ReduceAll',
|
||||||
'ScalarToArray',
|
'ScalarToArray',
|
||||||
'ScalarToTensor',
|
'ScalarToTensor',
|
||||||
|
@ -302,6 +307,7 @@ __all__ = [
|
||||||
"ApplyCenteredRMSProp",
|
"ApplyCenteredRMSProp",
|
||||||
"SpaceToBatchND",
|
"SpaceToBatchND",
|
||||||
"BatchToSpaceND",
|
"BatchToSpaceND",
|
||||||
|
"ReverseSequence",
|
||||||
"SquareSumAll",
|
"SquareSumAll",
|
||||||
"BitwiseAnd",
|
"BitwiseAnd",
|
||||||
"BitwiseOr",
|
"BitwiseOr",
|
||||||
|
@ -315,7 +321,8 @@ __all__ = [
|
||||||
"DataFormatDimMap",
|
"DataFormatDimMap",
|
||||||
"ApproximateEqual",
|
"ApproximateEqual",
|
||||||
"InplaceUpdate",
|
"InplaceUpdate",
|
||||||
"InTopK"
|
"InTopK",
|
||||||
|
"CropAndResize"
|
||||||
]
|
]
|
||||||
|
|
||||||
__all__.sort()
|
__all__.sort()
|
||||||
|
|
|
@ -1093,8 +1093,18 @@ class StridedSliceGrad(PrimitiveWithInfer):
|
||||||
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output'])
|
||||||
|
|
||||||
def __infer__(self, dy, shapex, begin, end, strides):
|
def __infer__(self, dy, shapex, begin, end, strides):
|
||||||
args = {"shapex": shapex['dtype'],"begin": begin['dtype'],"end": end['dtype'],"strides": strides['dtype']}
|
args = {"dy": dy['dtype']}
|
||||||
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
validator.check_tensor_type_same(args, mstype.number_type, self.name)
|
||||||
|
|
||||||
|
for idx, item in enumerate(shapex['value']):
|
||||||
|
validator.check_value_type("shapex[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(begin['value']):
|
||||||
|
validator.check_value_type("begin[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(end['value']):
|
||||||
|
validator.check_value_type("end[%d]" % idx, item, [int], self.name)
|
||||||
|
for idx, item in enumerate(strides['value']):
|
||||||
|
validator.check_value_type("strides[%d]" % idx, item, [int], self.name)
|
||||||
|
|
||||||
return {'shape': shapex['value'],
|
return {'shape': shapex['value'],
|
||||||
'dtype': dy['dtype'],
|
'dtype': dy['dtype'],
|
||||||
'value': None}
|
'value': None}
|
||||||
|
|
|
@ -1697,6 +1697,60 @@ class DataFormatDimMap(PrimitiveWithInfer):
|
||||||
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
validator.check_tensor_type_same({"x": x_type}, valid_types, self.name)
|
||||||
return x_type
|
return x_type
|
||||||
|
|
||||||
|
class RNNTLoss(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Computes the RNNTLoss and its gradient with respect to the softmax outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blank_label (int): blank label. Default: 0.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **acts** (Tensor[float32]) - Tensor of shape :math:`(B, T, U, V)`.
|
||||||
|
- **labels** (Tensor[int32]) - Tensor of shape :math:`(B, N)`.
|
||||||
|
- **input_lengths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
|
||||||
|
- **label_lebgths** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **costs** (Tensor[int32]) - Tensor of shape :math:`(B,)`.
|
||||||
|
- **grads** (Tensor[int32]) - Has the same shape as `acts`.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> B, T, U, V = 1, 2, 3, 5
|
||||||
|
>>> acts = np.random.random((B, T, U, V)).astype(np.float32)
|
||||||
|
>>> labels = np.array([[1, 2]]).astype(np.int32)
|
||||||
|
>>> input_length = np.array([T] * B).astype(np.int32)
|
||||||
|
>>> label_length = np.array([len(l) for l in labels]).astype(np.int32)
|
||||||
|
>>> rnnt_loss = P.RNNTLoss(blank_label=blank)
|
||||||
|
>>> costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, blank_label=0):
|
||||||
|
validator.check_value_type('blank_label', blank_label, [int], self.name)
|
||||||
|
self.init_prim_io_names(inputs=['acts', 'labels', 'input_length', 'label_length'],
|
||||||
|
outputs=['costs', 'grads'])
|
||||||
|
|
||||||
|
def infer_shape(self, acts_shape, labels_shape, input_length_shape, label_length_shape):
|
||||||
|
validator.check_integer('acts_rank', len(acts_shape), 4, Rel.EQ, self.name)
|
||||||
|
validator.check_integer('labels_rank', len(labels_shape), 2, Rel.EQ, self.name)
|
||||||
|
validator.check_integer('input_length_rank', len(input_length_shape), 1, Rel.EQ, self.name)
|
||||||
|
validator.check_integer('label_length_rank', len(label_length_shape), 1, Rel.EQ, self.name)
|
||||||
|
validator.check('labels shape[0]', labels_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
|
||||||
|
validator.check('input_length size', input_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
|
||||||
|
validator.check('label_length size', label_length_shape[0], 'acts shape[0]', acts_shape[0], Rel.EQ, self.name)
|
||||||
|
costs_shape = (acts_shape[0],)
|
||||||
|
return (costs_shape, acts_shape)
|
||||||
|
|
||||||
|
def infer_dtype(self, acts_type, labels_type, input_length_type, label_length_type):
|
||||||
|
validator.check_subclass("acts_type", acts_type, mstype.tensor, self.name)
|
||||||
|
validator.check_subclass("labels_type", labels_type, mstype.tensor, self.name)
|
||||||
|
validator.check_subclass("input_length_type", input_length_type, mstype.tensor, self.name)
|
||||||
|
validator.check_subclass("label_length_type", label_length_type, mstype.tensor, self.name)
|
||||||
|
validator.check_tensor_type_same({"acts_type": acts_type}, [mstype.float32], self.name)
|
||||||
|
validator.check_tensor_type_same({"labels_type": labels_type}, [mstype.int32], self.name)
|
||||||
|
validator.check_tensor_type_same({"input_length_type": input_length_type}, [mstype.int32], self.name)
|
||||||
|
validator.check_tensor_type_same({"label_length_type": label_length_type}, [mstype.int32], self.name)
|
||||||
|
return (acts_type, acts_type)
|
||||||
|
|
||||||
|
|
||||||
class SGD(PrimitiveWithInfer):
|
class SGD(PrimitiveWithInfer):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -108,3 +108,60 @@ class Normal(PrimitiveWithInfer):
|
||||||
"dtype": mstype.float32,
|
"dtype": mstype.float32,
|
||||||
"value": None}
|
"value": None}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class RandomCategorical(PrimitiveWithInfer):
|
||||||
|
"""
|
||||||
|
Generates random samples from a given categorical distribution tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtype (mindspore.dtype): The type of output. Its value should be one of [mindspore.int16,
|
||||||
|
mindspore.int32, mindspore.int64]. Default: mindspore.int64.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
- **logits** (Tensor) - The input tensor. 2-D Tensor with shape [batch_size, num_classes].
|
||||||
|
- **num_sample** (int) - Number of sample to be drawn. Only constant values is allowed.
|
||||||
|
- **seed** (int) - Random seed. Default: 0.
|
||||||
|
|
||||||
|
Outputs:
|
||||||
|
- **output** (Tensor) - The output Tensor with shape [batch_size, num_samples].
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> class Net(nn.Cell):
|
||||||
|
>>> def __init__(self, num_sample):
|
||||||
|
>>> super(Net, self).__init__()
|
||||||
|
>>> self.random_categorical = P.RandomCategorical(mindspore.int64)
|
||||||
|
>>> self.num_sample = num_sample
|
||||||
|
>>> def construct(self, logits, seed=0):
|
||||||
|
>>> return self.random_categorical(logits, self.num_sample, seed)
|
||||||
|
>>>
|
||||||
|
>>> x = np.random.random((10, 5)).astype(np.float32)
|
||||||
|
>>> net = Net(8)
|
||||||
|
>>> output = net(Tensor(x))
|
||||||
|
"""
|
||||||
|
@prim_attr_register
|
||||||
|
def __init__(self, dtype=mstype.int64):
|
||||||
|
"""Init RandomCategorical"""
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
valid_values = (mstype.int32, mstype.int16, mstype.int64)
|
||||||
|
validator.check_type_name("dtype", dtype, valid_values, self.name)
|
||||||
|
self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'],
|
||||||
|
outputs=['output'])
|
||||||
|
|
||||||
|
def __infer__(self, logits, num_samples, seed):
|
||||||
|
logits_dtype = logits['dtype']
|
||||||
|
valid_types = (mstype.float32, mstype.float16, mstype.float64)
|
||||||
|
validator.check_tensor_type_same({'logits': logits_dtype}, valid_types, self.name)
|
||||||
|
num_samples_v = num_samples['value']
|
||||||
|
seed_v = seed['value']
|
||||||
|
validator.check_value_type('num_samples', num_samples_v, (int,), self.name)
|
||||||
|
validator.check_value_type('seed', seed_v, (int,), self.name)
|
||||||
|
validator.check_integer("num_samples", num_samples_v, 0, Rel.GT, self.name)
|
||||||
|
x_shape = list(logits['shape'])
|
||||||
|
if len(x_shape) != 2:
|
||||||
|
raise ValueError("RandomCategorical shape should be 2-dimension.")
|
||||||
|
ndim = len(x_shape) - 1
|
||||||
|
x_shape[ndim] = num_samples_v
|
||||||
|
return {'shape': (x_shape),
|
||||||
|
'dtype': (self.dtype),
|
||||||
|
'value': None}
|
||||||
|
|
|
@ -0,0 +1,75 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, x, dtype):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.cast = P.Cast()
|
||||||
|
self.x = x
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
def construct(self):
|
||||||
|
return self.cast(self.x, self.dtype)
|
||||||
|
|
||||||
|
def test_net_f32_bool():
|
||||||
|
x = np.random.randn(3,4).astype(np.float32)
|
||||||
|
x[:,1] = 0
|
||||||
|
net = Net(Tensor(x), mstype.bool_)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_f16_bool():
|
||||||
|
x = np.random.randn(3,4).astype(np.float16)
|
||||||
|
x[:,1] = 0
|
||||||
|
net = Net(Tensor(x), mstype.bool_)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_f64_bool():
|
||||||
|
x = np.random.randn(3,4).astype(np.float64)
|
||||||
|
x[:,1] = 0
|
||||||
|
net = Net(Tensor(x), mstype.bool_)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_int16_float16():
|
||||||
|
x = np.random.randint(-512, 512, size=(3,4)).astype(np.int16)
|
||||||
|
net = Net(Tensor(x), mstype.float16)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
||||||
|
|
||||||
|
def test_net_int64_float16():
|
||||||
|
x = np.random.randint(-512, 512, size=(3,4)).astype(np.int64)
|
||||||
|
net = Net(Tensor(x), mstype.float16)
|
||||||
|
output = net()
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(Tensor(x).dtype)
|
||||||
|
print(output.dtype)
|
|
@ -127,7 +127,6 @@ def test_net_int64():
|
||||||
print(output.asnumpy())
|
print(output.asnumpy())
|
||||||
assert np.array_equal(output.asnumpy(), np.stack([x, y], axis))
|
assert np.array_equal(output.asnumpy(), np.stack([x, y], axis))
|
||||||
|
|
||||||
|
|
||||||
def test_net_uint64():
|
def test_net_uint64():
|
||||||
x = np.random.randn(3, 5, 4).astype(np.uint64)
|
x = np.random.randn(3, 5, 4).astype(np.uint64)
|
||||||
y = np.random.randn(3, 5, 4).astype(np.uint64)
|
y = np.random.randn(3, 5, 4).astype(np.uint64)
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import mindspore
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.context as context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, num_sample):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.random_categorical = P.RandomCategorical(mindspore.int64)
|
||||||
|
self.num_sample = num_sample
|
||||||
|
|
||||||
|
def construct(self, logits, seed=0):
|
||||||
|
return self.random_categorical(logits, self.num_sample, seed)
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
x = np.random.random((10, 5)).astype(np.float32)
|
||||||
|
net = Net(8)
|
||||||
|
output = net(Tensor(x))
|
||||||
|
print(x)
|
||||||
|
print(output.asnumpy())
|
||||||
|
print(output.dtype())
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import mindspore as ms
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.common.api import ms_function
|
||||||
|
import numpy as np
|
||||||
|
import mindspore.context as context
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.rnnt_loss = P.RNNTLoss(blank_label=0)
|
||||||
|
|
||||||
|
def construct(self, acts, labels, act_lens, label_lens):
|
||||||
|
return self.rnnt_loss(acts, labels, act_lens, label_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def test_net():
|
||||||
|
B, T, U, V = 1, 2, 3, 5
|
||||||
|
acts = np.random.random((B, T, U, V)).astype(np.float32)
|
||||||
|
labels = np.array([[np.random.randint(1, V-1) for _ in range(U-1)]]).astype(np.int32)
|
||||||
|
input_length = np.array([T] * B).astype(np.int32)
|
||||||
|
label_length = np.array([len(l) for l in labels]).astype(np.int32)
|
||||||
|
|
||||||
|
rnnt_loss = Net()
|
||||||
|
costs, grads = rnnt_loss(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
|
||||||
|
print(Tensor(acts), Tensor(labels), Tensor(input_length), Tensor(label_length))
|
||||||
|
print(costs.asnumpy())
|
||||||
|
print(grads.asnumpy())
|
|
@ -0,0 +1,42 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE,
|
||||||
|
device_target="Ascend")
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, offset):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.embedding = P.EmbeddingLookup()
|
||||||
|
self.offset = offset
|
||||||
|
|
||||||
|
def construct(self, param, index):
|
||||||
|
return self.embedding(param, index, self.offset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_lookup_sparse():
|
||||||
|
params = Tensor(np.array([[8, 9], [10, 11], [12, 13], [14, 15]]), mstype.int32)
|
||||||
|
indices = Tensor(np.array([[5, 2], [8, 5]]), mstype.int32)
|
||||||
|
offset = 4
|
||||||
|
embedding = Net(offset)
|
||||||
|
out = embedding(params, indices)
|
||||||
|
assert(out.asnumpy() == [[[10, 11], [0, 0]], [[0, 0], [10, 11]]]).all()
|
|
@ -29,7 +29,6 @@ context.set_context(mode=context.GRAPH_MODE)
|
||||||
|
|
||||||
class LeNet5(nn.Cell):
|
class LeNet5(nn.Cell):
|
||||||
""" LeNet5 definition """
|
""" LeNet5 definition """
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(LeNet5, self).__init__()
|
super(LeNet5, self).__init__()
|
||||||
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
|
self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
|
||||||
|
|
Loading…
Reference in New Issue