mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-05 14:05:01 +08:00
40 lines
796 B
Python
40 lines
796 B
Python
"""
|
|
Misc shared utility functions
|
|
"""
|
|
from typing import *
|
|
|
|
import torch
|
|
|
|
|
|
def extract(a, t, x_shape):
|
|
"""
|
|
Return the t-th item in a for each item in t
|
|
"""
|
|
batch_size = t.shape[0]
|
|
out = a.gather(-1, t.cpu())
|
|
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
|
|
|
|
|
|
def num_to_groups(num: int, divisor: int) -> List[int]:
|
|
"""
|
|
Generates a list of ints of value at most divisor that sums to
|
|
|
|
>>> num_to_groups(18, 16)
|
|
[16, 2]
|
|
>>> num_to_groups(33, 8)
|
|
[8, 8, 8, 8, 1]
|
|
"""
|
|
groups = num // divisor
|
|
remainder = num % divisor
|
|
arr = [divisor] * groups
|
|
if remainder > 0:
|
|
arr.append(remainder)
|
|
assert sum(arr) == num
|
|
return arr
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import doctest
|
|
|
|
doctest.testmod()
|