forked from mrharicot/monodepth
-
Notifications
You must be signed in to change notification settings - Fork 0
/
monodepth_simple.py
110 lines (87 loc) · 3.81 KB
/
monodepth_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright UCL Business plc 2017. Patent Pending. All rights reserved.
#
# The MonoDepth Software is licensed under the terms of the UCLB ACP-A licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
#
# For any other use of the software not covered by the UCLB ACP-A Licence,
# please contact [email protected]
from __future__ import absolute_import, division, print_function
# only keep warnings and errors
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='0'
import numpy as np
import argparse
import re
import time
import tensorflow as tf
import tensorflow.contrib.slim as slim
import scipy.misc
import matplotlib.pyplot as plt
from monodepth_model import *
from monodepth_dataloader import *
from average_gradients import *
parser = argparse.ArgumentParser(description='Monodepth TensorFlow implementation.')
parser.add_argument('--encoder', type=str, help='type of encoder, vgg or resnet50', default='vgg')
parser.add_argument('--image_path', type=str, help='path to the image', required=True)
parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', required=True)
parser.add_argument('--input_height', type=int, help='input height', default=256)
parser.add_argument('--input_width', type=int, help='input width', default=512)
args = parser.parse_args()
def post_process_disparity(disp):
_, h, w = disp.shape
l_disp = disp[0,:,:]
r_disp = np.fliplr(disp[1,:,:])
m_disp = 0.5 * (l_disp + r_disp)
l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h))
l_mask = 1.0 - np.clip(20 * (l - 0.05), 0, 1)
r_mask = np.fliplr(l_mask)
return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp
def test_simple(params):
"""Test function."""
left = tf.placeholder(tf.float32, [2, args.input_height, args.input_width, 3])
model = MonodepthModel(params, "test", left, None)
input_image = scipy.misc.imread(args.image_path, mode="RGB")
original_height, original_width, num_channels = input_image.shape
input_image = scipy.misc.imresize(input_image, [args.input_height, args.input_width], interp='lanczos')
input_image = input_image.astype(np.float32) / 255
input_images = np.stack((input_image, np.fliplr(input_image)), 0)
# SESSION
config = tf.ConfigProto(allow_soft_placement=True)
sess = tf.Session(config=config)
# SAVER
train_saver = tf.train.Saver()
# INIT
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
coordinator = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coordinator)
# RESTORE
restore_path = args.checkpoint_path.split(".")[0]
train_saver.restore(sess, restore_path)
disp = sess.run(model.disp_left_est[0], feed_dict={left: input_images})
disp_pp = post_process_disparity(disp.squeeze()).astype(np.float32)
output_directory = os.path.dirname(args.image_path)
output_name = os.path.splitext(os.path.basename(args.image_path))[0]
np.save(os.path.join(output_directory, "{}_disp.npy".format(output_name)), disp_pp)
disp_to_img = scipy.misc.imresize(disp_pp.squeeze(), [original_height, original_width])
plt.imsave(os.path.join(output_directory, "{}_disp.png".format(output_name)), disp_to_img, cmap='plasma')
print('done!')
def main(_):
params = monodepth_parameters(
encoder=args.encoder,
height=args.input_height,
width=args.input_width,
batch_size=2,
num_threads=1,
num_epochs=1,
do_stereo=False,
wrap_mode="border",
use_deconv=False,
alpha_image_loss=0,
disp_gradient_loss_weight=0,
lr_loss_weight=0,
full_summary=False)
test_simple(params)
if __name__ == '__main__':
tf.app.run()