!5283 Support setting operator io format in the frontend
Merge pull request !5283 from liangchenghui/io_format
This commit is contained in:
commit
e94416be0c
|
@ -276,12 +276,8 @@ OutHandler OpAdapterImpl::getNormalOutput(const OperatorPtr &op, int index) {
|
|||
}
|
||||
|
||||
Status OpAdapterImpl::UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
|
||||
const TypePtr &type) {
|
||||
const TypePtr &type, const std::string &format) {
|
||||
MS_EXCEPTION_IF_NULL(type);
|
||||
std::string format = "NCHW";
|
||||
if (op->GetOpType() == kExtractImagePatchesOpName) {
|
||||
format = "NHWC";
|
||||
}
|
||||
|
||||
auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format);
|
||||
if (desc == nullptr) {
|
||||
|
@ -340,7 +336,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateOutputDesc(const abstract::Sh
|
|||
}
|
||||
|
||||
Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp,
|
||||
const TypePtr &type) {
|
||||
const TypePtr &type, const std::string &format) {
|
||||
auto tuple_shp = dyn_cast<abstract::TupleShape>(shp);
|
||||
MS_EXCEPTION_IF_NULL(tuple_shp);
|
||||
|
||||
|
@ -361,10 +357,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
|
|||
MS_LOG(ERROR) << "output_map is not equal tuple_shape size";
|
||||
return FAILED;
|
||||
}
|
||||
std::string format = "NCHW";
|
||||
if (op->GetOpType() == kTopKOpName) {
|
||||
format = "NHWC";
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
|
||||
auto tuple_type = dyn_cast<Tuple>(type);
|
||||
MS_EXCEPTION_IF_NULL(tuple_type);
|
||||
|
@ -389,7 +382,7 @@ Status OpAdapterImpl::UpdateMultiOutputDesc(const OperatorPtr &op, const abstrac
|
|||
return SUCCESS;
|
||||
}
|
||||
|
||||
std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node) {
|
||||
std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &node, const std::string &format) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
TypeId me_type = node->Type()->type_id();
|
||||
if (kObjectTypeTensorType == me_type) {
|
||||
|
@ -405,7 +398,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no
|
|||
shape = shape_ptr->shape();
|
||||
}
|
||||
|
||||
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW");
|
||||
auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, format);
|
||||
if (desc == nullptr) {
|
||||
MS_LOG(ERROR) << "Update output descriptor failed!";
|
||||
return nullptr;
|
||||
|
@ -413,7 +406,7 @@ std::shared_ptr<GeTensorDesc> OpAdapterImpl::CreateNodeDesc(const AnfNodePtr &no
|
|||
return desc;
|
||||
}
|
||||
|
||||
void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node) {
|
||||
void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is nullptr";
|
||||
return;
|
||||
|
@ -424,19 +417,18 @@ void OpAdapterImpl::UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNode
|
|||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
auto it = input_map_.find(i);
|
||||
if (it != input_map_.end()) {
|
||||
auto desc = CreateNodeDesc(inputs[i]);
|
||||
auto desc = CreateNodeDesc(inputs[i], format);
|
||||
if (desc == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (op->GetOpType() == kExtractImagePatchesOpName) {
|
||||
desc->SetFormat(ge::Format::FORMAT_NHWC);
|
||||
}
|
||||
|
||||
it->second.update_input_desc(op, *desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) {
|
||||
void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node,
|
||||
const std::string format) {
|
||||
if (op == nullptr) {
|
||||
MS_LOG(ERROR) << "op is nullptr";
|
||||
return;
|
||||
|
@ -452,7 +444,7 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN
|
|||
auto inputs = node->cast<CNodePtr>()->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); ++i) {
|
||||
if (input_map.find(i) != input_map.end()) {
|
||||
auto desc = CreateNodeDesc(inputs[i]);
|
||||
auto desc = CreateNodeDesc(inputs[i], format);
|
||||
if (desc == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
@ -464,11 +456,12 @@ void OpAdapterImpl::UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfN
|
|||
void OpAdapterImpl::updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
std::string format = GetOpIOFormat(node);
|
||||
if (IsCustomOp(op)) {
|
||||
auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
|
||||
UpdateCustomOpInputDesc(cus_op, node);
|
||||
UpdateCustomOpInputDesc(cus_op, node, format);
|
||||
} else {
|
||||
UpdateNormalOpInputDesc(op, node);
|
||||
UpdateNormalOpInputDesc(op, node, format);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -483,13 +476,14 @@ void OpAdapterImpl::updateOutputDesc(const OperatorPtr &op, const abstract::Base
|
|||
|
||||
auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
|
||||
auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);
|
||||
std::string format = GetOpIOFormat(node);
|
||||
|
||||
if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) {
|
||||
if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) {
|
||||
if (UpdateSingleOutputDesc(op, shp, type, format) != SUCCESS) {
|
||||
return;
|
||||
}
|
||||
} else if (nullptr != dyn_cast<abstract::TupleShape>(shp)) {
|
||||
if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) {
|
||||
if (UpdateMultiOutputDesc(op, shp, type, format) != SUCCESS) {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -75,14 +75,16 @@ class OpAdapterImpl {
|
|||
OutHandler getOutput(const OperatorPtr &op, int index);
|
||||
OutHandler getCustomOutput(const OperatorPtr &op, int index);
|
||||
OutHandler getNormalOutput(const OperatorPtr &op, int index);
|
||||
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type);
|
||||
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
|
||||
const std::string &format);
|
||||
size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op);
|
||||
std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type,
|
||||
const std::string &format);
|
||||
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type);
|
||||
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node);
|
||||
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node);
|
||||
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node);
|
||||
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
|
||||
const std::string &format);
|
||||
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format);
|
||||
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr &node, const std::string format);
|
||||
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format);
|
||||
void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node);
|
||||
void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
|
||||
const AnfNodePtr &node);
|
||||
|
@ -226,8 +228,9 @@ class OpAdapter : public BaseOpAdapter {
|
|||
|
||||
OutHandler getNormalOutput(const OperatorPtr &op, int index) { return impl_->getNormalOutput(op, index); }
|
||||
|
||||
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) {
|
||||
return impl_->UpdateSingleOutputDesc(op, shp, type);
|
||||
Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
|
||||
const std::string &format) {
|
||||
return impl_->UpdateSingleOutputDesc(op, shp, type, format);
|
||||
}
|
||||
|
||||
size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { return impl_->GetCustomOpOutputSize(cus_op); }
|
||||
|
@ -237,18 +240,21 @@ class OpAdapter : public BaseOpAdapter {
|
|||
return impl_->CreateOutputDesc(shape_ptr, type, format);
|
||||
}
|
||||
|
||||
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) {
|
||||
return impl_->UpdateMultiOutputDesc(op, shp, type);
|
||||
Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type,
|
||||
const std::string &format) {
|
||||
return impl_->UpdateMultiOutputDesc(op, shp, type, format);
|
||||
}
|
||||
|
||||
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node) { return impl_->CreateNodeDesc(node); }
|
||||
|
||||
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) {
|
||||
return impl_->UpdateNormalOpInputDesc(op, node);
|
||||
std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr &node, const std::string &format) {
|
||||
return impl_->CreateNodeDesc(node, format);
|
||||
}
|
||||
|
||||
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) {
|
||||
return impl_->UpdateCustomOpInputDesc(op, node);
|
||||
void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node, const std::string format) {
|
||||
return impl_->UpdateNormalOpInputDesc(op, node, format);
|
||||
}
|
||||
|
||||
void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node, const std::string format) {
|
||||
return impl_->UpdateCustomOpInputDesc(op, node, format);
|
||||
}
|
||||
|
||||
void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { impl_->updateInputDesc(op, node); }
|
||||
|
|
|
@ -247,7 +247,7 @@ bool IsCustomCNode(const AnfNodePtr &anf) {
|
|||
return false;
|
||||
}
|
||||
if (node->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "length of node inputs is empty";
|
||||
MS_LOG(EXCEPTION) << "Length of node inputs is empty";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(node->inputs()[0]);
|
||||
if (!node->inputs()[0]->isa<ValueNode>()) {
|
||||
|
@ -260,5 +260,37 @@ bool IsCustomCNode(const AnfNodePtr &anf) {
|
|||
|
||||
return IsCustomPrim(cus_prim);
|
||||
}
|
||||
|
||||
std::string GetOpIOFormat(const AnfNodePtr &anf) {
|
||||
std::string ret;
|
||||
if (anf == nullptr) {
|
||||
MS_LOG(ERROR) << "The anf is nullptr";
|
||||
return ret;
|
||||
}
|
||||
auto node = anf->cast<CNodePtr>();
|
||||
if (node == nullptr) {
|
||||
MS_LOG(ERROR) << "The anf is not a cnode.";
|
||||
return ret;
|
||||
}
|
||||
if (node->inputs().empty()) {
|
||||
MS_LOG(EXCEPTION) << "Length of node inputs is empty.";
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(node->inputs()[0]);
|
||||
if (!node->inputs()[0]->isa<ValueNode>()) {
|
||||
MS_LOG(ERROR) << "The anf is not a value node.";
|
||||
return ret;
|
||||
}
|
||||
auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
|
||||
if (prim == nullptr) {
|
||||
MS_LOG(ERROR) << "The anf is not a Primitive.";
|
||||
return ret;
|
||||
}
|
||||
ValuePtr format = prim->GetAttr("io_format");
|
||||
if (format == nullptr) {
|
||||
return "NCHW";
|
||||
}
|
||||
ret = GetValue<std::string>(format);
|
||||
return ret;
|
||||
}
|
||||
} // namespace transform
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -61,6 +61,7 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>);
|
|||
|
||||
bool IsCustomPrim(const PrimitivePtr &prim);
|
||||
bool IsCustomCNode(const AnfNodePtr &node);
|
||||
std::string GetOpIOFormat(const AnfNodePtr &node);
|
||||
} // namespace transform
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_ADAPTER_UTIL_H_
|
||||
|
|
|
@ -25,7 +25,7 @@ ATTR_MAP(BasicLSTMCell) = {{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>()
|
|||
{"state_is_tuple", ATTR_DESC(state_is_tuple, AnyTraits<bool>())},
|
||||
{"activation", ATTR_DESC(activation, AnyTraits<std::string>())}};
|
||||
OUTPUT_MAP(BasicLSTMCell) = {{0, OUTPUT_DESC(ct)}, {1, OUTPUT_DESC(ht)}, {2, OUTPUT_DESC(it)}, {3, OUTPUT_DESC(jt)},
|
||||
{4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {7, OUTPUT_DESC(tanhct)}};
|
||||
{4, OUTPUT_DESC(ft)}, {5, OUTPUT_DESC(ot)}, {6, OUTPUT_DESC(tanhct)}};
|
||||
REG_ADPT_DESC(BasicLSTMCell, kNameBasicLSTMCell, ADPT_DESC(BasicLSTMCell))
|
||||
|
||||
// BasicLSTMCellInputGrad
|
||||
|
@ -35,7 +35,7 @@ OUTPUT_MAP(BasicLSTMCellInputGrad) = {{0, OUTPUT_DESC(dxt)}, {1, OUTPUT_DESC(dht
|
|||
REG_ADPT_DESC(BasicLSTMCellInputGrad, kNameBasicLSTMCellInputGrad, ADPT_DESC(BasicLSTMCellInputGrad))
|
||||
|
||||
// BasicLSTMCellWeightGrad
|
||||
INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(h)}, {2, INPUT_DESC(x)}, {3, INPUT_DESC(dgate)}};
|
||||
INPUT_MAP(BasicLSTMCellWeightGrad) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(h)}, {3, INPUT_DESC(dgate)}};
|
||||
ATTR_MAP(BasicLSTMCellWeightGrad) = EMPTY_ATTR_MAP;
|
||||
OUTPUT_MAP(BasicLSTMCellWeightGrad) = {{0, OUTPUT_DESC(dw)}, {1, OUTPUT_DESC(db)}};
|
||||
REG_ADPT_DESC(BasicLSTMCellWeightGrad, kNameBasicLSTMCellWeightGrad, ADPT_DESC(BasicLSTMCellWeightGrad))
|
||||
|
|
|
@ -87,7 +87,10 @@ GeFormat TransformUtil::ConvertFormat(const string &format) {
|
|||
return GeFormat::FORMAT_NHWC;
|
||||
} else if (format == kOpFormat_HWCN) {
|
||||
return GeFormat::FORMAT_HWCN;
|
||||
} else if (format == kOpFormat_ND) {
|
||||
return GeFormat::FORMAT_ND;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Illegal tensor data format: (" << format << "). Use ND format instead.";
|
||||
return GeFormat::FORMAT_ND;
|
||||
}
|
||||
}
|
||||
|
@ -113,8 +116,7 @@ std::shared_ptr<GeTensorDesc> TransformUtil::GetGeTensorDesc(const ShapeVector &
|
|||
// convert me format to ge format
|
||||
GeFormat ge_format = ConvertFormat(format);
|
||||
if (ge_format == GeFormat::FORMAT_ND) {
|
||||
MS_LOG(ERROR) << "undefined data format : " << static_cast<int>(ge_format);
|
||||
return nullptr;
|
||||
MS_LOG(INFO) << "Set ND data format";
|
||||
}
|
||||
// convert me datatype to ge datatype
|
||||
GeDataType data_type = ConvertDataType(me_type);
|
||||
|
|
|
@ -1537,6 +1537,7 @@ class BasicLSTMCellCStateGrad(PrimitiveWithInfer):
|
|||
def __init__(self, forget_bias, activation):
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, c_shape, dht_shape, dct_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape):
|
||||
# dhy and dcy should be same shape
|
||||
|
@ -1586,7 +1587,7 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
|
|||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
self.add_prim_attr("io_format", "HWCN")
|
||||
|
||||
def infer_shape(self, x_shape, h_shape, dgate_shape):
|
||||
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
|
||||
|
@ -1595,8 +1596,10 @@ class BasicLSTMCellWeightGrad(PrimitiveWithInfer):
|
|||
validator.check("h_shape[0]", h_shape[0], "x_shape[0]", x_shape[0], Rel.EQ, self.name)
|
||||
validator.check("dgate_shape[0]", dgate_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
|
||||
validator.check("dgate_shape[1]", dgate_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
|
||||
dw_shape = (dgate_shape[1], x_shape[1] + h_shape[1], 1, 1)
|
||||
db_shape = (dgate_shape[1], 1, 1, 1)
|
||||
input_size = x_shape[1]
|
||||
hidden_size = h_shape[1]
|
||||
dw_shape = (input_size + hidden_size, 4 * hidden_size)
|
||||
db_shape = (4 * hidden_size,)
|
||||
return (dw_shape, db_shape)
|
||||
|
||||
def infer_dtype(self, x_dtype, h_dtype, dgate_dtype):
|
||||
|
@ -1616,13 +1619,17 @@ class BasicLSTMCellInputGrad(PrimitiveWithInfer):
|
|||
def __init__(self, keep_prob):
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.keep_prob = validator.check_number_range("keep_prob", keep_prob, 0.0, 1.0, Rel.INC_BOTH, self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, dgate_shape, w_shape):
|
||||
validator.check_integer("dgate rank", len(dgate_shape), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
|
||||
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[0]", w_shape[0], Rel.EQ, self.name)
|
||||
dxt_shape = (dgate_shape[0], w_shape[1] - w_shape[0] // 4)
|
||||
dht_shape = (dgate_shape[0], dgate_shape[1] // 4)
|
||||
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
|
||||
validator.check("dgate_shape[1]", dgate_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
batch_size = dgate_shape[0]
|
||||
hidden_size = dgate_shape[1] // 4
|
||||
input_size = w_shape[0] - hidden_size
|
||||
dxt_shape = (batch_size, input_size)
|
||||
dht_shape = (batch_size, hidden_size)
|
||||
return (dxt_shape, dht_shape)
|
||||
|
||||
def infer_dtype(self, dgate_dtype, w_dtype):
|
||||
|
|
|
@ -198,6 +198,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
|||
_check_tuple_or_list("rate", rates, self.name)
|
||||
self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name)
|
||||
self.add_prim_attr("padding", self.padding)
|
||||
self.add_prim_attr("io_format", "NHWC")
|
||||
|
||||
def infer_shape(self, input_x):
|
||||
"""infer shape"""
|
||||
|
|
|
@ -5353,35 +5353,41 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
forget_bias (float): Add forget bias to forget gate biases in order to decrease former scale. Default to 1.0.
|
||||
state_is_tuple (bool): If true, state is tensor tuple, containing h and c; If false, one tensor,
|
||||
need split first. Default to True.
|
||||
activation (str): Activation. Default to "tanh".
|
||||
activation (str): Activation. Default to "tanh". Only "tanh" is currently supported.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words. Tensor of shape (`batch_size`, `input_size`).
|
||||
The data type must be float16 or float32.
|
||||
- **h** (Tensor) - Hidden state last moment. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
The data type must be float16 or float32.
|
||||
- **c** (Tensor) - Cell state last moment. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
- **w** (Tensor) - Weight. Tensor of shape (`4 x hidden_size`, `input_size + hidden_size`, 1, 1).
|
||||
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`, 1, 1, 1).
|
||||
The data type must be float16 or float32.
|
||||
- **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`).
|
||||
The data type must be float16 or float32.
|
||||
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`).
|
||||
The data type must be same as `c`.
|
||||
|
||||
Outputs:
|
||||
- **ct** (Tensor) - Forward :math:`c_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
Has the same type with input `c`.
|
||||
- **ht** (Tensor) - Cell output. Tensor of shape (`batch_size`, `hidden_size`). With data type of float16.
|
||||
- **it** (Tensor) - Forward :math:`i_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
Has the same type with input `c`.
|
||||
- **jt** (Tensor) - Forward :math:`j_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
Has the same type with input `c`.
|
||||
- **ft** (Tensor) - Forward :math:`f_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
Has the same type with input `c`.
|
||||
- **ot** (Tensor) - Forward :math:`o_t` cache at moment `t`. Tensor of shape (`batch_size`, `hidden_size`).
|
||||
Has the same type with input `c`.
|
||||
- **tanhct** (Tensor) - Forward :math:`tanh c_t` cache at moment `t`.
|
||||
Tensor of shape (`batch_size`, `hidden_size`).
|
||||
Tensor of shape (`batch_size`, `hidden_size`). Has the same type with input `c`.
|
||||
|
||||
Examples:
|
||||
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
|
||||
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1],[512, 1, 1, 1]],
|
||||
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
|
||||
|
||||
>>> x = Tensor(np.random.rand(128, 128).astype(np.float16))
|
||||
>>> h = Tensor(np.random.rand(128, 128).astype(np.float16))
|
||||
>>> c = Tensor(np.random.rand(128, 128).astype(np.float16))
|
||||
>>> w = Tensor(np.random.rand(512, 256, 1, 1).astype(np.float16))
|
||||
>>> b = Tensor(np.random.rand(512, 1, 1, 1).astype(np.float16))
|
||||
>>> x = Tensor(np.random.rand(1, 32).astype(np.float16))
|
||||
>>> h = Tensor(np.random.rand(1, 64).astype(np.float16))
|
||||
>>> c = Tensor(np.random.rand(1, 64).astype(np.float16))
|
||||
>>> w = Tensor(np.random.rand(96, 256).astype(np.float16))
|
||||
>>> b = Tensor(np.random.rand(256, ).astype(np.float16))
|
||||
>>> lstm = P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh')
|
||||
>>> lstm(x, h, c, w, b)
|
||||
"""
|
||||
|
@ -5393,42 +5399,38 @@ class BasicLSTMCell(PrimitiveWithInfer):
|
|||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.state_is_tuple = validator.check_value_type("state_is_tuple", state_is_tuple, [bool], self.name)
|
||||
self.activation = validator.check_string("activation", activation, ['tanh'], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, h_shape, c_shape, w_shape, b_shape):
|
||||
# (batch_size, input_size)
|
||||
validator.check_integer("x_shape", len(x_shape), 2, Rel.EQ, self.name)
|
||||
|
||||
# h and c should be same shape
|
||||
validator.check_integer("h_shape", len(h_shape), 2, Rel.EQ, self.name)
|
||||
validator.check("h rank", len(h_shape), "c rank", len(c_shape), Rel.EQ, self.name)
|
||||
validator.check("h shape", h_shape, "c shape", c_shape, Rel.EQ, self.name)
|
||||
validator.check_integer("w rank", len(w_shape), 4, Rel.EQ, self.name)
|
||||
validator.check_integer("b rank", len(b_shape), 4, Rel.EQ, self.name)
|
||||
validator.check("w_shape[0]", w_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
|
||||
validator.check("w_shape[1]", w_shape[1], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
|
||||
validator.check_integer("x rank", len(x_shape), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("h rank", len(h_shape), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("c rank", len(c_shape), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
|
||||
validator.check("x_shape[0]", x_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
|
||||
validator.check("c_shape[0]", c_shape[0], "h_shape[0]", h_shape[0], Rel.EQ, self.name)
|
||||
validator.check("c_shape[1]", c_shape[1], "h_shape[1]", h_shape[1], Rel.EQ, self.name)
|
||||
validator.check("w_shape[1]", w_shape[1], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
|
||||
validator.check("w_shape[0]", w_shape[0], "x_shape[1]+h_shape[1]", x_shape[1] + h_shape[1], Rel.EQ, self.name)
|
||||
validator.check("b_shape[0]", b_shape[0], "4*h_shape[1]", 4 * h_shape[1], Rel.EQ, self.name)
|
||||
ct_shape = c_shape
|
||||
ht_shape = h_shape
|
||||
it_shape = h_shape
|
||||
jt_shape = h_shape
|
||||
ft_shape = h_shape
|
||||
ot_shape = h_shape
|
||||
tanhct_shape = h_shape
|
||||
ht_shape = c_shape
|
||||
it_shape = c_shape
|
||||
jt_shape = c_shape
|
||||
ft_shape = c_shape
|
||||
ot_shape = c_shape
|
||||
tanhct_shape = c_shape
|
||||
|
||||
return (ct_shape, ht_shape, it_shape, jt_shape, ft_shape, ot_shape, tanhct_shape)
|
||||
|
||||
def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype, b_dtype):
|
||||
validator.check_subclass("x", x_dtype, [mstype.tensor], self.name)
|
||||
validator.check_subclass("h", h_dtype, [mstype.tensor], self.name)
|
||||
validator.check_subclass("c", c_dtype, [mstype.tensor], self.name)
|
||||
validator.check_subclass("w", w_dtype, [mstype.tensor], self.name)
|
||||
validator.check_subclass("b", b_dtype, [mstype.tensor], self.name)
|
||||
validator.check_type_name("x", x_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_type_name("h", h_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_type_name("c", c_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_type_name("w", w_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_type_name("b", b_dtype, [mstype.float16, mstype.float32], self.name)
|
||||
return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype)
|
||||
validator.check_tensor_type_same({"x_dtype": x_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"h_dtype": h_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
validator.check_tensor_type_same({"w_dtype": w_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||
|
||||
args = {"c_dtype": c_dtype, "b_dtype": b_dtype}
|
||||
validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name)
|
||||
return (c_dtype, mstype.float16, c_dtype, c_dtype, c_dtype, c_dtype, c_dtype)
|
||||
|
||||
|
||||
class InTopK(PrimitiveWithInfer):
|
||||
|
|
|
@ -735,7 +735,7 @@ TEST_F(TestConvert, TestConvertTensorError) {
|
|||
std::vector<int> dims2{2, 3, 4};
|
||||
auto type_id_2 = kNumberTypeFloat32;
|
||||
auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
|
||||
ASSERT_EQ(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
|
||||
ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
|
||||
}
|
||||
|
||||
TEST_F(TestConvert, TestUtilsConvertDataType) {
|
||||
|
|
|
@ -701,6 +701,16 @@ class ParallelConcatNet(nn.Cell):
|
|||
return self.parallel_concat((x1, x2))
|
||||
|
||||
|
||||
class BasicLSTMCellNet(nn.Cell):
|
||||
""" BasicLSTMCellNet definition """
|
||||
|
||||
def __init__(self):
|
||||
super(BasicLSTMCellNet, self).__init__()
|
||||
self.lstm = P.BasicLSTMCell()
|
||||
|
||||
def construct(self, x, h, c, w, b):
|
||||
return self.lstm(x, h, c, w, b)
|
||||
|
||||
class EditDistance(nn.Cell):
|
||||
def __init__(self, hypothesis_shape, truth_shape, normalize=True):
|
||||
super(EditDistance, self).__init__()
|
||||
|
@ -1402,11 +1412,6 @@ test_case_nn_ops = [
|
|||
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]],
|
||||
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
|
||||
'skip': ['backward']}),
|
||||
('BasicLSTMCell', {
|
||||
'block': P.BasicLSTMCell(keep_prob=1.0, forget_bias=1.0, state_is_tuple=True, activation='tanh'),
|
||||
'desc_inputs': [[128, 128], [128, 128], [128, 128], [512, 256, 1, 1], [512, 1, 1, 1]],
|
||||
'desc_bprop': [[128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128], [128, 128]],
|
||||
'skip': []}),
|
||||
('TopK', {
|
||||
'block': P.TopK(),
|
||||
'desc_const': [5],
|
||||
|
@ -2346,6 +2351,18 @@ test_case_other_ops = [
|
|||
'block': P.PopulationCount(),
|
||||
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.int16))],
|
||||
'skip': ['backward']}),
|
||||
('BasicLSTMCellNet', {
|
||||
'block': BasicLSTMCellNet(),
|
||||
'desc_inputs': [Tensor(np.random.rand(1, 32).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(96, 256).astype(np.float16)),
|
||||
Tensor(np.random.rand(256, ).astype(np.float16))],
|
||||
'desc_bprop': [Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16))]}),
|
||||
]
|
||||
|
||||
test_case_quant_ops = [
|
||||
|
|
Loading…
Reference in New Issue