mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Isolate code for coords to angles
This commit is contained in:
@@ -59,49 +59,60 @@ class CathConsecutiveAnglesDataset(Dataset):
|
||||
raise IndexError(index)
|
||||
|
||||
coords = self.structures[index]["coords"]
|
||||
first_valid_idx, last_valid_idx = 0, len(coords["N"])
|
||||
|
||||
# Walk through coordinates and trim trailing nan
|
||||
for k in ["N", "CA", "C"]:
|
||||
logging.debug(f"{k}:\t{coords[k][:5]}")
|
||||
arr = np.array(coords[k])
|
||||
# Get all valid indices
|
||||
valid_idx = np.where(~np.any(np.isnan(arr), axis=1))[0]
|
||||
first_valid_idx = max(first_valid_idx, np.min(valid_idx))
|
||||
last_valid_idx = min(last_valid_idx, np.max(valid_idx) + 1)
|
||||
logging.debug(f"Trimming nans keeps {first_valid_idx}:{last_valid_idx}")
|
||||
for k in ["N", "CA", "C"]:
|
||||
coords[k] = coords[k][first_valid_idx:last_valid_idx]
|
||||
arr = np.array(coords[k])
|
||||
assert not np.any(np.isnan(arr)), f"Found nan in {index} {k}: {arr}"
|
||||
angles = pdb_utils.process_coords(coords)
|
||||
# https://www.rosettacommons.org/docs/latest/application_documentation/trRosetta/trRosetta#application-purpose_a-note-on-nomenclature
|
||||
# omega = inter-residue dihedral angle between CA/CB of first and CB/CA of second
|
||||
# theta = inter-residue dihedral angle between N, CA, CB of first and CB of second
|
||||
# phi = inter-residue angle between CA and CB of first and CB of second
|
||||
dist, omega, theta, phi = angles
|
||||
assert dist.shape == omega.shape == theta.shape == phi.shape
|
||||
logging.debug(
|
||||
f"Pre slice shape: {dist.shape, omega.shape, theta.shape, phi.shape}"
|
||||
)
|
||||
# Slice out so that we have the angles and distances between the n and n+1 items
|
||||
n = dist.shape[0]
|
||||
indices_i = np.arange(n - 1)
|
||||
indices_j = indices_i + 1
|
||||
dist_slice = dist[indices_i, indices_j]
|
||||
omega_slice = omega[indices_i, indices_j]
|
||||
theta_slice = theta[indices_i, indices_j]
|
||||
phi_slice = phi[indices_i, indices_j]
|
||||
logging.debug(
|
||||
f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}"
|
||||
)
|
||||
all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice])
|
||||
assert all_values.shape == (4, n - 1)
|
||||
all_values = coords_to_angles(coords)
|
||||
if all_values is None:
|
||||
return None
|
||||
assert not np.any(np.isnan(all_values))
|
||||
retval = torch.from_numpy(all_values)
|
||||
return retval
|
||||
|
||||
|
||||
def coords_to_angles(coords: Dict[str, List[List[float]]]) -> Union[np.ndarray, None]:
|
||||
"""
|
||||
Sanitize the coordinates to not have NaN and convert them into
|
||||
arrays of angles. If sanitization fails, return None
|
||||
"""
|
||||
first_valid_idx, last_valid_idx = 0, len(coords["N"])
|
||||
|
||||
# Walk through coordinates and trim trailing nan
|
||||
for k in ["N", "CA", "C"]:
|
||||
logging.debug(f"{k}:\t{coords[k][:5]}")
|
||||
arr = np.array(coords[k])
|
||||
# Get all valid indices
|
||||
valid_idx = np.where(~np.any(np.isnan(arr), axis=1))[0]
|
||||
first_valid_idx = max(first_valid_idx, np.min(valid_idx))
|
||||
last_valid_idx = min(last_valid_idx, np.max(valid_idx) + 1)
|
||||
logging.debug(f"Trimming nans keeps {first_valid_idx}:{last_valid_idx}")
|
||||
for k in ["N", "CA", "C"]:
|
||||
coords[k] = coords[k][first_valid_idx:last_valid_idx]
|
||||
arr = np.array(coords[k])
|
||||
if np.any(np.isnan(arr)):
|
||||
logging.debug("Got nan in middle of array")
|
||||
return None
|
||||
angles = pdb_utils.process_coords(coords)
|
||||
# https://www.rosettacommons.org/docs/latest/application_documentation/trRosetta/trRosetta#application-purpose_a-note-on-nomenclature
|
||||
# omega = inter-residue dihedral angle between CA/CB of first and CB/CA of second
|
||||
# theta = inter-residue dihedral angle between N, CA, CB of first and CB of second
|
||||
# phi = inter-residue angle between CA and CB of first and CB of second
|
||||
dist, omega, theta, phi = angles
|
||||
assert dist.shape == omega.shape == theta.shape == phi.shape
|
||||
logging.debug(f"Pre slice shape: {dist.shape, omega.shape, theta.shape, phi.shape}")
|
||||
# Slice out so that we have the angles and distances between the n and n+1 items
|
||||
n = dist.shape[0]
|
||||
indices_i = np.arange(n - 1)
|
||||
indices_j = indices_i + 1
|
||||
dist_slice = dist[indices_i, indices_j]
|
||||
omega_slice = omega[indices_i, indices_j]
|
||||
theta_slice = theta[indices_i, indices_j]
|
||||
phi_slice = phi[indices_i, indices_j]
|
||||
logging.debug(
|
||||
f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}"
|
||||
)
|
||||
all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice])
|
||||
assert all_values.shape == (4, n - 1)
|
||||
return all_values
|
||||
|
||||
|
||||
def main():
|
||||
dset = CathConsecutiveAnglesDataset()
|
||||
error_counter = 0
|
||||
|
||||
Reference in New Issue
Block a user