[feat][assistant][I48OCS]add new dynamic shape operator

This commit is contained in:
zhangjie 2021-10-26 16:13:27 +08:00
parent 90cbad6baf
commit e9c9088984
5 changed files with 222 additions and 19 deletions

View File

@ -0,0 +1,106 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/apply_ftrl.h"
#include <algorithm>
#include <set>
#include "abstract/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr ApplyFtrlInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t kInputNum = 8;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto var_shape = input_args[kInputIndex0]->BuildShape();
auto accum_shape = input_args[kInputIndex1]->BuildShape();
auto linear_shape = input_args[kInputIndex2]->BuildShape();
if (var_shape->IsDynamic() || accum_shape->IsDynamic() || linear_shape->IsDynamic()) {
return var_shape->cast<abstract::ShapePtr>();
}
std::map<std::string, abstract::BaseShapePtr> same_shape_args_map;
same_shape_args_map.insert({"accum", accum_shape});
same_shape_args_map.insert({"linear", linear_shape});
for (auto &elem : same_shape_args_map) {
if (*elem.second != *var_shape) {
MS_EXCEPTION(ValueError) << prim_name << " evaluator arg " << elem.first << " shape " << elem.second->ToString()
<< " are not consistent with var shape " << var_shape->ToString();
}
}
auto shape_ptr = var_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_ptr);
return shape_ptr;
}
TypePtr ApplyFtrlInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_name = prim->name();
const int64_t kInputNum = 8;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto var_type = input_args[kInputIndex0]->BuildType();
auto accum_type = input_args[kInputIndex1]->BuildType();
auto linear_type = input_args[kInputIndex2]->BuildType();
auto grad_type = input_args[kInputIndex3]->BuildType();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> args;
(void)args.insert({"var_type", var_type});
(void)args.insert({"accum_type", accum_type});
(void)args.insert({"linear_type", linear_type});
(void)args.insert({"grad_type", grad_type});
// var accum linear grad must have same dtypes
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
auto lr_type = input_args[kInputIndex4]->BuildType();
auto l1_type = input_args[kInputIndex5]->BuildType();
auto l2_type = input_args[kInputIndex6]->BuildType();
auto lr_power_type = input_args[kInputIndex7]->BuildType();
std::map<std::string, TypePtr> args_lr;
std::map<std::string, TypePtr> args_l1;
std::map<std::string, TypePtr> args_l2;
std::map<std::string, TypePtr> args_lr_power;
(void)args_lr.insert({"lr_type", lr_type});
(void)args_l1.insert({"l1_type", l1_type});
(void)args_l2.insert({"l2_type", l2_type});
(void)args_lr_power.insert({"lr_power_type", lr_power_type});
// lr, l1, l2, lr_power type must be float or scalar tensor with float
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l1, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_l2, valid_types, prim_name);
(void)CheckAndConvertUtils::CheckScalarOrTensorTypesSame(args_lr_power, valid_types, prim_name);
return var_type;
}
} // namespace
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t kInputNum = 8;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputNum, prim_name);
auto infer_type = ApplyFtrlInferType(primitive, input_args);
auto infer_shape = ApplyFtrlInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(ApplyFtrl, prim::kPrimApplyFtrl, ApplyFtrlInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,46 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_APPLY_FTRL_H_
#define MINDSPORE_CORE_OPS_APPLY_FTRL_H_
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "abstract/abstract_value.h"
#include "ops/primitive_c.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameApplyFtrl = "ApplyFtrl";
class ApplyFtrl : public PrimitiveC {
public:
ApplyFtrl() : PrimitiveC(kNameApplyFtrl) {
InitIOName({"var", "accum", "linear", "grad", "lr", "l1", "l2", "lr_power"}, {"var"});
}
~ApplyFtrl() = default;
MS_DECLARE_PARENT(ApplyFtrl, PrimitiveC);
};
AbstractBasePtr ApplyFtrlInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimApplyFtrlPtr = std::shared_ptr<ApplyFtrl>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_APPLY_FTRL_H_

View File

@ -38,6 +38,7 @@ from .add_n_ds import _add_n_ds_tbe
from .accumulate_n_v2 import _accumulate_n_v2_tbe
from .accumulate_n_v2_ds import _accumulate_n_v2_ds_tbe
from .apply_ftrl import _apply_ftrl_tbe
from .apply_ftrl_ds import _apply_ftrl_ds_tbe
from .apply_keras_momentum import _apply_keras_momentum_tbe
from .apply_momentum import _apply_momentum_tbe
from .apply_momentum_ds import _apply_momentum_ds_tbe

View File

@ -0,0 +1,68 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ApplyFtrl op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
apply_ftrl_ds_op_info = TBERegOp("ApplyFtrl") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("apply_ftrl.so") \
.compute_cost(10) \
.kernel_name("apply_ftrl_d") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "var", False, "required", "all") \
.input(1, "accum", False, "required", "all") \
.input(2, "linear", False, "required", "all") \
.input(3, "grad", False, "required", "all") \
.input(4, "lr", False, "required", "all") \
.input(5, "l1", False, "required", "all") \
.input(6, "l2", False, "required", "all") \
.input(7, "lr_power", False, "required", "all") \
.output(0, "var", False, "required", "all") \
.output(1, "accum", False, "required", "all") \
.output(2, "linear", False, "required", "all") \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ,
DataType.F16_FracZ, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0,
DataType.F16_C1HWNCoC0, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ,
DataType.F32_FracZ, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0,
DataType.F32_C1HWNCoC0, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(apply_ftrl_ds_op_info)
def _apply_ftrl_ds_tbe():
"""Applyftrl TBE register"""
return

View File

@ -6569,7 +6569,7 @@ class LARSUpdate(PrimitiveWithInfer):
validator.check_value_type("use_clip", use_clip, [bool], self.name)
class ApplyFtrl(PrimitiveWithInfer):
class ApplyFtrl(Primitive):
"""
Updates relevant entries according to the FTRL scheme.
@ -6644,24 +6644,6 @@ class ApplyFtrl(PrimitiveWithInfer):
self.add_prim_attr('side_effect_mem', True)
self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name)
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape,
lr_power_shape):
validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name)
validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name)
return var_shape
def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type,
lr_power_type):
valid_dtypes = [mstype.float16, mstype.float32]
args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr": lr_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l1": l1_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"l2": l2_type}, valid_dtypes, self.name)
validator.check_scalar_or_tensor_types_same({"lr_power": lr_power_type}, valid_dtypes, self.name)
return var_type
class SparseApplyFtrl(PrimitiveWithCheck):
"""