MeCo/correlation/models/SharedUtils.py
HamsterMimi 3f6d16e791 update
2024-01-23 10:08:45 +08:00

38 lines
905 B
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
#####################################################
import torch
import torch.nn as nn
def additive_func(A, B):
assert A.dim() == B.dim() and A.size(0) == B.size(0), "{:} vs {:}".format(
A.size(), B.size()
)
C = min(A.size(1), B.size(1))
if A.size(1) == B.size(1):
return A + B
elif A.size(1) < B.size(1):
out = B.clone()
out[:, :C] += A
return out
else:
out = A.clone()
out[:, :C] += B
return out
def change_key(key, value):
def func(m):
if hasattr(m, key):
setattr(m, key, value)
return func
def parse_channel_info(xstring):
blocks = xstring.split(" ")
blocks = [x.split("-") for x in blocks]
blocks = [[int(_) for _ in x] for x in blocks]
return blocks