upload
This commit is contained in:
		
							
								
								
									
										47
									
								
								Scorers/scorer.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								Scorers/scorer.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| import torch | ||||
| import numpy as np  | ||||
|  | ||||
| class Jocab_Scorer: | ||||
|     def __init__(self, gpu): | ||||
|         self.gpu = gpu | ||||
|         print('Jacob score init') | ||||
|  | ||||
|     def score(self, model, input, target): | ||||
|         batch_size = input.shape[0] | ||||
|         model.K = torch.zeros(batch_size, batch_size).cuda() | ||||
|  | ||||
|         input = input.cuda() | ||||
|         with torch.no_grad(): | ||||
|             model(input) | ||||
|         score = self.hooklogdet(model.K.cpu().numpy()) | ||||
|  | ||||
|         #print(score) | ||||
|         return score | ||||
|  | ||||
|     def setup_hooks(self, model, batch_size): | ||||
|         #initalize score  | ||||
|         model = model.to(torch.device('cuda', self.gpu)) | ||||
|         model.eval() | ||||
|         model.K = torch.zeros(batch_size, batch_size).cuda() | ||||
|         def counting_forward_hook(module, inp, out): | ||||
|             try: | ||||
|                 # if not module.visited_backwards: | ||||
|                 #     return | ||||
|                 if isinstance(inp, tuple): | ||||
|                     inp = inp[0] | ||||
|                 inp = inp.view(inp.size(0), -1) | ||||
|                 x = (inp > 0).float() | ||||
|                 K = x @ x.t() | ||||
|                 K2 = (1.-x) @ (1.-x.t()) | ||||
|                 model.K = model.K + K + K2 | ||||
|             except: | ||||
|                 pass | ||||
|  | ||||
|         for name, module in model.named_modules(): | ||||
|             if 'ReLU' in str(type(module)): | ||||
|                 module.register_forward_hook(counting_forward_hook) | ||||
|                 #module.register_backward_hook(counting_backward_hook) | ||||
|  | ||||
|     def hooklogdet(self, K, labels=None): | ||||
|         s, ld = np.linalg.slogdet(K) | ||||
|         return ld | ||||
		Reference in New Issue
	
	Block a user