mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
15 lines
487 B
Python
15 lines
487 B
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class ProteinMaxPool(nn.Module):
|
|
def __init__(self, window, stride = None):
|
|
super(ProteinMaxPool, self).__init__()
|
|
self.window = window
|
|
stride = stride if stride is not None else window
|
|
self.stride = stride
|
|
|
|
def forward(self, input):
|
|
N, H, C, D = input.shape
|
|
assert H > self.window
|
|
maxfold, _ = torch.max(input.unfold(1, self.window, self.stride), dim = 4)
|
|
return maxfold |