tensorflow2
This commit is contained in:
@@ -1,22 +1,22 @@
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
from keras.utils import Sequence
|
||||
from tensorflow.keras.utils import Sequence
|
||||
from utils.bbox import BoundBox, bbox_iou
|
||||
from utils.image import apply_random_scale_and_crop, random_distort_image, random_flip, correct_bounding_boxes
|
||||
|
||||
class BatchGenerator(Sequence):
|
||||
def __init__(self,
|
||||
instances,
|
||||
anchors,
|
||||
labels,
|
||||
def __init__(self,
|
||||
instances,
|
||||
anchors,
|
||||
labels,
|
||||
downsample=32, # ratio between network input's size and network output's size, 32 for YOLOv3
|
||||
max_box_per_image=30,
|
||||
batch_size=1,
|
||||
min_net_size=320,
|
||||
max_net_size=608,
|
||||
shuffle=True,
|
||||
jitter=True,
|
||||
max_net_size=608,
|
||||
shuffle=True,
|
||||
jitter=True,
|
||||
norm=None
|
||||
):
|
||||
self.instances = instances
|
||||
@@ -30,13 +30,13 @@ class BatchGenerator(Sequence):
|
||||
self.jitter = jitter
|
||||
self.norm = norm
|
||||
self.anchors = [BoundBox(0, 0, anchors[2*i], anchors[2*i+1]) for i in range(len(anchors)//2)]
|
||||
self.net_h = 416
|
||||
self.net_h = 416
|
||||
self.net_w = 416
|
||||
|
||||
if shuffle: np.random.shuffle(self.instances)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return int(np.ceil(float(len(self.instances))/self.batch_size))
|
||||
return int(np.ceil(float(len(self.instances))/self.batch_size))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# get image input size, change every 10 batches
|
||||
@@ -63,7 +63,7 @@ class BatchGenerator(Sequence):
|
||||
dummy_yolo_1 = np.zeros((r_bound - l_bound, 1))
|
||||
dummy_yolo_2 = np.zeros((r_bound - l_bound, 1))
|
||||
dummy_yolo_3 = np.zeros((r_bound - l_bound, 1))
|
||||
|
||||
|
||||
instance_count = 0
|
||||
true_box_index = 0
|
||||
|
||||
@@ -71,18 +71,18 @@ class BatchGenerator(Sequence):
|
||||
for train_instance in self.instances[l_bound:r_bound]:
|
||||
# augment input image and fix object's position and size
|
||||
img, all_objs = self._aug_image(train_instance, net_h, net_w)
|
||||
|
||||
|
||||
for obj in all_objs:
|
||||
# find the best anchor box for this object
|
||||
max_anchor = None
|
||||
max_anchor = None
|
||||
max_index = -1
|
||||
max_iou = -1
|
||||
|
||||
shifted_box = BoundBox(0,
|
||||
shifted_box = BoundBox(0,
|
||||
0,
|
||||
obj['xmax']-obj['xmin'],
|
||||
obj['ymax']-obj['ymin'])
|
||||
|
||||
obj['xmax']-obj['xmin'],
|
||||
obj['ymax']-obj['ymin'])
|
||||
|
||||
for i in range(len(self.anchors)):
|
||||
anchor = self.anchors[i]
|
||||
iou = bbox_iou(shifted_box, anchor)
|
||||
@@ -90,18 +90,18 @@ class BatchGenerator(Sequence):
|
||||
if max_iou < iou:
|
||||
max_anchor = anchor
|
||||
max_index = i
|
||||
max_iou = iou
|
||||
|
||||
max_iou = iou
|
||||
|
||||
# determine the yolo to be responsible for this bounding box
|
||||
yolo = yolos[max_index//3]
|
||||
grid_h, grid_w = yolo.shape[1:3]
|
||||
|
||||
|
||||
# determine the position of the bounding box on the grid
|
||||
center_x = .5*(obj['xmin'] + obj['xmax'])
|
||||
center_x = center_x / float(net_w) * grid_w # sigma(t_x) + c_x
|
||||
center_y = .5*(obj['ymin'] + obj['ymax'])
|
||||
center_y = center_y / float(net_h) * grid_h # sigma(t_y) + c_y
|
||||
|
||||
|
||||
# determine the sizes of the bounding box
|
||||
w = np.log((obj['xmax'] - obj['xmin']) / float(max_anchor.xmax)) # t_w
|
||||
h = np.log((obj['ymax'] - obj['ymin']) / float(max_anchor.ymax)) # t_h
|
||||
@@ -109,7 +109,7 @@ class BatchGenerator(Sequence):
|
||||
box = [center_x, center_y, w, h]
|
||||
|
||||
# determine the index of the label
|
||||
obj_indx = self.labels.index(obj['name'])
|
||||
obj_indx = self.labels.index(obj['name'])
|
||||
|
||||
# determine the location of the cell responsible for this object
|
||||
grid_x = int(np.floor(center_x))
|
||||
@@ -126,25 +126,25 @@ class BatchGenerator(Sequence):
|
||||
t_batch[instance_count, 0, 0, 0, true_box_index] = true_box
|
||||
|
||||
true_box_index += 1
|
||||
true_box_index = true_box_index % self.max_box_per_image
|
||||
true_box_index = true_box_index % self.max_box_per_image
|
||||
|
||||
# assign input image to x_batch
|
||||
if self.norm != None:
|
||||
if self.norm != None:
|
||||
x_batch[instance_count] = self.norm(img)
|
||||
else:
|
||||
# plot image and bounding boxes for sanity check
|
||||
for obj in all_objs:
|
||||
cv2.rectangle(img, (obj['xmin'],obj['ymin']), (obj['xmax'],obj['ymax']), (255,0,0), 3)
|
||||
cv2.putText(img, obj['name'],
|
||||
(obj['xmin']+2, obj['ymin']+12),
|
||||
0, 1.2e-3 * img.shape[0],
|
||||
cv2.putText(img, obj['name'],
|
||||
(obj['xmin']+2, obj['ymin']+12),
|
||||
0, 1.2e-3 * img.shape[0],
|
||||
(0,255,0), 2)
|
||||
|
||||
|
||||
x_batch[instance_count] = img
|
||||
|
||||
# increase instance counter in the current batch
|
||||
instance_count += 1
|
||||
|
||||
instance_count += 1
|
||||
|
||||
return [x_batch, t_batch, yolo_1, yolo_2, yolo_3], [dummy_yolo_1, dummy_yolo_2, dummy_yolo_3]
|
||||
|
||||
def _get_net_size(self, idx):
|
||||
@@ -154,16 +154,16 @@ class BatchGenerator(Sequence):
|
||||
#print("resizing: ", net_size, net_size)
|
||||
self.net_h, self.net_w = net_size, net_size
|
||||
return self.net_h, self.net_w
|
||||
|
||||
|
||||
def _aug_image(self, instance, net_h, net_w):
|
||||
image_name = instance['filename']
|
||||
image = cv2.imread(image_name) # RGB image
|
||||
|
||||
if image is None: print('Cannot find ', image_name)
|
||||
image = image[:,:,::-1] # RGB image
|
||||
|
||||
|
||||
image_h, image_w, _ = image.shape
|
||||
|
||||
|
||||
# determine the amount of scaling and cropping
|
||||
dw = self.jitter * image_w;
|
||||
dh = self.jitter * image_h;
|
||||
@@ -177,33 +177,33 @@ class BatchGenerator(Sequence):
|
||||
else:
|
||||
new_w = int(scale * net_w);
|
||||
new_h = int(net_w / new_ar);
|
||||
|
||||
|
||||
dx = int(np.random.uniform(0, net_w - new_w));
|
||||
dy = int(np.random.uniform(0, net_h - new_h));
|
||||
|
||||
|
||||
# apply scaling and cropping
|
||||
im_sized = apply_random_scale_and_crop(image, new_w, new_h, net_w, net_h, dx, dy)
|
||||
|
||||
|
||||
# randomly distort hsv space
|
||||
im_sized = random_distort_image(im_sized)
|
||||
|
||||
|
||||
# randomly flip
|
||||
flip = np.random.randint(2)
|
||||
im_sized = random_flip(im_sized, flip)
|
||||
|
||||
|
||||
# correct the size and pos of bounding boxes
|
||||
all_objs = correct_bounding_boxes(instance['object'], new_w, new_h, net_w, net_h, dx, dy, flip, image_w, image_h)
|
||||
|
||||
return im_sized, all_objs
|
||||
|
||||
return im_sized, all_objs
|
||||
|
||||
def on_epoch_end(self):
|
||||
if self.shuffle: np.random.shuffle(self.instances)
|
||||
|
||||
|
||||
def num_classes(self):
|
||||
return len(self.labels)
|
||||
|
||||
def size(self):
|
||||
return len(self.instances)
|
||||
return len(self.instances)
|
||||
|
||||
def get_anchors(self):
|
||||
anchors = []
|
||||
@@ -225,4 +225,4 @@ class BatchGenerator(Sequence):
|
||||
return np.array(annots)
|
||||
|
||||
def load_image(self, i):
|
||||
return cv2.imread(self.instances[i]['filename'])
|
||||
return cv2.imread(self.instances[i]['filename'])
|
||||
|
||||
Reference in New Issue
Block a user