autodl-projects/xautodl/procedures/advanced_main.py

101 lines
2.2 KiB
Python
Raw Normal View History

2021-04-25 15:02:43 +02:00
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
#####################################################
2021-04-26 15:44:03 +02:00
# To be finished.
#
2021-04-25 15:02:43 +02:00
import os, sys, time, torch
2021-04-26 15:44:03 +02:00
from typing import Optional, Text, Callable
2021-04-25 15:02:43 +02:00
# modules in AutoDL
from log_utils import AverageMeter
from log_utils import time_string
from .eval_funcs import obtain_accuracy
2021-04-29 10:30:47 +02:00
def get_device(tensors):
if isinstance(tensors, (list, tuple)):
return get_device(tensors[0])
elif isinstance(tensors, dict):
for key, value in tensors.items():
return get_device(value)
else:
return tensors.device
def basic_train_fn(
2021-04-25 15:02:43 +02:00
xloader,
network,
criterion,
optimizer,
2021-04-29 10:30:47 +02:00
metric,
2021-04-25 15:02:43 +02:00
logger,
):
2021-04-29 10:30:47 +02:00
results = procedure(
2021-04-25 15:02:43 +02:00
xloader,
network,
criterion,
optimizer,
2021-04-29 10:30:47 +02:00
metric,
2021-04-25 15:02:43 +02:00
"train",
logger,
)
2021-04-29 10:30:47 +02:00
return results
2021-04-25 15:02:43 +02:00
2021-04-29 10:30:47 +02:00
def basic_eval_fn(xloader, network, metric, logger):
2021-04-25 15:02:43 +02:00
with torch.no_grad():
2021-04-29 10:30:47 +02:00
results = procedure(
2021-04-25 15:02:43 +02:00
xloader,
network,
None,
None,
2021-04-29 10:30:47 +02:00
metric,
2021-04-25 15:02:43 +02:00
"valid",
logger,
)
2021-04-29 10:30:47 +02:00
return results
2021-04-25 15:02:43 +02:00
def procedure(
xloader,
network,
criterion,
optimizer,
2021-04-29 10:30:47 +02:00
metric,
2021-04-25 15:02:43 +02:00
mode: Text,
2021-04-26 15:44:03 +02:00
logger_fn: Callable = None,
2021-04-25 15:02:43 +02:00
):
2021-04-29 10:30:47 +02:00
data_time, batch_time = AverageMeter(), AverageMeter()
2021-04-25 15:02:43 +02:00
if mode.lower() == "train":
network.train()
elif mode.lower() == "valid":
network.eval()
else:
raise ValueError("The mode is not right : {:}".format(mode))
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
# measure data loading time
data_time.update(time.time() - end)
# calculate prediction and loss
if mode == "train":
optimizer.zero_grad()
outputs = network(inputs)
2021-04-29 10:30:47 +02:00
targets = targets.to(get_device(outputs))
2021-04-25 15:02:43 +02:00
if mode == "train":
2021-04-29 10:30:47 +02:00
loss = criterion(outputs, targets)
2021-04-25 15:02:43 +02:00
loss.backward()
optimizer.step()
# record
2021-04-29 10:30:47 +02:00
with torch.no_grad():
results = metric(outputs, targets)
2021-04-25 15:02:43 +02:00
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
2021-04-29 10:30:47 +02:00
return metric.get_info()