%reload_ext autoreload
%autoreload 2
#export
from fastai import *
from fastai.vision import *
PATH = Path('../data/coco')
ANNOT_PATH = PATH/'annotations'
train_ds = ObjectDetectDataset.from_json(PATH/'train2017', ANNOT_PATH/'train_sample.json')
tfms = get_transforms()
train_tds = DatasetTfm(train_ds, tfms[0], tfm_y=True, size=224)
x,y = train_tds[5]
x.show(y=y, classes=train_ds.classes, figsize=(6,4))
size = 224
tfms = ([flip_lr(p=0.5), crop_pad(size=size)], [crop_pad(size=size)])
train_tds = DatasetTfm(train_ds, tfms[0], tfm_y=True, size=size, padding_mode='zeros', do_crop=False)
x,y = train_tds[0]
x.show(y=y, classes=train_ds.classes, figsize=(6,4))
y.data
x.size
#export
def bb_pad_collate(samples:BatchSamples, pad_idx:int=0, pad_first:bool=True) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]:
"Function that collect samples and adds padding."
max_len = max([len(s[1].data[1]) for s in samples])
bboxes = torch.zeros(len(samples), max_len, 4)
labels = torch.zeros(len(samples), max_len).long() + pad_idx
imgs = []
for i,s in enumerate(samples):
imgs.append(s[0].data[None])
bbs, lbls = s[1].data
bboxes[i,-len(lbls):] = bbs
labels[i,-len(lbls):] = lbls
return torch.cat(imgs,0), (bboxes,labels)
train_dl = DataLoader(train_tds, 64, shuffle=False, collate_fn=bb_pad_collate)
def show_sample(dl, rows, start=0):
x,y = next(iter(dl))
x = x[start:start+rows*rows].cpu()
_,axs = plt.subplots(rows,rows,figsize=(9,9))
for i, ax in enumerate(axs.flatten()):
img = Image(x[i])
idxs = y[1][start+i].nonzero()[:,0]
if len(idxs) != 0:
bbs,lbls = y[0][start+i][idxs],y[1][start+i][idxs]
h,w = img.size
bbs = ((bbs+1) * torch.tensor([h/2,w/2, h/2, w/2])).long()
bbox = ImageBBox.create(bbs, *img.size, lbls)
img.show(ax=ax, y=bbox, classes=dl.dataset.classes)
else: img.show(ax=ax)
plt.tight_layout()
show_sample(train_dl, 3, 18)
train_ds, valid_ds = ObjectDetectDataset.from_json(PATH/'train2017', ANNOT_PATH/'train_sample.json', valid_pct=0.2)
data = DataBunch.create(train_ds, valid_ds, path=PATH, ds_tfms=tfms, tfms=imagenet_norm, collate_fn=bb_pad_collate,
num_workers=8, bs=16, size=128, tfm_y=True, padding_mode='zeros', do_crop=False)
def show_sample(dl, rows, denorm=None):
x,y = next(iter(dl))
x = x[:rows*rows].cpu()
if denorm: x = denorm(x)
_,axs = plt.subplots(rows,rows,figsize=(9,9))
for i, ax in enumerate(axs.flatten()):
img = Image(x[i])
idxs = y[1][i].nonzero()[:,0]
if len(idxs) != 0:
bbs,lbls = y[0][i][idxs],y[1][i][idxs]
h,w = img.size
bbs = ((bbs.cpu()+1) * torch.tensor([h/2,w/2, h/2, w/2])).long()
bbox = ImageBBox.create(bbs, *img.size, lbls)
img.show(ax=ax, y=bbox, classes=dl.dataset.classes)
else: img.show(ax=ax)
plt.tight_layout()
show_sample(data.train_dl, 3, denorm=imagenet_denorm)
#export
def _get_sfs_idxs(sizes:Sizes) -> List[int]:
"Get the indexes of the layers where the size of the activation changes."
feature_szs = [size[-1] for size in sizes]
sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
return sfs_idxs
encoder = create_body(tvm.resnet50(True), -2)
#export
class LateralUpsampleMerge(nn.Module):
def __init__(self, ch, ch_lat, hook):
super().__init__()
self.hook = hook
self.conv_lat = conv2d(ch_lat, ch, ks=1, bias=True)
def forward(self, x):
return self.conv_lat(self.hook.stored) + F.interpolate(x, scale_factor=2)
#export
class RetinaNet(nn.Module):
"Implements RetinaNet from https://arxiv.org/abs/1708.02002"
def __init__(self, encoder:Model, n_classes, final_bias=0., chs=256, n_anchors=9, flatten=True):
super().__init__()
self.n_classes,self.flatten = n_classes,flatten
imsize = (256,256)
sfs_szs,x,hooks = model_sizes(encoder, size=imsize)
sfs_idxs = _get_sfs_idxs(sfs_szs)
self.encoder = encoder
self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)
self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)
self.p6top7 = nn.Sequential(nn.ReLU(), conv2d(chs, chs, stride=2, bias=True))
self.merges = nn.ModuleList([LateralUpsampleMerge(chs, szs[1], hook)
for szs,hook in zip(sfs_szs[-2:-4:-1], hooks[-2:-4:-1])])
self.smoothers = nn.ModuleList([conv2d(chs, chs, 3, bias=True) for _ in range(3)])
self.classifier = self._head_subnet(n_classes, n_anchors, final_bias, chs=chs)
self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs)
def _head_subnet(self, n_classes, n_anchors, final_bias=0., n_conv=4, chs=256):
layers = [conv2d_relu(chs, chs, bias=True) for _ in range(n_conv)]
layers += [conv2d(chs, n_classes * n_anchors, bias=True)]
layers[-1].bias.data.zero_().add_(final_bias)
layers[-1].weight.data.fill_(0)
return nn.Sequential(*layers)
def _apply_transpose(self, func, p_states, n_classes):
if not self.flatten:
sizes = [[p.size(0), p.size(2), p.size(3)] for p in p_states]
return [func(p).permute(0,2,3,1).view(*sz,-1,n_classes) for p,sz in zip(p_states,sizes)]
else:
return torch.cat([func(p).permute(0,2,3,1).contiguous().view(p.size(0),-1,n_classes) for p in p_states],1)
def forward(self, x):
c5 = self.encoder(x)
p_states = [self.c5top5(c5.clone()), self.c5top6(c5)]
p_states.append(self.p6top7(p_states[-1]))
for merge in self.merges: p_states = [merge(p_states[0])] + p_states
for i, smooth in enumerate(self.smoothers[:3]):
p_states[i] = smooth(p_states[i])
return [self._apply_transpose(self.classifier, p_states, self.n_classes),
self._apply_transpose(self.box_regressor, p_states, 4),
[[p.size(2), p.size(3)] for p in p_states]]
encoder = create_body(tvm.resnet50(True), -2)
model = RetinaNet(encoder, 6, -4)
model.eval()
x = torch.randn(2,3,256,256)
output = model(x)
[y.size() for y in output[:2]], output[2]
We need to create the corresponding anchors in this order:
torch.arange(1,17).long().view(4,4)
#export
def create_grid(size):
"Create a grid of a given `size`."
H, W = size if is_tuple(size) else (size,size)
grid = FloatTensor(H, W, 2)
linear_points = torch.linspace(-1+1/W, 1-1/W, W) if W > 1 else tensor([0.])
grid[:, :, 1] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, 0])
linear_points = torch.linspace(-1+1/H, 1-1/H, H) if H > 1 else tensor([0.])
grid[:, :, 0] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, 1])
return grid.view(-1,2)
Convention (-1.,-1.) to (1.,1.), first is y, second is x (like for the bboxes). -1 is left/top, 1 is right/bottom.
#export
def show_anchors(ancs, size):
_,ax = plt.subplots(1,1, figsize=(5,5))
ax.set_xticks(np.linspace(-1,1, size[1]+1))
ax.set_yticks(np.linspace(-1,1, size[0]+1))
ax.grid()
ax.scatter(ancs[:,1], ancs[:,0]) #y is first
ax.set_yticklabels([])
ax.set_xticklabels([])
ax.set_xlim(-1,1)
ax.set_ylim(1,-1) #-1 is top, 1 is bottom
for i, (x, y) in enumerate(zip(ancs[:, 1], ancs[:, 0])): ax.annotate(i, xy = (x,y))
size = (4,4)
show_anchors(create_grid(size), size)
#export
def create_anchors(sizes, ratios, scales, flatten=True):
"Create anchor of `sizes`, `ratios` and `scales`."
aspects = [[[s*math.sqrt(r), s*math.sqrt(1/r)] for s in scales] for r in ratios]
aspects = torch.tensor(aspects).view(-1,2)
anchors = []
for h,w in sizes:
#4 here to have the anchors overlap.
sized_aspects = 4 * (aspects * torch.tensor([2/h,2/w])).unsqueeze(0)
base_grid = create_grid((h,w)).unsqueeze(1)
n,a = base_grid.size(0),aspects.size(0)
ancs = torch.cat([base_grid.expand(n,a,2), sized_aspects.expand(n,a,2)], 2)
anchors.append(ancs.view(h,w,a,4))
return torch.cat([anc.view(-1,4) for anc in anchors],0) if flatten else anchors
ratios = [1/2,1,2]
#scales = [1,2**(-1/3), 2**(-2/3)]
scales = [1,2**(1/3), 2**(2/3)]
sizes = [(2**i,2**i) for i in range(5)]
sizes.reverse()
anchors = create_anchors(sizes, ratios, scales)
anchors.size()
#[anc.size() for anc in anchors]
import matplotlib.cm as cmx
import matplotlib.colors as mcolors
from cycler import cycler
def get_cmap(N):
color_norm = mcolors.Normalize(vmin=0, vmax=N-1)
return cmx.ScalarMappable(norm=color_norm, cmap='Set3').to_rgba
num_color = 12
cmap = get_cmap(num_color)
color_list = [cmap(float(x)) for x in range(num_color)]
def draw_outline(o, lw):
o.set_path_effects([patheffects.Stroke(
linewidth=lw, foreground='black'), patheffects.Normal()])
def draw_rect(ax, b, color='white'):
patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
draw_outline(patch, 4)
def draw_text(ax, xy, txt, sz=14, color='white'):
text = ax.text(*xy, txt,
verticalalignment='top', color=color, fontsize=sz, weight='bold')
draw_outline(text, 1)
def show_boxes(boxes):
"Show the `boxes` (size by 4)"
_, ax = plt.subplots(1,1, figsize=(5,5))
ax.set_xlim(-1,1)
ax.set_ylim(1,-1)
for i, bbox in enumerate(boxes):
bb = bbox.numpy()
rect = [bb[1]-bb[3]/2, bb[0]-bb[2]/2, bb[3], bb[2]]
draw_rect(ax, rect, color=color_list[i%num_color])
draw_text(ax, [bb[1]-bb[3]/2,bb[0]-bb[2]/2], str(i), color=color_list[i%num_color])
show_boxes(anchors[-9:])
#export
def activ_to_bbox(acts, anchors, flatten=True):
"Extrapolate bounding boxes on anchors from the model activations."
if flatten:
acts.mul_(acts.new_tensor([[0.1, 0.1, 0.2, 0.2]]))
centers = anchors[...,2:] * acts[...,:2] + anchors[...,:2]
sizes = anchors[...,2:] * torch.exp(acts[...,:2])
return torch.cat([centers, sizes], -1)
else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]
return res
size=(3,4)
anchors = create_grid(size)
anchors = torch.cat([anchors, torch.tensor([2/size[0],2/size[1]]).expand_as(anchors)], 1)
activations = 0.1 * torch.randn(size[0]*size[1], 4)
bboxes = activ_to_bbox(activations, anchors)
show_boxes(bboxes)
#export
def cthw2tlbr(boxes):
"Convert center/size format `boxes` to top/left bottom/right corners."
top_left = boxes[:,:2] - boxes[:,2:]/2
bot_right = boxes[:,:2] + boxes[:,2:]/2
return torch.cat([top_left, bot_right], 1)
#export
def intersection(anchors, targets):
"Compute the sizes of the intersections of `anchors` by `targets`."
ancs, tgts = cthw2tlbr(anchors), cthw2tlbr(targets)
a, t = ancs.size(0), tgts.size(0)
ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4)
top_left_i = torch.max(ancs[...,:2], tgts[...,:2])
bot_right_i = torch.min(ancs[...,2:], tgts[...,2:])
sizes = torch.clamp(bot_right_i - top_left_i, min=0)
return sizes[...,0] * sizes[...,1]
show_boxes(anchors)
targets = torch.tensor([[0.,0.,2.,2.], [-0.5,-0.5,1.,1.], [1/3,0.5,0.5,0.5]])
show_boxes(targets)
intersection(anchors, targets)
#export
def IoU_values(anchors, targets):
"Compute the IoU values of `anchors` by `targets`."
inter = intersection(anchors, targets)
anc_sz, tgt_sz = anchors[:,2] * anchors[:,3], targets[:,2] * targets[:,3]
union = anc_sz.unsqueeze(1) + tgt_sz.unsqueeze(0) - inter
return inter/(union+1e-8)
IoU_values(anchors, targets)
Manually checked that those are right.
#export
def match_anchors(anchors, targets, match_thr=0.5, bkg_thr=0.4):
"Match `anchors` to targets. -1 is match to background, -2 is ignore."
ious = IoU_values(anchors, targets)
matches = anchors.new(anchors.size(0)).zero_().long() - 2
vals,idxs = torch.max(ious,1)
matches[vals < bkg_thr] = -1
matches[vals > match_thr] = idxs[vals > match_thr]
#Overwrite matches with each target getting the anchor that has the max IoU.
#vals,idxs = torch.max(ious,0)
#If idxs contains repetition, this doesn't bug and only the last is considered.
#matches[idxs] = targets.new_tensor(list(range(targets.size(0)))).long()
return matches
Last example
match_anchors(anchors, targets)
With anchors very close to the targets.
size=(3,4)
anchors = create_grid(size)
anchors = torch.cat([anchors, torch.tensor([2/size[0],2/size[1]]).expand_as(anchors)], 1)
activations = 0.1 * torch.randn(size[0]*size[1], 4)
bboxes = activ_to_bbox(activations, anchors)
match_anchors(anchors,bboxes)
With anchors in the grey area.
anchors = create_grid((2,2))
anchors = torch.cat([anchors, torch.tensor([1.,1.]).expand_as(anchors)], 1)
targets = anchors.clone()
anchors = torch.cat([anchors, torch.tensor([[-0.5,0.,1.,1.8]])], 0)
match_anchors(anchors,targets)
#export
def tlbr2cthw(boxes):
"Convert top/left bottom/right format `boxes` to center/size corners."
center = (boxes[:,:2] + boxes[:,2:])/2
sizes = boxes[:,2:] - boxes[:,:2]
return torch.cat([center, sizes], 1)
#export
def bbox_to_activ(bboxes, anchors, flatten=True):
"Return the target of the model on `anchors` for the `bboxes`."
if flatten:
t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:]
t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8)
return torch.cat([t_centers, t_sizes], -1).div_(bboxes.new_tensor([[0.1, 0.1, 0.2, 0.2]]))
else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]
return res
#export
def encode_class(idxs, n_classes):
target = idxs.new_zeros(len(idxs), n_classes).float()
mask = idxs != 0
i1s = LongTensor(list(range(len(idxs))))
target[i1s[mask],idxs[mask]-1] = 1
return target
encode_class(LongTensor([1,2,0,1,3]),3)
#export
class RetinaNetFocalLoss(nn.Module):
def __init__(self, gamma:float=2., alpha:float=0.25, pad_idx:int=0, scales:Collection[float]=None,
ratios:Collection[float]=None, reg_loss:LossFunction=F.smooth_l1_loss):
super().__init__()
self.gamma,self.alpha,self.pad_idx,self.reg_loss = gamma,alpha,pad_idx,reg_loss
self.scales = ifnone(scales, [1,2**(-1/3), 2**(-2/3)])
self.ratios = ifnone(ratios, [1/2,1,2])
def _change_anchors(self, sizes:Sizes) -> bool:
if not hasattr(self, 'sizes'): return True
for sz1, sz2 in zip(self.sizes, sizes):
if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True
return False
def _create_anchors(self, sizes:Sizes, device:torch.device):
self.sizes = sizes
self.anchors = create_anchors(sizes, self.ratios, self.scales).to(device)
def _unpad(self, bbox_tgt, clas_tgt):
i = torch.min(torch.nonzero(clas_tgt-self.pad_idx))
return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:]-1+self.pad_idx
def _focal_loss(self, clas_pred, clas_tgt):
encoded_tgt = encode_class(clas_tgt, clas_pred.size(1))
ps = torch.sigmoid(clas_pred)
weights = encoded_tgt * (1-ps) + (1-encoded_tgt) * ps
alphas = (1-encoded_tgt) * self.alpha + encoded_tgt * (1-self.alpha)
weights.pow_(self.gamma).mul_(alphas)
clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum')
return clas_loss
def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt)
matches = match_anchors(self.anchors, bbox_tgt)
bbox_mask = matches>=0
if bbox_mask.sum() != 0:
bbox_pred = bbox_pred[bbox_mask]
bbox_tgt = bbox_tgt[matches[bbox_mask]]
bb_loss = self.reg_loss(bbox_pred, bbox_to_activ(bbox_tgt, self.anchors[bbox_mask]))
else: bb_loss = 0.
matches.add_(1)
clas_tgt = clas_tgt + 1
clas_mask = matches>=0
clas_pred = clas_pred[clas_mask]
clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])
clas_tgt = clas_tgt[matches[clas_mask]]
return bb_loss + self._focal_loss(clas_pred, clas_tgt)/torch.clamp(bbox_mask.sum(), min=1.)
def forward(self, output, bbox_tgts, clas_tgts):
clas_preds, bbox_preds, sizes = output
if self._change_anchors(sizes): self._create_anchors(sizes, clas_preds.device)
n_classes = clas_preds.size(2)
return sum([self._one_loss(cp, bp, ct, bt)
for (cp, bp, ct, bt) in zip(clas_preds, bbox_preds, clas_tgts, bbox_tgts)])/clas_tgts.size(0)
Alternative to the L1 smooth loss used in online implementations
#export
class SigmaL1SmoothLoss(nn.Module):
def forward(self, output, target):
reg_diff = torch.abs(target - output)
reg_loss = torch.where(torch.le(reg_diff, 1/9), 4.5 * torch.pow(reg_diff, 2), reg_diff - 1/18)
return reg_loss.mean()
Sketch to test the loss
LongTensor([[[0,0,64,128,0], [32,64,128,128,1]], [[128,96,256,192,2], [96,192,128,256,3]]]).float().cuda()
tgt_clas = LongTensor([[1,2], [3,4]])
tgt_bbox = FloatTensor([[[0,0,128,64], [64,32,128,128]], [[96,128,192,256], [192,96,256,128]]])
tgt_bbox = tgt_bbox / 128 - 1.
y = [tgt_bbox.cuda(), tgt_clas.cuda()]
clas = torch.load(PATH/'models'/'tst_clas.pth')
regr = torch.load(PATH/'models'/'tst_regr.pth')
sizes = [[32, 32], [16, 16], [8, 8], [4, 4], [2, 2]]
output = [logit(clas), regr, sizes]
crit(output, *y)
Checking the output
#export
def unpad(tgt_bbox, tgt_clas, pad_idx=0):
i = torch.min(torch.nonzero(tgt_clas-pad_idx))
return tlbr2cthw(tgt_bbox[i:]), tgt_clas[i:]-1+pad_idx
idx = 0
clas_pred,bbox_pred,sizes = output[0][idx].cpu(), output[1][idx].cpu(), output[2]
bbox_tgt, clas_tgt = y[0][idx].cpu(),y[1][idx].cpu()
bbox_tgt, clas_tgt = unpad(bbox_tgt, clas_tgt)
bbox_tgt
anchors = create_anchors(sizes, ratios, scales)
ious = IoU_values(anchors, bbox_tgt)
matches = match_anchors(anchors, bbox_tgt)
ious[-9:]
(matches==-2).sum(), (matches==-1).sum(), (matches>=0).sum()
bbox_mask = matches>=0
bbox_pred = bbox_pred[bbox_mask]
bbox_tgt = bbox_tgt[matches[bbox_mask]]
bb_loss = F.smooth_l1_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))
F.smooth_l1_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))
tst_loss = SigmaL1SmoothLoss()
tst_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))
crit.reg_loss(bbox_pred, bbox_to_activ(bbox_tgt, anchors[bbox_mask]))
matches.add_(1)
clas_tgt += 1
clas_mask = matches>=0
clas_pred = clas_pred[clas_mask]
clas_tgt = torch.cat([clas_tgt.new_zeros(1).long(), clas_tgt])
clas_tgt = clas_tgt[matches[clas_mask]]
Focal loss
alpha, gamma, n_classes = 0.25, 2., 6
encoded_tgt = encode_class(clas_tgt, n_classes)
ps = torch.sigmoid(clas_pred)
weights = encoded_tgt * (1-ps) + (1-encoded_tgt) * ps
alphas = encoded_tgt * alpha + (1-encoded_tgt) * (1-alpha)
weights.pow_(gamma).mul_(alphas)
clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum') / bbox_mask.sum()
clas_loss
Let's look at the objects missclassified.
clas_pred[clas_tgt.nonzero().squeeze()]
F.binary_cross_entropy_with_logits(clas_pred[clas_tgt.nonzero().squeeze()], encoded_tgt[clas_tgt.nonzero().squeeze()], weights[clas_tgt.nonzero().squeeze()], reduction='sum') / bbox_mask.sum()
They account for half the loss!
n_classes = 6
encoder = create_body(tvm.resnet50(True), -2)
model = RetinaNet(encoder, n_classes,final_bias=-4)
crit = RetinaNetFocalLoss(scales=scales, ratios=ratios)
learn = Learner(data, model, loss_fn=crit)
learn.split([model.encoder[6], model.c5top5])
learn.freeze()
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(1, 1e-4)
learn.save('sample')
learn.load('sample')
img,target = next(iter(data.valid_dl))
with torch.no_grad():
output = model(img)
torch.save(img, PATH/'models'/'tst_input.pth')
def _draw_outline(o:Patch, lw:int):
"Outline bounding box onto image `Patch`."
o.set_path_effects([patheffects.Stroke(
linewidth=lw, foreground='black'), patheffects.Normal()])
def draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14):
"Draw bounding box on `ax`."
patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
_draw_outline(patch, 4)
if text is not None:
patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
_draw_outline(patch,1)
def show_preds(img, output, idx, detect_thresh=0.3, classes=None):
clas_pred,bbox_pred,sizes = output[0][idx].cpu(), output[1][idx].cpu(), output[2]
anchors = create_anchors(sizes, ratios, scales)
bbox_pred = activ_to_bbox(bbox_pred, anchors)
clas_pred = torch.sigmoid(clas_pred)
detect_mask = clas_pred.max(1)[0] > detect_thresh
bbox_pred, clas_pred = bbox_pred[detect_mask], clas_pred[detect_mask]
t_sz = torch.Tensor([*img.size])[None].float()
bbox_pred[:,:2] = bbox_pred[:,:2] - bbox_pred[:,2:]/2
bbox_pred[:,:2] = (bbox_pred[:,:2] + 1) * t_sz/2
bbox_pred[:,2:] = bbox_pred[:,2:] * t_sz
bbox_pred = bbox_pred.long()
_, ax = plt.subplots(1,1)
for bbox, c in zip(bbox_pred, clas_pred.argmax(1)):
img.show(ax=ax)
txt = str(c.item()) if classes is None else classes[c.item()+1]
draw_rect(ax, [bbox[1],bbox[0],bbox[3],bbox[2]], text=txt)
idx = 0
img = data.valid_ds[idx][0]
classes = data.train_ds.classes
show_preds(img, output, idx, detect_thresh=0.2, classes=classes)
#export
def nms(boxes, scores, thresh=0.5):
idx_sort = scores.argsort(descending=True)
boxes, scores = boxes[idx_sort], scores[idx_sort]
to_keep, indexes = [], torch.LongTensor(range_of(scores))
while len(scores) > 0:
#pdb.set_trace()
to_keep.append(idx_sort[indexes[0]])
iou_vals = IoU_values(boxes, boxes[:1]).squeeze()
mask_keep = iou_vals <= thresh
if len(mask_keep.nonzero()) == 0: break
idx_first = mask_keep.nonzero().min().item()
boxes, scores, indexes = boxes[mask_keep], scores[mask_keep], indexes[mask_keep]
return LongTensor(to_keep)
#export
def process_output(output, i, detect_thresh=0.25):
clas_pred,bbox_pred,sizes = output[0][i], output[1][i], output[2]
anchors = create_anchors(sizes, ratios, scales).to(clas_pred.device)
bbox_pred = activ_to_bbox(bbox_pred, anchors)
clas_pred = torch.sigmoid(clas_pred)
detect_mask = clas_pred.max(1)[0] > detect_thresh
bbox_pred, clas_pred = bbox_pred[detect_mask], clas_pred[detect_mask]
bbox_pred = tlbr2cthw(torch.clamp(cthw2tlbr(bbox_pred), min=-1, max=1))
scores, preds = clas_pred.max(1)
return bbox_pred, scores, preds
def show_preds(img, output, idx, detect_thresh=0.25, classes=None):
bbox_pred, scores, preds = process_output(output, idx, detect_thresh)
to_keep = nms(bbox_pred, scores)
bbox_pred, preds, scores = bbox_pred[to_keep].cpu(), preds[to_keep].cpu(), scores[to_keep].cpu()
t_sz = torch.Tensor([*img.size])[None].float()
bbox_pred[:,:2] = bbox_pred[:,:2] - bbox_pred[:,2:]/2
bbox_pred[:,:2] = (bbox_pred[:,:2] + 1) * t_sz/2
bbox_pred[:,2:] = bbox_pred[:,2:] * t_sz
bbox_pred = bbox_pred.long()
_, ax = plt.subplots(1,1)
for bbox, c, scr in zip(bbox_pred, preds, scores):
img.show(ax=ax)
txt = str(c.item()) if classes is None else classes[c.item()+1]
draw_rect(ax, [bbox[1],bbox[0],bbox[3],bbox[2]], text=f'{txt} {scr:.2f}')
idx = 0
img = data.valid_ds[idx][0]
show_preds(img, output, idx, detect_thresh=0.2, classes=data.classes)
#export
def get_predictions(output, idx, detect_thresh=0.05):
bbox_pred, scores, preds = process_output(output, idx, detect_thresh)
to_keep = nms(bbox_pred, scores)
return bbox_pred[to_keep], preds[to_keep], scores[to_keep]
get_predictions(output, 0)
#export
def compute_ap(precision, recall):
"Compute the average precision for `precision` and `recall` curve."
recall = np.concatenate(([0.], list(recall), [1.]))
precision = np.concatenate(([0.], list(precision), [0.]))
for i in range(len(precision) - 1, 0, -1):
precision[i - 1] = np.maximum(precision[i - 1], precision[i])
idx = np.where(recall[1:] != recall[:-1])[0]
ap = np.sum((recall[idx + 1] - recall[idx]) * precision[idx + 1])
return ap
#export
def compute_class_AP(model, dl, n_classes, iou_thresh=0.5, detect_thresh=0.05, num_keep=100):
tps, clas, p_scores = [], [], []
classes, n_gts = LongTensor(range(n_classes)),torch.zeros(n_classes).long()
with torch.no_grad():
for input,target in progress_bar(dl):
output = model(input)
for i in range(target[0].size(0)):
bbox_pred, preds, scores = get_predictions(output, i, detect_thresh)
tgt_bbox, tgt_clas = unpad(target[0][i], target[1][i])
ious = IoU_values(bbox_pred, tgt_bbox)
max_iou, matches = ious.max(1)
detected = []
for i in range_of(preds):
if max_iou[i] >= iou_thresh and matches[i] not in detected and tgt_clas[matches[i]] == preds[i]:
detected.append(matches[i])
tps.append(1)
else: tps.append(0)
clas.append(preds.cpu())
p_scores.append(scores.cpu())
n_gts += (tgt_clas.cpu()[:,None] == classes[None,:]).sum(0)
tps, p_scores, clas = torch.tensor(tps), torch.cat(p_scores,0), torch.cat(clas,0)
fps = 1-tps
idx = p_scores.argsort(descending=True)
tps, fps, clas = tps[idx], fps[idx], clas[idx]
aps = []
#return tps, clas
for cls in range(n_classes):
tps_cls, fps_cls = tps[clas==cls].float().cumsum(0), fps[clas==cls].float().cumsum(0)
if tps_cls[-1] != 0:
precision = tps_cls / (tps_cls + fps_cls + 1e-8)
recall = tps_cls / (n_gts[cls] + 1e-8)
aps.append(compute_ap(precision, recall))
else: aps.append(0.)
return aps
L = compute_class_AP(learn.model, tst_dl, 6)
L[0]