!5283 Support setting operator io format in the frontend

Merge pull request !5283 from liangchenghui/io_format
This commit is contained in:
mindspore-ci-bot 2020-08-27 15:48:13 +08:00 committed by Gitee
commit e94416be0c
11 changed files with 160 additions and 98 deletions

View File

@ -276,12 +276,8 @@ OutHandler OpAdapterImpl::getNormalOutput(const OperatorPtr &op, int index) {
}
Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
const TypePtr &type) {
const TypePtr &type, const std::string &format) {
MS_EXCEPTION_IF_NULL(type);
std::string format = "NCHW";
if (op->GetOpType() == kExtractImagePatchesOpName) {
format = "NHWC";
}
auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format);
if (desc == nullptr) {
@ -340,7 +336,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::Sh
}
Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
const TypePtr &type) {
const TypePtr &type, const std::string &format) {
auto tuple_shp = dyn_cast<abstract::TupleShape>(shp);
MS_EXCEPTION_IF_NULL(tuple_shp);
@ -361,10 +357,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
MS_LOG(ERROR) << "output_map is not equal tuple_shape size";
return FAILED;
}
std::string format = "NCHW";
if (op->GetOpType() == kTopKOpName) {
format = "NHWC";
}
for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
auto tuple_type = dyn_cast<Tuple>(type);
MS_EXCEPTION_IF_NULL(tuple_type);
@ -389,7 +382,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
return SUCCESS;
}
std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node) {
std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node, const std::string &format) {
MS_EXCEPTION_IF_NULL(node);
TypeId me_type = node->Type()->type_id();
if (kObjectTypeTensorType == me_type) {
@ -405,7 +398,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no
shape = shape_ptr->shape();
}
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW");
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
if (desc == nullptr) {
MS_LOG(ERROR) << "Update output descriptor failed!";
return nullptr;
@ -413,7 +406,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no
return desc;
}
void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node) {
void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is nullptr";
return;
@ -424,19 +417,18 @@ void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNode
for (size_t i = 1; i < inputs.size(); ++i) {
auto it = input_map_.find(i);
if (it != input_map_.end()) {
auto desc = CreateNodeDesc(inputs[i]);
auto desc = CreateNodeDesc(inputs[i], format);
if (desc == nullptr) {
continue;
}
if (op->GetOpType() == kExtractImagePatchesOpName) {
desc->SetFormat(ge::Format::FORMAT_NHWC);
}
it->second.update_input_desc(op, *desc);
}
}
}
void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) {
void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node,
const std::string format) {
if (op == nullptr) {
MS_LOG(ERROR) << "op is nullptr";
return;
@ -452,7 +444,7 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN
auto inputs = node->cast<CNodePtr>()->inputs();
for (size_t i = 1; i < inputs.size(); ++i) {
if (input_map.find(i) != input_map.end()) {
auto desc = CreateNodeDesc(inputs[i]);
auto desc = CreateNodeDesc(inputs[i], format);
if (desc == nullptr) {
continue;
}
@ -464,11 +456,12 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN
void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(op);
MS_EXCEPTION_IF_NULL(node);
std::string format = GetOpIOFormat(node);
if (IsCustomOp(op)) {
auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
UpdateCustomOpInputDesc(cus_op, node);
UpdateCustomOpInputDesc(cus_op, node, format);
} else {
UpdateNormalOpInputDesc(op, node);
UpdateNormalOpInputDesc(op, node, format);
}
}
@ -483,13 +476,14 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base
auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);
std::string format = GetOpIOFormat(node);
if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) {
if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) {
if (UpdateSingleOutputDesc(op, shp, type, format) != SUCCESS) {
return;
}
} else if (nullptr != dyn_cast<abstract::TupleShape>(shp)) {
if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) {
if (UpdateMultiOutputDesc(op, shp, type, format) != SUCCESS) {
return;
}
} else {

View File

@ -75,14 +75,16 @@ class OpAdapterImpl {
OutHandler getOutput(const OperatorPtr &op, int index);
OutHandler getCustomOutput(const OperatorPtr &op, int index);
OutHandler getNormalOutput(const OperatorPtr &op, int index);
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type);
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
const std::string &format);
size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op);
std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
const std::string &format);
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type);
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node);
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node);
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node);
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
const std::string &format);
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format);
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format);
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format);
void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node);
void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
const AnfNodePtr &node);
@ -226,8 +228,9 @@ class OpAdapter : public BaseOpAdapter {
OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); }
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) {
return impl_->UpdateSingleOutputDesc(op, shp, type);
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
const std::string &format) {
return impl_->UpdateSingleOutputDesc(op, shp, type, format);
}
size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); }
@ -237,18 +240,21 @@ class OpAdapter : public BaseOpAdapter {
return impl_->CreateOutputDesc(shape_ptr, type, format);
}
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) {
return impl_->UpdateMultiOutputDesc(op, shp, type);
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
const std::string &format) {
return impl_->UpdateMultiOutputDesc(op, shp, type, format);
}
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node) { return impl_->CreateNodeDesc(node); }
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) {
return impl_->UpdateNormalOpInputDesc(op, node);
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) {
return impl_->CreateNodeDesc(node, format);
}
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) {
return impl_->UpdateCustomOpInputDesc(op, node);
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) {
return impl_->UpdateNormalOpInputDesc(op, node, format);
}
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) {
return impl_->UpdateCustomOpInputDesc(op, node, format);
}
void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); }

View File

@ -247,7 +247,7 @@ bool IsCustomCNode(const AnfNodePtr &anf) {
return false;
}
if (node->inputs().empty()) {
MS_LOG(EXCEPTION) << "length of node inputs is empty";
MS_LOG(EXCEPTION) << "Length of node inputs is empty";
}
MS_EXCEPTION_IF_NULL(node->inputs()[0]);
if (!node->inputs()[0]->isa<ValueNode>()) {
@ -260,5 +260,37 @@ bool IsCustomCNode(const AnfNodePtr &anf) {
return IsCustomPrim(cus_prim);
}
std::string GetOpIOFormat(const AnfNodePtr &anf) {
std::string ret;
if (anf == nullptr) {
MS_LOG(ERROR) << "The anf is nullptr";
return ret;
}
auto node = anf->cast<CNodePtr>();
if (node == nullptr) {
MS_LOG(ERROR) << "The anf is not a cnode.";
return ret;
}
if (node->inputs().empty()) {
MS_LOG(EXCEPTION) << "Length of node inputs is empty.";
}
MS_EXCEPTION_IF_NULL(node->inputs()[0]);
if (!node->inputs()[0]->isa<ValueNode>()) {
MS_LOG(ERROR) << "The anf is not a value node.";
return ret;
}
auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
if (prim == nullptr) {
MS_LOG(ERROR) << "The anf is not a Primitive.";
return ret;
}
ValuePtr format = prim->GetAttr("io_format");
if (format == nullptr) {
return "NCHW";
}
ret = GetValue<std::string>(format);
return ret;
}
} // namespace transform
} // namespace mindspore

View File

@ -61,6 +61,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>);
bool IsCustomPrim(const PrimitivePtr &prim);
bool IsCustomCNode(const AnfNodePtr &node);
std::string GetOpIOFormat(const AnfNodePtr &node);
} // namespace transform
} // namespace mindspore
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_

View File

@ -25,7 +25,7 @@ ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>()
{"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())},
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)},
{4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}};
{4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {6, OUTPUT_DESC(tanhct)}};
REG_ADPT_DESC(BasicLSTMCell, kNameBasicLSTMCell, ADPT_DESC(BasicLSTMCell))
// BasicLSTMCellInputGrad
@ -35,7 +35,7 @@ OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht
REG_ADPT_DESC(BasicLSTMCellInputGrad, kNameBasicLSTMCellInputGrad, ADPT_DESC(BasicLSTMCellInputGrad))
// BasicLSTMCellWeightGrad
INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}};
INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(dgate)}};
ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP;
OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}};
REG_ADPT_DESC(BasicLSTMCellWeightGrad, kNameBasicLSTMCellWeightGrad, ADPT_DESC(BasicLSTMCellWeightGrad))

View File

@ -87,7 +87,10 @@ GeFormat TransformUtil::ConvertFormat(const string &format) {
return GeFormat::FORMAT_NHWC;
} else if (format == kOpFormat_HWCN) {
return GeFormat::FORMAT_HWCN;
} else if (format == kOpFormat_ND) {
return GeFormat::FORMAT_ND;
} else {
MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
return GeFormat::FORMAT_ND;
}
}
@ -113,8 +116,7 @@ std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &
// convert me format to ge format
GeFormat ge_format = ConvertFormat(format);
if (ge_format == GeFormat::FORMAT_ND) {
MS_LOG(ERROR) << "undefined data format : " << static_cast<int>(ge_format);
return nullptr;
MS_LOG(INFO) << "Set ND data format";
}
// convert me datatype to ge datatype
GeDataType data_type = ConvertDataType(me_type);

View File

@ -1537,6 +1537,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
def __init__(self, forget_bias, activation):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
# dhy and dcy should be same shape
@ -1586,7 +1587,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
pass
self.add_prim_attr("io_format", "HWCN")
def infer_shape(self, x_shape, h_shape, dgate_shape):
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
@ -1595,8 +1596,10 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1)
db_shape = (dgate_shape[1], 1, 1, 1)
input_size = x_shape[1]
hidden_size = h_shape[1]
dw_shape = (input_size + hidden_size, 4 * hidden_size)
db_shape = (4 * hidden_size,)
return (dw_shape, db_shape)
def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
@ -1616,13 +1619,17 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
def __init__(self, keep_prob):
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, dgate_shape, w_shape):
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name)
dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4)
dht_shape = (dgate_shape[0], dgate_shape[1] // 4)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
batch_size = dgate_shape[0]
hidden_size = dgate_shape[1] // 4
input_size = w_shape[0] - hidden_size
dxt_shape = (batch_size, input_size)
dht_shape = (batch_size, hidden_size)
return (dxt_shape, dht_shape)
def infer_dtype(self, dgate_dtype, w_dtype):

View File

@ -198,6 +198,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
_check_tuple_or_list("rate", rates, self.name)
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
self.add_prim_attr("padding", self.padding)
self.add_prim_attr("io_format", "NHWC")
def infer_shape(self, input_x):
"""infer shape"""

View File

@ -5353,35 +5353,41 @@ class BasicLSTMCell(PrimitiveWithInfer):
forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0.
state_is_tuple (bool): If true, state is tensor tuple, containing h and c; If false, one tensor,
need split first. Default to True.
activation (str): Activation. Default to "tanh".
activation (str): Activation. Default to "tanh". Only "tanh" is currently supported.
Inputs:
- **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`).
The data type must be float16 or float32.
- **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`).
The data type must be float16 or float32.
- **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`).
- **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1).
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1).
The data type must be float16 or float32.
- **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`).
The data type must be float16 or float32.
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`).
The data type must be same as `c`.
Outputs:
- **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`).
Has the same type with input `c`.
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). With data type of float16.
- **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
Has the same type with input `c`.
- **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
Has the same type with input `c`.
- **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
Has the same type with input `c`.
- **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
Has the same type with input `c`.
- **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`.
Tensor of shape (`batch_size`, `hidden_size`).
Tensor of shape (`batch_size`, `hidden_size`). Has the same type with input `c`.
Examples:
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
>>> x = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> h = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> c = Tensor(np.random.rand(128, 128).astype(np.float16))
>>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16))
>>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16))
>>> x = Tensor(np.random.rand(1, 32).astype(np.float16))
>>> h = Tensor(np.random.rand(1, 64).astype(np.float16))
>>> c = Tensor(np.random.rand(1, 64).astype(np.float16))
>>> w = Tensor(np.random.rand(96, 256).astype(np.float16))
>>> b = Tensor(np.random.rand(256, ).astype(np.float16))
>>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh')
>>> lstm(x, h, c, w, b)
"""
@ -5393,42 +5399,38 @@ class BasicLSTMCell(PrimitiveWithInfer):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
# (batch_size, input_size)
validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name)
# h and c should be same shape
validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name)
validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name)
validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name)
validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name)
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name)
validator.check("w_shape[1]", w_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
validator.check("w_shape[0]", w_shape[0], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
ct_shape = c_shape
ht_shape = h_shape
it_shape = h_shape
jt_shape = h_shape
ft_shape = h_shape
ot_shape = h_shape
tanhct_shape = h_shape
ht_shape = c_shape
it_shape = c_shape
jt_shape = c_shape
ft_shape = c_shape
ot_shape = c_shape
tanhct_shape = c_shape
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
validator.check_subclass("x", x_dtype, [mstype.tensor], self.name)
validator.check_subclass("h", h_dtype, [mstype.tensor], self.name)
validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
validator.check_subclass("w", w_dtype, [mstype.tensor], self.name)
validator.check_subclass("b", b_dtype, [mstype.tensor], self.name)
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name)
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name)
validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name)
args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
class InTopK(PrimitiveWithInfer):

View File

@ -735,7 +735,7 @@ TEST_F(TestConvert, TestConvertTensorError) {
std::vector<int> dims2{2, 3, 4};
auto type_id_2 = kNumberTypeFloat32;
auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
ASSERT_EQ(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
}
TEST_F(TestConvert, TestUtilsConvertDataType) {

View File

@ -701,6 +701,16 @@ class ParallelConcatNet(nn.Cell):
return self.parallel_concat((x1, x2))
class BasicLSTMCellNet(nn.Cell):
""" BasicLSTMCellNet definition """
def __init__(self):
super(BasicLSTMCellNet, self).__init__()
self.lstm = P.BasicLSTMCell()
def construct(self, x, h, c, w, b):
return self.lstm(x, h, c, w, b)
class EditDistance(nn.Cell):
def __init__(self, hypothesis_shape, truth_shape, normalize=True):
super(EditDistance, self).__init__()
@ -1402,11 +1412,6 @@ test_case_nn_ops = [
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}),
('BasicLSTMCell', {
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1], [512, 1, 1, 1]],
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
'skip': []}),
('TopK', {
'block': P.TopK(),
'desc_const': [5],
@ -2346,6 +2351,18 @@ test_case_other_ops = [
'block': P.PopulationCount(),
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))],
'skip': ['backward']}),
('BasicLSTMCellNet', {
'block': BasicLSTMCellNet(),
'desc_inputs': [Tensor(np.random.rand(1, 32).astype(np.float16)),
Tensor(np.random.rand(1, 64).astype(np.float16)),
Tensor(np.random.rand(1, 64).astype(np.float16)),
Tensor(np.random.rand(96, 256).astype(np.float16)),
Tensor(np.random.rand(256, ).astype(np.float16))],
'desc_bprop': [Tensor(np.random.rand(1, 64).astype(np.float16)),
Tensor(np.random.rand(1, 64).astype(np.float16)),
Tensor(np.random.rand(1, 64).astype(np.float16)),
Tensor(np.random.rand(1, 64).astype(np.float16)),
Tensor(np.random.rand(1, 64).astype(np.float16))]}),
]
test_case_quant_ops = [