forked from mindspore-Ecosystem/mindspore
!14985 fix code review alarms
From: @shibeiji Reviewed-by: @oacjiewen,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
12fe7eccc4
|
@ -28,7 +28,7 @@ size_t get_element_num(const std::vector<size_t> &shape) {
|
|||
}
|
||||
|
||||
template <typename T, typename I>
|
||||
void CopyTask(size_t cur, std::vector<size_t> *pos, T *input, I *index, const int &dim, T *output,
|
||||
void CopyTask(size_t cur, std::vector<size_t> *pos, T *input, const I *index, const int &dim, T *output,
|
||||
const std::vector<size_t> &output_shape, const std::vector<size_t> &out_cargo_size,
|
||||
const std::vector<size_t> &input_cargo_size, bool reverse) {
|
||||
for (size_t i = 0; i < output_shape[cur]; ++i) {
|
||||
|
@ -65,7 +65,6 @@ template <typename T, typename I>
|
|||
void GatherDCPUKernel<T, I>::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
index_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 2);
|
||||
|
||||
if (input_shape_.size() != index_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid shape size, shape size of input: " << input_shape_.size()
|
||||
<< ", and index: " << index_shape_.size() << " should be equal";
|
||||
|
@ -81,7 +80,6 @@ bool GatherDCPUKernel<T, I>::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
size_t index_size = get_element_num(index_shape_) * sizeof(I);
|
||||
size_t dim_size = sizeof(int);
|
||||
size_t output_size = get_element_num(output_shape_) * sizeof(T);
|
||||
|
||||
if (inputs[0]->size != input_size || inputs[1]->size != dim_size || inputs[2]->size != index_size ||
|
||||
outputs[0]->size != output_size) {
|
||||
MS_LOG(EXCEPTION) << "invalid input or output data size!";
|
||||
|
@ -92,7 +90,6 @@ bool GatherDCPUKernel<T, I>::Launch(const std::vector<kernel::AddressPtr> &input
|
|||
auto index = reinterpret_cast<I *>(inputs[2]->addr);
|
||||
auto output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
int32_t input_rank = SizeToInt(input_shape_.size());
|
||||
|
||||
if (dim[0] >= input_rank || dim[0] < -input_rank) {
|
||||
MS_LOG(EXCEPTION) << "The value of 'dim' should be in [" << -input_rank << ", " << input_rank
|
||||
<< "], but got: " << dim[0];
|
||||
|
|
|
@ -37,7 +37,6 @@ class GatherDCPUKernel : public CPUKernel {
|
|||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> index_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
int32_t axis_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL_T_S(GatherD,
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
template <typename T>
|
||||
void MinimumCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
|
@ -147,7 +146,7 @@ void MinimumCPUKernel<T>::InitTensorBroadcastShape() {
|
|||
}
|
||||
}
|
||||
|
||||
// Broadcast comparation
|
||||
// Broadcast comparison
|
||||
template <typename T>
|
||||
size_t MinimumCPUKernel<T>::Index(const size_t &index, const size_t &dim) {
|
||||
return dim == 1 ? 0 : index;
|
||||
|
@ -216,6 +215,5 @@ void MinimumCPUKernel<T>::BroadcastArithTensors(const T *input_x, const T *input
|
|||
output[i] = MinimumFunc(input_x[i], input_y[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -115,8 +115,8 @@ void MinimumGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, c
|
|||
|
||||
size_t x_tensor_len = GetTensorLen(x_shape_);
|
||||
size_t y_tensor_len = GetTensorLen(y_shape_);
|
||||
memset(dx_addr, 0, x_tensor_len * sizeof(T));
|
||||
memset(dy_addr, 0, y_tensor_len * sizeof(T));
|
||||
memset_s(dx_addr, x_tensor_len * sizeof(T), 0x00, x_tensor_len * sizeof(T));
|
||||
memset_s(dy_addr, y_tensor_len * sizeof(T), 0x00, y_tensor_len * sizeof(T));
|
||||
|
||||
std::vector<size_t> x_shape(dout_shape.size(), 1);
|
||||
std::vector<size_t> y_shape(dout_shape.size(), 1);
|
||||
|
|
|
@ -276,10 +276,6 @@ class COCOHP(ds.Dataset):
|
|||
inp = cv2.warpAffine(img, trans_input, (self.data_opt.input_res[0], self.data_opt.input_res[1]),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
|
||||
# caution: image normalization and transpose to nchw will both be done on device
|
||||
# inp = (inp.astype(np.float32) / 255. - self.data_opt.mean) / self.data_opt.std
|
||||
# inp = inp.transpose(2, 0, 1)
|
||||
|
||||
assert self.data_opt.output_res[0] == self.data_opt.output_res[1]
|
||||
output_res = self.data_opt.output_res[0]
|
||||
num_joints = self.data_opt.num_joints
|
||||
|
|
|
@ -252,7 +252,6 @@ class DeformConv2d(nn.Cell):
|
|||
self.expand_dims(g_rt, 1) * x_q_rt)
|
||||
|
||||
if self.modulation:
|
||||
# modulation (b, 1, h, w, N)
|
||||
m = self.sigmoid(self.m_conv(x))
|
||||
m = self.transpose(m, self.perm_list)
|
||||
m = self.expand_dims(m, 1)
|
||||
|
|
|
@ -56,7 +56,6 @@ def merge_outputs(detections, soft_nms=True):
|
|||
def convert_eval_format(detections, img_id):
|
||||
"""convert detection to annotation json format"""
|
||||
# detections. scores: (b, K); bboxes: (b, K, 4); kps: (b, K, J * 2); clses: (b, K)
|
||||
# only batch_size = 1 is supported
|
||||
detections = np.array(detections).reshape((-1, 39))
|
||||
pred_anno = {"images": [], "annotations": []}
|
||||
num_objs, _ = detections.shape
|
||||
|
|
|
@ -124,13 +124,26 @@ def visual_image(img, annos, save_path, ratio=None, height=None, width=None, nam
|
|||
h, w = img.shape[0], img.shape[1]
|
||||
num_objects = len(annos)
|
||||
num = 0
|
||||
|
||||
def define_color(pair):
|
||||
"""define line color"""
|
||||
left_part = [0, 1, 3, 5, 7, 9, 11, 13, 15]
|
||||
right_part = [0, 2, 4, 6, 8, 10, 12, 14, 16]
|
||||
if pair[0] in left_part and pair[1] in left_part:
|
||||
color = (255, 0, 0)
|
||||
elif pair[0] in right_part and pair[1] in right_part:
|
||||
color = (0, 0, 255)
|
||||
else:
|
||||
color = (139, 0, 255)
|
||||
return color
|
||||
|
||||
def visible(a, w, h):
|
||||
return a[0] >= 0 and a[0] < w and a[1] >= 0 and a[1] < h
|
||||
|
||||
for i in range(num_objects):
|
||||
ann = annos[i]
|
||||
bbox = coco_box_to_bbox(ann['bbox'])
|
||||
if "score" in ann:
|
||||
score = ann["score"]
|
||||
if score < score_threshold and num != 0:
|
||||
continue
|
||||
if "score" in ann and (ann["score"] >= score_threshold or num == 0):
|
||||
num += 1
|
||||
txt = ("p" + "{:.2f}".format(ann["score"]))
|
||||
cv2.putText(img, txt, (bbox[0], bbox[1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
|
||||
|
@ -143,33 +156,19 @@ def visual_image(img, annos, save_path, ratio=None, height=None, width=None, nam
|
|||
keypoints = ann["keypoints"]
|
||||
keypoints = np.array(keypoints, dtype=np.int32).reshape(_NUM_JOINTS, 3).tolist()
|
||||
|
||||
left_part = [0, 1, 3, 5, 7, 9, 11, 13, 15]
|
||||
right_part = [0, 2, 4, 6, 8, 10, 12, 14, 16]
|
||||
for pair in data_cfg.edges:
|
||||
partA = pair[0]
|
||||
partB = pair[1]
|
||||
if partA in left_part and partB in left_part:
|
||||
color = (255, 0, 0)
|
||||
elif partA in right_part and partB in right_part:
|
||||
color = (0, 0, 255)
|
||||
else:
|
||||
color = (139, 0, 255)
|
||||
color = define_color(pair)
|
||||
p_a = tuple(keypoints[partA][:2])
|
||||
p_b = tuple(keypoints[partB][:2])
|
||||
mask_a = keypoints[partA][2]
|
||||
mask_b = keypoints[partB][2]
|
||||
if (p_a[0] >= 0 and p_a[0] < w and p_a[1] >= 0 and p_a[1] < h and
|
||||
p_b[0] >= 0 and p_b[0] < w and p_b[1] >= 0 and p_b[1] < h and
|
||||
mask_a * mask_b > 0):
|
||||
if (visible(p_a, w, h) and visible(p_b, w, h) and mask_a * mask_b > 0):
|
||||
cv2.line(img, p_a, p_b, color, 2)
|
||||
cv2.circle(img, p_a, 3, color, thickness=-1, lineType=cv2.FILLED)
|
||||
cv2.circle(img, p_b, 3, color, thickness=-1, lineType=cv2.FILLED)
|
||||
if annos and "image_id" in annos[0]:
|
||||
img_id = annos[0]["image_id"]
|
||||
else:
|
||||
img_id = random.randint(0, 9999999)
|
||||
if name is None:
|
||||
image_name = "cv_image_" + str(img_id) + ".png"
|
||||
else:
|
||||
image_name = "cv_image_" + str(img_id) + name + ".png"
|
||||
|
||||
img_id = annos[0]["image_id"] if annos and "image_id" in annos[0] else random.randint(0, 9999999)
|
||||
image_name = "cv_image_" + str(img_id) + ".png" if name is None else "cv_image_" + str(img_id) + name + ".png"
|
||||
cv2.imwrite("{}/{}".format(save_path, image_name), img)
|
||||
|
|
Loading…
Reference in New Issue