2019-01-31 15:27:38 +01:00
|
|
|
import os
|
|
|
|
|
|
|
|
|
2021-03-07 04:09:47 +01:00
|
|
|
class GPUManager:
|
2021-04-11 15:45:20 +02:00
|
|
|
queries = (
|
|
|
|
"index",
|
|
|
|
"gpu_name",
|
|
|
|
"memory.free",
|
|
|
|
"memory.used",
|
|
|
|
"memory.total",
|
|
|
|
"power.draw",
|
|
|
|
"power.limit",
|
|
|
|
)
|
2019-01-31 15:27:38 +01:00
|
|
|
|
2021-03-07 04:09:47 +01:00
|
|
|
def __init__(self):
|
|
|
|
all_gpus = self.query_gpu(False)
|
2019-01-31 15:27:38 +01:00
|
|
|
|
2021-03-07 04:09:47 +01:00
|
|
|
def get_info(self, ctype):
|
|
|
|
cmd = "nvidia-smi --query-gpu={} --format=csv,noheader".format(ctype)
|
|
|
|
lines = os.popen(cmd).readlines()
|
|
|
|
lines = [line.strip("\n") for line in lines]
|
|
|
|
return lines
|
2019-01-31 15:27:38 +01:00
|
|
|
|
2021-03-07 04:09:47 +01:00
|
|
|
def query_gpu(self, show=True):
|
|
|
|
num_gpus = len(self.get_info("index"))
|
|
|
|
all_gpus = [{} for i in range(num_gpus)]
|
2019-01-31 15:27:38 +01:00
|
|
|
for query in self.queries:
|
2021-03-07 04:09:47 +01:00
|
|
|
infos = self.get_info(query)
|
|
|
|
for idx, info in enumerate(infos):
|
|
|
|
all_gpus[idx][query] = info
|
|
|
|
|
|
|
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
|
|
|
CUDA_VISIBLE_DEVICES = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
|
|
|
|
selected_gpus = []
|
|
|
|
for idx, CUDA_VISIBLE_DEVICE in enumerate(CUDA_VISIBLE_DEVICES):
|
|
|
|
find = False
|
|
|
|
for gpu in all_gpus:
|
|
|
|
if gpu["index"] == CUDA_VISIBLE_DEVICE:
|
2021-04-11 15:45:20 +02:00
|
|
|
assert not find, "Duplicate cuda device index : {}".format(
|
|
|
|
CUDA_VISIBLE_DEVICE
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
find = True
|
|
|
|
selected_gpus.append(gpu.copy())
|
|
|
|
selected_gpus[-1]["index"] = "{}".format(idx)
|
|
|
|
assert find, "Does not find the device : {}".format(CUDA_VISIBLE_DEVICE)
|
|
|
|
all_gpus = selected_gpus
|
|
|
|
|
|
|
|
if show:
|
|
|
|
allstrings = ""
|
|
|
|
for gpu in all_gpus:
|
|
|
|
string = "| "
|
|
|
|
for query in self.queries:
|
|
|
|
if query.find("memory") == 0:
|
|
|
|
xinfo = "{:>9}".format(gpu[query])
|
|
|
|
else:
|
|
|
|
xinfo = gpu[query]
|
|
|
|
string = string + query + " : " + xinfo + " | "
|
|
|
|
allstrings = allstrings + string + "\n"
|
|
|
|
return allstrings
|
|
|
|
else:
|
|
|
|
return all_gpus
|
|
|
|
|
|
|
|
def select_by_memory(self, numbers=1):
|
|
|
|
all_gpus = self.query_gpu(False)
|
2021-04-11 15:45:20 +02:00
|
|
|
assert numbers <= len(all_gpus), "Require {} gpus more than you have".format(
|
|
|
|
numbers
|
|
|
|
)
|
2021-03-07 04:09:47 +01:00
|
|
|
alls = []
|
|
|
|
for idx, gpu in enumerate(all_gpus):
|
|
|
|
free_memory = gpu["memory.free"]
|
|
|
|
free_memory = free_memory.split(" ")[0]
|
|
|
|
free_memory = int(free_memory)
|
|
|
|
index = gpu["index"]
|
|
|
|
alls.append((free_memory, index))
|
|
|
|
alls.sort(reverse=True)
|
|
|
|
alls = [int(alls[i][1]) for i in range(numbers)]
|
|
|
|
return sorted(alls)
|
2019-01-31 15:27:38 +01:00
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
if __name__ == '__main__':
|
|
|
|
manager = GPUManager()
|
|
|
|
manager.query_gpu(True)
|
|
|
|
indexes = manager.select_by_memory(3)
|
|
|
|
print (indexes)
|
|
|
|
"""
|