Files
foldingdiff/protdiff/utils.py
2022-07-12 18:57:26 +00:00

40 lines
796 B
Python

"""
Misc shared utility functions
"""
from typing import *
import torch
def extract(a, t, x_shape):
"""
Return the t-th item in a for each item in t
"""
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
def num_to_groups(num: int, divisor: int) -> List[int]:
"""
Generates a list of ints of value at most divisor that sums to
>>> num_to_groups(18, 16)
[16, 2]
>>> num_to_groups(33, 8)
[8, 8, 8, 8, 1]
"""
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
assert sum(arr) == num
return arr
if __name__ == "__main__":
import doctest
doctest.testmod()