!25059 dynamic shape ops fix

Merge pull request !25059 from wangnan39/dynamic_shape_ops_fix
This commit is contained in:
i-robot 2021-10-28 11:14:37 +00:00 committed by Gitee
commit 6c9bd5b3cb
40 changed files with 551 additions and 144 deletions

View File

@ -20,14 +20,30 @@
namespace mindspore {
namespace kernel {
namespace {
const size_t kInputNum = 2;
const int kInputNum = 2;
const size_t one = 1;
void UpdatePreIsOne(std::vector<bool> *prev_is_one, std::vector<bool> current_is_one) {
for (size_t i = 0; i < kInputNum; ++i) {
(*prev_is_one)[i] = current_is_one[i];
}
}
void AddElementToGradReduceIdx(std::vector<std::vector<int64_t>> *grad_reduce_idx, std::vector<bool> current_is_one,
bool none_is_one, const size_t largest_rank, size_t j) {
MS_EXCEPTION_IF_NULL(grad_reduce_idx);
for (size_t i = 0; i < kInputNum; ++i) {
if (current_is_one[i] && !none_is_one) {
(void)(*grad_reduce_idx)[i].emplace_back(SizeToLong(largest_rank - one - j));
}
}
}
std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vector<int64_t>> &reverse_shape,
const size_t largest_rank) {
std::vector<std::vector<int64_t>> grad_reduce_idx(kInputNum);
// indices of j-th component of each input.
bool prev_is_one[kInputNum];
bool current_is_one[kInputNum];
std::vector<bool> prev_is_one(kInputNum);
std::vector<bool> current_is_one(kInputNum);
for (size_t i = 0; i < kInputNum; ++i) {
prev_is_one[i] = false;
current_is_one[i] = false;
@ -46,37 +62,26 @@ std::vector<std::vector<int64_t>> GetGradientIndices(const std::vector<std::vect
} else {
current_is_one[i] = false;
if (!output_dim_set || reverse_shape[i][j] == static_cast<int64_t>(output_dim)) {
output_dim = static_cast<int>(reverse_shape[i][j]);
output_dim = LongToInt(reverse_shape[i][j]);
output_dim_set = true;
} else {
MS_LOG(EXCEPTION) << "Input[0] and input[1] Cannot broadcast!";
}
}
}
// All dimensions are 1.
if (!output_dim_set) {
for (size_t i = 0; i < kInputNum; ++i) {
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
(void)grad_reduce_idx[i].emplace_back(SizeToLong(largest_rank - one - j));
}
continue;
} else if (std::equal(current_is_one, current_is_one + kInputNum, prev_is_one) && set_one) {
for (size_t i = 0; i < kInputNum; ++i) {
if (current_is_one[i] && !none_is_one) {
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
}
}
} else if (std::equal(current_is_one.begin(), current_is_one.end(), prev_is_one.begin()) && set_one) {
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
} else {
for (size_t i = 0; i < kInputNum; ++i) {
if (current_is_one[i] && !none_is_one) {
(void)grad_reduce_idx[i].emplace_back(largest_rank - 1 - j);
}
}
AddElementToGradReduceIdx(&grad_reduce_idx, current_is_one, none_is_one, largest_rank, j);
}
set_one = true;
for (size_t i = 0; i < kInputNum; ++i) {
prev_is_one[i] = current_is_one[i];
}
UpdatePreIsOne(&prev_is_one, current_is_one);
}
return grad_reduce_idx;
}
@ -172,9 +177,10 @@ size_t SetOutputValue(const CNodePtr &cnode, const std::vector<std::vector<int64
*(data_ptr + i) = output[i];
}
(void)out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()),
tensor_for_sync->data_type(), tensor_for_sync->data_c(),
tensor_for_sync->device_info().host_format_);
if (!out_addr->SyncHostToDevice(out_shape, LongToSize(tensor_for_sync->data().nbytes()), tensor_for_sync->data_type(),
tensor_for_sync->data_c(), tensor_for_sync->device_info().host_format_)) {
MS_LOG(EXCEPTION) << "Output Value SyncHostToDevice failed.";
}
return out_size;
}
} // namespace

View File

@ -30,7 +30,7 @@ namespace kernel {
void GetRealInputSize(const nlohmann::json &input_json, std::vector<size_t> *input_size_list, size_t *size_i) {
if (input_json[kJShape].size() == 1 && input_json[kJShape][0] == -2) {
auto input_max_shape = input_json[kJShape];
auto input_max_shape = input_json[kJRange];
for (auto &max_shape : input_max_shape) {
(*size_i) *= LongToSize(max_shape[1]);
}
@ -77,7 +77,7 @@ void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *inp
void GetRealOutputSize(const nlohmann::json &output_json, std::vector<size_t> *output_size_list, size_t *size_i) {
if (output_json[kJShape].size() == 1 && output_json[kJShape][0] == -2) {
auto output_max_shape = output_json[kJShape];
auto output_max_shape = output_json[kJRange];
for (auto &max_shape : output_max_shape) {
(*size_i) *= LongToSize(max_shape[1]);
}

View File

@ -122,10 +122,9 @@ bool CheckIndexOutput(const CNodePtr &node, const std::shared_ptr<kernel::Kernel
void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size_t cast_index) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(cast);
using Shape = std::vector<size_t>;
auto cast_dtype = AnfAlgo::GetOutputInferDataType(cast, 0);
auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0);
std::vector<Shape> shapes;
auto cast_shape = AnfAlgo::GetOutputDetailShape(cast, 0);
std::vector<abstract::BaseShapePtr> shapes;
std::vector<TypeId> types;
size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
for (size_t index = 0; index < output_num; ++index) {
@ -134,10 +133,10 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size
(void)types.emplace_back(cast_dtype);
continue;
}
(void)shapes.emplace_back(AnfAlgo::GetOutputInferShape(cnode, index));
(void)shapes.emplace_back(AnfAlgo::GetOutputDetailShape(cnode, index));
(void)types.emplace_back(AnfAlgo::GetOutputInferDataType(cnode, index));
}
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, cnode.get());
AnfAlgo::SetOutputTypeAndDetailShape(types, shapes, cnode.get());
auto prim_op = AnfAlgo::GetCNodePrimitive(cnode);
if (prim_op != nullptr) {
(void)prim_op->AddAttr("cast_type", TypeIdToType(cast_dtype));

View File

@ -48,6 +48,7 @@ std::string OpTilingCalculateAdapter::GetRealOpType(const std::string &op_type)
{"Softmax", "SoftmaxV2"},
{"DropoutDoMask", "DropOutDoMask"},
{"IOU", "Iou"},
{"DynamicBroadcastTo", "BroadcastTo"},
};
auto iter = kOpTypeMap.find(op_type);
if (iter == kOpTypeMap.end()) {

View File

@ -110,18 +110,34 @@ TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &acce
return CheckType(type, accepts, error_message_prefix);
}
ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {
MS_EXCEPTION_IF_NULL(tensor_base);
ShapePtr shape_base = tensor_base->shape();
MS_EXCEPTION_IF_NULL(shape_base);
MS_EXCEPTION_IF_NULL(tensor);
ShapePtr shape = tensor->shape();
MS_EXCEPTION_IF_NULL(shape);
if (*shape != *shape_base) {
if (shape_base->IsDimUnknown() || shape->IsDimUnknown()) {
return;
}
auto shape_vector = shape->shape();
auto shape_base_vector = shape_base->shape();
if (shape_vector.size() != shape_base_vector.size()) {
MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString()
<< " are not consistent with second arg shape " << shape_base->ToString();
}
return shape_base;
for (size_t i = 0; i < shape_vector.size(); i++) {
if (shape_vector[i] == Shape::SHP_ANY || shape_base_vector[i] == Shape::SHP_ANY) {
continue;
}
if (shape_vector[i] != shape_base_vector[i]) {
MS_LOG(EXCEPTION) << op << " evaluator first arg shape " << shape->ToString()
<< " are not consistent with second arg shape " << shape_base->ToString();
}
}
return;
}
TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor) {

View File

@ -41,7 +41,7 @@ TypePtr CheckTensorsDTypeSame(const AbstractTensorPtrList &tensor_list, const Ty
TypePtr CheckScalarType(const AbstractScalarPtr &scalar, const TypePtrList &accepts,
const std::string &error_message_prefix);
ShapePtr CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);
void CheckShapeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);
TypePtr CheckDtypeSame(const std::string &op, const AbstractTensorPtr &tensor_base, const AbstractTensorPtr &tensor);

View File

@ -114,7 +114,7 @@ AbstractBasePtr InferImplStack(const AnalysisEnginePtr &, const PrimitivePtr &pr
for (size_t i = 1; i < tuple_len; ++i) {
AbstractTensorPtr tensor = CheckArg<AbstractTensor>(op_name, arg->elements(), i);
(void)CheckDtypeSame(op_name, tensor_base, tensor);
(void)CheckShapeSame(op_name, tensor_base, tensor);
CheckShapeSame(op_name, tensor_base, tensor);
}
auto element = tensor_base->element();
MS_EXCEPTION_IF_NULL(element);
@ -1241,6 +1241,7 @@ AbstractBasePtr InferImplMaskedSelect(const AnalysisEnginePtr &, const Primitive
AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
MS_EXCEPTION_IF_NULL(primitive);
bool output_shape_unknow = false;
auto prim_name = primitive->name();
constexpr int64_t args_size = 2;
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(args_spec_list.size()), kEqual, args_size,
@ -1253,6 +1254,22 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
auto input_tuple = args_spec_list[0]->cast<abstract::AbstractSequeuePtr>();
MS_EXCEPTION_IF_NULL(input_tuple);
auto indices = input_tuple->elements();
auto input_indice_size = input_tuple->size();
int64_t first_dim_size = 0;
for (size_t i = 0; i < input_indice_size; i++) {
auto indicei = indices[i]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(indicei);
auto valuei = indicei->BuildValue();
MS_EXCEPTION_IF_NULL(valuei);
if (!valuei->isa<tensor::Tensor>()) {
output_shape_unknow = true;
continue;
}
auto indicei_value = CheckAndConvertUtils::CheckTensorIntValue("indices", valuei, prim_name);
auto indicei_max = std::max_element(indicei_value.begin(), indicei_value.end());
first_dim_size = *indicei_max > first_dim_size ? *indicei_max : first_dim_size;
}
auto indices0 = indices[0]->cast<abstract::AbstractTensorPtr>();
MS_EXCEPTION_IF_NULL(indices0);
auto indices0_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(indices0->BuildShape())[kShape];
@ -1282,7 +1299,12 @@ AbstractBasePtr InferImplDynamicStitch(const AnalysisEnginePtr &, const Primitiv
std::set<TypePtr> valid_types = ops::common_valid_types;
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim_name);
ShapeVector out_shape = {abstract::Shape::SHP_ANY};
ShapeVector out_shape;
if (output_shape_unknow) {
out_shape.push_back(abstract::Shape::SHP_ANY);
} else {
out_shape.push_back(first_dim_size + 1);
}
for (size_t i = indices0_shape.size(); i < data0_shape.size(); ++i) {
out_shape.push_back(data0_shape[i]);
}

View File

@ -60,7 +60,7 @@ AbstractBasePtr InferImplSqrtGrad(const AnalysisEnginePtr &, const PrimitivePtr
auto out = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto dout = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
(void)CheckDtypeSame(op_name, out, dout);
(void)CheckShapeSame(op_name, out, dout);
CheckShapeSame(op_name, out, dout);
return out->Broaden();
}

View File

@ -47,6 +47,7 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
const auto kStridedSlice = prim::kPrimStridedSlice->name();
const auto kStridedSliceGrad = prim::kPrimStridedSliceGrad->name();
const auto kReduceSum = prim::kPrimReduceSum->name();
const auto kDynamicBroadcastTo = prim::kPrimDynamicBroadcastTo->name();
const auto kUnsortedSegmentSum = prim::kPrimUnsortedSegmentSum->name();
const auto kUnsortedSegmentMin = prim::kPrimUnsortedSegmentMin->name();
const auto kUnsortedSegmentMax = prim::kPrimUnsortedSegmentMax->name();
@ -78,7 +79,8 @@ std::vector<int64_t> GetDependsFormMap(const CNodePtr &cnode) {
{kTile, {1}},
{kReshape, {1}},
{kSlice, {1, 2}},
{kSliceGrad, {2, 3}}};
{kSliceGrad, {2, 3}},
{kDynamicBroadcastTo, {1}}};
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);

View File

@ -95,6 +95,7 @@ constexpr auto kDiagPart = "DiagPart";
constexpr auto kDynamicBroadcastGradientArgs = "DynamicBroadcastGradientArgs";
constexpr auto kTranspose = "Transpose";
constexpr auto kSplitV = "SplitV";
constexpr auto kDynamicBroadcastTo = "DynamicBroadcastTo";
// NN
constexpr auto kCTCLoss = "CTCLoss";
@ -170,6 +171,7 @@ inline const PrimitivePtr kPrimStackPush = std::make_shared<Primitive>("StackPus
inline const PrimitivePtr kPrimStackPop = std::make_shared<Primitive>("StackPop");
// Arrays
inline const PrimitivePtr kPrimDynamicBroadcastTo = std::make_shared<Primitive>(kDynamicBroadcastTo);
inline const PrimitivePtr kPrimBroadcastTo = std::make_shared<Primitive>("BroadcastTo");
inline const PrimitivePtr kPrimScalarToArray = std::make_shared<Primitive>("scalar_to_array");
inline const PrimitivePtr kPrimTopK = std::make_shared<Primitive>("TopK");

View File

@ -73,8 +73,8 @@ abstract::ShapePtr AudioSpectrogramInferShape(const PrimitivePtr &primitive,
}
TypePtr AudioSpectrogramInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
const int64_t x_index = 0;
return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim->name());
const size_t x_index = 0;
return CheckAndConvertUtils::GetTensorInputType(prim->name(), input_args, x_index);
}
} // namespace

View File

@ -115,7 +115,20 @@ TypePtr BatchMatmulInferType(const PrimitivePtr &prim, const std::vector<Abstrac
std::map<std::string, TypePtr> types;
(void)types.emplace("x", input_args[0]->BuildType());
(void)types.emplace("w", input_args[1]->BuildType());
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
(void)CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
TypePtr x_type = input_args[0]->BuildType();
if (x_type->type_id() == TypeId::kNumberTypeInt8) {
x_type = kInt32;
}
if (prim->HasAttr("cast_type")) {
auto out_type = prim->GetAttr("cast_type");
MS_EXCEPTION_IF_NULL(out_type);
if (!out_type->isa<Type>()) {
MS_EXCEPTION(ValueError) << "MatMul cast_type must be a `Type`";
}
x_type = out_type->cast<TypePtr>();
}
return x_type;
}
} // namespace

View File

@ -48,17 +48,19 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto x_shape_vector = x_shape->shape();
auto mask_shape_vector = mask_shape->shape();
int64_t x_size = 1;
for (size_t i = 0; i < x_shape_vector.size(); i++) {
x_size *= x_shape_vector[i];
}
if (mask_shape_vector.size() != 1) {
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask must be 1-dimension.";
}
auto mask_size = mask_shape_vector[0] * 8;
if (x_size > mask_size) {
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString()
<< ", mask shape: " << mask_shape->ToString();
if (!x_shape->IsDynamic() && !mask_shape->IsDynamic()) {
int64_t x_size = 1;
for (size_t i = 0; i < x_shape_vector.size(); i++) {
x_size *= x_shape_vector[i];
}
if (mask_shape_vector.size() != 1) {
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask must be 1-dimension.";
}
auto mask_size = mask_shape_vector[0] * 8;
if (x_size > mask_size) {
MS_EXCEPTION(ValueError) << "DropoutDoMask input mask do not match input, input_x shape: " << x_shape->ToString()
<< ", mask shape: " << mask_shape->ToString();
}
}
auto keep_prop = input_args[kInputIndex2];
if (keep_prop->isa<abstract::AbstractTensor>()) {

View File

@ -86,7 +86,6 @@ ShapeVector CalOutputShape(const AbstractBasePtrList shape_list) {
}
count = count * value;
}
// convert to bytes(8 bits) mask, using round up
int64_t n128s = count / mask_convert_len;
if ((count % mask_convert_len) != 0) {
@ -106,7 +105,18 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
AbstractBasePtr shape_args = input_args[0];
MS_EXCEPTION_IF_NULL(shape_args);
ShapeVector out_shape;
if (shape_args->isa<abstract::AbstractTensor>()) {
auto shape_value = shape_args->BuildValue();
MS_EXCEPTION_IF_NULL(shape_value);
if (shape_value->isa<tensor::Tensor>()) {
auto mask_shape = CheckAndConvertUtils::CheckTensorIntValue("shape", shape_value, op_name);
std::vector<ValuePtr> value_elements;
std::transform(mask_shape.begin(), mask_shape.end(), std::back_inserter(value_elements),
[](int64_t elem) { return MakeValue(elem); });
out_shape = CalDynamicOutputShape(value_elements);
return std::make_shared<abstract::Shape>(out_shape);
}
auto shape_abstract = dyn_cast<abstract::AbstractTensor>(shape_args);
MS_EXCEPTION_IF_NULL(shape_abstract);
auto shape_base = shape_abstract->BuildShape();
@ -139,7 +149,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto x_shape = dyn_cast<abstract::AbstractTuple>(shape_args);
auto x_shape_data = x_shape->elements();
ShapeVector out_shape = CalOutputShape(x_shape_data);
out_shape = CalOutputShape(x_shape_data);
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {

View File

@ -0,0 +1,91 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/dynamic_broadcast_to.h"
#include <set>
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 2, prim_name);
auto input_y = input_args[1];
MS_EXCEPTION_IF_NULL(input_y);
abstract::ShapePtr y_shape;
auto y_value = input_y->BuildValue();
MS_EXCEPTION_IF_NULL(y_value);
if (input_y->isa<abstract::AbstractTensor>()) {
if (y_value->isa<tensor::Tensor>()) {
auto shape_value = CheckAndConvertUtils::CheckTensorIntValue("shape", y_value, prim_name);
return std::make_shared<abstract::Shape>(shape_value);
}
y_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 1);
auto shape_value = y_shape->shape();
if (shape_value.size() != 1) {
MS_EXCEPTION(TypeError) << "shape size error: " << shape_value.size();
}
std::vector<int64_t> output_shape;
std::vector<int64_t> max_shape;
std::vector<int64_t> min_shape;
if (y_shape->IsDynamic()) {
// max shape unknown
output_shape.push_back(-2);
} else {
auto out_dims = LongToSize(y_shape->shape()[0]);
for (size_t i = 0; i < out_dims; i++) {
output_shape.push_back(-1);
}
auto min_value = input_y->cast<abstract::AbstractTensorPtr>()->get_min_value();
auto max_value = input_y->cast<abstract::AbstractTensorPtr>()->get_max_value();
if (!min_value || !max_value) {
MS_EXCEPTION(ValueError) << "For BroadcastTo, inputs['shape'] min or max value is empty.";
}
min_shape = GetValue<std::vector<int64_t>>(min_value);
max_shape = GetValue<std::vector<int64_t>>(max_value);
if (min_shape.size() != out_dims || max_shape.size() != out_dims) {
MS_EXCEPTION(ValueError) << "For BroadcastTo, inputs['shape'] min or max value not match with out dims.";
}
}
return std::make_shared<abstract::Shape>(output_shape, min_shape, max_shape);
} else if (input_y->isa<abstract::AbstractTuple>()) {
auto out_shape = GetValue<std::vector<int64_t>>(y_value);
return std::make_shared<abstract::Shape>(out_shape);
}
MS_EXCEPTION(TypeError) << "For BroadcastTo, input args must be tensor or tuple.";
}
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_dtype = input_args[0]->BuildType()->cast<TensorTypePtr>();
std::set<TypePtr> template_types = {kTensorType};
CheckAndConvertUtils::CheckSubClass("x_dtype", x_dtype, template_types, prim->name());
return x_dtype->element();
}
} // namespace
AbstractBasePtr DynamicBroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(InferShape(primitive, input_args), InferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(DynamicBroadcastTo, prim::kPrimDynamicBroadcastTo, DynamicBroadcastToInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_
#define MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/op_utils.h"
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
namespace mindspore {
namespace ops {
class DynamicBroadcastTo : public PrimitiveC {
public:
DynamicBroadcastTo() : PrimitiveC(prim::kPrimDynamicBroadcastTo->name()) { InitIOName({"x", "shape"}, {"y"}); }
~DynamicBroadcastTo() = default;
MS_DECLARE_PARENT(DynamicBroadcastTo, PrimitiveC);
void Init() {}
};
AbstractBasePtr DynamicBroadcastToInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimDynamicBroadcastToPtr = std::shared_ptr<DynamicBroadcastTo>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_DYNAMIC_BROADCAST_TO_H_

View File

@ -48,8 +48,8 @@ AbstractBasePtr ExpandDimsInfer(const abstract::AnalysisEnginePtr &, const Primi
(void)out_shape.insert(out_shape.begin() + dim_val, 1, 1);
// Infer type
const int64_t x_index = 0;
auto x_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, prim_name);
const size_t x_index = 0;
auto x_type = CheckAndConvertUtils::GetTensorInputType(prim_name, input_args, x_index);
std::set<TypePtr> valid_x_type = {kTensorType};
(void)CheckAndConvertUtils::CheckSubClass("x_type", x_type, valid_x_type, prim_name);
return std::make_shared<abstract::AbstractTensor>(x_type, out_shape);

View File

@ -24,9 +24,9 @@
namespace mindspore {
namespace ops {
namespace {
constexpr int64_t kDoutIndex = 0;
constexpr int64_t kInputIndex = 1;
constexpr int64_t kFilterSizeIdex = 2;
constexpr size_t kDoutIndex = 0;
constexpr size_t kInputIndex = 1;
constexpr size_t kFilterSizeIdex = 2;
constexpr size_t kStride2dSize = 2;
constexpr size_t kStride4dSize = 4;

View File

@ -27,7 +27,7 @@ namespace ops {
namespace {
constexpr size_t kDoutIndex = 0;
constexpr size_t kInputIndex = 1;
constexpr int64_t kSizeIndex = 2;
constexpr size_t kSizeIndex = 2;
void SetPadList(const PrimitivePtr &primitive, const std::vector<int64_t> &dout_shape_norm,
const std::vector<int64_t> &x_size_v) {

View File

@ -44,8 +44,8 @@ AbstractBasePtr DropoutGradInfer(const abstract::AnalysisEnginePtr &, const Prim
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, op_name);
const int64_t dy_index = 0;
const int64_t mask_index = 1;
const size_t dy_index = 0;
const size_t mask_index = 1;
auto dy_type = input_args[dy_index]->BuildType();
auto mask_type = input_args[mask_index]->BuildType();

View File

@ -37,7 +37,7 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}
auto dout = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto out = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 1);
(void)abstract::CheckShapeSame(prim_name, out, dout);
abstract::CheckShapeSame(prim_name, out, dout);
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();

View File

@ -124,7 +124,7 @@ AbstractBasePtr StridedSliceGradInfer(const abstract::AnalysisEnginePtr &, const
StridedSliceGradInferType(primitive, input_args));
}
void StridedSliceGrad::set_begin_mask(const int64_t begin_mask) {
void StridedSliceGrad::set_begin_mask(int64_t begin_mask) {
(void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kBeginMask, MakeValue(begin_mask));
}
@ -133,7 +133,7 @@ int64_t StridedSliceGrad::get_begin_mask() const {
MS_EXCEPTION_IF_NULL(value_ptr);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_end_mask(const int64_t end_mask) {
void StridedSliceGrad::set_end_mask(int64_t end_mask) {
(void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kEndMask, MakeValue(end_mask));
}
@ -141,7 +141,7 @@ int64_t StridedSliceGrad::get_end_mask() const {
auto value_ptr = GetAttr(kEndMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_ellipsis_mask(const int64_t ellipsis_mask) {
void StridedSliceGrad::set_ellipsis_mask(int64_t ellipsis_mask) {
(void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
std::ostringstream buffer;
@ -155,7 +155,7 @@ int64_t StridedSliceGrad::get_ellipsis_mask() const {
auto value_ptr = GetAttr(kEllipsisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_new_axis_mask(const int64_t new_axis_mask) {
void StridedSliceGrad::set_new_axis_mask(int64_t new_axis_mask) {
(void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
}
@ -163,7 +163,7 @@ int64_t StridedSliceGrad::get_new_axis_mask() const {
auto value_ptr = GetAttr(kNewAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
void StridedSliceGrad::set_shrink_axis_mask(int64_t shrink_axis_mask) {
(void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
}
@ -171,8 +171,8 @@ int64_t StridedSliceGrad::get_shrink_axis_mask() const {
auto value_ptr = GetAttr(kShrinkAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSliceGrad::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
void StridedSliceGrad::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
int64_t shrink_axis_mask) {
this->set_begin_mask(begin_mask);
this->set_end_mask(end_mask);
this->set_ellipsis_mask(ellipsis_mask);

View File

@ -34,13 +34,13 @@ class MS_CORE_API StridedSliceGrad : public PrimitiveC {
~StridedSliceGrad() = default;
MS_DECLARE_PARENT(StridedSliceGrad, PrimitiveC);
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0);
void set_begin_mask(const int64_t begin_mask);
void set_end_mask(const int64_t end_mask);
void set_ellipsis_mask(const int64_t ellipsis_mask);
void set_new_axis_mask(const int64_t new_axis_mask);
void set_shrink_axis_mask(const int64_t shrink_axis_mask);
void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0,
int64_t shrink_axis_mask = 0);
void set_begin_mask(int64_t begin_mask);
void set_end_mask(int64_t end_mask);
void set_ellipsis_mask(int64_t ellipsis_mask);
void set_new_axis_mask(int64_t new_axis_mask);
void set_shrink_axis_mask(int64_t shrink_axis_mask);
int64_t get_begin_mask() const;
int64_t get_end_mask() const;
int64_t get_ellipsis_mask() const;

View File

@ -29,7 +29,7 @@ namespace ops {
namespace {
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
(void)CheckAndConvertUtils::CheckInteger("infer_shape", int64_t(input_args.size()), kGreaterEqual, 1, op_name);
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kGreaterEqual, 1, op_name);
return CheckAndConvertUtils::GetTensorInputShape(op_name, input_args, 0);
}

View File

@ -33,9 +33,9 @@ AbstractBasePtr QuantDTypeCastInfer(const abstract::AnalysisEnginePtr &, const P
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
const int64_t x_index = 0;
const size_t x_index = 0;
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, input_num, primitive->name());
auto input_type = CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name());
auto input_type = CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index);
auto dst_type = TypeIdToType(TypeId(GetValue<int64_t>(primitive->GetAttr(kDstT))));
MS_EXCEPTION_IF_NULL(dst_type);
if (input_type != dst_type) {

View File

@ -48,7 +48,7 @@ void InferImplReduceFuncCalShape(ShapeVector *shape, const ShapeVector &x_shape,
ValuePtrList::iterator it;
if (keep_dims_value) {
for (it = axis_items.begin(); it != axis_items.end(); ++it) {
auto axis_value = GetValue<int64_t>(*it);
auto axis_value = InferImplReduceFuncCheckAxis(GetValue<int64_t>(*it), x_shape.size());
shape->at(LongToSize(axis_value)) = 1;
}
} else {
@ -108,10 +108,8 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto axis_shape = axis_tensor->shape()->shape();
if (axis_shape.size() == 1 && axis_shape[0] == -1 && !keep_dims) {
out_shape.push_back(-2);
for (size_t i = 0; i < input_shape.size(); ++i) {
out_min_shape.push_back(1);
out_max_shape.push_back(max_v);
}
out_min_shape = input_min_shape;
out_max_shape = input_max_shape;
} else if (!keep_dims) {
for (size_t i = 0; i < input_shape.size() - axis_shape.size(); ++i) {
out_shape.push_back(-1);
@ -136,7 +134,6 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
}
MS_EXCEPTION_IF_NULL(axis_ptr);
if (axis_ptr->isa<tensor::Tensor>()) {
MS_LOG(ERROR) << "Tensor with value";
auto axis_type = input_args[1]->BuildType();
MS_EXCEPTION_IF_NULL(axis_type);
auto axis_type_id = axis_type->cast<TensorTypePtr>();
@ -178,8 +175,9 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
return CheckAndConvertUtils::CheckTensorTypeValid("x dtype", input_args[0]->BuildType(), common_valid_types,
"ReduceSum");
auto x_type = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x dtype", x_type, common_valid_types, prim->name());
return x_type;
}
} // namespace

View File

@ -321,7 +321,7 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t x_index = 0;
const size_t x_index = 0;
auto x_shape = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, x_index);
if (x_shape->IsDynamic()) {
MS_EXCEPTION(ValueError) << "input x dynamic shape is currently not supported.";
@ -363,12 +363,12 @@ abstract::ShapePtr StridedSliceInferShape(const PrimitivePtr &primitive,
}
TypePtr StridedSliceInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
const int64_t x_index = 0;
return CheckAndConvertUtils::GetInputTensorType(input_args, x_index, primitive->name());
const size_t x_index = 0;
return CheckAndConvertUtils::GetTensorInputType(primitive->name(), input_args, x_index);
}
} // namespace
void StridedSlice::set_begin_mask(const int64_t begin_mask) {
void StridedSlice::set_begin_mask(int64_t begin_mask) {
(void)CheckAndConvertUtils::CheckInteger(kBeginMask, begin_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kBeginMask, MakeValue(begin_mask));
}
@ -376,7 +376,7 @@ int64_t StridedSlice::get_begin_mask() const {
auto value_ptr = GetAttr(kBeginMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_end_mask(const int64_t end_mask) {
void StridedSlice::set_end_mask(int64_t end_mask) {
(void)CheckAndConvertUtils::CheckInteger(kEndMask, end_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kEndMask, MakeValue(end_mask));
}
@ -384,7 +384,7 @@ int64_t StridedSlice::get_end_mask() const {
auto value_ptr = GetAttr(kEndMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_ellipsis_mask(const int64_t ellipsis_mask) {
void StridedSlice::set_ellipsis_mask(int64_t ellipsis_mask) {
(void)CheckAndConvertUtils::CheckInteger(kEllipsisMask, ellipsis_mask, kGreaterEqual, 0, this->name());
std::bitset<sizeof(int64_t) * 8> bs(ellipsis_mask);
std::ostringstream buffer;
@ -398,7 +398,7 @@ int64_t StridedSlice::get_ellipsis_mask() const {
auto value_ptr = GetAttr(kEllipsisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_new_axis_mask(const int64_t new_axis_mask) {
void StridedSlice::set_new_axis_mask(int64_t new_axis_mask) {
(void)CheckAndConvertUtils::CheckInteger(kNewAxisMask, new_axis_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kNewAxisMask, MakeValue(new_axis_mask));
}
@ -406,7 +406,7 @@ int64_t StridedSlice::get_new_axis_mask() const {
auto value_ptr = GetAttr(kNewAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::set_shrink_axis_mask(const int64_t shrink_axis_mask) {
void StridedSlice::set_shrink_axis_mask(int64_t shrink_axis_mask) {
(void)CheckAndConvertUtils::CheckInteger(kShrinkAxisMask, shrink_axis_mask, kGreaterEqual, 0, this->name());
(void)this->AddAttr(kShrinkAxisMask, MakeValue(shrink_axis_mask));
}
@ -414,8 +414,8 @@ int64_t StridedSlice::get_shrink_axis_mask() const {
auto value_ptr = GetAttr(kShrinkAxisMask);
return GetValue<int64_t>(value_ptr);
}
void StridedSlice::Init(const int64_t begin_mask, const int64_t end_mask, const int64_t ellipsis_mask,
const int64_t new_axis_mask, const int64_t shrink_axis_mask) {
void StridedSlice::Init(int64_t begin_mask, int64_t end_mask, int64_t ellipsis_mask, int64_t new_axis_mask,
int64_t shrink_axis_mask) {
this->set_begin_mask(begin_mask);
this->set_end_mask(end_mask);
this->set_ellipsis_mask(ellipsis_mask);

View File

@ -38,18 +38,18 @@ class MS_CORE_API StridedSlice : public PrimitiveC {
~StridedSlice() = default;
MS_DECLARE_PARENT(StridedSlice, PrimitiveC);
/// \brief Init. Refer to the parameters of python API @ref mindspore.ops.StridedSlice for the inputs.
void Init(const int64_t begin_mask = 0, const int64_t end_mask = 0, const int64_t ellipsis_mask = 0,
const int64_t new_axis_mask = 0, const int64_t shrink_axis_mask = 0);
void Init(int64_t begin_mask = 0, int64_t end_mask = 0, int64_t ellipsis_mask = 0, int64_t new_axis_mask = 0,
int64_t shrink_axis_mask = 0);
/// \brief Set begin_mask.
void set_begin_mask(const int64_t begin_mask);
void set_begin_mask(int64_t begin_mask);
/// \brief Set end_mask.
void set_end_mask(const int64_t end_mask);
void set_end_mask(int64_t end_mask);
/// \brief Set ellipsis_mask.
void set_ellipsis_mask(const int64_t ellipsis_mask);
void set_ellipsis_mask(int64_t ellipsis_mask);
/// \brief Set new_axis_mask.
void set_new_axis_mask(const int64_t new_axis_mask);
void set_new_axis_mask(int64_t new_axis_mask);
/// \brief Set shrink_axis_mask.
void set_shrink_axis_mask(const int64_t shrink_axis_mask);
void set_shrink_axis_mask(int64_t shrink_axis_mask);
/// \brief Get begin_mask.
///
/// \return begin_mask.

View File

@ -378,28 +378,6 @@ void CheckAndConvertUtils::CheckInputArgs(const std::vector<AbstractBasePtr> &in
}
}
TypePtr CheckAndConvertUtils::GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
const std::string &prim_name) {
if (input_args.size() <= index) {
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index
<< "] is out of the input number " << input_args.size();
}
auto input_arg = input_args[index];
if (input_arg == nullptr) {
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is nullptr.";
}
auto base_type = input_arg->BuildType();
MS_EXCEPTION_IF_NULL(base_type);
if (!base_type->isa<TensorType>()) {
MS_EXCEPTION(ValueError) << "The primitive[" << prim_name << "]'s input index[" << index << "] is not a tensor.";
}
auto tensor_type = base_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto type = tensor_type->element();
MS_EXCEPTION_IF_NULL(type);
return type;
}
ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
if (!shape->isa<abstract::Shape>()) {
@ -416,8 +394,8 @@ ShapeMap CheckAndConvertUtils::ConvertShapePtrToShapeMap(const BaseShapePtr &sha
abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &prim_name,
const std::vector<AbstractBasePtr> &input_args,
int64_t index) {
auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, LongToSize(index));
size_t index) {
auto abstract = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, index);
MS_EXCEPTION_IF_NULL(abstract);
auto base_shape = abstract->BuildShape();
MS_EXCEPTION_IF_NULL(base_shape);
@ -429,6 +407,28 @@ abstract::ShapePtr CheckAndConvertUtils::GetTensorInputShape(const std::string &
return shape;
}
TypePtr CheckAndConvertUtils::GetTensorInputType(const std::string &prim_name,
const std::vector<AbstractBasePtr> &input_args, size_t index) {
if (input_args.size() <= index) {
MS_EXCEPTION(ValueError) << "For " << prim_name << ", the index " << index << " is out of the input number "
<< input_args.size();
}
auto input_arg = input_args[index];
if (input_arg == nullptr) {
MS_EXCEPTION(ValueError) << "The " << index << "'s input of " << prim_name << " is nullptr.";
}
auto base_type = input_arg->BuildType();
MS_EXCEPTION_IF_NULL(base_type);
if (!base_type->isa<TensorType>()) {
MS_EXCEPTION(ValueError) << "The " << index << "'s input type of " << prim_name << " is not Tensor.";
}
auto tensor_type = base_type->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto type = tensor_type->element();
MS_EXCEPTION_IF_NULL(type);
return type;
}
void CheckAndConvertUtils::Check(const string &arg_name, int64_t arg_value, CompareEnum compare_type, const string &,
int64_t value, const string &prim_name, ExceptionType) {
auto iter = kCompareMap<float>.find(compare_type);

View File

@ -219,7 +219,9 @@ class CheckAndConvertUtils {
static ShapeMap ConvertShapePtrToShapeMap(const BaseShapePtr &shape);
static abstract::ShapePtr GetTensorInputShape(const std::string &prim_name,
const std::vector<AbstractBasePtr> &input_args, int64_t index);
const std::vector<AbstractBasePtr> &input_args, size_t index);
static TypePtr GetTensorInputType(const std::string &prim_name, const std::vector<AbstractBasePtr> &input_args,
size_t index);
static void Check(const std::string &arg_name, int64_t arg_value, CompareEnum compare_type,
const std::string &value_name, int64_t value, const std::string &prim_name = "",
ExceptionType exception_type = ValueError);
@ -313,8 +315,6 @@ class CheckAndConvertUtils {
static size_t GetRemoveMonadAbsNum(const AbstractBasePtrList &abs_list);
static void CheckInputArgs(const std::vector<AbstractBasePtr> &input_args, const CompareEnum compare_operator,
const int64_t match_value, const std::string &prim_name);
static TypePtr GetInputTensorType(const std::vector<AbstractBasePtr> &input_args, const size_t index,
const std::string &prim_name);
static bool HasDynamicShapeInput(const AbstractBasePtrList &abs_list);
private:

View File

@ -55,7 +55,9 @@ from .batch_matmul_ds import _batch_matmul_ds_tbe
from .batchnorm import _batch_norm_tbe
from .batchnorm_grad import _batch_norm_grad_tbe
from .bias_add import _bias_add_tbe
from .bias_add_ds import _bias_add_ds_tbe
from .bias_add_grad import _bias_add_grad_tbe
from .bias_add_grad_ds import _bias_add_grad_ds_tbe
from .cast import _cast_tbe
from .cast_ds import _cast_ds_tbe
from .conv2d import _conv2d_tbe
@ -113,6 +115,7 @@ from .scatter_nd_sub import _scatter_nd_sub_tbe
from .scatter_non_aliasing_add import _scatter_non_aliasing_add_tbe
from .reduce_mean import _reduce_mean_tbe
from .tile import _tile_tbe
from .tile_ds import _tile_ds_tbe
from .atomic_addr_clean import _atomic_addr_clean_tbe
from .gather_v2 import _gather_v2_tbe
from .gather_v2_ds import _gather_v2_ds_tbe
@ -186,6 +189,7 @@ from .sparse_apply_proximal_adagrad_ds import _sparse_apply_proximal_adagrad_ds
from .apply_proximal_adagrad import _apply_proximal_adagrad
from .transpose import _transpose_tbe
from .transpose_d import _transpose_d_tbe
from .transpose_ds import _transpose_ds_tbe
from .truncate_div import _truncate_div_tbe
from .truncate_mod import _truncate_mod_tbe
from .unsorted_segment_sum import _unsorted_segment_sum_tbe
@ -352,6 +356,7 @@ from .gru_v2_hidden_grad_cell import _gru_v2_hidden_grad_cell_tbe
from .lstm_input_grad import _lstm_input_grad_tbe
from .confusion_matrix import _confusion_matrix_tbe
from .broadcast_to import _broadcast_to_tbe
from .broadcast_to_ds import _broadcast_to_ds_tbe
from .strided_read import _strided_read_tbe
from .strided_write import _strided_write_tbe
from .range import _range_tbe

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BiasAdd op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bias_add_grad_op_info = TBERegOp("BiasAdd") \
.fusion_type("COMMREDUCE") \
.async_flag(False) \
.binfile_name("bias_add.so") \
.compute_cost(10) \
.kernel_name("bias_add") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("format", "required", "str", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "bias", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.is_dynamic_format(True) \
.dtype_format(DataType.None_None, DataType.None_None, DataType.None_None) \
.get_op_info()
@op_info_register(bias_add_grad_op_info)
def _bias_add_ds_tbe():
"""BiasAdd TBE register"""
return

View File

@ -0,0 +1,52 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BiasAddGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
bias_add_grad_op_info = TBERegOp("BiasAddGrad") \
.fusion_type("COMMREDUCE") \
.async_flag(False) \
.binfile_name("bias_add_grad.so") \
.compute_cost(10) \
.kernel_name("bias_add_grad") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("format", "required", "str", "all") \
.input(0, "output_backprop", False, "required", "all") \
.output(0, "output", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_NHWC) \
.dtype_format(DataType.F32_FracNZ, DataType.F32_NHWC) \
.dtype_format(DataType.F16_Default, DataType.F16_NHWC) \
.dtype_format(DataType.F32_Default, DataType.F32_NHWC) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_Default) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_Default) \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_NHWC) \
.dtype_format(DataType.F32_NDC1HWC0, DataType.F32_NHWC) \
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_Default) \
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_Default) \
.dtype_format(DataType.F16_FRACTAL_Z_3D, DataType.F16_NHWC) \
.dtype_format(DataType.F32_FRACTAL_Z_3D, DataType.F32_NHWC) \
.get_op_info()
@op_info_register(bias_add_grad_op_info)
def _bias_add_grad_ds_tbe():
"""BiasAddGrad TBE register"""
return

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""BroadcastTo op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
broadcast_to_op_info = TBERegOp("DynamicBroadcastTo") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("broadcast_to.so") \
.compute_cost(10) \
.kernel_name("broadcast_to") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "x", False, "required", "all") \
.input(1, "shape", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.get_op_info()
@op_info_register(broadcast_to_op_info)
def _broadcast_to_ds_tbe():
"""BroadcastTo TBE register"""
return

View File

@ -34,12 +34,8 @@ matmul_op_info = TBERegOp("MatMul") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_Default, DataType.I8_Default,
DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
DataType.F16_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F32_Default, DataType.I8_Default,
DataType.F32_FracNZ) \
.get_op_info()

View File

@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dynamic Tile op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
tile_op_info = TBERegOp("Tile") \
.fusion_type("ELEMWISE") \
.async_flag(False) \
.binfile_name("tile.so") \
.compute_cost(10) \
.kernel_name("tile") \
.partial_flag(True) \
.dynamic_shape(True) \
.input(0, "x1", False, "required", "all") \
.input(1, "multiples", False, "required", "all") \
.output(0, "y", False, "required", "all") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.F16_Default) \
.dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default) \
.dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.F16_Default) \
.get_op_info()
@op_info_register(tile_op_info)
def _tile_ds_tbe():
"""Tile TBE register"""
return

View File

@ -1403,3 +1403,29 @@ class SliceGetItem(Primitive):
if value == "step":
return slice_value.step
raise AttributeError("\'slice\' object has no attribute {}".format(value))
class DynamicBroadcastTo(Primitive):
"""
Broadcasts input tensor to a given shape.
Inputs:
- **input_x** (Tensor) - The input tensor. The data type should be one of the following types:
float16, float32, int32, int8, uint8.
The shape is :math:`(N,*)` where :math:`*` means,any number of additional dimensions.
- **shape** (Tensor): The target shape to broadcast.
Outputs:
Tensor, with the given `shape` and the same data type as `input_x`.
Raises:
ValueError: if the target and input shapes are incompatible.
Supported Platforms:
``Ascend``
"""
@prim_attr_register
def __init__(self):
"""Initialize DynamicBroadcastTo"""
self.init_prim_io_names(inputs=['x', 'shape'], outputs=['y'])

View File

@ -521,7 +521,7 @@ class Reshape(PrimitiveWithInfer):
neg_index = i
else:
dim_prod *= shp_i
arr_prod = np.prod(x_shp)
if -1 in x_shp:
if 'max_shape' in x:
x_max_shape = x['max_shape']
@ -545,6 +545,7 @@ class Reshape(PrimitiveWithInfer):
'max_shape': tuple(max_shape),
'min_shape': tuple(min_shape)}
else:
arr_prod = np.prod(x_shp)
if dim_prod <= 0:
raise ValueError(f"For '{self.name}', the shape of 'input_x' is {x_shp}, "
f"the value of 'input_shape' is {shape_v}. "

View File

@ -451,9 +451,8 @@ class _Reduce(PrimitiveWithInfer):
axis_shape = axis_shape_list[0]
if axis_shape == -1 and not self.keep_dims:
out_shape = np.array([-2]).tolist()
output_min_shape = np.ones_like(input_shp).tolist()
output_max_shape = max_v * np.ones_like(input_shp)
output_max_shape = output_max_shape.tolist()
output_min_shape = input_x['min_shape']
output_max_shape = input_x['max_shape']
elif not self.keep_dims:
out_shape = -1 * np.ones_like(input_shp[:-axis_shape])
out_shape = out_shape.tolist()
@ -467,12 +466,12 @@ class _Reduce(PrimitiveWithInfer):
output_max_shape = max_v * np.ones_like(input_shp)
output_max_shape = output_max_shape.tolist()
else:
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
output_max_shape = _infer_shape_reduce(input_x['max_shape'], axis_v, self.keep_dims, self.name)
output_min_shape = _infer_shape_reduce(input_x['min_shape'], axis_v, self.keep_dims, self.name)
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
else:
if axis_v is None:
raise ValueError(f"For {self.name}, the 'axis' cannot be None.")
raise ValueError(f"For {self.name}, axis must be const, its value cannot be None.")
out_shape = _infer_shape_reduce(input_shp, axis_v, self.keep_dims, self.name)
output_max_shape = out_shape
output_min_shape = out_shape

View File

@ -160,8 +160,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', boost_leve
validator.check_value_type('network', network, nn.Cell)
validator.check_value_type('optimizer', optimizer, (nn.Optimizer, boost.FreezeOpt))
if not isinstance(level, str):
raise TypeError("The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], \
but got type {}.".format(type(level)))
raise TypeError(f"The argument `level` must be a string in ['O0', 'O2', 'O3', 'auto'], "
f"but got type {str(type(level))}.")
validator.check('level', level, "", ['O0', 'O2', 'O3', 'auto'], Rel.IN)
validator.check('boost_level', boost_level, "", ['O0', 'O1', 'O2'], Rel.IN)