Refine lib -> xautodl
This commit is contained in:
		
							
								
								
									
										5
									
								
								.github/workflows/super_model_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.github/workflows/super_model_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -35,3 +35,8 @@ jobs: | ||||
|           python -m pip install torch torchvision | ||||
|           python -m pytest ./tests/test_super_*.py | ||||
|         shell: bash | ||||
|  | ||||
|       - name: Test TAS (NeurIPS 2019) | ||||
|         run: | | ||||
|           python -m pytest ./tests/test_tas.py | ||||
|         shell: bash | ||||
|   | ||||
| @@ -25,7 +25,7 @@ def main(args): | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = True | ||||
|     # torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(args.workers) | ||||
|     # torch.set_num_threads(args.workers) | ||||
|  | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|   | ||||
| @@ -470,7 +470,7 @@ if __name__ == "__main__": | ||||
|     assert torch.cuda.is_available(), "CUDA is not available." | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(args.workers) | ||||
|     # torch.set_num_threads(args.workers) | ||||
|  | ||||
|     main( | ||||
|         save_dir, | ||||
|   | ||||
| @@ -340,7 +340,7 @@ def train_single_model( | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     # torch.backends.cudnn.benchmark = True | ||||
|     torch.set_num_threads(workers) | ||||
|     # torch.set_num_threads(workers) | ||||
|  | ||||
|     save_dir = ( | ||||
|         Path(save_dir) | ||||
| @@ -675,7 +675,7 @@ if __name__ == "__main__": | ||||
|         assert torch.cuda.is_available(), "CUDA is not available." | ||||
|         torch.backends.cudnn.enabled = True | ||||
|         torch.backends.cudnn.deterministic = True | ||||
|         torch.set_num_threads(args.workers if args.workers > 0 else 1) | ||||
|         # torch.set_num_threads(args.workers if args.workers > 0 else 1) | ||||
|  | ||||
|         main( | ||||
|             save_dir, | ||||
|   | ||||
| @@ -132,7 +132,7 @@ def select_action(policy): | ||||
|  | ||||
|  | ||||
| def main(xargs, api): | ||||
|     torch.set_num_threads(4) | ||||
|     # torch.set_num_threads(4) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|   | ||||
| @@ -204,7 +204,7 @@ def main(xargs): | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = False | ||||
|     torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(xargs.workers) | ||||
|     # torch.set_num_threads(xargs.workers) | ||||
|     prepare_seed(xargs.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|  | ||||
|   | ||||
| @@ -8,17 +8,14 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import load_config, obtain_basic_args as obtain_args | ||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from procedures import get_optim_scheduler, get_procedures | ||||
| from datasets import get_datasets | ||||
| from models import obtain_model | ||||
| from nas_infer_model import obtain_nas_infer_model | ||||
| from utils import get_model_infos | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.datasets import get_datasets | ||||
| from xautodl.config_utils import load_config, obtain_basic_args as obtain_args | ||||
| from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from xautodl.procedures import get_optim_scheduler, get_procedures | ||||
| from xautodl.models import obtain_model | ||||
| from xautodl.nas_infer_model import obtain_nas_infer_model | ||||
| from xautodl.utils import get_model_infos | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
| @@ -26,7 +23,7 @@ def main(args): | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = True | ||||
|     # torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(args.workers) | ||||
|     # torch.set_num_threads(args.workers) | ||||
|  | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|   | ||||
| @@ -10,21 +10,17 @@ import numpy as np | ||||
| from copy import deepcopy | ||||
| from pathlib import Path | ||||
|  | ||||
| lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||
| print("lib_dir : {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
| from config_utils import ( | ||||
| from xautodl.config_utils import ( | ||||
|     load_config, | ||||
|     configure2str, | ||||
|     obtain_search_single_args as obtain_args, | ||||
| ) | ||||
| from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from procedures import get_optim_scheduler, get_procedures | ||||
| from datasets import get_datasets, SearchDataset | ||||
| from models import obtain_search_model, obtain_model, change_key | ||||
| from utils import get_model_infos | ||||
| from log_utils import AverageMeter, time_string, convert_secs2time | ||||
| from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint | ||||
| from xautodl.procedures import get_optim_scheduler, get_procedures | ||||
| from xautodl.datasets import get_datasets, SearchDataset | ||||
| from xautodl.models import obtain_search_model, obtain_model, change_key | ||||
| from xautodl.utils import get_model_infos | ||||
| from xautodl.log_utils import AverageMeter, time_string, convert_secs2time | ||||
|  | ||||
|  | ||||
| def main(args): | ||||
| @@ -32,7 +28,7 @@ def main(args): | ||||
|     torch.backends.cudnn.enabled = True | ||||
|     torch.backends.cudnn.benchmark = True | ||||
|     # torch.backends.cudnn.deterministic = True | ||||
|     torch.set_num_threads(args.workers) | ||||
|     # torch.set_num_threads(args.workers) | ||||
|  | ||||
|     prepare_seed(args.rand_seed) | ||||
|     logger = prepare_logger(args) | ||||
|   | ||||
							
								
								
									
										29
									
								
								tests/test_loader.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/test_loader.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| # pytest tests/test_loader.py -s                    # | ||||
| ##################################################### | ||||
| import unittest | ||||
| import tempfile | ||||
| import torch | ||||
|  | ||||
| from xautodl.datasets import get_datasets | ||||
|  | ||||
|  | ||||
| def test_simple(): | ||||
|     xdir = tempfile.mkdtemp() | ||||
|     train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1) | ||||
|     print(train_data) | ||||
|     print(valid_data) | ||||
|  | ||||
|     xloader = torch.utils.data.DataLoader( | ||||
|         train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True | ||||
|     ) | ||||
|     print(xloader) | ||||
|     print(next(iter(xloader))) | ||||
|  | ||||
|     for i, data in enumerate(xloader): | ||||
|         print(i) | ||||
|  | ||||
|  | ||||
| test_simple() | ||||
							
								
								
									
										23
									
								
								tests/test_tas.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								tests/test_tas.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,23 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from xautodl.models.shape_searchs.SoftSelect import ChannelWiseInter | ||||
|  | ||||
|  | ||||
| class TestTASFunc(unittest.TestCase): | ||||
|     """Test the TAS function.""" | ||||
|  | ||||
|     def test_channel_interplation(self): | ||||
|         tensors = torch.rand((16, 128, 7, 7)) | ||||
|  | ||||
|         for oc in range(200, 210): | ||||
|             out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|             out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|             assert (out_v1 == out_v2).any().item() == 1 | ||||
|         for oc in range(48, 160): | ||||
|             out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|             out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|             assert (out_v1 == out_v2).any().item() == 1 | ||||
| @@ -9,9 +9,9 @@ from typing import List, Text, Any | ||||
| import random, torch | ||||
| import torch.nn as nn | ||||
|  | ||||
| from models.cell_operations import ResNetBasicblock | ||||
| from models.cell_infers.cells import InferCell | ||||
| from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter | ||||
| from ..cell_operations import ResNetBasicblock | ||||
| from ..cell_infers.cells import InferCell | ||||
| from .shape_searchs.SoftSelect import select2withP, ChannelWiseInter | ||||
|  | ||||
|  | ||||
| class GenericNAS301Model(nn.Module): | ||||
|   | ||||
| @@ -1,20 +0,0 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from SoftSelect import ChannelWiseInter | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|  | ||||
|     tensors = torch.rand((16, 128, 7, 7)) | ||||
|  | ||||
|     for oc in range(200, 210): | ||||
|         out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|         out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|         assert (out_v1 == out_v2).any().item() == 1 | ||||
|     for oc in range(48, 160): | ||||
|         out_v1 = ChannelWiseInter(tensors, oc, "v1") | ||||
|         out_v2 = ChannelWiseInter(tensors, oc, "v2") | ||||
|         assert (out_v1 == out_v2).any().item() == 1 | ||||
		Reference in New Issue
	
	Block a user