SWIG pickling improvements (and other cleanup) (#6133)

* wip

* - avoid leaking memory after instantiating UChar_Vect
- fix some indentation
- make it easier to read/write pickles as native byte arrays from Java and C#
- add tests

---------

Co-authored-by: Tosco, Paolo <paolo.tosco@novartis.com>
This commit is contained in:
Paolo Tosco
2023-03-01 05:00:56 +01:00
committed by GitHub
parent ec5d922eee
commit ab980bca44
6 changed files with 223 additions and 37 deletions

View File

@@ -61,37 +61,44 @@
#endif
%typemap(javacode) ExplicitBitVect %{
public static ExplicitBitVect fromByteArray(byte[] fp) {
UChar_Vect vec = new UChar_Vect();
vec.reserve(fp.length);
for (int size=0;size<fp.length;++size) {
vec.add((short)fp[size]);
}
return new ExplicitBitVect(vec);
}
public static ExplicitBitVect fromByteArray(byte[] fp) {
UChar_Vect vec = null;
try {
vec = new UChar_Vect();
vec.reserve(fp.length);
for (int size=0;size<fp.length;++size) {
vec.add((short)fp[size]);
}
return new ExplicitBitVect(vec);
} finally {
if (vec != null) {
vec.delete();
}
}
}
%}
%include <DataStructs/ExplicitBitVect.h>
%newobject ExplicitBitVect::getOnBits;
%extend ExplicitBitVect {
IntVect *getOnBits() {
IntVect* bits = new IntVect;
($self)->getOnBits(*bits);
return bits;
}
IntVect *getOnBits() {
IntVect* bits = new IntVect;
($self)->getOnBits(*bits);
return bits;
}
}
#ifdef SWIGJAVA
%extend ExplicitBitVect {
const std::string toByteArray() {
return ($self)->toString();
}
const std::string toByteArray() {
return ($self)->toString();
}
ExplicitBitVect(const std::vector<unsigned char> & data ) {
std::string str(data.begin(), data.end());
return new ExplicitBitVect(str);
}
ExplicitBitVect(const std::vector<unsigned char> & data ) {
std::string str(data.begin(), data.end());
return new ExplicitBitVect(str);
}
}
#endif

View File

@@ -64,14 +64,21 @@ typedef std::vector<std::string> STR_VECT;
%typemap(javacode) RDKit::FilterCatalog %{
public static FilterCatalog Deserialize(byte[] b) {
UChar_Vect vec = new UChar_Vect();
vec.reserve(b.length);
for (int size=0;size<b.length;++size) {
vec.add((short)b[size]);
}
return new FilterCatalog(vec);
}
public static FilterCatalog Deserialize(byte[] b) {
UChar_Vect vec = null;
try {
vec = new UChar_Vect();
vec.reserve(b.length);
for (int size=0;size<b.length;++size) {
vec.add((short)b[size]);
}
return new FilterCatalog(vec);
} finally {
if (vec != null) {
vec.delete();
}
}
}
%}
%extend RDKit::FilterMatch {

View File

@@ -72,6 +72,7 @@
%template(ROMol_Vect_Vect) std::vector< std::vector< boost::shared_ptr<RDKit::ROMol> > >;
%template(Atom_Vect) std::vector<RDKit::Atom*>;
%template(StereoGroup_Vect) std::vector<RDKit::StereoGroup>;
%template(UChar_Vect) std::vector<unsigned char>;
// These prevent duplicate definitions in Java code
%ignore RDKit::ROMol::hasProp(std::string const) const ;
@@ -97,6 +98,19 @@
%ignore RDKit::ROMol::getTopology() const ;
#ifdef SWIGJAVA
%typemap(jni) std::string RDKit::ROMol::toByteArray "jbyteArray"
%typemap(jtype) std::string RDKit::ROMol::toByteArray "byte[]"
%typemap(jstype) std::string RDKit::ROMol::toByteArray "byte[]"
%typemap(javaout) std::string RDKit::ROMol::toByteArray {
return $jnicall;
}
%typemap(out) std::string RDKit::ROMol::toByteArray {
$result = JCALL1(NewByteArray, jenv, $1.size());
JCALL4(SetByteArrayRegion, jenv, $result, 0, $1.size(), (const jbyte*)$1.c_str());
}
#endif
/*
* Special handling for Conformer objects which should not be GCed until the molecule is destroyed
* We want to modify the behavior of the Conformer coming into the addConformer method without
@@ -115,6 +129,51 @@
conf.setSwigCMemOwn(false);
return Conformer.getCPtr(conf);
}
public static ROMol fromByteArray(byte[] pkl) {
UChar_Vect vec = null;
try {
vec = new UChar_Vect();
vec.reserve(pkl.length);
for (int i = 0; i < pkl.length; ++i) {
vec.add((byte)pkl[i]);
}
return ROMol.fromUCharVect(vec);
} finally {
if (vec != null) {
vec.delete();
}
}
}
%}
%typemap(cscode) RDKit::ROMol %{
public static ROMol FromByteArray(byte[] pkl) {
UChar_Vect vec = null;
try {
vec = new UChar_Vect();
vec.Capacity = pkl.Length;
for (int i = 0; i < pkl.Length; ++i) {
vec.Add((byte)pkl[i]);
}
return ROMol.fromUCharVect(vec);
} finally {
if (vec != null) {
vec.Dispose();
}
}
}
public byte[] ToByteArray() {
UChar_Vect vec = null;
try {
vec = toUCharVect();
byte[] res = new byte[vec.Count];
vec.CopyTo(res);
return res;
} finally {
if (vec != null) {
vec.Dispose();
}
}
}
%}
%include <GraphMol/ROMol.h>
@@ -176,6 +235,14 @@ void setUseLegacyStereoPerception(bool);
bool getAllowNontetrahedralChirality();
void setAllowNontetrahedralChirality(bool);
#ifdef SWIGJAVA
%javamethodmodifiers RDKit::ROMol::fromUCharVect "private";
#endif
#ifdef SWIGCSHARP
%csmethodmodifiers RDKit::ROMol::fromUCharVect "private";
%csmethodmodifiers RDKit::ROMol::toUCharVect "private";
#endif
%extend RDKit::ROMol {
std::string getProp(const std::string key){
std::string res;
@@ -523,7 +590,7 @@ void setAllowNontetrahedralChirality(bool);
std::copy(sres.begin(),sres.end(),res.begin());
return res;
};
static RDKit::ROMOL_SPTR MolFromBinary(std::vector<int> pkl){
static RDKit::ROMOL_SPTR MolFromBinary(const std::vector<int> &pkl){
std::string sres;
sres.resize(pkl.size());
std::copy(pkl.begin(),pkl.end(),sres.begin());
@@ -536,6 +603,32 @@ void setAllowNontetrahedralChirality(bool);
}
return RDKit::ROMOL_SPTR(res);
}
#ifdef SWIGJAVA
const std::string toByteArray() {
std::string sres;
RDKit::MolPickler::pickleMol(*($self), sres);
return sres;
}
#endif
#ifdef SWIGCSHARP
const std::vector<unsigned char> toUCharVect() {
std::string sres;
RDKit::MolPickler::pickleMol(*($self), sres);
const std::vector<unsigned char> vec(sres.begin(), sres.end());
return vec;
}
#endif
static RDKit::ROMOL_SPTR fromUCharVect(const std::vector<unsigned char> &pkl) {
std::string sres(pkl.begin(), pkl.end());
RDKit::ROMol *res;
try {
res = new RDKit::ROMol(sres);
} catch (const RDKit::MolPicklerException &e) {
res = nullptr;
throw;
}
return RDKit::ROMOL_SPTR(res);
}
/* From AddHs.cpp */
RDKit::ROMol *addHs(bool explicitOnly,bool addCoords=false){

View File

@@ -52,14 +52,21 @@
%template(UChar_Vect) std::vector<unsigned char>;
%typemap(javacode) RDKit::SubstructLibrary %{
public static SubstructLibrary Deserialize(byte[] b) {
UChar_Vect vec = new UChar_Vect();
vec.reserve(b.length);
for (int size=0;size<b.length;++size) {
vec.add((short)b[size]);
}
return new SubstructLibrary(vec);
}
public static SubstructLibrary Deserialize(byte[] b) {
UChar_Vect vec = null;
try {
vec = new UChar_Vect();
vec.reserve(b.length);
for (int size=0;size<b.length;++size) {
vec.add((short)b[size]);
}
return new SubstructLibrary(vec);
} finally {
if (vec != null) {
vec.delete();
}
}
}
%}
%extend RDKit::SubstructLibrary {

View File

@@ -0,0 +1,30 @@
// Linux:
// compile with
// mcs -platform:x64 -r:../RDKit2DotNet.dll -out:MolToFromByteArray.exe MolToFromByteArray.cs
// and run with
// LD_LIBRARY_PATH=..:$RDBASE/lib:$LD_LIBRARY_PATH MONO_PATH=.. mono MolToFromByteArray.exe
using System.IO;
using System.Diagnostics;
using GraphMolWrap;
public class MolToFromByteArrayTest
{
static void Main(string[] args)
{
string smi = "CN(C)c1ccc2c(=O)cc[nH]c2c1";
string pklFileName = "quinolone.pkl";
{
ROMol mol = RWMol.MolFromSmiles(smi);
byte[] pkl = mol.ToByteArray();
File.WriteAllBytes(pklFileName, pkl);
mol.Dispose();
}
{
byte[] pkl = File.ReadAllBytes(pklFileName);
ROMol mol = ROMol.FromByteArray(pkl);
Debug.Assert(mol.MolToSmiles() == smi);
mol.Dispose();
}
}
}

View File

@@ -34,6 +34,11 @@ package org.RDKit;
import static org.junit.Assert.*;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import org.junit.*;
public class PicklingTests extends GraphMolTest {
@@ -134,6 +139,43 @@ public class PicklingTests extends GraphMolTest {
}
}
@Test
public void testToFromByteArray() throws IOException {
String smi = "CN(C)c1ccc2c(=O)cc[nH]c2c1";
String pklFileName = "quinolone.pkl";
{
ROMol mol = RWMol.MolFromSmiles(smi);
byte[] pkl = mol.toByteArray();
FileOutputStream pklOutStream = null;
try {
pklOutStream = new FileOutputStream(pklFileName);
pklOutStream.write(pkl);
} finally {
if (pklOutStream != null) {
pklOutStream.close();
}
}
mol.delete();
}
{
FileInputStream pklInStream = null;
byte[] pkl = null;
File pklInFile = new File(pklFileName);
try {
pklInStream = new FileInputStream(pklInFile);
pkl = new byte[(int)pklInFile.length()];
assertEquals(pklInStream.read(pkl), pkl.length);
} finally {
if (pklInStream != null) {
pklInStream.close();
}
}
ROMol mol = ROMol.fromByteArray(pkl);
assertEquals(mol.MolToSmiles(), smi);
mol.delete();
}
}
public static void main(String args[]) {
org.junit.runner.JUnitCore.main("org.RDKit.PicklingTests");
}