tensorflow2
This commit is contained in:
@@ -3,6 +3,7 @@ import numpy as np
|
||||
import os
|
||||
from .bbox import BoundBox, bbox_iou
|
||||
from scipy.special import expit
|
||||
import tensorflow as tf
|
||||
|
||||
def _sigmoid(x):
|
||||
return expit(x)
|
||||
@@ -166,18 +167,30 @@ def do_nms(boxes, nms_thresh):
|
||||
if bbox_iou(boxes[index_i], boxes[index_j]) >= nms_thresh:
|
||||
boxes[index_j].classes[c] = 0
|
||||
|
||||
def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
|
||||
grid_h, grid_w = netout.shape[:2]
|
||||
def decode_netout(netout_old, anchors, obj_thresh, net_h, net_w):
|
||||
grid_h, grid_w = netout_old.shape[:2]
|
||||
nb_box = 3
|
||||
netout = netout.reshape((grid_h, grid_w, nb_box, -1))
|
||||
nb_class = netout.shape[-1] - 5
|
||||
#netout = netout.reshape((grid_h, grid_w, nb_box, -1))
|
||||
netout_old = tf.reshape(netout_old, (grid_h, grid_w, nb_box, -1))
|
||||
nb_class = netout_old.shape[-1] - 5
|
||||
|
||||
boxes = []
|
||||
## Tensorflow v.2
|
||||
#print(tf.shape(netout))
|
||||
aux_1 = _sigmoid(netout_old[..., :2])
|
||||
#print(tf.shape(aux_1))
|
||||
aux_2 = _sigmoid(netout_old[..., 4])
|
||||
#print(tf.shape(aux_2[..., np.newaxis]))
|
||||
aux_3 = aux_2[..., np.newaxis] * _softmax(netout_old[..., 5:])
|
||||
aux_4 = aux_3 * (aux_3 > obj_thresh)
|
||||
#print(tf.shape(aux_4))
|
||||
netout = tf.concat([aux_1,netout_old[..., 2:4] ,aux_2[..., np.newaxis], aux_4], 3)
|
||||
#print(tf.shape(new_netout))
|
||||
|
||||
netout[..., :2] = _sigmoid(netout[..., :2])
|
||||
netout[..., 4] = _sigmoid(netout[..., 4])
|
||||
netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])
|
||||
netout[..., 5:] *= netout[..., 5:] > obj_thresh
|
||||
#netout[..., :2] = _sigmoid(netout[..., :2])
|
||||
#netout[..., 4] = _sigmoid(netout[..., 4])
|
||||
#netout[..., 5:] = netout[..., 4][..., np.newaxis] * _softmax(netout[..., 5:])
|
||||
#netout[..., 5:] *= netout[..., 5:] > obj_thresh
|
||||
|
||||
for i in range(grid_h*grid_w):
|
||||
row = i // grid_w
|
||||
@@ -198,7 +211,7 @@ def decode_netout(netout, anchors, obj_thresh, net_h, net_w):
|
||||
h = anchors[2 * b + 1] * np.exp(h) / net_h # unit: image height
|
||||
|
||||
# last elements are class probabilities
|
||||
classes = netout[row,col,b,5:]
|
||||
classes = np.array(netout[row,col,b,5:])
|
||||
|
||||
box = BoundBox(x-w/2, y-h/2, x+w/2, y+h/2, objectness, classes)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user