Merge pull request !18971 from 陈志鹏/czp_midas
This commit is contained in:
i-robot 2021-07-16 01:43:01 +00:00 committed by Gitee
commit ac3c031f22
19 changed files with 2620 additions and 0 deletions

View File

@ -0,0 +1,270 @@
# 目录
<!-- TOC -->
- [midas描述](#midas描述)
- [模型架构](#模型架构)
- [数据集](#数据集)
- [特性](#特性)
- [混合精度](#混合精度)
- [环境要求](#环境要求)
- [快速入门](#快速入门)
- [脚本说明](#脚本说明)
- [脚本及样例代码](#脚本及样例代码)
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [评估过程](#评估过程)
- [模型描述](#模型描述)
- [性能](#性能)
- [评估性能](#评估性能)
- [随机情况说明](#随机情况说明)
- [ModelZoo主页](#ModelZoo主页)
<!-- /TOC -->
# midas描述
## 概述
Midas全称为Towards Robust Monocular Depth Estimation:Mixing Datasets for Zero-shot Cross-dataset Transfer,用来估计图片的深度信息使用了五个不同的训练数据集五个训练数据集混合策略为多目标优化其中包括作者自制的3D电影数据集使用6个和训练集完全不同的测试集进行验证。本次只使用ReDWeb数据集进行训练。
Midas模型网络具体细节可参考[Towards Robust Monocular Depth Estimation:Mixing Datasets for
Zero-shot Cross-dataset Transfer](https://arxiv.org/pdf/1907.01341v3.pdf)Midas模型网络的Pytorch版本实现可参考(<https://github.com/intel-isl/MiDaS>)。
## 论文
1. [论文:](https://arxiv.org/pdf/1907.01341v3.pdf) Ranftl*, Katrin Lasinger*, David Hafner, Konrad Schindler, and Vladlen Koltun.
# 模型架构
Midas的总体网络架构如下
[链接](https://arxiv.org/pdf/1907.01341v3.pdf)
# 数据集
使用的数据集:[ReDWeb](<https://www.paperswithcode.com/dataset/redweb>)
- 数据集大小:
- 训练集292M, 3600个图像
- 数据格式:
- 原图imgsJPG
- 深度图RDsPNG
# 特性
## 混合精度
采用[混合精度](https://www.mindspore.cn/tutorial/training/en/master/advanced_use/enable_mixed_precision.html)的训练方法使用支持单精度和半精度数据来提高深度学习神经网络的训练速度,同时保持单精度训练所能达到的网络精度。混合精度训练提高计算速度、减少内存使用的同时,支持在特定硬件上训练更大的模型或实现更大批次的训练。
以FP16算子为例如果输入数据类型为FP32MindSpore后台会自动降低精度来处理数据。用户可打开INFO日志搜索“reduce precision”查看精度降低的算子。
# 环境要求
- 硬件(Ascend)
- 准备Ascend处理器搭建硬件环境.
- 框架
- [MindSpore](https://www.mindspore.cn/install/en)
- 如需查看详情,请参见如下资源:
- [MindSpore教程](https://www.mindspore.cn/tutorial/training/zh-CN/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/zh-CN/master/index.html)
# 快速入门
通过官方网站安装MindSpore后您可以按照如下步骤进行训练和评估
- 预训练模型
当开始训练之前需要获取mindspore图像网络预训练模型使用在resnext101上训练出来的预训练模型[resnext101_32x8d_wsl](<https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth>).
- 数据集准备
midas网络模型使用ReDWeb数据集用于训练,使用DIW,ETH3D,Sintel,KITTI,NYU,TUM数据集进行推理,数据集可通过[ReDWeb](<https://www.paperswithcode.com/dataset/redweb>),[DIW](https://github.com/princeton-vl/relative_depth),[ETH3D](https://www.eth3d.net/),[Sintel](http://sintel.is.tue.mpg.de/),[Kitti](http://www.cvlibs.net/datasets/kitti/raw_data.php),[NYU](https://cs.nyu.edu/~silberman/datasets/),[TUM](https://vision.in.tum.de/data/datasets/rgbd-dataset)官方网站下载使用。
- Ascend处理器环境运行
```text
# 分布式训练
用法sh run_distribute_train.sh 8
# 单机训练
用法sh run_standalone_train.sh [DEVICE_ID]
# 运行评估示例
用法sh run_eval.sh [DEVICE_ID]
```
# 脚本说明
## 脚本及样例代码
```shell
└──midas
├── README.md
├── scripts
├── run_distribute_train.sh # 启动Ascend分布式训练8卡
├── run_eval.sh # 启动Ascend评估
└── run_standalone_train.sh # 启动Ascend单机训练单卡
├── src
├── utils
├── loadImgDepth.py # 读取数据集
└── transforms.py # 图像处理转换
├─config.py # 训练配置
├── cunstom_op.py # 网络操作
├── blocks_ms.py # 网络组件
├── loss.py # 损失函数定义
├── util.py # 读取图片工具
└── midas_net.py # 主干网络定义
├── config.yaml # 训练参数配置文件
├── midas_eval.py # 评估网络
├── midas_export.py # 模型导出
├── midas_run.py # 模型运行
└── midas_train.py # 训练网络
```
## 脚本参数
在config.yaml中配置相关参数。
- 配置训练相关参数:
```python
device_target: 'Ascend' #服务器的类型,有CPU,GPU,Ascend
device_id: 7 #卡的编号
run_distribute: False #是否进行分布式并行训练
is_modelarts: False #是否在云上训练
no_backbone_params_lr: 0.00001 #1e-5
no_backbone_params_end_lr: 0.00000001 #1e-8
backbone_params_lr: 0.0001 #1e-4
backbone_params_end_lr: 0.0000001 #1e-7
power: 0.5 #PolynomialDecayLR种控制lr参数
epoch_size: 400 #总epoch
batch_size: 8 #batch_size
lr_decay: False #是否采用动态学习率
train_data_dir: '/midas/' #训练集根路径
width_per_group: 8 #网络参数
groups: 32
in_channels: 64
features: 256
layers: [3, 4, 23, 3]
img_width: 384 #输入网络的图片宽度
img_height: 384 #输入网络的图片高度
nm_img_mean: [0.485, 0.456, 0.406] #图片预处理正则化参数
nm_img_std: [0.229, 0.224, 0.225]
resize_target: True #如果为True,修改image, mask, target的尺寸否则只修改image尺寸
keep_aspect_ratio: False #保持纵横比
ensure_multiple_of: 32 #图片尺寸为32倍数
resize_method: "upper_bound" #resize模式
```
- 配置验证相关参数:
```python
datapath_TUM: '/data/TUM' #TUM数据集地址
datapath_Sintel: '/data/sintel/sintel-data' #Sintel数据集地址
datapath_ETH3D: '/data/ETH3D/ETH3D-data' #ETH3D数据集地址
datapath_Kitti: '/data/Kitti_raw_data' #Kitti数据集地址
datapath_DIW: '/data/DIW' #DIW数据集地址
datapath_NYU: ['/data/NYU/nyu.mat','/data/NYU/splits.mat'] #NYU数据集地址
ann_file: 'val.json' #存放推理结果的文件地址
ckpt_path: '/midas/ckpt/Midas_0-600_56_1.ckpt' #存放推理使用的ckpt地址
data_name: 'all' #需要推理的数据集名称,有 Sintel,Kitti,TUM,DIW,ETH3D,all
```
- 配置运行和导出模型相关参数:
```python
input_path: '/midas/input' #输入图片的路径
output_path: '/midas/output' #模型输出图片的路径
model_weights: '/ckpt/Midas_0-600_56_1.ckpt'#模型参数路径
file_format: "AIR" #AIR/MIDIR
```
## 训练过程
### 用法
#### Ascend处理器环境运行
```text
# 分布式训练
用法sh run_distribute_train.sh 8
# 单机训练
用法sh run_standalone_train.sh [DEVICE_ID]
# 运行评估示例
用法sh run_eval.sh [DEVICE_ID] [DATA_NAME]
```
### 结果
- 使用ReDWeb数据集训练midas
```text
分布式训练结果8P
epoch: 1 step: 56, loss is 579.5216
epoch time: 1497998.993 ms, per step time: 26749.982 ms
epoch: 2 step: 56, loss is 773.3644
epoch time: 74565.443 ms, per step time: 1331.526 ms
epoch: 3 step: 56, loss is 270.76688
epoch time: 63373.872 ms, per step time: 1131.676 ms
epoch: 4 step: 56, loss is 319.71643
epoch time: 61290.421 ms, per step time: 1094.472 ms
...
epoch time: 58586.128 ms, per step time: 1046.181 ms
epoch: 396 step: 56, loss is 8.707727
epoch time: 63755.860 ms, per step time: 1138.498 ms
epoch: 397 step: 56, loss is 8.139318
epoch time: 47222.517 ms, per step time: 843.259 ms
epoch: 398 step: 56, loss is 10.746628
epoch time: 23364.224 ms, per step time: 417.218 ms
epoch: 399 step: 56, loss is 7.4859796
epoch time: 24304.195 ms, per step time: 434.003 ms
epoch: 400 step: 56, loss is 8.2024975
epoch time: 23696.833 ms, per step time: 423.158 ms
```
## 评估过程
### 用法
#### Ascend处理器环境运行
可通过改变config.yaml文件中的"data_name"进行对应的数据集推理,默认为全部数据集。
```bash
# 评估
sh run_eval.sh [DEVICE_ID] [DATA_NAME]
```
### 结果
打开val.json查看推理的结果,如下所示:
```text
{"Kitti": 24.222 "Sintel":0.323 "TUM":15.08 "ETH3D":0.158 "NYU":20.499 }
```
# 模型描述
## 性能
### 评估性能
#### ReDWeb上性能参数
| Parameters | Ascend 910 |
| ------------------- | --------------------------- |
| 模型版本 | Midas |
| 资源 | Ascend 910CPU2.60GHz192核内存755G |
| 上传日期 | 2021-06-24 |
| MindSpore版本 | 1.2.0 |
| 数据集 | ReDWeb |
| 预训练模型 | ResNeXt_101_WSL |
| 训练参数 | epoch=400, batch_size=8, no_backbone_lr=1e-4,backbone_lr=1e-5 |
| 优化器 | Adam |
| 损失函数 | 自定义损失函数 |
| 速度 | 8pc: 423.4 ms/step |
| 训练性能 | "Kitti": 24.222 "Sintel":0.323 "TUM":15.08 "ETH3D":0.158 "NYU":20.499 |
# ModelZoo主页
请浏览官网[主页](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)。

View File

@ -0,0 +1,47 @@
device_target: 'Ascend' #推理使用服务器的类型,有CPU,GPU,Ascend
device_id: 7 #推理使用的卡的编号
#train
run_distribute: False
is_modelarts: False
no_backbone_params_lr: 0.0001
no_backbone_params_end_lr: 0.0000001
backbone_params_lr: 0.00001
backbone_params_end_lr: 0.00000001
power: 0.5
epoch_size: 400
batch_size: 8
lr_decay: False
train_data_dir: '/midas/'
width_per_group: 8
groups: 32
in_channels: 64
features: 256
layers: [3, 4, 23, 3]
img_width: 384
img_height: 384
nm_img_mean: [0.485, 0.456, 0.406]
nm_img_std: [0.229, 0.224, 0.225]
resize_target: True
keep_aspect_ratio: False
ensure_multiple_of: 32
resize_method: "upper_bound"
#run&export
input_path: '/input'
output_path: '/output'
model_weights: '/ckpt/Midas_0-600_56_1.ckpt'
file_format: "AIR" #AIR/MIDIR
#eval
datapath_TUM: '/data/TUM' #TUM数据集地址
datapath_Sintel: '/data/sintel/sintel-data' #Sintel数据集地址
datapath_ETH3D: '/data/ETH3D/ETH3D-data' #ETH3D数据集地址
datapath_Kitti: '/data/Kitti_raw_data' #Kitti数据集地址
datapath_DIW: '/data/DIW' #DIW数据集地址
datapath_NYU: ['/data/NYU/nyu.mat','/data/NYU/splits.mat'] #NYU数据集地址
ann_file: 'val.json' #存放推理结果的文件地址
ckpt_path: '/midas/ckpt/Midas_0-600_56_1.ckpt' #存放推理使用的ckpt地址
data_name: 'all' #需要推理的数据集名称,有 Sintel,Kitti,TUM,DIW,ETH3D,all

View File

@ -0,0 +1,435 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""eval midas."""
import glob
import csv
import os
import struct
import json
import numpy as np
from mindspore import Tensor
from mindspore import context
from mindspore import dtype as mstype
from mindspore.train import serialization
import mindspore.ops as ops
from src.util import depth_read_kitti, depth_read_sintel, BadPixelMetric
from src.midas_net import MidasNet
from src.config import config
from src.utils import transforms
from scipy.io import loadmat
import cv2
from PIL import Image
import h5py
def eval_Kitti(data_path, net):
"""eval Kitti."""
img_input_1 = transforms.Resize(config.img_width,
config.img_height,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_CUBIC)
img_input_2 = transforms.NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
img_input_3 = transforms.PrepareForNet()
metric = BadPixelMetric(1.25, 80, 'KITTI')
loss_sum = 0
sample = {}
image_path = glob.glob(os.path.join(data_path, '*', 'image', '*.png'))
num = 0
for file_name in image_path:
num += 1
print(f"processing: {num} / {len(image_path)}")
image = np.array(Image.open(file_name)).astype(float) # (436,1024,3)
image = image / 255
print(file_name)
all_path = file_name.split('/')
depth_path_name = all_path[-1].split('.')[0]
depth = depth_read_kitti(os.path.join(data_path, all_path[-3], 'depth', depth_path_name + '.png')) # (436,1024)
mask = (depth > 0) & (depth < 80)
sample['image'] = image
sample["depth"] = depth
sample["mask"] = mask
sample = img_input_1(sample)
sample = img_input_2(sample)
sample = img_input_3(sample)
# print('transform later', sample['image'].shape)
sample['image'] = Tensor([sample["image"]], mstype.float32)
sample['depth'] = Tensor([sample["depth"]], mstype.float32)
sample['mask'] = Tensor([sample["mask"]], mstype.int32)
print(sample['image'].shape, sample['depth'].shape)
prediction = net(sample['image'])
mask = sample['mask'].asnumpy()
depth = sample['depth'].asnumpy()
expand_dims = ops.ExpandDims()
prediction = expand_dims(prediction, 0)
resize_bilinear = ops.ResizeBilinear(mask.shape[1:])
prediction = resize_bilinear(prediction)
prediction = np.squeeze(prediction.asnumpy())
loss = metric(prediction, depth, mask)
print('loss is ', loss)
loss_sum += loss
print(f"Kitti bad pixel: {loss_sum / num:.3f}")
return loss_sum / num
def eval_TUM(datapath, net):
"""eval TUM."""
img_input_1 = transforms.Resize(config.img_width,
config.img_height,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC)
img_input_2 = transforms.NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
img_input_3 = transforms.PrepareForNet()
# get data
metric = BadPixelMetric(1.25, 10, 'TUM')
loss_sum = 0
sample = {}
file_path = glob.glob(os.path.join(datapath, '*_person', 'associate.txt'))
num = 0
for ind in file_path:
all_path = ind.split('/')
for line in open(ind):
num += 1
print(f"processing: {num}")
data = line.split('\n')[0].split(' ')
image_path = os.path.join(datapath, all_path[-2], data[0]) # (480,640,3)
depth_path = os.path.join(datapath, all_path[-2], data[1]) # (480,640,3)
image = cv2.imread(image_path) / 255
depth = cv2.imread(depth_path)[:, :, 0] / 5000
mask = (depth > 0) & (depth < 10)
print('mask is ', np.unique(mask))
sample['image'] = image
sample["depth"] = depth
sample["mask"] = mask
sample = img_input_1(sample)
sample = img_input_2(sample)
sample = img_input_3(sample)
sample['image'] = Tensor([sample["image"]], mstype.float32)
sample['depth'] = Tensor([sample["depth"]], mstype.float32)
sample['mask'] = Tensor([sample["mask"]], mstype.int32)
print(sample['image'].shape, sample['depth'].shape)
prediction = net(sample['image'])
mask = sample['mask'].asnumpy()
depth = sample['depth'].asnumpy()
expand_dims = ops.ExpandDims()
prediction = expand_dims(prediction, 0)
print(prediction.shape, mask.shape)
resize_bilinear = ops.ResizeBilinear(mask.shape[1:])
prediction = resize_bilinear(prediction)
prediction = np.squeeze(prediction.asnumpy())
loss = metric(prediction, depth, mask)
print('loss is ', loss)
loss_sum += loss
print(f"TUM bad pixel: {loss_sum / num:.2f}")
return loss_sum / num
def eval_Sintel(datapath, net):
"""eval Sintel."""
img_input_1 = transforms.Resize(config.img_width,
config.img_height,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC)
img_input_2 = transforms.NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
img_input_3 = transforms.PrepareForNet()
# get data
metric = BadPixelMetric(1.25, 72, 'sintel')
loss_sum = 0
sample = {}
image_path = glob.glob(os.path.join(datapath, 'final_left', '*', '*.png'))
num = 0
for file_name in image_path:
num += 1
print(f"processing: {num} / {len(image_path)}")
image = np.array(Image.open(file_name)).astype(float) # (436,1024,3)
image = image / 255
print(file_name)
all_path = file_name.split('/')
depth_path_name = all_path[-1].split('.')[0]
depth = depth_read_sintel(os.path.join(datapath, 'depth', all_path[-2], depth_path_name + '.dpt')) # (436,1024)
mask1 = np.array(Image.open(os.path.join(datapath, 'occlusions', all_path[-2], all_path[-1]))).astype(int)
mask1 = mask1 / 255
mask = (mask1 == 1) & (depth > 0) & (depth < 72)
sample['image'] = image
sample["depth"] = depth
sample["mask"] = mask
sample = img_input_1(sample)
sample = img_input_2(sample)
sample = img_input_3(sample)
sample['image'] = Tensor([sample["image"]], mstype.float32)
sample['depth'] = Tensor([sample["depth"]], mstype.float32)
sample['mask'] = Tensor([sample["mask"]], mstype.int32)
print(sample['image'].shape, sample['depth'].shape)
prediction = net(sample['image'])
mask = sample['mask'].asnumpy()
depth = sample['depth'].asnumpy()
expand_dims = ops.ExpandDims()
prediction = expand_dims(prediction, 0)
resize_bilinear = ops.ResizeBilinear(mask.shape[1:])
prediction = resize_bilinear(prediction)
prediction = np.squeeze(prediction.asnumpy())
loss = metric(prediction, depth, mask)
print('loss is ', loss)
loss_sum += loss
print(f"sintel bad pixel: {loss_sum / len(image_path):.3f}")
return loss_sum / len(image_path)
def eval_ETH3D(datapath, net):
"""eval ETH3D."""
img_input_1 = transforms.Resize(config.img_width,
config.img_height,
resize_target=True,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC)
img_input_2 = transforms.NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
img_input_3 = transforms.PrepareForNet()
metric = BadPixelMetric(1.25, 72, 'ETH3D')
loss_sum = 0
sample = {}
image_path = glob.glob(os.path.join(datapath, '*', 'images', 'dslr_images', '*.JPG'))
num = 0
for file_name in image_path:
num += 1
print(f"processing: {num} / {len(image_path)}")
image = cv2.imread(file_name) / 255
all_path = file_name.split('/')
depth_path = os.path.join(datapath, all_path[-4], "ground_truth_depth", 'dslr_images', all_path[-1])
depth = []
with open(depth_path, 'rb') as f:
data = f.read(4)
while data:
depth.append(struct.unpack('f', data))
data = f.read(4)
depth = np.reshape(np.array(depth), (4032, -1))
mask = (depth > 0) & (depth < 72)
sample['image'] = image
sample["depth"] = depth
sample["mask"] = mask
sample = img_input_1(sample)
sample = img_input_2(sample)
sample = img_input_3(sample)
sample['image'] = Tensor([sample["image"]], mstype.float32)
sample['depth'] = Tensor([sample["depth"]], mstype.float32)
sample['mask'] = Tensor([sample["mask"]], mstype.int32)
prediction = net(sample['image'])
mask = sample['mask'].asnumpy()
depth = sample['depth'].asnumpy()
expand_dims = ops.ExpandDims()
prediction = expand_dims(prediction, 0)
resize_bilinear = ops.ResizeBilinear(mask.shape[1:])
prediction = resize_bilinear(prediction)
prediction = np.squeeze(prediction.asnumpy())
loss = metric(prediction, depth, mask)
print('loss is ', loss)
loss_sum += loss
print(f"ETH3D bad pixel: {loss_sum / num:.3f}")
return loss_sum / num
def eval_DIW(datapath, net):
"""eval DIW."""
img_input_1 = transforms.Resize(config.img_width,
config.img_height,
resize_target=True,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC)
img_input_2 = transforms.NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
img_input_3 = transforms.PrepareForNet()
loss_sum = 0
num = 0
sample = {}
file_path = os.path.join(datapath, 'DIW_Annotations', 'DIW_test.csv')
with open(file_path) as f:
reader = list(csv.reader(f))
for (i, row) in enumerate(reader):
if i % 2 == 0:
path = row[0].split('/')
sample['file_name'] = os.path.join(datapath, path[-2], path[-1])
sample['image'] = cv2.imread(sample['file_name']) / 255
else:
sample['depths'] = row
if not os.path.exists(sample['file_name']):
continue
num += 1 # 图片个数+1
print(f"processing: {num}")
sample = img_input_1(sample)
sample = img_input_2(sample)
sample = img_input_3(sample)
sample['image'] = Tensor([sample["image"]], mstype.float32)
prediction = net(sample['image'])
shape_w, shape_h = [int(sample['depths'][-2]), int(sample['depths'][-1])]
expand_dims = ops.ExpandDims()
prediction = expand_dims(prediction, 0)
resize_bilinear = ops.ResizeBilinear((shape_h, shape_w))
prediction = resize_bilinear(prediction)
prediction = np.squeeze(prediction.asnumpy())
pixtel_a = prediction[int(sample['depths'][0]) - 1][int(sample['depths'][1]) - 1]
pixtel_b = prediction[int(sample['depths'][2]) - 1][int(sample['depths'][3]) - 1]
if pixtel_a > pixtel_b:
if sample['depths'][4] == '>':
loss_sum += 1
if pixtel_a < pixtel_b:
if sample['depths'][4] == '<':
loss_sum += 1
print(f"bad pixel: {(num - loss_sum) / num:.4f}")
return (num - loss_sum) / num
def eval_NYU(datamat, splitmat, net):
"""eval NYU."""
img_input_1 = Resize(config.img_width,
config.img_height,
resize_target=None,
keep_aspect_ratio=True,
ensure_multiple_of=32,
resize_method="upper_bound",
image_interpolation_method=cv2.INTER_CUBIC)
img_input_2 = NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
img_input_3 = PrepareForNet()
# get data
metric = BadPixelMetric(1.25, 10, 'NYU')
loss_sum = 0
sample = {}
mat = loadmat(splitmat)
indices = [ind[0] - 1 for ind in mat["testNdxs"]]
num = 0
with h5py.File(datamat, "r") as f:
for ind in indices:
num += 1
print(num)
image = np.swapaxes(f["images"][ind], 0, 2)
image = image / 255
depth = np.swapaxes(f["rawDepths"][ind], 0, 1)
mask = (depth > 0) & (depth < 10)
# mask = mask1
sample['image'] = image
sample["depth"] = depth
sample["mask"] = mask
sample = img_input_1(sample)
sample = img_input_2(sample)
sample = img_input_3(sample)
sample['image'] = Tensor([sample["image"]], mstype.float32)
sample['depth'] = Tensor([sample["depth"]], mstype.float32)
sample['mask'] = Tensor([sample["mask"]], mstype.int32)
print(sample['image'].shape, sample['depth'].shape)
prediction = net(sample['image'])
mask = sample['mask'].asnumpy()
depth = sample['depth'].asnumpy()
expand_dims = ops.ExpandDims()
prediction = expand_dims(prediction, 0)
resize_bilinear = ops.ResizeBilinear(mask.shape[1:])
prediction = resize_bilinear(prediction)
prediction = np.squeeze(prediction.asnumpy())
loss = metric(prediction, depth, mask)
print('loss is ', loss)
loss_sum += loss
print(f"bad pixel: {loss_sum / num:.3f}")
return loss_sum / num
def run_eval():
"""run."""
datapath_TUM = config.train_data_dir+config.datapath_TUM
datapath_Sintel = config.train_data_dir+config.datapath_Sintel
datapath_ETH3D = config.train_data_dir+config.datapath_ETH3D
datapath_Kitti = config.train_data_dir+config.datapath_Kitti
datapath_DIW = config.train_data_dir+config.datapath_DIW
datamat = config.train_data_dir+config.datapath_NYU[0]
splitmat = config.train_data_dir+config.datapath_NYU[1]
net = MidasNet()
param_dict = serialization.load_checkpoint(config.train_data_dir+config.ckpt_path)
serialization.load_param_into_net(net, param_dict)
results = {}
if config.data_name == 'Sintel' or config.data_name == "all":
result_sintel = eval_Sintel(datapath_Sintel, net)
results['Sintel'] = result_sintel
if config.data_name == 'Kitti' or config.data_name == "all":
result_kitti = eval_Kitti(datapath_Kitti, net)
results['Kitti'] = result_kitti
if config.data_name == 'TUM' or config.data_name == "all":
result_tum = eval_TUM(datapath_TUM, net)
results['TUM'] = result_tum
if config.data_name == 'DIW' or config.data_name == "all":
result_DIW = eval_DIW(datapath_DIW, net)
results['DIW'] = result_DIW
if config.data_name == 'ETH3D' or config.data_name == "all":
result_ETH3D = eval_ETH3D(datapath_ETH3D, net)
results['ETH3D'] = result_ETH3D
if config.data_name == 'NYU' or config.data_name == "all":
result_NYU = eval_NYU(datamat, splitmat, net)
results['NYU'] = result_NYU
print(results)
json.dump(results, open(config.ann_file, 'w'))
if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target, device_id=config.device_id)
run_eval()

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export midas."""
import numpy as np
from src.midas_net import MidasNet
from src.config import config
from mindspore import Tensor, export, context
from mindspore import dtype as mstype
from mindspore.train.serialization import load_checkpoint
def midas_export():
"""export midas."""
context.set_context(
mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=config.device_id)
net = MidasNet()
load_checkpoint(config.model_weights, net=net)
net.set_train(False)
input_data = Tensor(np.zeros([1, 3, config.img_width, config.img_height]), mstype.float32)
export(net, input_data, file_name='midas', file_format=config.file_format)
if __name__ == '__main__':
midas_export()

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""export midas."""
import glob
import os
import cv2
from mindspore import context
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.train import serialization
import mindspore.ops as ops
from src.utils import transforms
import src.util as util
from src.config import config
from src.midas_net import MidasNet
def export():
"""export."""
print("initialize")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=2, save_graphs=False)
net = MidasNet()
param_dict = serialization.load_checkpoint(config.model_weights)
serialization.load_param_into_net(net, param_dict)
img_input_1 = transforms.Resize(config.img_width,
config.img_height,
resize_target=config.resize_target,
keep_aspect_ratio=config.keep_aspect_ratio,
ensure_multiple_of=config.ensure_multiple_of,
resize_method=config.resize_method,
image_interpolation_method=cv2.INTER_CUBIC)
img_input_2 = transforms.NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
img_input_3 = transforms.PrepareForNet()
# get input
img_names = glob.glob(os.path.join(config.input_path, "*"))
num_images = len(img_names)
# create output folder
os.makedirs(config.output_path, exist_ok=True)
print("start processing")
expand_dims = ops.ExpandDims()
resize_bilinear = ops.ResizeBilinear
squeeze = ops.Squeeze()
for ind, img_name in enumerate(img_names):
print(" processing {} ({}/{})".format(img_name, ind + 1, num_images))
# input
img = util.read_image(img_name)
img_input = img_input_1({"image": img})
img_input = img_input_2(img_input)
img_input = img_input_3(img_input)["image"]
sample = Tensor(img_input, mstype.float32)
sample = expand_dims(sample, 0)
prediction = net(sample)
prediction = expand_dims(prediction, 1)
prediction = resize_bilinear((img.shape[:2]))(prediction)
prediction = squeeze(prediction).asnumpy()
# output
filename = os.path.join(
config.output_path, os.path.splitext(os.path.basename(img_name))[0]
)
util.write_depth(filename, prediction, bits=2)
print("finished")
if __name__ == "__main__":
export()

View File

@ -0,0 +1,156 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train midas."""
import os
import json
from mindspore import dtype as mstype
from mindspore import context
from mindspore import nn
from mindspore import Tensor
from mindspore.context import ParallelMode
import mindspore.dataset as ds
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.model import Model
from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
from mindspore.communication.management import init
from src.midas_net import MidasNet, Loss, NetwithCell
from src.utils import loadImgDepth
from src.config import config
def dynamic_lr(num_epoch_per_decay, total_epochs, steps_per_epoch, lr, end_lr):
"""dynamic learning rate generator"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
decay_steps = steps_per_epoch * num_epoch_per_decay
lr = nn.PolynomialDecayLR(lr, end_lr, decay_steps, 0.5)
for i in range(total_steps):
if i < decay_steps:
i = Tensor(i, mstype.int32)
lr_each_step.append(lr(i).asnumpy())
else:
lr_each_step.append(end_lr)
return lr_each_step
def train(mixdata_path):
"""train"""
epoch_number_total = config.epoch_size
batch_size = config.batch_size
if config.is_modelarts:
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
local_data_path = '/cache/data'
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, max_call_depth=10000)
context.set_context(device_id=device_id)
# define distributed local data path
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
local_data_path = os.path.join(local_data_path, str(device_id))
mixdata_path = os.path.join(local_data_path, mixdata_path)
load_path = os.path.join(local_data_path, 'midas_resnext_101_WSL.ckpt')
output_path = config.train_url
print('local_data_path:', local_data_path)
print('mixdata_path:', mixdata_path)
print('load_path:', load_path)
print('output_path:', output_path)
# data download
print('Download data.')
mox.file.copy_parallel(src_url=config.data_url, dst_url=local_data_path)
elif config.run_distribute:
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(device_id=device_id, mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
max_call_depth=10000)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
device_num=device_num,
parameter_broadcast=True
)
init()
local_data_path = config.train_data_dir + '/data'
mixdata_path = config.train_data_dir + '/data/mixdata.json'
load_path = config.train_data_dir + '/midas/ckpt/midas_resnext_101_WSL.ckpt'
else:
local_data_path = config.train_data_dir + '/data'
mixdata_path = config.train_data_dir + '/data/mixdata.json'
load_path = config.train_data_dir + '/midas/ckpt/midas_resnext_101_WSL.ckpt'
device_id = config.device_id
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id,
enable_auto_mixed_precision=True, max_call_depth=10000)
# load data
f = open(mixdata_path)
data_config = json.load(f)
img_paths = data_config['img']
# depth_paths = data_config['depth']
f.close()
mix_dataset = loadImgDepth.LoadImagesDepth(local_path=local_data_path, img_paths=img_paths)
if config.is_modelarts or config.run_distribute:
mix_dataset = ds.GeneratorDataset(mix_dataset, ['img', 'mask', 'depth'], shuffle=True, num_parallel_workers=8,
num_shards=device_num, shard_id=device_id)
else:
mix_dataset = ds.GeneratorDataset(mix_dataset, ['img', 'mask', 'depth'], shuffle=True)
mix_dataset = mix_dataset.batch(8, drop_remainder=True)
per_step_size = mix_dataset.get_dataset_size()
# define net_loss_opt
net = MidasNet()
net = net.set_train()
loss = Loss()
param_dict = load_checkpoint(load_path)
load_param_into_net(net, param_dict)
backbone_params = list(filter(lambda x: 'backbone' in x.name, net.trainable_params()))
no_backbone_params = list(filter(lambda x: 'backbone' not in x.name, net.trainable_params()))
# no_backbone_params_lr = Tensor(dynamic_lr(5, epoch_number_total, per_step_size, 1e-4, 1e-6), mstype.float32)
# backbone_params_lr = Tensor(dynamic_lr(5, epoch_number_total, per_step_size, 1e-5, 1e-7), mstype.float32)
if config.lr_decay:
group_params = [{'params': backbone_params,
'lr': nn.PolynomialDecayLR(config.backbone_params_lr
, config.backbone_params_end_lr,
epoch_number_total * per_step_size, config.power)},
{'params': no_backbone_params,
'lr': nn.PolynomialDecayLR(config.no_backbone_params_lr,
config.no_backbone_params_end_lr,
epoch_number_total * per_step_size, config.power)},
{'order_params': net.trainable_params()}]
else:
group_params = [{'params': backbone_params, 'lr': 1e-5},
{'params': no_backbone_params, 'lr': 1e-4},
{'order_params': net.trainable_params()}]
optim = nn.Adam(group_params)
netwithLoss = NetwithCell(net, loss)
midas_net = nn.TrainOneStepCell(netwithLoss, optim)
model = Model(midas_net)
# define callback
loss_cb = LossMonitor()
time_cb = TimeMonitor()
checkpointconfig = CheckpointConfig(saved_network=net)
if config.is_modelarts:
ckpoint_cb = ModelCheckpoint(prefix='Midas_{}'.format(device_id), directory=local_data_path + '/output/ckpt',
config=checkpointconfig)
else:
ckpoint_cb = ModelCheckpoint(prefix='Midas_{}'.format(device_id), directory='./ckpt/', config=checkpointconfig)
callbacks = [loss_cb, time_cb, ckpoint_cb]
# train
print("Starting Training:per_step_size={},batchsize={},epoch={}".format(per_step_size, batch_size,
epoch_number_total))
model.train(epoch_number_total, mix_dataset, callbacks=callbacks)
if config.is_modelarts:
mox.file.copy_parallel(local_data_path + "/output", output_path)
if __name__ == '__main__':
train(mixdata_path="mixdata.json")

View File

@ -0,0 +1,10 @@
{
"img":
{
"RW":"/ReDWeb_V1/Imgs/*"
},
"depth":
{
"RW":"/ReDWeb_V1/RDs/*"
}
}

View File

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "========================================================================"
echo "Please run the script as: "
echo "bash run.sh RANK_SIZE"
echo "For example: bash run_distribute.sh 8"
echo "It is better to use the absolute path."
echo "========================================================================"
set -e
RANK_SIZE=$1
export RANK_SIZE
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
test_dist_8pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_8pcs.json
export RANK_SIZE=8
}
test_dist_2pcs()
{
export RANK_TABLE_FILE=${EXEC_PATH}/rank_table_2pcs.json
export RANK_SIZE=2
}
test_dist_${RANK_SIZE}pcs
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf distribute_train
mkdir distribute_train
cd distribute_train
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
cd ./device$i
mkdir src
cd src
mkdir utils
cd ../../../
cp ./midas_train.py config.yaml ./distribute_train/device$i
cp ./src/*.py ./distribute_train/device$i/src
cp ./src/utils/*.py ./distribute_train/device$i/src/utils
cd ./distribute_train/device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
python midas_train.py --run_distribute True --is_modelarts False > train$i.log 2>&1 &
echo "$i finish"
cd ../
done
if [ $? -eq 0 ];then
echo "training success"
else
echo "training failed"
exit 2
fi
echo "finish"
cd ../

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
export DATA_NAME=$2
python -u ../midas_eval.py --device_id=$DEVICE_ID --data_name=$DATA_NAME > eval.txt 2>&1 &

View File

@ -0,0 +1,18 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export DEVICE_ID=$1
python -u ../midas_train.py \
--device_id=$DEVICE_ID > train_$DEVICE_ID.log 2>&1 &

View File

@ -0,0 +1,75 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""blocks net."""
import mindspore.nn as nn
import mindspore.ops as ops
class FeatureFusionBlock(nn.Cell):
"""FeatureFusionBlock."""
def __init__(self, features):
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
self.resize_bilinear = ops.ResizeBilinear
self.shape = ops.Shape()
def construct(self, *xs):
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
size_x = self.shape(output)[2] * 2
size_y = self.shape(output)[3] * 2
output = self.resize_bilinear((size_x, size_y))(output)
return output
class ResidualConvUnit(nn.Cell):
"""ResidualConvUnit."""
def __init__(self, features):
super().__init__()
self.conv1 = nn.Conv2d(
features, features, kernel_size=3, stride=1, has_bias=True,
padding=1, pad_mode="pad"
)
self.conv2 = nn.Conv2d(
features, features, kernel_size=3, stride=1, has_bias=True,
padding=1, pad_mode="pad"
)
self.relu = nn.ReLU()
def construct(self, x):
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class Interpolate(nn.Cell):
"""Interpolate."""
def __init__(self, scale_factor):
super(Interpolate, self).__init__()
self.resize_bilinear = ops.ResizeBilinear
self.scale_factor = scale_factor
self.shape = ops.Shape()
def construct(self, x):
size_x = self.shape(x)[2] * self.scale_factor
size_y = self.shape(x)[3] * self.scale_factor
x = self.resize_bilinear((size_x, size_y))(x)
return x

View File

@ -0,0 +1,130 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,108 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""custom op."""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
class GlobalAvgPooling(nn.Cell):
"""
global average pooling feature map.
Args:
mean (tuple): means for each channel.
"""
def __init__(self):
super(GlobalAvgPooling, self).__init__()
self.mean = P.ReduceMean(False)
def construct(self, x):
x = self.mean(x, (2, 3))
return x
class SEBlock(nn.Cell):
"""
squeeze and excitation block.
Args:
channel (int): number of feature maps.
reduction (int): weight.
"""
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = GlobalAvgPooling()
self.fc1 = nn.Dense(channel, channel // reduction)
self.relu = P.ReLU()
self.fc2 = nn.Dense(channel // reduction, channel)
self.sigmoid = P.Sigmoid()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.sum = P.Sum()
self.cast = P.Cast()
def construct(self, x):
"""construct"""
b, c = self.shape(x)
y = self.avg_pool(x)
y = self.reshape(y, (b, c))
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
y = self.reshape(y, (b, c, 1, 1))
return x * y
class GroupConv(nn.Cell):
"""
group convolution operation.
Args:
in_channels (int): Input channels of feature map.
out_channels (int): Output channels of feature map.
kernel_size (int): Size of convolution kernel.
stride (int): Stride size for the group convolution layer.
Returns:
tensor, output tensor.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
super(GroupConv, self).__init__()
assert in_channels % groups == 0 and out_channels % groups == 0
self.groups = groups
self.convs = nn.CellList()
self.op_split = P.Split(axis=1, output_num=self.groups)
self.op_concat = P.Concat(axis=1)
self.cast = P.Cast()
for _ in range(groups):
self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
kernel_size=kernel_size, stride=stride, has_bias=has_bias,
padding=pad, pad_mode=pad_mode, group=1))
def construct(self, x):
"""construct."""
features = self.op_split(x)
outputs = ()
for i in range(self.groups):
outputs = outputs + (self.convs[i](self.cast(features[i], mstype.float32)),)
out = self.op_concat(outputs)
return out

View File

@ -0,0 +1,93 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""loss."""
import mindspore.nn as nn
import mindspore.ops as ops
class Gradient_loss(nn.Cell):
"""Gradient_loss"""
def __init__(self):
super(Gradient_loss, self).__init__()
self.ms_sum = ops.ReduceSum(keep_dims=False)
self.abs = ops.Abs()
def construct(self, prediction, target, mask):
"""Gradient_loss construct"""
M = self.ms_sum(mask, (1, 2))
diff = prediction - target
diff = mask * diff
grad_x = self.abs(diff[:, :, 1:] - diff[:, :, :-1])
mask_x = mask[:, :, 1:] * mask[:, :, :-1]
grad_x = mask_x * grad_x
grad_y = self.abs(diff[:, 1:, :] - diff[:, :-1, :])
mask_y = mask[:, 1:, :] * mask[:, :-1, :]
grad_y = mask_y * grad_y
image_loss = self.ms_sum(grad_x, (1, 2)) + self.ms_sum(grad_y, (1, 2))
divisor = self.ms_sum(M)
total = self.ms_sum(image_loss) / divisor
return total
class ScaleAndShiftInvariantLoss(nn.Cell):
"""ScaleAndShiftInvariantLoss"""
def __init__(self, alpha=0.5, scales=4):
super(ScaleAndShiftInvariantLoss, self).__init__()
self.ms_sum = ops.ReduceSum(keep_dims=False)
self.zeroslike = ops.ZerosLike()
self.select = ops.Select()
self.ones = ops.OnesLike()
self.reshape = ops.Reshape()
self.alpha = alpha
self.scales = scales
self.loss = Gradient_loss()
def construct(self, prediction, mask, target):
"""construct"""
a_00 = self.ms_sum(mask * prediction * prediction, (1, 2))
a_01 = self.ms_sum(mask * prediction, (1, 2))
a_11 = self.ms_sum(mask, (1, 2))
b_0 = self.ms_sum(mask * prediction * target, (1, 2))
b_1 = self.ms_sum(mask * target, (1, 2))
det = a_00 * a_11 - a_01 * a_01
mask_det = det != 0
input_y = self.zeroslike(det)
input_z = self.ones(det)
a_11 = self.select(mask_det, a_11, input_y)
b_0 = self.select(mask_det, b_0, input_y)
a_01 = self.select(mask_det, a_01, input_y)
b_1 = self.select(mask_det, b_1, input_y)
a_00 = self.select(mask_det, a_00, input_y)
det = self.select(mask_det, det, input_z)
x_0 = (a_11 * b_0 - a_01 * b_1) / det
x_1 = (-a_01 * b_0 + a_00 * b_1) / det
scale = self.reshape(x_0, (-1, 1, 1))
shift = self.reshape(x_1, (-1, 1, 1))
prediction_ssi = scale * prediction + shift
M = self.ms_sum(mask, (1, 2))
res = prediction_ssi - target
image_loss = self.ms_sum(mask * res * res, (1, 2))
divisor = self.ms_sum(M)
total = self.ms_sum(image_loss) / divisor
for scale in range(self.scales):
step = pow(2, scale)
total += self.loss(prediction_ssi[:, ::step, ::step], target[:, ::step, ::step],
mask[:, ::step, ::step])
return total * self.alpha

View File

@ -0,0 +1,393 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""net."""
import numpy as np
from mindspore import ops
from mindspore import ParameterTuple
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops.operations import Add, Split, Concat
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.ops import functional as F
from src.custom_op import SEBlock, GroupConv
from src.blocks_ms import Interpolate, FeatureFusionBlock
from src.loss import ScaleAndShiftInvariantLoss
from src.config import config
def conv7x7(in_channels, out_channels, stride=1, padding=3, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
def conv3x3(in_channels, out_channels, stride=1, padding=1, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
def conv1x1(in_channels, out_channels, stride=1, padding=0, has_bias=False, groups=1):
return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, has_bias=has_bias,
padding=padding, pad_mode="pad", group=groups)
class _DownSample(nn.Cell):
"""
Downsample for ResNext-ResNet.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the 1*1 convolutional layer.
Returns:
Tensor, output tensor.
Examples:
>>>DownSample(32, 64, 2)
"""
def __init__(self, in_channels, out_channels, stride):
super(_DownSample, self).__init__()
self.conv = conv1x1(in_channels, out_channels, stride=stride, padding=0)
self.bn = nn.BatchNorm2d(out_channels)
def construct(self, x):
out = self.conv(x)
out = self.bn(out)
return out
class BasicBlock(nn.Cell):
"""
ResNet basic block definition.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the first convolutional layer. Default: 1.
Returns:
Tensor, output tensor.
Examples:
>>>BasicBlock(32, 256, stride=2)
"""
expansion = 1
def __init__(self, in_channels, out_channels, stride=1, down_sample=None, use_se=False, **kwargs):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride=stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = P.ReLU()
self.conv2 = conv3x3(out_channels, out_channels, stride=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.use_se = use_se
if self.use_se:
self.se = SEBlock(out_channels)
self.down_sample_flag = False
if down_sample is not None:
self.down_sample = down_sample
self.down_sample_flag = True
self.add = Add()
def construct(self, x):
"""construct."""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.use_se:
out = self.se(out)
if self.down_sample_flag:
identity = self.down_sample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class Bottleneck(nn.Cell):
"""
ResNet Bottleneck block definition.
Args:
in_channels (int): Input channels.
out_channels (int): Output channels.
stride (int): Stride size for the initial convolutional layer. Default: 1.
Returns:
Tensor, the ResNet unit's output.
Examples:
>>>Bottleneck(3, 256, stride=2)
"""
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, down_sample=None,
base_width=64, groups=1, use_se=False, **kwargs):
super(Bottleneck, self).__init__()
width = int(out_channels * (base_width / 64.0)) * groups
self.groups = groups
self.conv1 = conv1x1(in_channels, width, stride=1)
self.bn1 = nn.BatchNorm2d(width)
self.relu = P.ReLU()
self.conv3x3s = nn.CellList()
self.conv2 = GroupConv(width, width, 3, stride, pad=1, groups=groups)
self.op_split = Split(axis=1, output_num=self.groups)
self.op_concat = Concat(axis=1)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, out_channels * self.expansion, stride=1)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.use_se = use_se
if self.use_se:
self.se = SEBlock(out_channels * self.expansion)
self.down_sample_flag = False
if down_sample is not None:
self.down_sample = down_sample
self.down_sample_flag = True
self.cast = P.Cast()
self.add = Add()
def construct(self, x):
"""construct."""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.use_se:
out = self.se(out)
if self.down_sample_flag:
identity = self.down_sample(x)
out = self.add(out, identity)
out = self.relu(out)
return out
class MidasNet(nn.Cell):
"""Network for monocular depth estimation.
"""
def __init__(self, block=Bottleneck, width_per_group=config.width_per_group,
groups=config.groups, use_se=False,
features=config.features, non_negative=True,
expand=False):
super(MidasNet, self).__init__()
self.in_channels = config.in_channels
self.groups = groups
self.layers = config.layers
self.base_width = width_per_group
self.backbone_conv = conv7x7(3, self.in_channels, stride=2, padding=3)
self.backbone_bn = nn.BatchNorm2d(self.in_channels)
self.backbone_relu = P.ReLU()
self.backbone_maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
self.backbone_layer1 = self._make_layer(block, 64, self.layers[0], use_se=use_se)
self.backbone_layer2 = self._make_layer(block, 128, self.layers[1], stride=2, use_se=use_se)
self.backbone_layer3 = self._make_layer(block, 256, self.layers[2], stride=2, use_se=use_se)
self.backbone_layer4 = self._make_layer(block, 512, self.layers[3], stride=2, use_se=use_se)
self.out_channels = 512 * block.expansion
out_shape1 = features
out_shape2 = features
out_shape3 = features
out_shape4 = features
self.non_negative = non_negative
if expand:
out_shape1 = features
out_shape2 = features * 2
out_shape3 = features * 4
out_shape4 = features * 8
self.layer1_rn_scratch = nn.Conv2d(
256, out_shape1, kernel_size=3, stride=1, has_bias=False,
padding=1, pad_mode="pad", group=1
)
self.layer2_rn_scratch = nn.Conv2d(
512, out_shape2, kernel_size=3, stride=1, has_bias=False,
padding=1, pad_mode="pad", group=1
)
self.layer3_rn_scratch = nn.Conv2d(
1024, out_shape3, kernel_size=3, stride=1, has_bias=False,
padding=1, pad_mode="pad", group=1
)
self.layer4_rn_scratch = nn.Conv2d(
2048, out_shape4, kernel_size=3, stride=1, has_bias=False,
padding=1, pad_mode="pad", group=1
)
self.refinenet4_scratch = FeatureFusionBlock(features)
self.refinenet3_scratch = FeatureFusionBlock(features)
self.refinenet2_scratch = FeatureFusionBlock(features)
self.refinenet1_scratch = FeatureFusionBlock(features)
self.output_conv_scratch = nn.SequentialCell([
nn.Conv2d(
features, 128, kernel_size=3, stride=1, has_bias=True,
padding=1, pad_mode="pad"
),
Interpolate(scale_factor=2),
nn.Conv2d(
128, 32, kernel_size=3, stride=1, has_bias=True,
padding=1, pad_mode="pad"
),
nn.ReLU(),
nn.Conv2d(
32, 1, kernel_size=1, stride=1, has_bias=True,
padding=0, pad_mode="pad"
),
nn.ReLU() if non_negative else ops.Identity(),
])
self.squeeze = ops.Squeeze(1)
def construct(self, x):
"""construct pass.
Args:
x (tensor): input data (image)
Returns:
tensor: depth
"""
x = self.backbone_conv(x)
x = self.backbone_bn(x)
x = self.backbone_relu(x)
x = self.backbone_maxpool(x)
layer1 = self.backbone_layer1(x)
layer2 = self.backbone_layer2(layer1)
layer3 = self.backbone_layer3(layer2)
layer4 = self.backbone_layer4(layer3)
layer_1_rn = self.layer1_rn_scratch(layer1)
layer_2_rn = self.layer2_rn_scratch(layer2)
layer_3_rn = self.layer3_rn_scratch(layer3)
layer_4_rn = self.layer4_rn_scratch(layer4)
path_4 = self.refinenet4_scratch(layer_4_rn)
path_3 = self.refinenet3_scratch(path_4, layer_3_rn)
path_2 = self.refinenet2_scratch(path_3, layer_2_rn)
path_1 = self.refinenet1_scratch(path_2, layer_1_rn)
out = self.output_conv_scratch(path_1)
result = self.squeeze(out)
return result
def _make_layer(self, block, out_channels, blocks_num, stride=1, use_se=False):
"""_make_layer"""
down_sample = None
if stride != 1 or self.in_channels != out_channels * block.expansion:
down_sample = _DownSample(self.in_channels,
out_channels * block.expansion,
stride=stride)
layers = [block(self.in_channels,
out_channels,
stride=stride,
down_sample=down_sample,
base_width=self.base_width,
groups=self.groups,
use_se=use_se)]
self.in_channels = out_channels * block.expansion
for _ in range(1, blocks_num):
layers.append(block(self.in_channels, out_channels,
base_width=self.base_width, groups=self.groups, use_se=use_se))
return nn.SequentialCell(layers)
def get_out_channels(self):
return self.out_channels
class Loss(nn.Cell):
def __init__(self):
super(Loss, self).__init__()
self.lossvalue = ScaleAndShiftInvariantLoss()
def construct(self, prediction, mask, target):
loss_value = self.lossvalue(prediction, mask, target)
return loss_value
class NetwithCell(nn.Cell):
"""NetwithCell."""
def __init__(self, net, loss):
super(NetwithCell, self).__init__(auto_prefix=False)
self._net = net
self._loss = loss
def construct(self, image, mask, depth):
prediction = self._net(image)
return self._loss(prediction, mask, depth)
@property
def backbone_network(self):
return self._net
class TrainOneStepCell(nn.Cell):
"""
Network training package class.
Append an optimizer to the training network after that the construct function
can be called to create the backward graph.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The adjust parameter. Default value is 1.0.
reduce_flag (bool): The reduce flag. Default value is False.
mean (bool): Allreduce method. Default value is False.
degree (int): Device number. Default value is None.
"""
def __init__(self, network, optimizer, sens=1.0, reduce_flag=False, mean=True, degree=None):
super(TrainOneStepCell, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True,
sens_param=True)
self.sens = Tensor((np.ones((1,)) * sens).astype(np.float32))
self.reduce_flag = reduce_flag
self.hyper_map = C.HyperMap()
if reduce_flag:
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, *ds):
weights = self.weights
loss = self.network(*ds)
grads = self.grad(self.network, weights)(*ds, self.sens)
if self.reduce_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,211 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""util."""
import sys
import numpy as np
import cv2
from PIL import Image
TAG_FLOAT = 202021.25
TAG_CHAR = 'PIEH'
def read_image(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def write_depth(path, depth, bits=1):
"""Write depth map to pfm and png file.
Args:
path (str): filepath without extension
depth (array): depth
:param path:
:param depth:
:param bits:
"""
write_pfm(path + ".pfm", depth.astype(np.float32))
depth_min = depth.min()
depth_max = depth.max()
max_val = (2 ** (8 * bits)) - 1
if depth_max - depth_min > np.finfo("float").eps:
out = max_val * (depth - depth_min) / (depth_max - depth_min)
else:
out = np.zeros(depth.shape, dtype=depth.type)
if bits == 1:
cv2.imwrite(path + ".png", out.astype("uint8"))
elif bits == 2:
cv2.imwrite(path + ".png", out.astype("uint16"))
return out
def write_pfm(path, image, scale=1):
"""Write pfm file.
Args:
path (str): path file
image (array): data
scale (int, optional): Scale. Defaults to 1.
"""
with open(path, "wb") as file:
color = None
if image.dtype.name != "float32":
raise Exception("Image dtype must be float32.")
image = np.flipud(image)
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif (
len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
): # greyscale
color = False
else:
raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
file.write("PF\n" if color else "Pf\n".encode())
file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
endian = image.dtype.byteorder
if endian == "<" or endian == "=" and sys.byteorder == "little":
scale = -scale
file.write("%f\n".encode() % scale)
image.tofile(file)
def depth_read_kitti(filename):
"""
:type filename: object
"""
depth_png = np.array(Image.open(filename), dtype=int)
assert np.max(depth_png) > 255
depth = depth_png.astype(np.float) / 256.
depth[depth_png == 0] = -1.
return depth
def depth_read_sintel(filename):
""" Read depth data from file, return as numpy array. """
f = open(filename, 'rb')
check = np.fromfile(f, dtype=np.float32, count=1)[0]
assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file ' \
'(should be: {0}, is: {1}). Big-endian machine? ' \
.format(TAG_FLOAT, check)
width = np.fromfile(f, dtype=np.int32, count=1)[0]
height = np.fromfile(f, dtype=np.int32, count=1)[0]
size = width * height
assert width > 0 and height > 0 and 1 < size < 100000000, \
' depth_read:: Wrong input size (width = {0}, height = {1}).' \
.format(width, height)
depth = np.fromfile(f, dtype=np.float32, count=-1).reshape((height, width))
return depth
class BadPixelMetric:
""" BadPixelMetric. """
def __init__(self, threshold=1.25, depth_cap=10, model='NYU'):
self.__threshold = threshold
self.__depth_cap = depth_cap
self.__model = model
def compute_scale_and_shift(self, prediction, target, mask):
""" compute_scale_and_shift. """
# system matrix: A = [[a_00, a_01], [a_10, a_11]]
a_00 = np.sum(mask * prediction * prediction, (1, 2))
a_01 = np.sum(mask * prediction, (1, 2))
a_11 = np.sum(mask, (1, 2))
# right hand side: b = [b_0, b_1]
b_0 = np.sum(mask * prediction * target, (1, 2))
b_1 = np.sum(mask * target, (1, 2))
# solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b
x_0 = np.zeros_like(b_0)
x_1 = np.zeros_like(b_1)
det = a_00 * a_11 - a_01 * a_01
# A needs to be a positive definite matrix.
valid = det > 0
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]
return x_0, x_1
def __call__(self, prediction, target, mask):
# transform predicted disparity to aligned depth
target_disparity = np.zeros_like(target)
target_disparity[mask == 1] = 1.0 / target[mask == 1]
scale, shift = self.compute_scale_and_shift(prediction, target_disparity, mask)
prediction_aligned = scale.reshape((-1, 1, 1)) * prediction + shift.reshape((-1, 1, 1))
disparity_cap = 1.0 / self.__depth_cap
prediction_aligned[prediction_aligned < disparity_cap] = disparity_cap
prediciton_depth = 1.0 / prediction_aligned
# bad pixel
err = np.zeros_like(prediciton_depth, dtype=np.float)
if self.__model == 'NYU' or self.__model == 'TUM' or self.__model == 'KITTI':
err[mask == 1] = np.maximum(
prediciton_depth[mask == 1] / target[mask == 1],
target[mask == 1] / prediciton_depth[mask == 1],
)
err[mask == 1] = (err[mask == 1] > self.__threshold)
p = np.sum(err, (1, 2)) / np.sum(mask, (1, 2))
if self.__model == 'sintel' or self.__model == 'ETH3D':
err[mask == 1] = np.abs((prediciton_depth[mask == 1] - target[mask == 1]) / target[mask == 1])
err_sum = np.sum(err, (1, 2))
mask_sum = np.sum(mask, (1, 2))
print('err_sum is ', err_sum)
print('mask_sum is ', mask_sum)
if mask_sum == 0:
p = np.zeros(1)
else:
p = err_sum / mask_sum
return np.mean(p)
return 100 * np.mean(p)

View File

@ -0,0 +1,120 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""loadImg."""
import glob
from collections import OrderedDict
import h5py
import cv2
import numpy as np
from src.utils.transforms import Resize, NormalizeImage, PrepareForNet
from src.config import config
class LoadImagesDepth:
"""LoadImagesDepth."""
def __init__(self, local_path=None, img_paths=None):
self.img_files = OrderedDict()
self.depth_files = OrderedDict()
self.nF = 0
for ds, path in img_paths.items():
self.img_files[ds] = sorted(glob.glob(local_path + path))
self.depth_files[ds] = [x.replace('Imgs', 'RDs').replace('jpg', 'png') for x in
self.img_files[ds]]
self.ds = ds
self.nds = [len(x) for x in self.img_files.values()]
self.cds = [sum(self.nds[:i]) for i in range(len(self.nds))]
self.nF = sum(self.nds)
print(self.nds)
print(self.cds)
print(self.nF)
self.img_input_1 = Resize(config.img_width,
config.img_height,
resize_target=config.resize_target,
keep_aspect_ratio=config.keep_aspect_ratio,
ensure_multiple_of=config.ensure_multiple_of,
resize_method=config.resize_method,
image_interpolation_method=cv2.INTER_CUBIC)
self.img_input_2 = NormalizeImage(mean=config.nm_img_mean, std=config.nm_img_std)
self.img_input_3 = PrepareForNet()
def __getitem__(self, files_index):
for i, c in enumerate(self.cds):
if files_index >= c:
ds = list(self.depth_files.keys())[i]
start_index = c
img_path = self.img_files[ds][files_index - start_index]
depth_path = self.depth_files[ds][files_index - start_index]
return self.get_data(self.ds, img_path, depth_path)
def get_data(self, ds, img_path, label_path):
"""get_data."""
sample = {}
img = read_image2RGB(img_path)
if ds == 'Mega':
depth = read_h5(label_path)
else:
depth = read_image2gray(label_path)
mask = np.ones(depth.shape)
sample["image"] = img
sample["mask"] = mask
sample["depth"] = depth
sample = self.img_input_1(sample)
sample = self.img_input_2(sample)
sample = self.img_input_3(sample)
return sample["image"], sample["mask"], sample["depth"]
def __len__(self):
return self.nF
def read_image2gray(path):
"""Read image and output GRAY image (0-1).
Args:
path (str): path to file
Returns:
array: GRAY image (0-1)
"""
imgOri = cv2.imread(path, -1)
depth = cv2.split(imgOri)[0]
return depth
def read_image2RGB(path):
"""Read image and output RGB image (0-1).
Args:
path (str): path to file
Returns:
array: RGB image (0-1)
"""
img = cv2.imread(path)
if img is None:
raise ValueError('File corrupt {}'.format(path))
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
return img
def read_h5(path):
f = h5py.File(path, 'r')
if f is None:
raise ValueError('File corrupt {}'.format(path))
gt = f.get('/depth')
return np.array(gt)

View File

@ -0,0 +1,85 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""pth2ckpt."""
from mindspore import Tensor
from mindspore.train.serialization import save_checkpoint
import torch
import numpy as np
def pytorch2mindspore():
"""pth to ckpt."""
par_dict = torch.load('/opt_data/xidian_wks/czp/midas/ckpt/ig_resnext101_32x8-c38310e5.pth', map_location='cpu')
new_params_list = []
for name in par_dict:
print(name)
param_dict = {}
parameter = par_dict[name]
name = name.replace('layer', 'backbone_layer', 1)
name = name.replace('running_mean', 'moving_mean', 1)
name = name.replace('running_var', 'moving_variance', 1)
temp = name
if name.endswith('conv2.weight'):
x = parameter.numpy()
y = np.split(x, 32)
for i in range(32):
name = temp[:temp.rfind('weight')] + 'convs.' + str(i) + '.weight'
data = Tensor(y[i])
new_params_list.append({"name": name, 'data': data})
continue
if name.startswith('bn1'):
name = name.replace('bn1', 'backbone_bn', 1)
name = name.replace('bias', 'beta', 1)
name = name.replace('weight', 'gamma', 1)
if name.startswith('conv1.weight'):
name = 'backbone_conv.weight'
if name.endswith('layer1.0.weight'):
name = 'backbone_conv.weight'
if name.endswith('layer1.1.weight'):
name = 'backbone_bn.gamma'
if name.endswith('layer1.1.bias'):
name = 'backbone_bn.beta'
if name.endswith('bn1.weight'):
name = name[:name.rfind('weight')]
name = name + 'gamma'
if name.endswith('bn1.bias'):
name = name[:name.rfind('bias')]
name = name + 'beta'
if name.endswith('bn2.weight'):
name = name[:name.rfind('weight')]
name = name + 'gamma'
if name.endswith('bn2.bias'):
name = name[:name.rfind('bias')]
name = name + 'beta'
if name.endswith('bn3.weight'):
name = name[:name.rfind('weight')]
name = name + 'gamma'
if name.endswith('bn3.bias'):
name = name[:name.rfind('bias')]
name = name + 'beta'
if name.find('downsample') != -1:
name = name.replace("downsample.1", 'down_sample.bn')
name = name.replace("bn.weight", 'bn.gamma')
name = name.replace("bias", 'beta')
name = name.replace("downsample.0.weight", 'down_sample.conv.weight')
print("----------------", name)
param_dict['name'] = name
param_dict['data'] = Tensor(parameter.numpy())
new_params_list.append(param_dict)
save_checkpoint(new_params_list, 'midas_pth.ckpt')
if __name__ == '__main__':
pytorch2mindspore()

View File

@ -0,0 +1,252 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""transforms."""
import math
import numpy as np
import cv2
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
Args:
sample (dict): sample
size (tuple): image size
Returns:
tuple: new size
"""
shape = list(sample["disparity"].shape)
if shape[0] >= size[0] and shape[1] >= size[1]:
return sample
scale = [0, 0]
scale[0] = size[0] / shape[0]
scale[1] = size[1] / shape[1]
scale = max(scale)
shape[0] = math.ceil(scale * shape[0])
shape[1] = math.ceil(scale * shape[1])
# resize
sample["image"] = cv2.resize(
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
)
sample["disparity"] = cv2.resize(
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
)
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
tuple(shape[::-1]),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return tuple(shape)
class Resize:
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size.
(Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
"""get_size."""
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(
f"resize_method {self.__resize_method} not implemented"
)
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, min_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, min_val=self.__width
)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(
scale_height * height, max_val=self.__height
)
new_width = self.constrain_to_multiple_of(
scale_width * width, max_val=self.__width
)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(
sample["image"].shape[1], sample["image"].shape[0]
)
# resize sample
sample["image"] = cv2.resize(
sample["image"],
(width, height),
interpolation=self.__image_interpolation_method,
)
if self.__resize_target:
if "disparity" in sample:
sample["disparity"] = cv2.resize(
sample["disparity"],
(width, height),
interpolation=cv2.INTER_NEAREST,
)
if "depth" in sample:
sample["depth"] = cv2.resize(
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
)
if "mask" in sample:
sample["mask"] = cv2.resize(
sample["mask"].astype(np.float32),
(width, height),
interpolation=cv2.INTER_NEAREST,
)
sample["mask"] = sample["mask"].astype(bool)
return sample
class NormalizeImage:
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet:
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
if "disparity" in sample:
disparity = sample["disparity"].astype(np.float32)
sample["disparity"] = np.ascontiguousarray(disparity)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
return sample