This changelist begins the migration of the AlphaFold model configuration from
`ml_collections.ConfigDict` to a more robust and type-safe system based on Python's `dataclasses`.
This is the **first step** in a larger effort. Future changes will update all call sites to use the new `get_model_config()` function.
---
## Key Changes
### `base_config.py`
- Introduced a new **`BaseConfig`** class, serving as a base for all configuration dataclasses.
- Features:
- **Automatic coercion** of nested dictionaries into the appropriate dataclass instances during initialization.
- **`as_dict()`** method for converting the configuration object back into a dictionary.
- **`freeze()`** method to make the configuration and all its sub-configs immutable.
### `config.py`
- Entire configuration schema redefined using a hierarchy of `BaseConfig` subclasses
(e.g., `AlphaFoldConfig`, `Model`, `Heads`) — making the structure explicit and statically verifiable.
- Introduced **`get_model_config(model_name)`** as the primary entrypoint to access model configurations.
- Returns a fully initialized, type-safe `AlphaFoldConfig` object.
- **Model-specific variations**:
- Previously handled by updating a dictionary (`CONFIG_DIFFS`).
- Now handled via a **mapping of functions (`CONFIG_DIFF_OPS`)**, improving clarity and safety.
### Testing
- Added **`base_config_test.py`** to validate the functionality of the new `BaseConfig` class.
- Added **`config_test.py`** to verify that:
- `get_model_config` produces equivalent configurations to the legacy `model_config`.
- Refactoring does not alter behavior.
---
PiperOrigin-RevId: 810873050
Change-Id: I7c96c23c1d168d722af41efcf98d44cc11ca7707
jax.util was deprecated in JAX v0.6.0, and will be removed in JAX v0.7.0.
PiperOrigin-RevId: 759626266
Change-Id: If3dcb9a8151a99ecab1f8ec670bd99bbb31bd5de
- a is now positional-only
- a_min is now min
- a_max is now max
The old argument names have been deprecated since JAX v0.4.27.
PiperOrigin-RevId: 715343439
Change-Id: I50b086b249360c142f42ed4d8e50d48692c11dfa
This change migrates users of APIs removed in NumPy 2.0 to their recommended replacements (https://numpy.org/devdocs/numpy_2_0_migration_guide.html).
PiperOrigin-RevId: 655904142
Change-Id: Idcf2384de70cb27cb92e3f421dc8e0a9a6466507
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.
PiperOrigin-RevId: 633878846
Change-Id: I282e177b3c9026b805d6eb5e5d6e6b1287f87e55
- Ubuntu to 20.04
- Cuda to 12.2.2
- Tensorflow to 2.16.1
- Jax to 0.4.26
- Openmm to 8.0.0
PiperOrigin-RevId: 624094170
Change-Id: I49391c3f721e93ac8ccd5a4483cdb6f2ec61cd3b
An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes.
This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future.
This change is always be safe (i.e., it should always preserve the previous behavior) but it may sometimes be unnecessary if code is never used in a multicontroller setting.
PiperOrigin-RevId: 582680953
Change-Id: I614739de052185c42932ddb91d1d464751442f65