Add a context managed unfreeze() method in base_config

PiperOrigin-RevId: 812777327
Change-Id: I19480f4c103f65280ddf8f396c63351e80144978
This commit is contained in:
Harsh Tiku
2025-09-29 07:20:40 -07:00
committed by Copybara-Service
parent dbaafbcdea
commit 25102bf040
2 changed files with 63 additions and 7 deletions

View File

@@ -15,11 +15,12 @@
"""Config for the protein folding model and experiment."""
from collections.abc import Mapping
import contextlib
import copy
import dataclasses
import types
import typing
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, Iterator, TypeVar
_T = TypeVar('_T')
@@ -164,8 +165,8 @@ class BaseConfig(metaclass=ConfigMeta):
else {k: v for k, v in result.items() if v is not None}
)
def __setattr__(self, name, value):
if getattr(self, '_is_frozen', False):
def __setattr__(self, name: str, value: Any) -> None:
if getattr(self, '_is_frozen', False) and name != '_is_frozen':
# If we are frozen, raise an error
raise dataclasses.FrozenInstanceError(
f"Cannot assign to field '{name}'; instance is frozen."
@@ -174,10 +175,25 @@ class BaseConfig(metaclass=ConfigMeta):
# If not frozen, set the attribute normally
super().__setattr__(name, value)
def freeze(self) -> None:
"""Freezes the config and all subconfigs to prevent further changes."""
self._is_frozen = True
def _toggle_freeze(self, frozen: bool) -> None:
"""Toggles the frozen state of the config and all subconfigs."""
self._is_frozen = frozen
for field_name in self._coercable_fields:
field_value = getattr(self, field_name, None)
if isinstance(field_value, BaseConfig):
field_value.freeze()
field_value._toggle_freeze(frozen)
def freeze(self) -> None:
"""Freezes the config and all subconfigs to prevent further changes."""
self._toggle_freeze(True)
@contextlib.contextmanager
def unfreeze(self: _ConfigT) -> Iterator[_ConfigT]:
"""A context manager to temporarily unfreeze the config."""
was_frozen = self._is_frozen
self._toggle_freeze(False)
try:
yield self
finally:
if was_frozen:
self._toggle_freeze(True)

View File

@@ -133,6 +133,46 @@ class ModelConfigTest(absltest.TestCase):
with self.assertRaises(dataclasses.FrozenInstanceError):
config.z.a = 1
def test_unfreeze(self):
config = OuterConfig(
x=5,
z=InnerConfig(a=2),
optional_z=None,
z_requires_a=InnerConfig(a=3),
z_default=None,
)
config.freeze()
# Check that we can modify the config within the unfrozen context.
with config.unfreeze() as mutable_config:
mutable_config.x = 1
mutable_config.z.a = 1
self.assertEqual(config.x, 1)
self.assertEqual(config.z.a, 1)
# Check that the config and all subconfigs are frozen again.
self.assertTrue(config._is_frozen)
self.assertTrue(config.z._is_frozen)
self.assertTrue(config.z_requires_a._is_frozen)
with self.assertRaises(dataclasses.FrozenInstanceError):
config.x = 2
with self.assertRaises(dataclasses.FrozenInstanceError):
config.z.a = 2
# Check that a config that was not frozen remains unfrozen.
unfrozen_config = OuterConfig(
x=5,
z=InnerConfig(a=2),
optional_z=None,
z_requires_a=InnerConfig(a=3),
z_default=None,
)
self.assertFalse(unfrozen_config._is_frozen)
with unfrozen_config.unfreeze() as mutable_config:
mutable_config.x = 1
self.assertEqual(unfrozen_config.x, 1)
self.assertFalse(unfrozen_config._is_frozen)
if __name__ == '__main__':
absltest.main()