Explicitly set backend='triton' in pallas_call invocations.

This change is being made because the default backend for `pallas_call` is changing. Explicitly setting `backend='triton'` ensures that these calls continue to use the Triton backend, maintaining current behavior.

PiperOrigin-RevId: 822019610
Change-Id: I7ba83894fcd960d424502335c8b904d76c88a733
This commit is contained in:
Benjamin Chetioui
2025-10-21 02:25:57 -07:00
committed by Copybara-Service
parent c9121bf646
commit cf404610b5

View File

@@ -221,6 +221,7 @@ def _gated_linear_unit(
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype) if dst is None else dst,
input_output_aliases=input_output_aliases,
compiler_params=compiler_params,
backend='triton',
)(x, weights_projection, weights_gate, dst, epilogue_args)