mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2026-06-02 11:54:36 +08:00
Explicitly set backend='triton' in pallas_call invocations.
This change is being made because the default backend for `pallas_call` is changing. Explicitly setting `backend='triton'` ensures that these calls continue to use the Triton backend, maintaining current behavior. PiperOrigin-RevId: 822019610 Change-Id: I7ba83894fcd960d424502335c8b904d76c88a733
This commit is contained in:
committed by
Copybara-Service
parent
c9121bf646
commit
cf404610b5
@@ -221,6 +221,7 @@ def _gated_linear_unit(
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype) if dst is None else dst,
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=compiler_params,
|
||||
backend='triton',
|
||||
)(x, weights_projection, weights_gate, dst, epilogue_args)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user