Fix bugs in TAS: missing ReLU in the end of each searching block
This commit is contained in:
		| @@ -172,7 +172,7 @@ class ResNetBasicblock(nn.Module): | ||||
|     else: | ||||
|       residual, expected_flop_c = inputs, 0 | ||||
|     out = additive_func(residual, out_b) | ||||
|     return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) | ||||
|     return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) | ||||
|  | ||||
|   def basic_forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
| @@ -244,8 +244,7 @@ class ResNetBottleneck(nn.Module): | ||||
|     else: | ||||
|       residual, expected_flop_c = inputs, 0 | ||||
|     out = additive_func(residual, out_1x4) | ||||
|     return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) | ||||
|  | ||||
|     return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) | ||||
|  | ||||
|  | ||||
| class SearchShapeCifarResNet(nn.Module): | ||||
|   | ||||
| @@ -156,7 +156,7 @@ class ResNetBasicblock(nn.Module): | ||||
|     else: | ||||
|       residual, expected_flop_c = inputs, 0 | ||||
|     out = additive_func(residual, out_b) | ||||
|     return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) | ||||
|     return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) | ||||
|  | ||||
|   def basic_forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
| @@ -228,8 +228,7 @@ class ResNetBottleneck(nn.Module): | ||||
|     else: | ||||
|       residual, expected_flop_c = inputs, 0 | ||||
|     out = additive_func(residual, out_1x4) | ||||
|     return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) | ||||
|  | ||||
|     return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) | ||||
|  | ||||
|  | ||||
| class SearchWidthCifarResNet(nn.Module): | ||||
|   | ||||
| @@ -171,7 +171,7 @@ class ResNetBasicblock(nn.Module): | ||||
|     else: | ||||
|       residual, expected_flop_c = inputs, 0 | ||||
|     out = additive_func(residual, out_b) | ||||
|     return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) | ||||
|     return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c]) | ||||
|  | ||||
|   def basic_forward(self, inputs): | ||||
|     basicblock = self.conv_a(inputs) | ||||
| @@ -243,8 +243,7 @@ class ResNetBottleneck(nn.Module): | ||||
|     else: | ||||
|       residual, expected_flop_c = inputs, 0 | ||||
|     out = additive_func(residual, out_1x4) | ||||
|     return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) | ||||
|  | ||||
|     return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c]) | ||||
|  | ||||
|  | ||||
| class SearchShapeImagenetResNet(nn.Module): | ||||
|   | ||||
| @@ -153,7 +153,7 @@ class SimBlock(nn.Module): | ||||
|     else: | ||||
|       residual, expected_flop_c = inputs, 0 | ||||
|     out = additive_func(residual, out) | ||||
|     return out, expected_next_inC, sum([expected_flop, expected_flop_c]) | ||||
|     return nn.functional.relu(out, inplace=True), expected_next_inC, sum([expected_flop, expected_flop_c]) | ||||
|  | ||||
|   def basic_forward(self, inputs): | ||||
|     basicblock = self.conv(inputs) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user