Update Auto-ReID

This commit is contained in:
D-X-Y 2019-10-25 21:28:40 +11:00
parent d28826793d
commit 7f13385f28
2 changed files with 56 additions and 0 deletions

View File

@ -5,6 +5,7 @@ This project contains the following neural architecture search algorithms, imple
- Network Pruning via Transformable Architecture Search, NeurIPS 2019
- One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019
- Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019
- Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
## Requirements and Preparation
@ -90,6 +91,12 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_
Searching codes come soon! A small example forward code segment for searching can be found in [this issue](https://github.com/D-X-Y/NAS-Projects/issues/12).
## [Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification](https://arxiv.org/abs/1903.09776)
The part-aware module is defined at [here](https://github.com/D-X-Y/NAS-Projects/blob/master/lib/models/cell_searchs/operations.py#L85).
For more questions, please contact Ruijie Quan (Ruijie.Quan@student.uts.edu.au).
# Citation
If you find that this project helps your research, please consider citing some of the following papers:

View File

@ -111,3 +111,52 @@ class FactorizedReduce(nn.Module):
def extra_repr(self):
return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__)
# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019
class PartAwareOp(nn.Module):
def __init__(self, C_in, C_out, stride, part=4):
super().__init__()
self.part = 4
self.hidden = C_in // 3
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.local_conv_list = nn.ModuleList()
for i in range(self.part):
self.local_conv_list.append(
nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True))
)
self.W_K = nn.Linear(self.hidden, self.hidden)
self.W_Q = nn.Linear(self.hidden, self.hidden)
if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2)
elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1)
else: raise ValueError('Invalid Stride : {:}'.format(stride))
def forward(self, x):
batch, C, H, W = x.size()
assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part)
IHs = [0]
for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) )
local_feat_list = []
for i in range(self.part):
feature = x[:, :, IHs[i]:IHs[i+1], :]
xfeax = self.avg_pool(feature)
xfea = self.local_conv_list[i]( xfeax )
local_feat_list.append( xfea )
part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part)
part_feature = part_feature.transpose(1,2).contiguous()
part_K = self.W_K(part_feature)
part_Q = self.W_Q(part_feature).transpose(1,2).contiguous()
weight_att = torch.bmm(part_K, part_Q)
attention = torch.softmax(weight_att, dim=2)
aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous()
features = []
for i in range(self.part):
feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i])
feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1)
features.append( feature )
features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W)
final_fea = torch.cat((x,features), dim=1)
outputs = self.last( final_fea )
return outputs