diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 61552ea9dcc..85445584c14 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -273,7 +273,9 @@ union PrimitiveType { RandomStandardNormal, CropAndResize, Erf, - StridedSliceGrad + StridedSliceGrad, + IsFinite, + BatchMatMul, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 837f49d18b5..3c972a3667c 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1274,4 +1274,12 @@ table StridedSliceGrad { } table Erf { +} + +table IsFinite { +} + +table BatchMatMul { + adj_x : bool = false; + adj_y : bool = false; } \ No newline at end of file diff --git a/mindspore/lite/src/ops/batch_matmul.cc b/mindspore/lite/src/ops/batch_matmul.cc new file mode 100644 index 00000000000..0a646cac7b9 --- /dev/null +++ b/mindspore/lite/src/ops/batch_matmul.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/batch_matmul.h" +#include +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +bool BatchMatMul::GetAdjX() const { return this->primitive_->value.AsBatchMatMul()->adj_x; } + +void BatchMatMul::SetAdjX(bool adj_x) { this->primitive_->value.AsBatchMatMul()->adj_x = adj_x; } + +bool BatchMatMul::GetAdjY() const { return this->primitive_->value.AsBatchMatMul()->adj_y; } + +void BatchMatMul::SetAdjY(bool adj_y) { this->primitive_->value.AsBatchMatMul()->adj_y = adj_y; } + +int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_BatchMatMul; + } + if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::BatchMatMulT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new FusedBatchMatMulT failed"; + delete this->primitive_; + this->primitive_ = nullptr; + return RET_ERROR; + } + attr->adj_x = GetValue(prim.GetAttr("adj_x")); + attr->adj_y = GetValue(prim.GetAttr("adj_y")); + this->primitive_->value.value = attr; + } + return RET_OK; +} + +#else +int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto val_offset = schema::CreateBatchMatMul(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} +bool BatchMatMul::GetAdjX() const { return this->primitive_->value_as_BatchMatMul()->adj_x(); } + +bool BatchMatMul::GetAdjY() const { return this->primitive_->value_as_BatchMatMul()->adj_y(); } + +PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) { + return PrimitiveC::NewPrimitiveC(primitive); +} +Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator); +#endif + +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/batch_matmul.h b/mindspore/lite/src/ops/batch_matmul.h new file mode 100644 index 00000000000..ca6260ec476 --- /dev/null +++ b/mindspore/lite/src/ops/batch_matmul.h @@ -0,0 +1,46 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ +#define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ + +#include +#include +#include +#include "src/ops/primitive_c.h" + +namespace mindspore { +namespace lite { +class BatchMatMul : public PrimitiveC { + public: + BatchMatMul() = default; + ~BatchMatMul() = default; +#ifdef PRIMITIVE_WRITEABLE + MS_DECLARE_PARENT(BatchMatMul, PrimitiveC); + explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; + void SetAdjX(bool adj_x); + void SetAdjY(bool adj_y); +#else + int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; +#endif + bool GetAdjX() const; + bool GetAdjY() const; +}; +} // namespace lite +} // namespace mindspore + +#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ diff --git a/mindspore/lite/src/ops/is_finite.h b/mindspore/lite/src/ops/is_finite.h new file mode 100644 index 00000000000..9d18ebc757e --- /dev/null +++ b/mindspore/lite/src/ops/is_finite.h @@ -0,0 +1,33 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/primitive_c.h" + +#ifndef LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ +#define LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ + +namespace mindspore { +namespace lite { +class IsFinite : public PrimitiveC { + public: + MS_DECLARE_PARENT(IsFinite, PrimitiveC); + IsFinite() = default; + ~IsFinite() = default; + explicit IsFinite(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} +}; +} // namespace lite +} // namespace mindspore +#endif // LITE_MINDSPORE_LITE_C_OPS_IS_FINITE_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 8fa9fde0399..aaa87506403 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -170,6 +170,8 @@ #include "src/ops/crop_and_resize.h" #include "src/ops/nonzero.h" #include "src/ops/erf.h" +#include "src/ops/is_finite.h" +#include "src/ops/batch_matmul.h" #ifdef SUPPORT_TRAIN #include "src/ops/neg_grad.h" @@ -665,7 +667,6 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Gelu") { return NewPrimitiveC(prim, inputs, quantType); - #ifdef SUPPORT_TRAIN } else if (op_type == "SoftmaxCrossEntropyWithLogits") { return NewPrimitiveC(prim, inputs, quantType); @@ -1034,6 +1035,10 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) NonZero(primitive); case schema::PrimitiveType_Erf: return new (std::nothrow) Erf(primitive); + case schema::PrimitiveType_IsFinite: + return new (std::nothrow) IsFinite(primitive); + case schema::PrimitiveType_BatchMatMul: + return new (std::nothrow) BatchMatMul(primitive); #ifdef SUPPORT_TRAIN case schema::PrimitiveType_ActivationGrad: return new (std::nothrow) ActivationGrad(primitive); diff --git a/mindspore/lite/tools/common/node_util.h b/mindspore/lite/tools/common/node_util.h index a0b247cd509..7f870019748 100644 --- a/mindspore/lite/tools/common/node_util.h +++ b/mindspore/lite/tools/common/node_util.h @@ -29,6 +29,18 @@ namespace mindspore { namespace lite { +template +int CreateOperator(const std::unique_ptr &primitive, schema::PrimitiveType type) { + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new attr failed"; + return RET_NULL_PTR; + } + primitive->value.type = type; + primitive->value.value = attr.release(); + return RET_OK; +} + using STATUS = int; STATUS BroadCastQuantParam(schema::MetaGraphT *graphT, const std::unique_ptr &node); @@ -91,6 +103,226 @@ STATUS GetFilterDim(const std::vector &oriDims, kTransFilterType type, STATUS SetFilterDim(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW); +template +static void TransKHWC2CHWK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = dstData + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static void TransKHWC2HWCK(int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = srcData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static void TransCKHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = srcData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + if (type == kCKHW2HWCK) { + p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kCKHW2KHWC) { + p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else { + p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static void TransKCHW(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int k = 0; k < filterK; ++k) { + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + p1Buff = srcData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + if (type == kKCHW2HWCK) { + p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kKCHW2KHWC) { + p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } else if (type == kKCHW2CKHW) { + p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = dstData + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static void TransCHWK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int c = 0; c < filterC; ++c) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int k = 0; k < filterK; ++k) { + p1Buff = srcData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); + if (type == kCHWK2HWCK) { + p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else { + p2Buff = dstData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static void TransHWCK(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + if (type == kHWCK2KCHW) { + p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static void TransHWKC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + for (int k = 0; k < filterK; ++k) { + p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kHWKC2KCHW) { + p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static void TransNHWC(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *srcData, T *dstData) { + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (int k = 0; k < filterK; ++k) { + for (int h = 0; h < filterH; ++h) { + for (int w = 0; w < filterW; ++w) { + for (int c = 0; c < filterC; ++c) { + p1Buff = srcData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); + if (type == kNHWC2HWCK) { + p2Buff = dstData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); + } else if (type == kNHWC2CKHW) { + p2Buff = dstData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); + } else { + p2Buff = dstData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); + } + *p2Buff = *p1Buff; + } + } + } + } +} + +template +static STATUS TransFilterData(kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW, + T *srcData, T *dstData) { + switch (type) { + case kCHWK2HWCK: + case kCHWK2KHWC: { + TransCHWK(type, filterK, filterC, filterH, filterW, srcData, dstData); + } break; + case kKHWC2HWCK: { + TransKHWC2HWCK(filterK, filterC, filterH, filterW, srcData, dstData); + } break; + case kKCHW2HWCK: + case kKCHW2CKHW: + case kKCHW2KHWC: + case kKCHW2HWKC: { + TransKCHW(type, filterK, filterC, filterH, filterW, srcData, dstData); + } break; + case kCKHW2HWCK: + case kCKHW2KHWC: + case kCKHW2HWKC: { + TransCKHW(type, filterK, filterC, filterH, filterW, srcData, dstData); + } break; + case kHWCK2KCHW: + case kHWCK2CKHW: { + TransHWCK(type, filterK, filterC, filterH, filterW, srcData, dstData); + } break; + case kHWKC2KCHW: + case kHWKC2CKHW: { + TransHWKC(type, filterK, filterC, filterH, filterW, srcData, dstData); + } break; + case kNHWC2HWCK: + case kNHWC2KCHW: + case kNHWC2CKHW: { + TransNHWC(type, filterK, filterC, filterH, filterW, srcData, dstData); + } break; + case kKHWC2CHWK: { + TransKHWC2CHWK(filterK, filterC, filterH, filterW, srcData, dstData); + } break; + default: { + MS_LOG(ERROR) << "Unsupported transFilterType: " << type; + return RET_ERROR; + } + } + return RET_OK; +} + template static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, int32_t filterK, int32_t filterC, int32_t filterH, int32_t filterW) { @@ -113,175 +345,10 @@ static STATUS TransFilterData(schema::TensorT *tensor, kTransFilterType type, in MS_LOG(ERROR) << "weightData is nullptr"; return RET_ERROR; } - T *p1Buff = nullptr; - T *p2Buff = nullptr; - switch (type) { - case kCHWK2HWCK: - case kCHWK2KHWC: { - for (int c = 0; c < filterC; ++c) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int k = 0; k < filterK; ++k) { - p1Buff = weightData + ((c * filterH * filterW * filterK) + (h * filterW * filterK) + (w * filterK) + (k)); - if (type == kCHWK2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else { - p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - } - *p2Buff = *p1Buff; - } - } - } - } - } break; - case kKHWC2HWCK: { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - p2Buff = buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - *p2Buff = *p1Buff; - } - } - } - } - } break; - case kKCHW2HWCK: - case kKCHW2CKHW: - case kKCHW2KHWC: - case kKCHW2HWKC: { - for (int k = 0; k < filterK; ++k) { - for (int c = 0; c < filterC; ++c) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - p1Buff = weightData + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - if (type == kKCHW2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kKCHW2KHWC) { - p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - } else if (type == kKCHW2CKHW) { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); - } - *p2Buff = *p1Buff; - } - } - } - } - } break; - case kCKHW2HWCK: - case kCKHW2KHWC: - case kCKHW2HWKC: { - for (int c = 0; c < filterC; ++c) { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - p1Buff = weightData + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - if (type == kCKHW2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kCKHW2KHWC) { - p2Buff = - buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - } else { - p2Buff = - buf.get() + ((h * filterW * filterK * filterC) + (w * filterK * filterC) + (k * filterC) + (c)); - } - *p2Buff = *p1Buff; - } - } - } - } - } break; - case kHWCK2KCHW: - case kHWCK2CKHW: { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - for (int k = 0; k < filterK; ++k) { - p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - if (type == kHWCK2KCHW) { - p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } - *p2Buff = *p1Buff; - } - } - } - } - } break; - case kHWKC2KCHW: - case kHWKC2CKHW: { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - for (int k = 0; k < filterK; ++k) { - p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); - if (type == kHWKC2KCHW) { - p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } - *p2Buff = *p1Buff; - } - } - } - } - } break; - case kNHWC2HWCK: - case kNHWC2KCHW: - case kNHWC2CKHW: { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - p1Buff = weightData + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (k * filterC) + (c)); - if (type == kNHWC2HWCK) { - p2Buff = - buf.get() + ((h * filterW * filterC * filterK) + (w * filterC * filterK) + (c * filterK) + (k)); - } else if (type == kNHWC2CKHW) { - p2Buff = - buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w)); - } else { - p2Buff = - buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w)); - } - *p2Buff = *p1Buff; - } - } - } - } - } break; - case kKHWC2CHWK: { - for (int k = 0; k < filterK; ++k) { - for (int h = 0; h < filterH; ++h) { - for (int w = 0; w < filterW; ++w) { - for (int c = 0; c < filterC; ++c) { - p1Buff = weightData + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c)); - p2Buff = buf.get() + ((c * filterK * filterH * filterW) + (h * filterK * filterW) + (w * filterK) + (k)); - *p2Buff = *p1Buff; - } - } - } - } - } break; - default: { - MS_LOG(ERROR) << "Unsupported transFilterType: " << type; - return RET_ERROR; - } + + if (TransFilterData(type, filterK, filterC, filterH, filterW, weightData, buf.get()) != RET_OK) { + MS_LOG(ERROR) << "TransFilterData failed"; + return RET_ERROR; } auto ret = ::memcpy_s(tensor->data.data(), count * sizeof(T), buf.get(), count * sizeof(T)); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc index b6939409eca..1de6c8854bc 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_arithmetic_self_parser.cc @@ -19,22 +19,11 @@ #include #include #include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/common/node_util.h" namespace mindspore { namespace lite { -template -int CreateOperator(const std::unique_ptr &primitive, schema::PrimitiveType type) { - auto attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new attr failed"; - return RET_NULL_PTR; - } - primitive->value.type = type; - primitive->value.value = attr.release(); - return RET_OK; -} - STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { @@ -61,6 +50,12 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, status = CreateOperator(primitive, schema::PrimitiveType_Log); } else if (tf_op.op() == "Sqrt") { status = CreateOperator(primitive, schema::PrimitiveType_Sqrt); + } else if (tf_op.op() == "Cos") { + status = CreateOperator(primitive, schema::PrimitiveType_Cos); + } else if (tf_op.op() == "Sin") { + status = CreateOperator(primitive, schema::PrimitiveType_Sin); + } else if (tf_op.op() == "Square") { + status = CreateOperator(primitive, schema::PrimitiveType_Square); } else if (tf_op.op() == "Pow") { status = CreateOperator(primitive, schema::PrimitiveType_Power); } @@ -81,6 +76,9 @@ STATUS TFArithmeticSelfParser::Parse(const tensorflow::NodeDef &tf_op, } return status; } +TFNodeRegistrar g_tfCosParser("Cos", new TFArithmeticSelfParser()); +TFNodeRegistrar g_tfSinParser("Sin", new TFArithmeticSelfParser()); +TFNodeRegistrar g_tfSquareParser("Square", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfCeilParser("Ceil", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfExpParser("Exp", new TFArithmeticSelfParser()); TFNodeRegistrar g_tfFloorParser("Floor", new TFArithmeticSelfParser()); diff --git a/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc new file mode 100644 index 00000000000..a6355e5f1c6 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc @@ -0,0 +1,64 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_batch_matmul_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFBatchMatmulParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + tensorflow::AttrValue attr_value; + TensorFlowUtils::FindAttrValue(tf_op, "adj_x", &attr_value); + attr->adj_x = attr_value.b(); + attr->adj_y = attr_value.b(); + + primitive->value.type = schema::PrimitiveType_BatchMatMul; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return RET_OK; +} +TFNodeRegistrar g_tfBatchMatMulParser("BatchMatMul", new TFBatchMatmulParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.h new file mode 100644 index 00000000000..9f5240f608c --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFBatchMatmulParser : public TFNodeParser { + public: + TFBatchMatmulParser() = default; + ~TFBatchMatmulParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_BATCH_MATMUL_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.cc new file mode 100644 index 00000000000..81a2e52ed5a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_is_finite_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/common/node_util.h" + +namespace mindspore { +namespace lite { +STATUS TFIsFiniteParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + + int status = CreateOperator(primitive, schema::PrimitiveType_IsFinite); + if (status != RET_OK) { + return status; + } + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = 1; + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + + return RET_OK; +} +TFNodeRegistrar g_tf_is_finite_parser("IsFinite", new TFIsFiniteParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.h new file mode 100644 index 00000000000..22975d6225a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_is_finite_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ + +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFIsFiniteParser : public TFNodeParser { + public: + TFIsFiniteParser() = default; + ~TFIsFiniteParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_IS_FINITE_PARSER_H_ diff --git a/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc index 3202d4d50f4..4dc87fc77fa 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_logical_parser.cc @@ -19,6 +19,7 @@ #include #include #include "tools/converter/parser/tf/tf_node_parser_registry.h" +#include "tools/common/node_util.h" namespace mindspore { namespace lite { @@ -36,37 +37,19 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "primitive is nullptr"; return RET_NULL_PTR; } + + int status = RET_ERROR; if (tf_op.op() == "LogicalAnd") { - auto attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - primitive->value.type = schema::PrimitiveType_LogicalAnd; - primitive->value.value = attr.release(); - *primitiveC = PrimitiveC::Create(primitive.release()); + status = CreateOperator(primitive, schema::PrimitiveType_LogicalAnd); } else if (tf_op.op() == "LogicalOr") { - auto attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - primitive->value.type = schema::PrimitiveType_LogicalOr; - primitive->value.value = attr.release(); - *primitiveC = PrimitiveC::Create(primitive.release()); + status = CreateOperator(primitive, schema::PrimitiveType_LogicalOr); } else if (tf_op.op() == "LogicalNot") { - auto attr = std::make_unique(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new op failed"; - return RET_NULL_PTR; - } - primitive->value.type = schema::PrimitiveType_LogicalNot; - primitive->value.value = attr.release(); - *primitiveC = PrimitiveC::Create(primitive.release()); - } else { - MS_LOG(ERROR) << tf_op.op() << " is not supported."; - return RET_ERROR; + status = CreateOperator(primitive, schema::PrimitiveType_LogicalNot); } + if (status != RET_OK) { + return status; + } + *primitiveC = PrimitiveC::Create(primitive.release()); if (*primitiveC == nullptr) { MS_LOG(ERROR) << "primitiveC is nullptr"; return RET_ERROR; @@ -79,8 +62,8 @@ STATUS TFLogicalParser::Parse(const tensorflow::NodeDef &tf_op, return RET_OK; } -TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); -TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser()); TFNodeRegistrar g_tfLogicalNotParser("LogicalNot", new TFLogicalParser()); +TFNodeRegistrar g_tfLogicalOrParser("LogicalOr", new TFLogicalParser()); +TFNodeRegistrar g_tfLogicalAndParser("LogicalAnd", new TFLogicalParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.cc new file mode 100644 index 00000000000..b72a6ecc925 --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.cc @@ -0,0 +1,60 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/converter/parser/tf/tf_zeros_like_parser.h" +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser_registry.h" + +namespace mindspore { +namespace lite { +STATUS TFZerosLikeParser::Parse(const tensorflow::NodeDef &tf_op, + const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) { + if (primitiveC == nullptr || output_size == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_NULL_PTR; + } + + auto primitive = std::make_unique(); + if (primitive == nullptr) { + MS_LOG(ERROR) << "primitive is nullptr"; + return RET_NULL_PTR; + } + auto attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + + primitive->value.type = schema::PrimitiveType_ZerosLike; + primitive->value.value = attr.release(); + *primitiveC = PrimitiveC::Create(primitive.release()); + if (*primitiveC == nullptr) { + MS_LOG(ERROR) << "primitiveC is nullptr"; + return RET_ERROR; + } + + *output_size = tf_op.input_size(); + for (int i = 0; i < tf_op.input_size(); i++) { + inputs->emplace_back(tf_op.input(i)); + } + return RET_OK; +} +TFNodeRegistrar g_tfZerosLikeParser("ZerosLike", new TFZerosLikeParser()); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.h b/mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.h new file mode 100644 index 00000000000..5be3bfd272a --- /dev/null +++ b/mindspore/lite/tools/converter/parser/tf/tf_zeros_like_parser.h @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ +#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_ +#include +#include +#include +#include +#include "tools/converter/parser/tf/tf_node_parser.h" + +namespace mindspore { +namespace lite { +class TFZerosLikeParser : public TFNodeParser { + public: + TFZerosLikeParser() = default; + ~TFZerosLikeParser() override = default; + + STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map &tf_node_map, + PrimitiveC **primitiveC, std::vector *inputs, int *output_size) override; +}; +} // namespace lite +} // namespace mindspore + +#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_ZERO_LIKE_PARSER_H_