autodl-projects/xautodl/models/shape_searchs/test.py

21 lines
678 B
Python
Raw Normal View History

2019-11-15 07:15:07 +01:00
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
2019-09-28 10:24:47 +02:00
import torch
import torch.nn as nn
from SoftSelect import ChannelWiseInter
2021-05-12 10:28:05 +02:00
if __name__ == "__main__":
2019-09-28 10:24:47 +02:00
2021-05-12 10:28:05 +02:00
tensors = torch.rand((16, 128, 7, 7))
for oc in range(200, 210):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1
for oc in range(48, 160):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1