From 7f13385f28bc2e60a67ae935eba3172a0174b67a Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Fri, 25 Oct 2019 21:28:40 +1100 Subject: [PATCH] Update Auto-ReID --- README.md | 7 ++++ lib/models/cell_searchs/operations.py | 49 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/README.md b/README.md index b6e1606..6de354a 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,7 @@ This project contains the following neural architecture search algorithms, imple - Network Pruning via Transformable Architecture Search, NeurIPS 2019 - One-Shot Neural Architecture Search via Self-Evaluated Template Network, ICCV 2019 - Searching for A Robust Neural Architecture in Four GPU Hours, CVPR 2019 +- Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019 ## Requirements and Preparation @@ -90,6 +91,12 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./scripts/nas-infer-train.sh imagenet-1k GDAS_ Searching codes come soon! A small example forward code segment for searching can be found in [this issue](https://github.com/D-X-Y/NAS-Projects/issues/12). +## [Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification](https://arxiv.org/abs/1903.09776) + +The part-aware module is defined at [here](https://github.com/D-X-Y/NAS-Projects/blob/master/lib/models/cell_searchs/operations.py#L85). + +For more questions, please contact Ruijie Quan (Ruijie.Quan@student.uts.edu.au). + # Citation If you find that this project helps your research, please consider citing some of the following papers: diff --git a/lib/models/cell_searchs/operations.py b/lib/models/cell_searchs/operations.py index 85c5253..76a6d06 100644 --- a/lib/models/cell_searchs/operations.py +++ b/lib/models/cell_searchs/operations.py @@ -111,3 +111,52 @@ class FactorizedReduce(nn.Module): def extra_repr(self): return 'C_in={C_in}, C_out={C_out}, stride={stride}'.format(**self.__dict__) + + +# Auto-ReID: Searching for a Part-Aware ConvNet for Person Re-Identification, ICCV 2019 +class PartAwareOp(nn.Module): + + def __init__(self, C_in, C_out, stride, part=4): + super().__init__() + self.part = 4 + self.hidden = C_in // 3 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.local_conv_list = nn.ModuleList() + for i in range(self.part): + self.local_conv_list.append( + nn.Sequential(nn.ReLU(), nn.Conv2d(C_in, self.hidden, 1), nn.BatchNorm2d(self.hidden, affine=True)) + ) + self.W_K = nn.Linear(self.hidden, self.hidden) + self.W_Q = nn.Linear(self.hidden, self.hidden) + + if stride == 2 : self.last = FactorizedReduce(C_in + self.hidden, C_out, 2) + elif stride == 1: self.last = FactorizedReduce(C_in + self.hidden, C_out, 1) + else: raise ValueError('Invalid Stride : {:}'.format(stride)) + + def forward(self, x): + batch, C, H, W = x.size() + assert H >= self.part, 'input size too small : {:} vs {:}'.format(x.shape, self.part) + IHs = [0] + for i in range(self.part): IHs.append( min(H, int((i+1)*(float(H)/self.part))) ) + local_feat_list = [] + for i in range(self.part): + feature = x[:, :, IHs[i]:IHs[i+1], :] + xfeax = self.avg_pool(feature) + xfea = self.local_conv_list[i]( xfeax ) + local_feat_list.append( xfea ) + part_feature = torch.cat(local_feat_list, dim=2).view(batch, -1, self.part) + part_feature = part_feature.transpose(1,2).contiguous() + part_K = self.W_K(part_feature) + part_Q = self.W_Q(part_feature).transpose(1,2).contiguous() + weight_att = torch.bmm(part_K, part_Q) + attention = torch.softmax(weight_att, dim=2) + aggreateF = torch.bmm(attention, part_feature).transpose(1,2).contiguous() + features = [] + for i in range(self.part): + feature = aggreateF[:, :, i:i+1].expand(batch, self.hidden, IHs[i+1]-IHs[i]) + feature = feature.view(batch, self.hidden, IHs[i+1]-IHs[i], 1) + features.append( feature ) + features = torch.cat(features, dim=2).expand(batch, self.hidden, H, W) + final_fea = torch.cat((x,features), dim=1) + outputs = self.last( final_fea ) + return outputs