Fix: Batched conditional generation

Closes #25
This commit is contained in:
sweagent
2024-04-11 03:20:38 +00:00
parent 929407c605
commit ea6aef111c

View File

@@ -63,8 +63,21 @@ def test_chroma(chroma):
conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1),
],
)
def test_sample(chroma, conditioner):
chroma.sample(steps=3, conditioner=conditioner, design_method=None)
@pytest.mark.parametrize(
"conditioner",
[
conditioners.Identity(),
conditioners.SymmetryConditioner(G="C_3", num_chain_neighbors=1),
],
)
@pytest.mark.parametrize("batch_size", [1, 2, 4])
def test_sample(chroma, conditioner, batch_size):
# Generate a batch of proteins with the specified batch size
proteins = [Protein.from_CIF(PROTEIN_SAMPLE) for _ in range(batch_size)]
# Stack proteins into a batch
protein_batch = Protein.stack(proteins)
# Sample with the specified conditioner and batch of proteins
chroma.sample(steps=3, conditioner=conditioner, protein_batch=protein_batch, design_method=None)
@pytest.mark.parametrize(