Isolate code for coords to angles

This commit is contained in:
Kevin Wu
2022-07-05 19:11:59 +00:00
parent a90e8bb507
commit 242632b765

View File

@@ -59,49 +59,60 @@ class CathConsecutiveAnglesDataset(Dataset):
raise IndexError(index)
coords = self.structures[index]["coords"]
first_valid_idx, last_valid_idx = 0, len(coords["N"])
# Walk through coordinates and trim trailing nan
for k in ["N", "CA", "C"]:
logging.debug(f"{k}:\t{coords[k][:5]}")
arr = np.array(coords[k])
# Get all valid indices
valid_idx = np.where(~np.any(np.isnan(arr), axis=1))[0]
first_valid_idx = max(first_valid_idx, np.min(valid_idx))
last_valid_idx = min(last_valid_idx, np.max(valid_idx) + 1)
logging.debug(f"Trimming nans keeps {first_valid_idx}:{last_valid_idx}")
for k in ["N", "CA", "C"]:
coords[k] = coords[k][first_valid_idx:last_valid_idx]
arr = np.array(coords[k])
assert not np.any(np.isnan(arr)), f"Found nan in {index} {k}: {arr}"
angles = pdb_utils.process_coords(coords)
# https://www.rosettacommons.org/docs/latest/application_documentation/trRosetta/trRosetta#application-purpose_a-note-on-nomenclature
# omega = inter-residue dihedral angle between CA/CB of first and CB/CA of second
# theta = inter-residue dihedral angle between N, CA, CB of first and CB of second
# phi = inter-residue angle between CA and CB of first and CB of second
dist, omega, theta, phi = angles
assert dist.shape == omega.shape == theta.shape == phi.shape
logging.debug(
f"Pre slice shape: {dist.shape, omega.shape, theta.shape, phi.shape}"
)
# Slice out so that we have the angles and distances between the n and n+1 items
n = dist.shape[0]
indices_i = np.arange(n - 1)
indices_j = indices_i + 1
dist_slice = dist[indices_i, indices_j]
omega_slice = omega[indices_i, indices_j]
theta_slice = theta[indices_i, indices_j]
phi_slice = phi[indices_i, indices_j]
logging.debug(
f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}"
)
all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice])
assert all_values.shape == (4, n - 1)
all_values = coords_to_angles(coords)
if all_values is None:
return None
assert not np.any(np.isnan(all_values))
retval = torch.from_numpy(all_values)
return retval
def coords_to_angles(coords: Dict[str, List[List[float]]]) -> Union[np.ndarray, None]:
"""
Sanitize the coordinates to not have NaN and convert them into
arrays of angles. If sanitization fails, return None
"""
first_valid_idx, last_valid_idx = 0, len(coords["N"])
# Walk through coordinates and trim trailing nan
for k in ["N", "CA", "C"]:
logging.debug(f"{k}:\t{coords[k][:5]}")
arr = np.array(coords[k])
# Get all valid indices
valid_idx = np.where(~np.any(np.isnan(arr), axis=1))[0]
first_valid_idx = max(first_valid_idx, np.min(valid_idx))
last_valid_idx = min(last_valid_idx, np.max(valid_idx) + 1)
logging.debug(f"Trimming nans keeps {first_valid_idx}:{last_valid_idx}")
for k in ["N", "CA", "C"]:
coords[k] = coords[k][first_valid_idx:last_valid_idx]
arr = np.array(coords[k])
if np.any(np.isnan(arr)):
logging.debug("Got nan in middle of array")
return None
angles = pdb_utils.process_coords(coords)
# https://www.rosettacommons.org/docs/latest/application_documentation/trRosetta/trRosetta#application-purpose_a-note-on-nomenclature
# omega = inter-residue dihedral angle between CA/CB of first and CB/CA of second
# theta = inter-residue dihedral angle between N, CA, CB of first and CB of second
# phi = inter-residue angle between CA and CB of first and CB of second
dist, omega, theta, phi = angles
assert dist.shape == omega.shape == theta.shape == phi.shape
logging.debug(f"Pre slice shape: {dist.shape, omega.shape, theta.shape, phi.shape}")
# Slice out so that we have the angles and distances between the n and n+1 items
n = dist.shape[0]
indices_i = np.arange(n - 1)
indices_j = indices_i + 1
dist_slice = dist[indices_i, indices_j]
omega_slice = omega[indices_i, indices_j]
theta_slice = theta[indices_i, indices_j]
phi_slice = phi[indices_i, indices_j]
logging.debug(
f"Post slice shape: {dist_slice.shape, omega_slice.shape, theta_slice.shape, phi_slice.shape}"
)
all_values = np.array([dist_slice, omega_slice, theta_slice, phi_slice])
assert all_values.shape == (4, n - 1)
return all_values
def main():
dset = CathConsecutiveAnglesDataset()
error_counter = 0