diff --git a/src/alphafold3/structure/cpp/string_array_pybind.cc b/src/alphafold3/structure/cpp/string_array_pybind.cc index 29fac72..a4fc4ec 100644 --- a/src/alphafold3/structure/cpp/string_array_pybind.cc +++ b/src/alphafold3/structure/cpp/string_array_pybind.cc @@ -174,6 +174,9 @@ py::array_t IsIn( py::array RemapMultipleArrays( const std::vector>& arrays, const py::dict& mapping) { + if (arrays.empty()) { + return py::array_t(0); + } size_t array_size = arrays[0].size(); for (const auto& array : arrays) { if (array.size() != array_size) { @@ -184,10 +187,6 @@ py::array RemapMultipleArrays( // Create a result buffer. auto result = py::array_t(array_size); absl::Span result_buffer(result.mutable_data(), array_size); - PyObject* entry = PyTuple_New(arrays.size()); - if (entry == nullptr) { - throw py::error_already_set(); - } std::vector> array_spans; array_spans.reserve(arrays.size()); for (const auto& array : arrays) { @@ -197,34 +196,38 @@ py::array RemapMultipleArrays( // Iterate over arrays and look up elements in the `py_dict`. bool fail = false; for (size_t i = 0; i < array_size; ++i) { + PyObject* entry = PyTuple_New(arrays.size()); + if (entry == nullptr) { + fail = true; + break; + } for (size_t j = 0; j < array_spans.size(); ++j) { - PyTuple_SET_ITEM(entry, j, array_spans[j][i]); + PyTuple_SET_ITEM(entry, j, Py_NewRef(array_spans[j][i])); } PyObject* result = PyDict_GetItem(mapping.ptr(), entry); if (result != nullptr) { int64_t result_value = PyLong_AsLongLong(result); if (result_value == -1 && PyErr_Occurred()) { fail = true; + Py_DECREF(entry); break; } if (result_value > std::numeric_limits::max() || result_value < std::numeric_limits::lowest()) { PyErr_SetString(PyExc_OverflowError, "Result value too large."); fail = true; + Py_DECREF(entry); break; } result_buffer[i] = result_value; } else { PyErr_Format(PyExc_KeyError, "%R", entry); fail = true; + Py_DECREF(entry); break; } + Py_DECREF(entry); } - - for (size_t j = 0; j < array_spans.size(); ++j) { - PyTuple_SET_ITEM(entry, j, nullptr); - } - Py_XDECREF(entry); if (fail) { throw py::error_already_set(); }