mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Add a context managed unfreeze() method in base_config
PiperOrigin-RevId: 812777327 Change-Id: I19480f4c103f65280ddf8f396c63351e80144978
This commit is contained in:
committed by
Copybara-Service
parent
dbaafbcdea
commit
25102bf040
@@ -15,11 +15,12 @@
|
||||
"""Config for the protein folding model and experiment."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import types
|
||||
import typing
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
from typing import Any, ClassVar, Iterator, TypeVar
|
||||
|
||||
|
||||
_T = TypeVar('_T')
|
||||
@@ -164,8 +165,8 @@ class BaseConfig(metaclass=ConfigMeta):
|
||||
else {k: v for k, v in result.items() if v is not None}
|
||||
)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if getattr(self, '_is_frozen', False):
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if getattr(self, '_is_frozen', False) and name != '_is_frozen':
|
||||
# If we are frozen, raise an error
|
||||
raise dataclasses.FrozenInstanceError(
|
||||
f"Cannot assign to field '{name}'; instance is frozen."
|
||||
@@ -174,10 +175,25 @@ class BaseConfig(metaclass=ConfigMeta):
|
||||
# If not frozen, set the attribute normally
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def freeze(self) -> None:
|
||||
"""Freezes the config and all subconfigs to prevent further changes."""
|
||||
self._is_frozen = True
|
||||
def _toggle_freeze(self, frozen: bool) -> None:
|
||||
"""Toggles the frozen state of the config and all subconfigs."""
|
||||
self._is_frozen = frozen
|
||||
for field_name in self._coercable_fields:
|
||||
field_value = getattr(self, field_name, None)
|
||||
if isinstance(field_value, BaseConfig):
|
||||
field_value.freeze()
|
||||
field_value._toggle_freeze(frozen)
|
||||
|
||||
def freeze(self) -> None:
|
||||
"""Freezes the config and all subconfigs to prevent further changes."""
|
||||
self._toggle_freeze(True)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def unfreeze(self: _ConfigT) -> Iterator[_ConfigT]:
|
||||
"""A context manager to temporarily unfreeze the config."""
|
||||
was_frozen = self._is_frozen
|
||||
self._toggle_freeze(False)
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
if was_frozen:
|
||||
self._toggle_freeze(True)
|
||||
|
||||
@@ -133,6 +133,46 @@ class ModelConfigTest(absltest.TestCase):
|
||||
with self.assertRaises(dataclasses.FrozenInstanceError):
|
||||
config.z.a = 1
|
||||
|
||||
def test_unfreeze(self):
|
||||
config = OuterConfig(
|
||||
x=5,
|
||||
z=InnerConfig(a=2),
|
||||
optional_z=None,
|
||||
z_requires_a=InnerConfig(a=3),
|
||||
z_default=None,
|
||||
)
|
||||
config.freeze()
|
||||
|
||||
# Check that we can modify the config within the unfrozen context.
|
||||
with config.unfreeze() as mutable_config:
|
||||
mutable_config.x = 1
|
||||
mutable_config.z.a = 1
|
||||
self.assertEqual(config.x, 1)
|
||||
self.assertEqual(config.z.a, 1)
|
||||
|
||||
# Check that the config and all subconfigs are frozen again.
|
||||
self.assertTrue(config._is_frozen)
|
||||
self.assertTrue(config.z._is_frozen)
|
||||
self.assertTrue(config.z_requires_a._is_frozen)
|
||||
with self.assertRaises(dataclasses.FrozenInstanceError):
|
||||
config.x = 2
|
||||
with self.assertRaises(dataclasses.FrozenInstanceError):
|
||||
config.z.a = 2
|
||||
|
||||
# Check that a config that was not frozen remains unfrozen.
|
||||
unfrozen_config = OuterConfig(
|
||||
x=5,
|
||||
z=InnerConfig(a=2),
|
||||
optional_z=None,
|
||||
z_requires_a=InnerConfig(a=3),
|
||||
z_default=None,
|
||||
)
|
||||
self.assertFalse(unfrozen_config._is_frozen)
|
||||
with unfrozen_config.unfreeze() as mutable_config:
|
||||
mutable_config.x = 1
|
||||
self.assertEqual(unfrozen_config.x, 1)
|
||||
self.assertFalse(unfrozen_config._is_frozen)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
||||
Reference in New Issue
Block a user